diff options
author | Ian Gudger <igudger@google.com> | 2019-04-29 21:20:05 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-04-29 21:21:08 -0700 |
commit | 81ecd8b6eab7457b331762626f8c210fec3504e6 (patch) | |
tree | 4c5e5aaf3ac8ff475657c66671dd6828938ae45e /pkg/sentry/fs | |
parent | 2843f2a956f5ef23e621f571f5c3e6a1e4a8223a (diff) |
Implement the MSG_CTRUNC msghdr flag for Unix sockets.
Updates google/gvisor#206
PiperOrigin-RevId: 245880573
Change-Id: Ifa715e98d47f64b8a32b04ae9378d6cd6bd4025e
Diffstat (limited to 'pkg/sentry/fs')
-rw-r--r-- | pkg/sentry/fs/host/control.go | 7 | ||||
-rw-r--r-- | pkg/sentry/fs/host/socket.go | 16 | ||||
-rw-r--r-- | pkg/sentry/fs/host/socket_test.go | 2 | ||||
-rw-r--r-- | pkg/sentry/fs/host/socket_unsafe.go | 12 |
4 files changed, 21 insertions, 16 deletions
diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go index 480f0c8f4..9ebb9bbb3 100644 --- a/pkg/sentry/fs/host/control.go +++ b/pkg/sentry/fs/host/control.go @@ -32,17 +32,20 @@ func newSCMRights(fds []int) control.SCMRights { } // Files implements control.SCMRights.Files. -func (c *scmRights) Files(ctx context.Context, max int) control.RightsFiles { +func (c *scmRights) Files(ctx context.Context, max int) (control.RightsFiles, bool) { n := max + var trunc bool if l := len(c.fds); n > l { n = l + } else if n < l { + trunc = true } rf := control.RightsFiles(fdsToFiles(ctx, c.fds[:n])) // Only consume converted FDs (fdsToFiles may convert fewer than n FDs). c.fds = c.fds[len(rf):] - return rf + return rf, trunc } // Clone implements transport.RightsControlMessage.Clone. diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 3034e9441..3ed137006 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -282,11 +282,11 @@ func (c *ConnectedEndpoint) EventUpdate() { } // Recv implements transport.Receiver.Recv. -func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, tcpip.FullAddress, bool, *syserr.Error) { +func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() if c.readClosed { - return 0, 0, transport.ControlMessages{}, tcpip.FullAddress{}, false, syserr.ErrClosedForReceive + return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.ErrClosedForReceive } var cm unet.ControlMessage @@ -296,7 +296,7 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p // N.B. Unix sockets don't have a receive buffer, the send buffer // serves both purposes. - rl, ml, cl, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.sndbuf) + rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.sndbuf) if rl > 0 && err != nil { // We got some data, so all we need to do on error is return // the data that we got. Short reads are fine, no need to @@ -304,7 +304,7 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p err = nil } if err != nil { - return 0, 0, transport.ControlMessages{}, tcpip.FullAddress{}, false, syserr.FromError(err) + return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err) } // There is no need for the callee to call RecvNotify because fdReadVec uses @@ -317,18 +317,18 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, p // Avoid extra allocations in the case where there isn't any control data. if len(cm) == 0 { - return rl, ml, transport.ControlMessages{}, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil + return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil } fds, err := cm.ExtractFDs() if err != nil { - return 0, 0, transport.ControlMessages{}, tcpip.FullAddress{}, false, syserr.FromError(err) + return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err) } if len(fds) == 0 { - return rl, ml, transport.ControlMessages{}, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil + return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil } - return rl, ml, control.New(nil, nil, newSCMRights(fds)), tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil + return rl, ml, control.New(nil, nil, newSCMRights(fds)), cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil } // close releases all resources related to the endpoint. diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go index cc760a7e1..06392a65a 100644 --- a/pkg/sentry/fs/host/socket_test.go +++ b/pkg/sentry/fs/host/socket_test.go @@ -207,7 +207,7 @@ func TestSend(t *testing.T) { func TestRecv(t *testing.T) { e := ConnectedEndpoint{readClosed: true} - if _, _, _, _, _, err := e.Recv(nil, false, 0, false); err != syserr.ErrClosedForReceive { + if _, _, _, _, _, _, err := e.Recv(nil, false, 0, false); err != syserr.ErrClosedForReceive { t.Errorf("Got %#v.Recv() = %v, want = %v", e, err, syserr.ErrClosedForReceive) } } diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go index 8873705c0..e57be0506 100644 --- a/pkg/sentry/fs/host/socket_unsafe.go +++ b/pkg/sentry/fs/host/socket_unsafe.go @@ -23,7 +23,7 @@ import ( // // If the total length of bufs is > maxlen, fdReadVec will do a partial read // and err will indicate why the message was truncated. -func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, err error) { +func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, controlTrunc bool, err error) { flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC) if peek { flags |= syscall.MSG_PEEK @@ -34,7 +34,7 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (re length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true) if err != nil && len(iovecs) == 0 { // No partial write to do, return error immediately. - return 0, 0, 0, err + return 0, 0, 0, false, err } var msg syscall.Msghdr @@ -51,7 +51,7 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (re n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags) if e != 0 { // N.B. prioritize the syscall error over the buildIovec error. - return 0, 0, 0, e + return 0, 0, 0, false, e } // Copy data back to bufs. @@ -59,11 +59,13 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (re copyToMulti(bufs, intermediate) } + controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC + if n > length { - return length, n, msg.Controllen, err + return length, n, msg.Controllen, controlTrunc, err } - return n, n, msg.Controllen, err + return n, n, msg.Controllen, controlTrunc, err } // fdWriteVec sends from bufs to fd. |