summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorDean Deng <deandeng@google.com>2019-12-10 19:14:05 -0800
committergVisor bot <gvisor-bot@google.com>2019-12-10 19:27:42 -0800
commit2e3b9b0a68d8f191d061008feda6e4a4ce202a78 (patch)
tree856007f4e14500289ff970236d46d3964799e9ee /pkg
parent46651a7d26559bdc69d460bdeb4de5968212d615 (diff)
Deduplicate and simplify control message processing for recvmsg and sendmsg.
Also, improve performance by calculating how much space is needed before making an allocation for sendmsg in hostinet. PiperOrigin-RevId: 284898581
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/socket/control/control.go51
-rw-r--r--pkg/sentry/socket/hostinet/socket.go18
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go18
3 files changed, 50 insertions, 37 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index fa3188d51..af1a4e95f 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -348,43 +348,62 @@ func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
)
}
-func addSpaceForCmsg(cmsgDataLen int, buf []byte) []byte {
- newBuf := make([]byte, 0, len(buf)+linux.SizeOfControlMessageHeader+cmsgDataLen)
- return append(newBuf, buf...)
-}
-
-// PackControlMessages converts the given ControlMessages struct into a buffer.
+// PackControlMessages packs control messages into the given buffer.
+//
// We skip control messages specific to Unix domain sockets.
-func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages) []byte {
- var buf []byte
- // The use of t.Arch().Width() is analogous to Linux's use of sizeof(long) in
- // CMSG_ALIGN.
- width := t.Arch().Width()
-
+//
+// Note that some control messages may be truncated if they do not fit under
+// the capacity of buf.
+func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
if cmsgs.IP.HasTimestamp {
- buf = addSpaceForCmsg(int(width), buf)
buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
}
if cmsgs.IP.HasInq {
// In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
- buf = addSpaceForCmsg(AlignUp(linux.SizeOfControlMessageInq, width), buf)
buf = PackInq(t, cmsgs.IP.Inq, buf)
}
if cmsgs.IP.HasTOS {
- buf = addSpaceForCmsg(AlignUp(linux.SizeOfControlMessageTOS, width), buf)
buf = PackTOS(t, cmsgs.IP.TOS, buf)
}
if cmsgs.IP.HasTClass {
- buf = addSpaceForCmsg(AlignUp(linux.SizeOfControlMessageTClass, width), buf)
buf = PackTClass(t, cmsgs.IP.TClass, buf)
}
return buf
}
+// cmsgSpace is equivalent to CMSG_SPACE in Linux.
+func cmsgSpace(t *kernel.Task, dataLen int) int {
+ return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width())
+}
+
+// CmsgsSpace returns the number of bytes needed to fit the control messages
+// represented in cmsgs.
+func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
+ space := 0
+
+ if cmsgs.IP.HasTimestamp {
+ space += cmsgSpace(t, linux.SizeOfTimeval)
+ }
+
+ if cmsgs.IP.HasInq {
+ space += cmsgSpace(t, linux.SizeOfControlMessageInq)
+ }
+
+ if cmsgs.IP.HasTOS {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTOS)
+ }
+
+ if cmsgs.IP.HasTClass {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
+ }
+
+ return space
+}
+
// Parse parses a raw socket control message into portable objects.
func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
var (
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index a8c152b54..c957b0f1d 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -45,7 +45,7 @@ const (
sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in)
// maxControlLen is the maximum size of a control message buffer used in a
- // recvmsg syscall.
+ // recvmsg or sendmsg syscall.
maxControlLen = 1024
)
@@ -412,9 +412,12 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
msg.Namelen = uint32(len(senderAddrBuf))
}
if controlLen > 0 {
- controlBuf = make([]byte, maxControlLen)
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
+ }
+ controlBuf = make([]byte, controlLen)
msg.Control = &controlBuf[0]
- msg.Controllen = maxControlLen
+ msg.Controllen = controlLen
}
n, err := recvmsg(s.fd, &msg, sysflags)
if err != nil {
@@ -489,7 +492,14 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return 0, syserr.ErrInvalidArgument
}
- controlBuf := control.PackControlMessages(t, controlMessages)
+ space := uint64(control.CmsgsSpace(t, controlMessages))
+ if space > maxControlLen {
+ space = maxControlLen
+ }
+ controlBuf := make([]byte, 0, space)
+ // PackControlMessages will append up to space bytes to controlBuf.
+ controlBuf = control.PackControlMessages(t, controlMessages, controlBuf)
+
sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
// Refuse to do anything if any part of src.Addrs was unusable.
if uint64(src.NumBytes()) != srcs.NumBytes() {
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index e3faa890b..4b5aafcc0 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -787,29 +787,13 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
defer cms.Release()
controlData := make([]byte, 0, msg.ControlLen)
+ controlData = control.PackControlMessages(t, cms, controlData)
if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
controlData, mflags = control.PackCredentials(t, creds, controlData, mflags)
}
- if cms.IP.HasTimestamp {
- controlData = control.PackTimestamp(t, cms.IP.Timestamp, controlData)
- }
-
- if cms.IP.HasInq {
- // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
- controlData = control.PackInq(t, cms.IP.Inq, controlData)
- }
-
- if cms.IP.HasTOS {
- controlData = control.PackTOS(t, cms.IP.TOS, controlData)
- }
-
- if cms.IP.HasTClass {
- controlData = control.PackTClass(t, cms.IP.TClass, controlData)
- }
-
if cms.Unix.Rights != nil {
controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags)
}