summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/control/control.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/control/control.go')
-rw-r--r--pkg/sentry/socket/control/control.go300
1 files changed, 225 insertions, 75 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 4e95101b7..70ccf77a7 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -19,13 +19,15 @@ package control
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const maxInt = int(^uint(0) >> 1)
@@ -39,6 +41,8 @@ type SCMCredentials interface {
Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID)
}
+// LINT.IfChange
+
// SCMRights represents a SCM_RIGHTS socket control message.
type SCMRights interface {
transport.RightsControlMessage
@@ -64,7 +68,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
for _, fd := range fds {
file := t.GetFile(fd)
if file == nil {
- files.Release()
+ files.Release(t)
return nil, syserror.EBADF
}
files = append(files, file)
@@ -96,9 +100,9 @@ func (fs *RightsFiles) Clone() transport.RightsControlMessage {
}
// Release implements transport.RightsControlMessage.Release.
-func (fs *RightsFiles) Release() {
+func (fs *RightsFiles) Release(ctx context.Context) {
for _, f := range *fs {
- f.DecRef()
+ f.DecRef(ctx)
}
*fs = nil
}
@@ -111,7 +115,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32
fd, err := t.NewFDFrom(0, files[0], kernel.FDFlags{
CloseOnExec: cloexec,
})
- files[0].DecRef()
+ files[0].DecRef(t)
files = files[1:]
if err != nil {
t.Warningf("Error inserting FD: %v", err)
@@ -140,6 +144,8 @@ func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte, flag
return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds)
}
+// LINT.ThenChange(./control_vfs2.go)
+
// scmCredentials represents an SCM_CREDENTIALS socket control message.
//
// +stateify savable
@@ -188,21 +194,21 @@ func putUint32(buf []byte, n uint32) []byte {
// putCmsg writes a control message header and as much data as will fit into
// the unused capacity of a buffer.
func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
- space := AlignDown(cap(buf)-len(buf), 4)
+ space := binary.AlignDown(cap(buf)-len(buf), 4)
// We can't write to space that doesn't exist, so if we are going to align
// the available space, we must align down.
//
// align must be >= 4 and each data int32 is 4 bytes. The length of the
- // header is already aligned, so if we align to the with of the data there
+ // header is already aligned, so if we align to the width of the data there
// are two cases:
// 1. The aligned length is less than the length of the header. The
// unaligned length was also less than the length of the header, so we
// can't write anything.
// 2. The aligned length is greater than or equal to the length of the
- // header. We can write the header plus zero or more datas. We can't write
- // a partial int32, so the length of the message will be
- // min(aligned length, header + datas).
+ // header. We can write the header plus zero or more bytes of data. We can't
+ // write a partial int32, so the length of the message will be
+ // min(aligned length, header + data).
if space < linux.SizeOfControlMessageHeader {
flags |= linux.MSG_CTRUNC
return buf, flags
@@ -239,12 +245,12 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf
buf = binary.Marshal(buf, usermem.ByteOrder, data)
- // Check if we went over.
+ // If the control message data brought us over capacity, omit it.
if cap(buf) != cap(ob) {
return hdrBuf
}
- // Fix up length.
+ // Update control message length to include data.
putUint64(ob, uint64(len(buf)-len(ob)))
return alignSlice(buf, align)
@@ -281,19 +287,9 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int
return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c)
}
-// AlignUp rounds a length up to an alignment. align must be a power of 2.
-func AlignUp(length int, align uint) int {
- return (length + int(align) - 1) & ^(int(align) - 1)
-}
-
-// AlignDown rounds a down to an alignment. align must be a power of 2.
-func AlignDown(length int, align uint) int {
- return length & ^(int(align) - 1)
-}
-
// alignSlice extends a slice's length (up to the capacity) to align it.
func alignSlice(buf []byte, align uint) []byte {
- aligned := AlignUp(len(buf), align)
+ aligned := binary.AlignUp(len(buf), align)
if aligned > cap(buf) {
// Linux allows unaligned data if there isn't room for alignment.
// Since there isn't room for alignment, there isn't room for any
@@ -320,35 +316,139 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte {
buf,
linux.SOL_TCP,
linux.TCP_INQ,
- 4,
+ t.Arch().Width(),
inq,
)
}
+// PackTOS packs an IP_TOS socket control message.
+func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_TOS,
+ t.Arch().Width(),
+ tos,
+ )
+}
+
+// PackTClass packs an IPV6_TCLASS socket control message.
+func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IPV6,
+ linux.IPV6_TCLASS,
+ t.Arch().Width(),
+ tClass,
+ )
+}
+
+// PackIPPacketInfo packs an IP_PKTINFO socket control message.
+func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_PKTINFO,
+ t.Arch().Width(),
+ p,
+ )
+}
+
+// PackControlMessages packs control messages into the given buffer.
+//
+// We skip control messages specific to Unix domain sockets.
+//
+// Note that some control messages may be truncated if they do not fit under
+// the capacity of buf.
+func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
+ if cmsgs.IP.HasTimestamp {
+ buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
+ }
+
+ if cmsgs.IP.HasInq {
+ // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
+ buf = PackInq(t, cmsgs.IP.Inq, buf)
+ }
+
+ if cmsgs.IP.HasTOS {
+ buf = PackTOS(t, cmsgs.IP.TOS, buf)
+ }
+
+ if cmsgs.IP.HasTClass {
+ buf = PackTClass(t, cmsgs.IP.TClass, buf)
+ }
+
+ if cmsgs.IP.HasIPPacketInfo {
+ buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ }
+
+ return buf
+}
+
+// cmsgSpace is equivalent to CMSG_SPACE in Linux.
+func cmsgSpace(t *kernel.Task, dataLen int) int {
+ return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width())
+}
+
+// CmsgsSpace returns the number of bytes needed to fit the control messages
+// represented in cmsgs.
+func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
+ space := 0
+
+ if cmsgs.IP.HasTimestamp {
+ space += cmsgSpace(t, linux.SizeOfTimeval)
+ }
+
+ if cmsgs.IP.HasInq {
+ space += cmsgSpace(t, linux.SizeOfControlMessageInq)
+ }
+
+ if cmsgs.IP.HasTOS {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTOS)
+ }
+
+ if cmsgs.IP.HasTClass {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
+ }
+
+ return space
+}
+
+// NewIPPacketInfo returns the IPPacketInfo struct.
+func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
+ var p tcpip.IPPacketInfo
+ p.NIC = tcpip.NICID(packetInfo.NIC)
+ copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
+ copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+
+ return p
+}
+
// Parse parses a raw socket control message into portable objects.
-func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) {
+func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
var (
- fds linux.ControlMessageRights
- haveCreds bool
- creds linux.ControlMessageCredentials
+ cmsgs socket.ControlMessages
+ fds linux.ControlMessageRights
)
for i := 0; i < len(buf); {
if i+linux.SizeOfControlMessageHeader > len(buf) {
- return transport.ControlMessages{}, syserror.EINVAL
+ return cmsgs, syserror.EINVAL
}
var h linux.ControlMessageHeader
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h)
if h.Length < uint64(linux.SizeOfControlMessageHeader) {
- return transport.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, syserror.EINVAL
}
if h.Length > uint64(len(buf)-i) {
- return transport.ControlMessages{}, syserror.EINVAL
- }
- if h.Level != linux.SOL_SOCKET {
- return transport.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, syserror.EINVAL
}
i += linux.SizeOfControlMessageHeader
@@ -358,59 +458,105 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.
// sizeof(long) in CMSG_ALIGN.
width := t.Arch().Width()
- switch h.Type {
- case linux.SCM_RIGHTS:
- rightsSize := AlignDown(length, linux.SizeOfControlMessageRight)
- numRights := rightsSize / linux.SizeOfControlMessageRight
-
- if len(fds)+numRights > linux.SCM_MAX_FD {
- return transport.ControlMessages{}, syserror.EINVAL
+ switch h.Level {
+ case linux.SOL_SOCKET:
+ switch h.Type {
+ case linux.SCM_RIGHTS:
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
+ numRights := rightsSize / linux.SizeOfControlMessageRight
+
+ if len(fds)+numRights > linux.SCM_MAX_FD {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
+ fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ }
+
+ i += binary.AlignUp(length, width)
+
+ case linux.SCM_CREDENTIALS:
+ if length < linux.SizeOfControlMessageCredentials {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ var creds linux.ControlMessageCredentials
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
+ scmCreds, err := NewSCMCredentials(t, creds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Credentials = scmCreds
+ i += binary.AlignUp(length, width)
+
+ default:
+ // Unknown message type.
+ return socket.ControlMessages{}, syserror.EINVAL
}
-
- for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
- fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ case linux.SOL_IP:
+ switch h.Type {
+ case linux.IP_TOS:
+ if length < linux.SizeOfControlMessageTOS {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ cmsgs.IP.HasTOS = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_PKTINFO:
+ if length < linux.SizeOfControlMessageIPPacketInfo {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ cmsgs.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+
+ cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ i += binary.AlignUp(length, width)
+
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
}
-
- i += AlignUp(length, width)
-
- case linux.SCM_CREDENTIALS:
- if length < linux.SizeOfControlMessageCredentials {
- return transport.ControlMessages{}, syserror.EINVAL
+ case linux.SOL_IPV6:
+ switch h.Type {
+ case linux.IPV6_TCLASS:
+ if length < linux.SizeOfControlMessageTClass {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ cmsgs.IP.HasTClass = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
+ i += binary.AlignUp(length, width)
+
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
}
-
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
- haveCreds = true
- i += AlignUp(length, width)
-
default:
- // Unknown message type.
- return transport.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, syserror.EINVAL
}
}
- var credentials SCMCredentials
- if haveCreds {
- var err error
- if credentials, err = NewSCMCredentials(t, creds); err != nil {
- return transport.ControlMessages{}, err
- }
- } else {
- credentials = makeCreds(t, socketOrEndpoint)
+ if cmsgs.Unix.Credentials == nil {
+ cmsgs.Unix.Credentials = makeCreds(t, socketOrEndpoint)
}
- var rights SCMRights
if len(fds) > 0 {
- var err error
- if rights, err = NewSCMRights(t, fds); err != nil {
- return transport.ControlMessages{}, err
+ if kernel.VFS2Enabled {
+ rights, err := NewSCMRightsVFS2(t, fds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Rights = rights
+ } else {
+ rights, err := NewSCMRights(t, fds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Rights = rights
}
}
- if credentials == nil && rights == nil {
- return transport.ControlMessages{}, nil
- }
-
- return transport.ControlMessages{Credentials: credentials, Rights: rights}, nil
+ return cmsgs, nil
}
func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials {
@@ -432,6 +578,8 @@ func MakeCreds(t *kernel.Task) SCMCredentials {
return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID}
}
+// LINT.IfChange
+
// New creates default control messages if needed.
func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) transport.ControlMessages {
return transport.ControlMessages{
@@ -439,3 +587,5 @@ func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) transpo
Rights: rights,
}
}
+
+// LINT.ThenChange(./control_vfs2.go)