summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/abi/linux/linux_abi_autogen_unsafe.go325
-rw-r--r--pkg/abi/linux/msgqueue.go108
-rw-r--r--pkg/sentry/kernel/ipc/ipc_state_autogen.go86
-rw-r--r--pkg/sentry/kernel/ipc/object.go115
-rw-r--r--pkg/sentry/kernel/ipc/registry.go196
-rw-r--r--pkg/sentry/kernel/ipc_namespace.go8
-rw-r--r--pkg/sentry/kernel/kernel_state_autogen.go11
-rw-r--r--pkg/sentry/kernel/msgqueue/message_list.go221
-rw-r--r--pkg/sentry/kernel/msgqueue/msgqueue.go220
-rw-r--r--pkg/sentry/kernel/msgqueue/msgqueue_state_autogen.go194
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go254
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore_state_autogen.go50
-rw-r--r--pkg/sentry/kernel/shm/shm.go259
-rw-r--r--pkg/sentry/kernel/shm/shm_state_autogen.go70
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_msgqueue.go57
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go43
-rw-r--r--pkg/sentry/syscalls/linux/sys_shm.go11
18 files changed, 1813 insertions, 431 deletions
diff --git a/pkg/abi/linux/linux_abi_autogen_unsafe.go b/pkg/abi/linux/linux_abi_autogen_unsafe.go
index ed00375de..29dadddbf 100644
--- a/pkg/abi/linux/linux_abi_autogen_unsafe.go
+++ b/pkg/abi/linux/linux_abi_autogen_unsafe.go
@@ -83,6 +83,9 @@ var _ marshal.Marshallable = (*KernelIP6TGetEntries)(nil)
var _ marshal.Marshallable = (*KernelIPTEntry)(nil)
var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil)
var _ marshal.Marshallable = (*Linger)(nil)
+var _ marshal.Marshallable = (*MsgBuf)(nil)
+var _ marshal.Marshallable = (*MsgInfo)(nil)
+var _ marshal.Marshallable = (*MsqidDS)(nil)
var _ marshal.Marshallable = (*NFNATRange)(nil)
var _ marshal.Marshallable = (*NetlinkAttrHeader)(nil)
var _ marshal.Marshallable = (*NetlinkErrorMessage)(nil)
@@ -5055,6 +5058,328 @@ func (n *NumaPolicy) WriteTo(w io.Writer) (int64, error) {
return int64(length), err
}
+// Packed implements marshal.Marshallable.Packed.
+//go:nosplit
+func (b *MsgBuf) Packed() bool {
+ // Type MsgBuf is dynamic so it might have slice/string headers. Hence, it is not packed.
+ return false
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (b *MsgBuf) MarshalUnsafe(dst []byte) {
+ // Type MsgBuf doesn't have a packed layout in memory, fallback to MarshalBytes.
+ b.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (b *MsgBuf) UnmarshalUnsafe(src []byte) {
+ // Type MsgBuf doesn't have a packed layout in memory, fallback to UnmarshalBytes.
+ b.UnmarshalBytes(src)
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+//go:nosplit
+func (b *MsgBuf) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {
+ // Type MsgBuf doesn't have a packed layout in memory, fall back to MarshalBytes.
+ buf := cc.CopyScratchBuffer(b.SizeBytes()) // escapes: okay.
+ b.MarshalBytes(buf) // escapes: fallback.
+ return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+//go:nosplit
+func (b *MsgBuf) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {
+ return b.CopyOutN(cc, addr, b.SizeBytes())
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+//go:nosplit
+func (b *MsgBuf) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {
+ // Type MsgBuf doesn't have a packed layout in memory, fall back to UnmarshalBytes.
+ buf := cc.CopyScratchBuffer(b.SizeBytes()) // escapes: okay.
+ length, err := cc.CopyInBytes(addr, buf) // escapes: okay.
+ // Unmarshal unconditionally. If we had a short copy-in, this results in a
+ // partially unmarshalled struct.
+ b.UnmarshalBytes(buf) // escapes: fallback.
+ return length, err
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (b *MsgBuf) WriteTo(writer io.Writer) (int64, error) {
+ // Type MsgBuf doesn't have a packed layout in memory, fall back to MarshalBytes.
+ buf := make([]byte, b.SizeBytes())
+ b.MarshalBytes(buf)
+ length, err := writer.Write(buf)
+ return int64(length), err
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MsgInfo) SizeBytes() int {
+ return 30
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MsgInfo) MarshalBytes(dst []byte) {
+ hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgPool))
+ dst = dst[4:]
+ hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgMap))
+ dst = dst[4:]
+ hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgMax))
+ dst = dst[4:]
+ hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgMnb))
+ dst = dst[4:]
+ hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgMni))
+ dst = dst[4:]
+ hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgSsz))
+ dst = dst[4:]
+ hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgTql))
+ dst = dst[4:]
+ hostarch.ByteOrder.PutUint16(dst[:2], uint16(m.MsgSeg))
+ dst = dst[2:]
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MsgInfo) UnmarshalBytes(src []byte) {
+ m.MsgPool = int32(hostarch.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ m.MsgMap = int32(hostarch.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ m.MsgMax = int32(hostarch.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ m.MsgMnb = int32(hostarch.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ m.MsgMni = int32(hostarch.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ m.MsgSsz = int32(hostarch.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ m.MsgTql = int32(hostarch.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ m.MsgSeg = uint16(hostarch.ByteOrder.Uint16(src[:2]))
+ src = src[2:]
+}
+
+// Packed implements marshal.Marshallable.Packed.
+//go:nosplit
+func (m *MsgInfo) Packed() bool {
+ return false
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (m *MsgInfo) MarshalUnsafe(dst []byte) {
+ // Type MsgInfo doesn't have a packed layout in memory, fallback to MarshalBytes.
+ m.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (m *MsgInfo) UnmarshalUnsafe(src []byte) {
+ // Type MsgInfo doesn't have a packed layout in memory, fallback to UnmarshalBytes.
+ m.UnmarshalBytes(src)
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+//go:nosplit
+func (m *MsgInfo) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {
+ // Type MsgInfo doesn't have a packed layout in memory, fall back to MarshalBytes.
+ buf := cc.CopyScratchBuffer(m.SizeBytes()) // escapes: okay.
+ m.MarshalBytes(buf) // escapes: fallback.
+ return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+//go:nosplit
+func (m *MsgInfo) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {
+ return m.CopyOutN(cc, addr, m.SizeBytes())
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+//go:nosplit
+func (m *MsgInfo) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {
+ // Type MsgInfo doesn't have a packed layout in memory, fall back to UnmarshalBytes.
+ buf := cc.CopyScratchBuffer(m.SizeBytes()) // escapes: okay.
+ length, err := cc.CopyInBytes(addr, buf) // escapes: okay.
+ // Unmarshal unconditionally. If we had a short copy-in, this results in a
+ // partially unmarshalled struct.
+ m.UnmarshalBytes(buf) // escapes: fallback.
+ return length, err
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (m *MsgInfo) WriteTo(writer io.Writer) (int64, error) {
+ // Type MsgInfo doesn't have a packed layout in memory, fall back to MarshalBytes.
+ buf := make([]byte, m.SizeBytes())
+ m.MarshalBytes(buf)
+ length, err := writer.Write(buf)
+ return int64(length), err
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MsqidDS) SizeBytes() int {
+ return 48 +
+ (*IPCPerm)(nil).SizeBytes() +
+ (*TimeT)(nil).SizeBytes() +
+ (*TimeT)(nil).SizeBytes() +
+ (*TimeT)(nil).SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MsqidDS) MarshalBytes(dst []byte) {
+ m.MsgPerm.MarshalBytes(dst[:m.MsgPerm.SizeBytes()])
+ dst = dst[m.MsgPerm.SizeBytes():]
+ m.MsgStime.MarshalBytes(dst[:m.MsgStime.SizeBytes()])
+ dst = dst[m.MsgStime.SizeBytes():]
+ m.MsgRtime.MarshalBytes(dst[:m.MsgRtime.SizeBytes()])
+ dst = dst[m.MsgRtime.SizeBytes():]
+ m.MsgCtime.MarshalBytes(dst[:m.MsgCtime.SizeBytes()])
+ dst = dst[m.MsgCtime.SizeBytes():]
+ hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.MsgCbytes))
+ dst = dst[8:]
+ hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.MsgQnum))
+ dst = dst[8:]
+ hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.MsgQbytes))
+ dst = dst[8:]
+ hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgLspid))
+ dst = dst[4:]
+ hostarch.ByteOrder.PutUint32(dst[:4], uint32(m.MsgLrpid))
+ dst = dst[4:]
+ hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.unused4))
+ dst = dst[8:]
+ hostarch.ByteOrder.PutUint64(dst[:8], uint64(m.unused5))
+ dst = dst[8:]
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MsqidDS) UnmarshalBytes(src []byte) {
+ m.MsgPerm.UnmarshalBytes(src[:m.MsgPerm.SizeBytes()])
+ src = src[m.MsgPerm.SizeBytes():]
+ m.MsgStime.UnmarshalBytes(src[:m.MsgStime.SizeBytes()])
+ src = src[m.MsgStime.SizeBytes():]
+ m.MsgRtime.UnmarshalBytes(src[:m.MsgRtime.SizeBytes()])
+ src = src[m.MsgRtime.SizeBytes():]
+ m.MsgCtime.UnmarshalBytes(src[:m.MsgCtime.SizeBytes()])
+ src = src[m.MsgCtime.SizeBytes():]
+ m.MsgCbytes = uint64(hostarch.ByteOrder.Uint64(src[:8]))
+ src = src[8:]
+ m.MsgQnum = uint64(hostarch.ByteOrder.Uint64(src[:8]))
+ src = src[8:]
+ m.MsgQbytes = uint64(hostarch.ByteOrder.Uint64(src[:8]))
+ src = src[8:]
+ m.MsgLspid = int32(hostarch.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ m.MsgLrpid = int32(hostarch.ByteOrder.Uint32(src[:4]))
+ src = src[4:]
+ m.unused4 = uint64(hostarch.ByteOrder.Uint64(src[:8]))
+ src = src[8:]
+ m.unused5 = uint64(hostarch.ByteOrder.Uint64(src[:8]))
+ src = src[8:]
+}
+
+// Packed implements marshal.Marshallable.Packed.
+//go:nosplit
+func (m *MsqidDS) Packed() bool {
+ return m.MsgCtime.Packed() && m.MsgPerm.Packed() && m.MsgRtime.Packed() && m.MsgStime.Packed()
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (m *MsqidDS) MarshalUnsafe(dst []byte) {
+ if m.MsgCtime.Packed() && m.MsgPerm.Packed() && m.MsgRtime.Packed() && m.MsgStime.Packed() {
+ gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(m), uintptr(m.SizeBytes()))
+ } else {
+ // Type MsqidDS doesn't have a packed layout in memory, fallback to MarshalBytes.
+ m.MarshalBytes(dst)
+ }
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (m *MsqidDS) UnmarshalUnsafe(src []byte) {
+ if m.MsgCtime.Packed() && m.MsgPerm.Packed() && m.MsgRtime.Packed() && m.MsgStime.Packed() {
+ gohacks.Memmove(unsafe.Pointer(m), unsafe.Pointer(&src[0]), uintptr(m.SizeBytes()))
+ } else {
+ // Type MsqidDS doesn't have a packed layout in memory, fallback to UnmarshalBytes.
+ m.UnmarshalBytes(src)
+ }
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+//go:nosplit
+func (m *MsqidDS) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {
+ if !m.MsgCtime.Packed() && m.MsgPerm.Packed() && m.MsgRtime.Packed() && m.MsgStime.Packed() {
+ // Type MsqidDS doesn't have a packed layout in memory, fall back to MarshalBytes.
+ buf := cc.CopyScratchBuffer(m.SizeBytes()) // escapes: okay.
+ m.MarshalBytes(buf) // escapes: fallback.
+ return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.
+ }
+
+ // Construct a slice backed by dst's underlying memory.
+ var buf []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))
+ hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m)))
+ hdr.Len = m.SizeBytes()
+ hdr.Cap = m.SizeBytes()
+
+ length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.
+ // Since we bypassed the compiler's escape analysis, indicate that m
+ // must live until the use above.
+ runtime.KeepAlive(m) // escapes: replaced by intrinsic.
+ return length, err
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+//go:nosplit
+func (m *MsqidDS) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {
+ return m.CopyOutN(cc, addr, m.SizeBytes())
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+//go:nosplit
+func (m *MsqidDS) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {
+ if !m.MsgCtime.Packed() && m.MsgPerm.Packed() && m.MsgRtime.Packed() && m.MsgStime.Packed() {
+ // Type MsqidDS doesn't have a packed layout in memory, fall back to UnmarshalBytes.
+ buf := cc.CopyScratchBuffer(m.SizeBytes()) // escapes: okay.
+ length, err := cc.CopyInBytes(addr, buf) // escapes: okay.
+ // Unmarshal unconditionally. If we had a short copy-in, this results in a
+ // partially unmarshalled struct.
+ m.UnmarshalBytes(buf) // escapes: fallback.
+ return length, err
+ }
+
+ // Construct a slice backed by dst's underlying memory.
+ var buf []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))
+ hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m)))
+ hdr.Len = m.SizeBytes()
+ hdr.Cap = m.SizeBytes()
+
+ length, err := cc.CopyInBytes(addr, buf) // escapes: okay.
+ // Since we bypassed the compiler's escape analysis, indicate that m
+ // must live until the use above.
+ runtime.KeepAlive(m) // escapes: replaced by intrinsic.
+ return length, err
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (m *MsqidDS) WriteTo(writer io.Writer) (int64, error) {
+ if !m.MsgCtime.Packed() && m.MsgPerm.Packed() && m.MsgRtime.Packed() && m.MsgStime.Packed() {
+ // Type MsqidDS doesn't have a packed layout in memory, fall back to MarshalBytes.
+ buf := make([]byte, m.SizeBytes())
+ m.MarshalBytes(buf)
+ length, err := writer.Write(buf)
+ return int64(length), err
+ }
+
+ // Construct a slice backed by dst's underlying memory.
+ var buf []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))
+ hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(m)))
+ hdr.Len = m.SizeBytes()
+ hdr.Cap = m.SizeBytes()
+
+ length, err := writer.Write(buf)
+ // Since we bypassed the compiler's escape analysis, indicate that m
+ // must live until the use above.
+ runtime.KeepAlive(m) // escapes: replaced by intrinsic.
+ return int64(length), err
+}
+
// SizeBytes implements marshal.Marshallable.SizeBytes.
func (i *IFConf) SizeBytes() int {
return 12 +
diff --git a/pkg/abi/linux/msgqueue.go b/pkg/abi/linux/msgqueue.go
new file mode 100644
index 000000000..e1e8d0357
--- /dev/null
+++ b/pkg/abi/linux/msgqueue.go
@@ -0,0 +1,108 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// Linux-specific control commands. Source: include/uapi/linux/msg.h
+const (
+ MSG_STAT = 11
+ MSG_INFO = 12
+ MSG_STAT_ANY = 13
+)
+
+// msgrcv(2) options. Source: include/uapi/linux/msg.h
+const (
+ MSG_NOERROR = 010000 // No error if message is too big.
+ MSG_EXCEPT = 020000 // Receive any message except of specified type.
+ MSG_COPY = 040000 // Copy (not remove) all queue messages.
+)
+
+// System-wide limits for message queues. Source: include/uapi/linux/msg.h
+const (
+ MSGMNI = 32000 // Maximum number of message queue identifiers.
+ MSGMAX = 8192 // Maximum size of message (bytes).
+ MSGMNB = 16384 // Default max size of a message queue.
+)
+
+// System-wide limits. Unused. Source: include/uapi/linux/msg.h
+const (
+ MSGPOOL = (MSGMNI * MSGMNB / 1024)
+ MSGTQL = MSGMNB
+ MSGMAP = MSGMNB
+ MSGSSZ = 16
+
+ // MSGSEG is simplified due to the inexistance of a ternary operator.
+ MSGSEG = (MSGPOOL * 1024) / MSGSSZ
+)
+
+// MsqidDS is equivelant to struct msqid64_ds. Source:
+// include/uapi/asm-generic/shmbuf.h
+//
+// +marshal
+type MsqidDS struct {
+ MsgPerm IPCPerm // IPC permissions.
+ MsgStime TimeT // Last msgsnd time.
+ MsgRtime TimeT // Last msgrcv time.
+ MsgCtime TimeT // Last change time.
+ MsgCbytes uint64 // Current number of bytes on the queue.
+ MsgQnum uint64 // Number of messages in the queue.
+ MsgQbytes uint64 // Max number of bytes in the queue.
+ MsgLspid int32 // PID of last msgsnd.
+ MsgLrpid int32 // PID of last msgrcv.
+ unused4 uint64
+ unused5 uint64
+}
+
+// MsgBuf is equivelant to struct msgbuf. Source: include/uapi/linux/msg.h
+//
+// +marshal dynamic
+type MsgBuf struct {
+ Type primitive.Int64
+ Text primitive.ByteSlice
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (b *MsgBuf) SizeBytes() int {
+ return b.Type.SizeBytes() + b.Text.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (b *MsgBuf) MarshalBytes(dst []byte) {
+ b.Type.MarshalUnsafe(dst)
+ b.Text.MarshalBytes(dst[b.Type.SizeBytes():])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (b *MsgBuf) UnmarshalBytes(src []byte) {
+ b.Type.UnmarshalUnsafe(src)
+ b.Text.UnmarshalBytes(src[b.Type.SizeBytes():])
+}
+
+// MsgInfo is equivelant to struct msginfo. Source: include/uapi/linux/msg.h
+//
+// +marshal
+type MsgInfo struct {
+ MsgPool int32
+ MsgMap int32
+ MsgMax int32
+ MsgMnb int32
+ MsgMni int32
+ MsgSsz int32
+ MsgTql int32
+ MsgSeg uint16 `marshal:"unaligned"`
+}
diff --git a/pkg/sentry/kernel/ipc/ipc_state_autogen.go b/pkg/sentry/kernel/ipc/ipc_state_autogen.go
new file mode 100644
index 000000000..b74f23a21
--- /dev/null
+++ b/pkg/sentry/kernel/ipc/ipc_state_autogen.go
@@ -0,0 +1,86 @@
+// automatically generated by stateify.
+
+package ipc
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (o *Object) StateTypeName() string {
+ return "pkg/sentry/kernel/ipc.Object"
+}
+
+func (o *Object) StateFields() []string {
+ return []string{
+ "UserNS",
+ "ID",
+ "Key",
+ "Creator",
+ "Owner",
+ "Perms",
+ }
+}
+
+func (o *Object) beforeSave() {}
+
+// +checklocksignore
+func (o *Object) StateSave(stateSinkObject state.Sink) {
+ o.beforeSave()
+ stateSinkObject.Save(0, &o.UserNS)
+ stateSinkObject.Save(1, &o.ID)
+ stateSinkObject.Save(2, &o.Key)
+ stateSinkObject.Save(3, &o.Creator)
+ stateSinkObject.Save(4, &o.Owner)
+ stateSinkObject.Save(5, &o.Perms)
+}
+
+func (o *Object) afterLoad() {}
+
+// +checklocksignore
+func (o *Object) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &o.UserNS)
+ stateSourceObject.Load(1, &o.ID)
+ stateSourceObject.Load(2, &o.Key)
+ stateSourceObject.Load(3, &o.Creator)
+ stateSourceObject.Load(4, &o.Owner)
+ stateSourceObject.Load(5, &o.Perms)
+}
+
+func (r *Registry) StateTypeName() string {
+ return "pkg/sentry/kernel/ipc.Registry"
+}
+
+func (r *Registry) StateFields() []string {
+ return []string{
+ "UserNS",
+ "objects",
+ "keysToIDs",
+ "lastIDUsed",
+ }
+}
+
+func (r *Registry) beforeSave() {}
+
+// +checklocksignore
+func (r *Registry) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.UserNS)
+ stateSinkObject.Save(1, &r.objects)
+ stateSinkObject.Save(2, &r.keysToIDs)
+ stateSinkObject.Save(3, &r.lastIDUsed)
+}
+
+func (r *Registry) afterLoad() {}
+
+// +checklocksignore
+func (r *Registry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.UserNS)
+ stateSourceObject.Load(1, &r.objects)
+ stateSourceObject.Load(2, &r.keysToIDs)
+ stateSourceObject.Load(3, &r.lastIDUsed)
+}
+
+func init() {
+ state.Register((*Object)(nil))
+ state.Register((*Registry)(nil))
+}
diff --git a/pkg/sentry/kernel/ipc/object.go b/pkg/sentry/kernel/ipc/object.go
new file mode 100644
index 000000000..387b35e7e
--- /dev/null
+++ b/pkg/sentry/kernel/ipc/object.go
@@ -0,0 +1,115 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ipc defines functionality and utilities common to sysvipc mechanisms.
+//
+// Lock ordering: [shm/semaphore/msgqueue].Registry.mu -> Mechanism
+package ipc
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// Key is a user-provided identifier for IPC objects.
+type Key int32
+
+// ID is a kernel identifier for IPC objects.
+type ID int32
+
+// Object represents an abstract IPC object with fields common to all IPC
+// mechanisms.
+//
+// +stateify savable
+type Object struct {
+ // User namespace which owns the IPC namespace which owns the IPC object.
+ // Immutable.
+ UserNS *auth.UserNamespace
+
+ // ID is a kernel identifier for the IPC object. Immutable.
+ ID ID
+
+ // Key is a user-provided identifier for the IPC object. Immutable.
+ Key Key
+
+ // Creator is the user who created the IPC object. Immutable.
+ Creator fs.FileOwner
+
+ // Owner is the current owner of the IPC object.
+ Owner fs.FileOwner
+
+ // Perms is the access permissions the IPC object.
+ Perms fs.FilePermissions
+}
+
+// Mechanism represents a SysV mechanism that holds an IPC object. It can also
+// be looked at as a container for an ipc.Object, which is by definition a fully
+// functional SysV object.
+type Mechanism interface {
+ // Lock behaves the same as Mutex.Lock on the mechanism.
+ Lock()
+
+ // Unlock behaves the same as Mutex.Unlock on the mechanism.
+ Unlock()
+
+ // Object returns a pointer to the mechanism's ipc.Object. Mechanism.Lock,
+ // and Mechanism.Unlock should be used when the object is used.
+ Object() *Object
+
+ // Destroy destroys the mechanism.
+ Destroy()
+}
+
+// NewObject returns a new, initialized ipc.Object. The newly returned object
+// doesn't have a valid ID. When the object is registered, the registry assigns
+// it a new unique ID.
+func NewObject(un *auth.UserNamespace, key Key, creator, owner fs.FileOwner, perms fs.FilePermissions) *Object {
+ return &Object{
+ UserNS: un,
+ Key: key,
+ Creator: creator,
+ Owner: owner,
+ Perms: perms,
+ }
+}
+
+// CheckOwnership verifies whether an IPC object may be accessed using creds as
+// an owner. See ipc/util.c:ipcctl_obtain_check() in Linux.
+func (o *Object) CheckOwnership(creds *auth.Credentials) bool {
+ if o.Owner.UID == creds.EffectiveKUID || o.Creator.UID == creds.EffectiveKUID {
+ return true
+ }
+
+ // Tasks with CAP_SYS_ADMIN may bypass ownership checks. Strangely, Linux
+ // doesn't use CAP_IPC_OWNER for this despite CAP_IPC_OWNER being documented
+ // for use to "override IPC ownership checks".
+ return creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, o.UserNS)
+}
+
+// CheckPermissions verifies whether an IPC object is accessible using creds for
+// access described by req. See ipc/util.c:ipcperms() in Linux.
+func (o *Object) CheckPermissions(creds *auth.Credentials, req fs.PermMask) bool {
+ p := o.Perms.Other
+ if o.Owner.UID == creds.EffectiveKUID {
+ p = o.Perms.User
+ } else if creds.InGroup(o.Owner.GID) {
+ p = o.Perms.Group
+ }
+
+ if p.SupersetOf(req) {
+ return true
+ }
+ return creds.HasCapabilityIn(linux.CAP_IPC_OWNER, o.UserNS)
+}
diff --git a/pkg/sentry/kernel/ipc/registry.go b/pkg/sentry/kernel/ipc/registry.go
new file mode 100644
index 000000000..91de19070
--- /dev/null
+++ b/pkg/sentry/kernel/ipc/registry.go
@@ -0,0 +1,196 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipc
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// Registry is similar to Object, but for registries. It represent an abstract
+// SysV IPC registry with fields common to all SysV registries. Registry is not
+// thread-safe, and should be protected using a mutex.
+//
+// +stateify savable
+type Registry struct {
+ // UserNS owning the IPC namespace this registry belongs to. Immutable.
+ UserNS *auth.UserNamespace
+
+ // objects is a map of IDs to IPC mechanisms.
+ objects map[ID]Mechanism
+
+ // KeysToIDs maps a lookup key to an ID.
+ keysToIDs map[Key]ID
+
+ // lastIDUsed is used to find the next available ID for object creation.
+ lastIDUsed ID
+}
+
+// NewRegistry return a new, initialized ipc.Registry.
+func NewRegistry(userNS *auth.UserNamespace) *Registry {
+ return &Registry{
+ UserNS: userNS,
+ objects: make(map[ID]Mechanism),
+ keysToIDs: make(map[Key]ID),
+ }
+}
+
+// Find uses key to search for and return a SysV mechanism. Find returns an
+// error if an object is found by shouldn't be, or if the user doesn't have
+// permission to use the object. If no object is found, Find checks create
+// flag, and returns an error only if it's false.
+func (r *Registry) Find(ctx context.Context, key Key, mode linux.FileMode, create, exclusive bool) (Mechanism, error) {
+ if id, ok := r.keysToIDs[key]; ok {
+ mech := r.objects[id]
+ mech.Lock()
+ defer mech.Unlock()
+
+ obj := mech.Object()
+ creds := auth.CredentialsFromContext(ctx)
+ if !obj.CheckPermissions(creds, fs.PermsFromMode(mode)) {
+ // The [calling process / user] does not have permission to access
+ // the set, and does not have the CAP_IPC_OWNER capability in the
+ // user namespace that governs its IPC namespace.
+ return nil, linuxerr.EACCES
+ }
+
+ if create && exclusive {
+ // IPC_CREAT and IPC_EXCL were specified, but an object already
+ // exists for key.
+ return nil, linuxerr.EEXIST
+ }
+ return mech, nil
+ }
+
+ if !create {
+ // No object exists for key and msgflg did not specify IPC_CREAT.
+ return nil, linuxerr.ENOENT
+ }
+
+ return nil, nil
+}
+
+// Register adds the given object into Registry.Objects, and assigns it a new
+// ID. It returns an error if all IDs are exhausted.
+func (r *Registry) Register(m Mechanism) error {
+ id, err := r.newID()
+ if err != nil {
+ return err
+ }
+
+ obj := m.Object()
+ obj.ID = id
+
+ r.objects[id] = m
+ r.keysToIDs[obj.Key] = id
+
+ return nil
+}
+
+// newID finds the first unused ID in the registry, and returns an error if
+// non is found.
+func (r *Registry) newID() (ID, error) {
+ // Find the next available ID.
+ for id := r.lastIDUsed + 1; id != r.lastIDUsed; id++ {
+ // Handle wrap around.
+ if id < 0 {
+ id = 0
+ continue
+ }
+ if r.objects[id] == nil {
+ r.lastIDUsed = id
+ return id, nil
+ }
+ }
+
+ log.Warningf("ids exhausted, they may be leaking")
+
+ // The man pages for shmget(2) mention that ENOSPC should be used if "All
+ // possible shared memory IDs have been taken (SHMMNI)". Other SysV
+ // mechanisms don't have a specific errno for running out of IDs, but they
+ // return ENOSPC if the max number of objects is exceeded, so we assume that
+ // it's the same case.
+ return 0, linuxerr.ENOSPC
+}
+
+// Remove removes the mechanism with the given id from the registry, and calls
+// mechanism.Destroy to perform mechanism-specific removal.
+func (r *Registry) Remove(id ID, creds *auth.Credentials) error {
+ mech := r.objects[id]
+ if mech == nil {
+ return linuxerr.EINVAL
+ }
+
+ mech.Lock()
+ defer mech.Unlock()
+
+ obj := mech.Object()
+
+ // The effective user ID of the calling process must match the creator or
+ // owner of the [mechanism], or the caller must be privileged.
+ if !obj.CheckOwnership(creds) {
+ return linuxerr.EPERM
+ }
+
+ delete(r.objects, obj.ID)
+ delete(r.keysToIDs, obj.Key)
+ mech.Destroy()
+
+ return nil
+}
+
+// ForAllObjects executes a given function for all given objects.
+func (r *Registry) ForAllObjects(f func(o Mechanism)) {
+ for _, o := range r.objects {
+ f(o)
+ }
+}
+
+// FindByID returns the mechanism with the given ID, nil if non exists.
+func (r *Registry) FindByID(id ID) Mechanism {
+ return r.objects[id]
+}
+
+// DissociateKey removes the association between a mechanism and its key
+// (deletes it from r.keysToIDs), preventing it from being discovered by any new
+// process, but not necessarily destroying it. If the given key doesn't exist,
+// nothing is changed.
+func (r *Registry) DissociateKey(key Key) {
+ delete(r.keysToIDs, key)
+}
+
+// DissociateID removes the association between a mechanism and its ID (deletes
+// it from r.objects). An ID can't be removed unless the associated key is
+// removed already, this is done to prevent the users from acquiring nil a
+// Mechanism.
+//
+// Precondition: must be preceded by a call to r.DissociateKey.
+func (r *Registry) DissociateID(id ID) {
+ delete(r.objects, id)
+}
+
+// ObjectCount returns the number of registered objects.
+func (r *Registry) ObjectCount() int {
+ return len(r.objects)
+}
+
+// LastIDUsed returns the last used ID.
+func (r *Registry) LastIDUsed() ID {
+ return r.lastIDUsed
+}
diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go
index 9545bb5ef..0b101b1bb 100644
--- a/pkg/sentry/kernel/ipc_namespace.go
+++ b/pkg/sentry/kernel/ipc_namespace.go
@@ -17,6 +17,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/msgqueue"
"gvisor.dev/gvisor/pkg/sentry/kernel/semaphore"
"gvisor.dev/gvisor/pkg/sentry/kernel/shm"
)
@@ -30,6 +31,7 @@ type IPCNamespace struct {
// User namespace which owns this IPC namespace. Immutable.
userNS *auth.UserNamespace
+ queues *msgqueue.Registry
semaphores *semaphore.Registry
shms *shm.Registry
}
@@ -38,6 +40,7 @@ type IPCNamespace struct {
func NewIPCNamespace(userNS *auth.UserNamespace) *IPCNamespace {
ns := &IPCNamespace{
userNS: userNS,
+ queues: msgqueue.NewRegistry(userNS),
semaphores: semaphore.NewRegistry(userNS),
shms: shm.NewRegistry(userNS),
}
@@ -45,6 +48,11 @@ func NewIPCNamespace(userNS *auth.UserNamespace) *IPCNamespace {
return ns
}
+// MsgqueueRegistry returns the message queue registry for this namespace.
+func (i *IPCNamespace) MsgqueueRegistry() *msgqueue.Registry {
+ return i.queues
+}
+
// SemaphoreRegistry returns the semaphore set registry for this namespace.
func (i *IPCNamespace) SemaphoreRegistry() *semaphore.Registry {
return i.semaphores
diff --git a/pkg/sentry/kernel/kernel_state_autogen.go b/pkg/sentry/kernel/kernel_state_autogen.go
index 860599e73..dd4ecfe0f 100644
--- a/pkg/sentry/kernel/kernel_state_autogen.go
+++ b/pkg/sentry/kernel/kernel_state_autogen.go
@@ -348,6 +348,7 @@ func (i *IPCNamespace) StateFields() []string {
return []string{
"IPCNamespaceRefs",
"userNS",
+ "queues",
"semaphores",
"shms",
}
@@ -360,8 +361,9 @@ func (i *IPCNamespace) StateSave(stateSinkObject state.Sink) {
i.beforeSave()
stateSinkObject.Save(0, &i.IPCNamespaceRefs)
stateSinkObject.Save(1, &i.userNS)
- stateSinkObject.Save(2, &i.semaphores)
- stateSinkObject.Save(3, &i.shms)
+ stateSinkObject.Save(2, &i.queues)
+ stateSinkObject.Save(3, &i.semaphores)
+ stateSinkObject.Save(4, &i.shms)
}
func (i *IPCNamespace) afterLoad() {}
@@ -370,8 +372,9 @@ func (i *IPCNamespace) afterLoad() {}
func (i *IPCNamespace) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(0, &i.IPCNamespaceRefs)
stateSourceObject.Load(1, &i.userNS)
- stateSourceObject.Load(2, &i.semaphores)
- stateSourceObject.Load(3, &i.shms)
+ stateSourceObject.Load(2, &i.queues)
+ stateSourceObject.Load(3, &i.semaphores)
+ stateSourceObject.Load(4, &i.shms)
}
func (r *IPCNamespaceRefs) StateTypeName() string {
diff --git a/pkg/sentry/kernel/msgqueue/message_list.go b/pkg/sentry/kernel/msgqueue/message_list.go
new file mode 100644
index 000000000..f2f2292e7
--- /dev/null
+++ b/pkg/sentry/kernel/msgqueue/message_list.go
@@ -0,0 +1,221 @@
+package msgqueue
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type msgElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (msgElementMapper) linkerFor(elem *Message) *Message { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type msgList struct {
+ head *Message
+ tail *Message
+}
+
+// Reset resets list l to the empty state.
+func (l *msgList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *msgList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *msgList) Front() *Message {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *msgList) Back() *Message {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *msgList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (msgElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *msgList) PushFront(e *Message) {
+ linker := msgElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ msgElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *msgList) PushBack(e *Message) {
+ linker := msgElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ msgElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *msgList) PushBackList(m *msgList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ msgElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ msgElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *msgList) InsertAfter(b, e *Message) {
+ bLinker := msgElementMapper{}.linkerFor(b)
+ eLinker := msgElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ msgElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *msgList) InsertBefore(a, e *Message) {
+ aLinker := msgElementMapper{}.linkerFor(a)
+ eLinker := msgElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ msgElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *msgList) Remove(e *Message) {
+ linker := msgElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ msgElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ msgElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type msgEntry struct {
+ next *Message
+ prev *Message
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *msgEntry) Next() *Message {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *msgEntry) Prev() *Message {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *msgEntry) SetNext(elem *Message) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *msgEntry) SetPrev(elem *Message) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go
new file mode 100644
index 000000000..3ce926950
--- /dev/null
+++ b/pkg/sentry/kernel/msgqueue/msgqueue.go
@@ -0,0 +1,220 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package msgqueue implements System V message queues.
+package msgqueue
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // System-wide limit for maximum number of queues.
+ maxQueues = linux.MSGMNI
+
+ // Maximum size of a queue in bytes.
+ maxQueueBytes = linux.MSGMNB
+
+ // Maximum size of a message in bytes.
+ maxMessageBytes = linux.MSGMAX
+)
+
+// Registry contains a set of message queues that can be referenced using keys
+// or IDs.
+//
+// +stateify savable
+type Registry struct {
+ // mu protects all the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // reg defines basic fields and operations needed for all SysV registries.
+ reg *ipc.Registry
+}
+
+// NewRegistry returns a new Registry ready to be used.
+func NewRegistry(userNS *auth.UserNamespace) *Registry {
+ return &Registry{
+ reg: ipc.NewRegistry(userNS),
+ }
+}
+
+// Queue represents a SysV message queue, described by sysvipc(7).
+//
+// +stateify savable
+type Queue struct {
+ // registry is the registry owning this queue. Immutable.
+ registry *Registry
+
+ // mu protects all the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // dead is set to true when a queue is removed from the registry and should
+ // not be used. Operations on the queue should check dead, and return
+ // EIDRM if set to true.
+ dead bool
+
+ // obj defines basic fields that should be included in all SysV IPC objects.
+ obj *ipc.Object
+
+ // senders holds a queue of blocked message senders. Senders are notified
+ // when enough space is available in the queue to insert their message.
+ senders waiter.Queue
+
+ // receivers holds a queue of blocked receivers. Receivers are notified
+ // when a new message is inserted into the queue and can be received.
+ receivers waiter.Queue
+
+ // messages is a list of sent messages.
+ messages msgList
+
+ // sendTime is the last time a msgsnd was perfomed.
+ sendTime ktime.Time
+
+ // receiveTime is the last time a msgrcv was performed.
+ receiveTime ktime.Time
+
+ // changeTime is the last time the queue was modified using msgctl.
+ changeTime ktime.Time
+
+ // byteCount is the current number of message bytes in the queue.
+ byteCount uint64
+
+ // messageCount is the current number of messages in the queue.
+ messageCount uint64
+
+ // maxBytes is the maximum allowed number of bytes in the queue, and is also
+ // used as a limit for the number of total possible messages.
+ maxBytes uint64
+
+ // sendPID is the PID of the process that performed the last msgsnd.
+ sendPID int32
+
+ // receivePID is the PID of the process that performed the last msgrcv.
+ receivePID int32
+}
+
+// Message represents a message exchanged through a Queue via msgsnd(2) and
+// msgrcv(2).
+//
+// +stateify savable
+type Message struct {
+ msgEntry
+
+ // mType is an integer representing the type of the sent message.
+ mType int64
+
+ // mText is an untyped block of memory.
+ mText []byte
+
+ // mSize is the size of mText.
+ mSize uint64
+}
+
+// FindOrCreate creates a new message queue or returns an existing one. See
+// msgget(2).
+func (r *Registry) FindOrCreate(ctx context.Context, key ipc.Key, mode linux.FileMode, private, create, exclusive bool) (*Queue, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if !private {
+ queue, err := r.reg.Find(ctx, key, mode, create, exclusive)
+ if err != nil {
+ return nil, err
+ }
+
+ if queue != nil {
+ return queue.(*Queue), nil
+ }
+ }
+
+ // Check system-wide limits.
+ if r.reg.ObjectCount() >= maxQueues {
+ return nil, linuxerr.ENOSPC
+ }
+
+ return r.newQueueLocked(ctx, key, fs.FileOwnerFromContext(ctx), fs.FilePermsFromMode(mode))
+}
+
+// newQueueLocked creates a new queue using the given fields. An error is
+// returned if there're no more available identifiers.
+//
+// Precondition: r.mu must be held.
+func (r *Registry) newQueueLocked(ctx context.Context, key ipc.Key, creator fs.FileOwner, perms fs.FilePermissions) (*Queue, error) {
+ q := &Queue{
+ registry: r,
+ obj: ipc.NewObject(r.reg.UserNS, key, creator, creator, perms),
+ sendTime: ktime.ZeroTime,
+ receiveTime: ktime.ZeroTime,
+ changeTime: ktime.NowFromContext(ctx),
+ maxBytes: maxQueueBytes,
+ }
+
+ err := r.reg.Register(q)
+ if err != nil {
+ return nil, err
+ }
+ return q, nil
+}
+
+// Remove removes the queue with specified ID. All waiters (readers and
+// writers) and writers will be awakened and fail. Remove will return an error
+// if the ID is invalid, or the the user doesn't have privileges.
+func (r *Registry) Remove(id ipc.ID, creds *auth.Credentials) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ r.reg.Remove(id, creds)
+ return nil
+}
+
+// Lock implements ipc.Mechanism.Lock.
+func (q *Queue) Lock() {
+ q.mu.Lock()
+}
+
+// Unlock implements ipc.mechanism.Unlock.
+//
+// +checklocksignore
+func (q *Queue) Unlock() {
+ q.mu.Unlock()
+}
+
+// Object implements ipc.Mechanism.Object.
+func (q *Queue) Object() *ipc.Object {
+ return q.obj
+}
+
+// Destroy implements ipc.Mechanism.Destroy.
+func (q *Queue) Destroy() {
+ q.dead = true
+
+ // Notify waiters. Senders and receivers will try to run, and return an
+ // error (EIDRM). Waiters should remove themselves from the queue after
+ // waking up.
+ q.senders.Notify(waiter.EventOut)
+ q.receivers.Notify(waiter.EventIn)
+}
+
+// ID returns queue's ID.
+func (q *Queue) ID() ipc.ID {
+ return q.obj.ID
+}
diff --git a/pkg/sentry/kernel/msgqueue/msgqueue_state_autogen.go b/pkg/sentry/kernel/msgqueue/msgqueue_state_autogen.go
new file mode 100644
index 000000000..3dfcd09cb
--- /dev/null
+++ b/pkg/sentry/kernel/msgqueue/msgqueue_state_autogen.go
@@ -0,0 +1,194 @@
+// automatically generated by stateify.
+
+package msgqueue
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (l *msgList) StateTypeName() string {
+ return "pkg/sentry/kernel/msgqueue.msgList"
+}
+
+func (l *msgList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *msgList) beforeSave() {}
+
+// +checklocksignore
+func (l *msgList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *msgList) afterLoad() {}
+
+// +checklocksignore
+func (l *msgList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *msgEntry) StateTypeName() string {
+ return "pkg/sentry/kernel/msgqueue.msgEntry"
+}
+
+func (e *msgEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *msgEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *msgEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *msgEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *msgEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func (r *Registry) StateTypeName() string {
+ return "pkg/sentry/kernel/msgqueue.Registry"
+}
+
+func (r *Registry) StateFields() []string {
+ return []string{
+ "reg",
+ }
+}
+
+func (r *Registry) beforeSave() {}
+
+// +checklocksignore
+func (r *Registry) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.reg)
+}
+
+func (r *Registry) afterLoad() {}
+
+// +checklocksignore
+func (r *Registry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.reg)
+}
+
+func (q *Queue) StateTypeName() string {
+ return "pkg/sentry/kernel/msgqueue.Queue"
+}
+
+func (q *Queue) StateFields() []string {
+ return []string{
+ "registry",
+ "dead",
+ "obj",
+ "senders",
+ "receivers",
+ "messages",
+ "sendTime",
+ "receiveTime",
+ "changeTime",
+ "byteCount",
+ "messageCount",
+ "maxBytes",
+ "sendPID",
+ "receivePID",
+ }
+}
+
+func (q *Queue) beforeSave() {}
+
+// +checklocksignore
+func (q *Queue) StateSave(stateSinkObject state.Sink) {
+ q.beforeSave()
+ stateSinkObject.Save(0, &q.registry)
+ stateSinkObject.Save(1, &q.dead)
+ stateSinkObject.Save(2, &q.obj)
+ stateSinkObject.Save(3, &q.senders)
+ stateSinkObject.Save(4, &q.receivers)
+ stateSinkObject.Save(5, &q.messages)
+ stateSinkObject.Save(6, &q.sendTime)
+ stateSinkObject.Save(7, &q.receiveTime)
+ stateSinkObject.Save(8, &q.changeTime)
+ stateSinkObject.Save(9, &q.byteCount)
+ stateSinkObject.Save(10, &q.messageCount)
+ stateSinkObject.Save(11, &q.maxBytes)
+ stateSinkObject.Save(12, &q.sendPID)
+ stateSinkObject.Save(13, &q.receivePID)
+}
+
+func (q *Queue) afterLoad() {}
+
+// +checklocksignore
+func (q *Queue) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &q.registry)
+ stateSourceObject.Load(1, &q.dead)
+ stateSourceObject.Load(2, &q.obj)
+ stateSourceObject.Load(3, &q.senders)
+ stateSourceObject.Load(4, &q.receivers)
+ stateSourceObject.Load(5, &q.messages)
+ stateSourceObject.Load(6, &q.sendTime)
+ stateSourceObject.Load(7, &q.receiveTime)
+ stateSourceObject.Load(8, &q.changeTime)
+ stateSourceObject.Load(9, &q.byteCount)
+ stateSourceObject.Load(10, &q.messageCount)
+ stateSourceObject.Load(11, &q.maxBytes)
+ stateSourceObject.Load(12, &q.sendPID)
+ stateSourceObject.Load(13, &q.receivePID)
+}
+
+func (m *Message) StateTypeName() string {
+ return "pkg/sentry/kernel/msgqueue.Message"
+}
+
+func (m *Message) StateFields() []string {
+ return []string{
+ "msgEntry",
+ "mType",
+ "mText",
+ "mSize",
+ }
+}
+
+func (m *Message) beforeSave() {}
+
+// +checklocksignore
+func (m *Message) StateSave(stateSinkObject state.Sink) {
+ m.beforeSave()
+ stateSinkObject.Save(0, &m.msgEntry)
+ stateSinkObject.Save(1, &m.mType)
+ stateSinkObject.Save(2, &m.mText)
+ stateSinkObject.Save(3, &m.mSize)
+}
+
+func (m *Message) afterLoad() {}
+
+// +checklocksignore
+func (m *Message) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &m.msgEntry)
+ stateSourceObject.Load(1, &m.mType)
+ stateSourceObject.Load(2, &m.mText)
+ stateSourceObject.Load(3, &m.mSize)
+}
+
+func init() {
+ state.Register((*msgList)(nil))
+ state.Register((*msgEntry)(nil))
+ state.Register((*Registry)(nil))
+ state.Register((*Queue)(nil))
+ state.Register((*Message)(nil))
+}
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index 485c3a788..b7879d284 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -21,9 +21,9 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
- "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -47,15 +47,15 @@ const (
//
// +stateify savable
type Registry struct {
- // userNS owning the ipc name this registry belongs to. Immutable.
- userNS *auth.UserNamespace
// mu protects all fields below.
- mu sync.Mutex `state:"nosave"`
- semaphores map[int32]*Set
- lastIDUsed int32
+ mu sync.Mutex `state:"nosave"`
+
+ // reg defines basic fields and operations needed for all SysV registries.
+ reg *ipc.Registry
+
// indexes maintains a mapping between a set's index in virtual array and
// its identifier.
- indexes map[int32]int32
+ indexes map[int32]ipc.ID
}
// Set represents a set of semaphores that can be operated atomically.
@@ -65,19 +65,11 @@ type Set struct {
// registry owning this sem set. Immutable.
registry *Registry
- // Id is a handle that identifies the set.
- ID int32
-
- // key is an user provided key that can be shared between processes.
- key int32
+ // mu protects all fields below.
+ mu sync.Mutex `state:"nosave"`
- // creator is the user that created the set. Immutable.
- creator fs.FileOwner
+ obj *ipc.Object
- // mu protects all fields below.
- mu sync.Mutex `state:"nosave"`
- owner fs.FileOwner
- perms fs.FilePermissions
opTime ktime.Time
changeTime ktime.Time
@@ -115,9 +107,8 @@ type waiter struct {
// NewRegistry creates a new semaphore set registry.
func NewRegistry(userNS *auth.UserNamespace) *Registry {
return &Registry{
- userNS: userNS,
- semaphores: make(map[int32]*Set),
- indexes: make(map[int32]int32),
+ reg: ipc.NewRegistry(userNS),
+ indexes: make(map[int32]ipc.ID),
}
}
@@ -126,7 +117,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
// a new set is always created. If create is false, it fails if a set cannot
// be found. If exclusive is true, it fails if a set with the same key already
// exists.
-func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) {
+func (r *Registry) FindOrCreate(ctx context.Context, key ipc.Key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) {
if nsems < 0 || nsems > semsMax {
return nil, linuxerr.EINVAL
}
@@ -135,31 +126,19 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
defer r.mu.Unlock()
if !private {
- // Look up an existing semaphore.
- if set := r.findByKey(key); set != nil {
- set.mu.Lock()
- defer set.mu.Unlock()
-
- // Check that caller can access semaphore set.
- creds := auth.CredentialsFromContext(ctx)
- if !set.checkPerms(creds, fs.PermsFromMode(mode)) {
- return nil, linuxerr.EACCES
- }
+ set, err := r.reg.Find(ctx, key, mode, create, exclusive)
+ if err != nil {
+ return nil, err
+ }
- // Validate parameters.
+ // Validate semaphore-specific parameters.
+ if set != nil {
+ set := set.(*Set)
if nsems > int32(set.Size()) {
return nil, linuxerr.EINVAL
}
- if create && exclusive {
- return nil, linuxerr.EEXIST
- }
return set, nil
}
-
- if !create {
- // Semaphore not found and should not be created.
- return nil, syserror.ENOENT
- }
}
// Zero is only valid if an existing set is found.
@@ -169,9 +148,9 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
// Apply system limits.
//
- // Map semaphores and map indexes in a registry are of the same size,
- // check map semaphores only here for the system limit.
- if len(r.semaphores) >= setsMax {
+ // Map reg.objects and map indexes in a registry are of the same size,
+ // check map reg.objects only here for the system limit.
+ if r.reg.ObjectCount() >= setsMax {
return nil, syserror.ENOSPC
}
if r.totalSems() > int(semsTotalMax-nsems) {
@@ -179,9 +158,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
}
// Finally create a new set.
- owner := fs.FileOwnerFromContext(ctx)
- perms := fs.FilePermsFromMode(mode)
- return r.newSet(ctx, key, owner, owner, perms, nsems)
+ return r.newSetLocked(ctx, key, fs.FileOwnerFromContext(ctx), fs.FilePermsFromMode(mode), nsems)
}
// IPCInfo returns information about system-wide semaphore limits and parameters.
@@ -208,7 +185,7 @@ func (r *Registry) SemInfo() *linux.SemInfo {
defer r.mu.Unlock()
info := r.IPCInfo()
- info.SemUsz = uint32(len(r.semaphores))
+ info.SemUsz = uint32(r.reg.ObjectCount())
info.SemAem = uint32(r.totalSems())
return info
@@ -231,77 +208,59 @@ func (r *Registry) HighestIndex() int32 {
return highestIndex
}
-// RemoveID removes set with give 'id' from the registry and marks the set as
+// Remove removes set with give 'id' from the registry and marks the set as
// dead. All waiters will be awakened and fail.
-func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
+func (r *Registry) Remove(id ipc.ID, creds *auth.Credentials) error {
r.mu.Lock()
defer r.mu.Unlock()
- set := r.semaphores[id]
- if set == nil {
- return linuxerr.EINVAL
- }
+ r.reg.Remove(id, creds)
+
index, found := r.findIndexByID(id)
if !found {
// Inconsistent state.
panic(fmt.Sprintf("unable to find an index for ID: %d", id))
}
-
- set.mu.Lock()
- defer set.mu.Unlock()
-
- // "The effective user ID of the calling process must match the creator or
- // owner of the semaphore set, or the caller must be privileged."
- if !set.checkCredentials(creds) && !set.checkCapability(creds) {
- return linuxerr.EACCES
- }
-
- delete(r.semaphores, set.ID)
delete(r.indexes, index)
- set.destroy()
+
return nil
}
-func (r *Registry) newSet(ctx context.Context, key int32, owner, creator fs.FileOwner, perms fs.FilePermissions, nsems int32) (*Set, error) {
+// newSetLocked creates a new Set using given fields. An error is returned if there
+// are no more available identifiers.
+//
+// Precondition: r.mu must be held.
+func (r *Registry) newSetLocked(ctx context.Context, key ipc.Key, creator fs.FileOwner, perms fs.FilePermissions, nsems int32) (*Set, error) {
set := &Set{
registry: r,
- key: key,
- owner: owner,
- creator: owner,
- perms: perms,
+ obj: ipc.NewObject(r.reg.UserNS, ipc.Key(key), creator, creator, perms),
changeTime: ktime.NowFromContext(ctx),
sems: make([]sem, nsems),
}
- // Find the next available ID.
- for id := r.lastIDUsed + 1; id != r.lastIDUsed; id++ {
- // Handle wrap around.
- if id < 0 {
- id = 0
- continue
- }
- if r.semaphores[id] == nil {
- index, found := r.findFirstAvailableIndex()
- if !found {
- panic("unable to find an available index")
- }
- r.indexes[index] = id
- r.lastIDUsed = id
- r.semaphores[id] = set
- set.ID = id
- return set, nil
- }
+ err := r.reg.Register(set)
+ if err != nil {
+ return nil, err
}
- log.Warningf("Semaphore map is full, they must be leaking")
- return nil, syserror.ENOMEM
+ index, found := r.findFirstAvailableIndex()
+ if !found {
+ panic("unable to find an available index")
+ }
+ r.indexes[index] = set.obj.ID
+
+ return set, nil
}
// FindByID looks up a set given an ID.
-func (r *Registry) FindByID(id int32) *Set {
+func (r *Registry) FindByID(id ipc.ID) *Set {
r.mu.Lock()
defer r.mu.Unlock()
- return r.semaphores[id]
+ mech := r.reg.FindByID(id)
+ if mech == nil {
+ return nil
+ }
+ return mech.(*Set)
}
// FindByIndex looks up a set given an index.
@@ -313,19 +272,10 @@ func (r *Registry) FindByIndex(index int32) *Set {
if !present {
return nil
}
- return r.semaphores[id]
+ return r.reg.FindByID(id).(*Set)
}
-func (r *Registry) findByKey(key int32) *Set {
- for _, v := range r.semaphores {
- if v.key == key {
- return v
- }
- }
- return nil
-}
-
-func (r *Registry) findIndexByID(id int32) (int32, bool) {
+func (r *Registry) findIndexByID(id ipc.ID) (int32, bool) {
for k, v := range r.indexes {
if v == id {
return k, true
@@ -345,12 +295,36 @@ func (r *Registry) findFirstAvailableIndex() (int32, bool) {
func (r *Registry) totalSems() int {
totalSems := 0
- for _, v := range r.semaphores {
- totalSems += v.Size()
- }
+ r.reg.ForAllObjects(
+ func(o ipc.Mechanism) {
+ totalSems += o.(*Set).Size()
+ },
+ )
return totalSems
}
+// ID returns semaphore's ID.
+func (s *Set) ID() ipc.ID {
+ return s.obj.ID
+}
+
+// Object implements ipc.Mechanism.Object.
+func (s *Set) Object() *ipc.Object {
+ return s.obj
+}
+
+// Lock implements ipc.Mechanism.Lock.
+func (s *Set) Lock() {
+ s.mu.Lock()
+}
+
+// Unlock implements ipc.mechanism.Unlock.
+//
+// +checklocksignore
+func (s *Set) Unlock() {
+ s.mu.Unlock()
+}
+
func (s *Set) findSem(num int32) *sem {
if num < 0 || int(num) >= s.Size() {
return nil
@@ -370,12 +344,12 @@ func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.File
// "The effective UID of the calling process must match the owner or creator
// of the semaphore set, or the caller must be privileged."
- if !s.checkCredentials(creds) && !s.checkCapability(creds) {
+ if !s.obj.CheckOwnership(creds) {
return linuxerr.EACCES
}
- s.owner = owner
- s.perms = perms
+ s.obj.Owner = owner
+ s.obj.Perms = perms
s.changeTime = ktime.NowFromContext(ctx)
return nil
}
@@ -395,18 +369,18 @@ func (s *Set) semStat(creds *auth.Credentials, permMask fs.PermMask) (*linux.Sem
s.mu.Lock()
defer s.mu.Unlock()
- if !s.checkPerms(creds, permMask) {
+ if !s.obj.CheckPermissions(creds, permMask) {
return nil, linuxerr.EACCES
}
return &linux.SemidDS{
SemPerm: linux.IPCPerm{
- Key: uint32(s.key),
- UID: uint32(creds.UserNamespace.MapFromKUID(s.owner.UID)),
- GID: uint32(creds.UserNamespace.MapFromKGID(s.owner.GID)),
- CUID: uint32(creds.UserNamespace.MapFromKUID(s.creator.UID)),
- CGID: uint32(creds.UserNamespace.MapFromKGID(s.creator.GID)),
- Mode: uint16(s.perms.LinuxMode()),
+ Key: uint32(s.obj.Key),
+ UID: uint32(creds.UserNamespace.MapFromKUID(s.obj.Owner.UID)),
+ GID: uint32(creds.UserNamespace.MapFromKGID(s.obj.Owner.GID)),
+ CUID: uint32(creds.UserNamespace.MapFromKUID(s.obj.Creator.UID)),
+ CGID: uint32(creds.UserNamespace.MapFromKGID(s.obj.Creator.GID)),
+ Mode: uint16(s.obj.Perms.LinuxMode()),
Seq: 0, // IPC sequence not supported.
},
SemOTime: s.opTime.TimeT(),
@@ -425,7 +399,7 @@ func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Cred
defer s.mu.Unlock()
// "The calling process must have alter permission on the semaphore set."
- if !s.checkPerms(creds, fs.PermMask{Write: true}) {
+ if !s.obj.CheckPermissions(creds, fs.PermMask{Write: true}) {
return linuxerr.EACCES
}
@@ -461,7 +435,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti
defer s.mu.Unlock()
// "The calling process must have alter permission on the semaphore set."
- if !s.checkPerms(creds, fs.PermMask{Write: true}) {
+ if !s.obj.CheckPermissions(creds, fs.PermMask{Write: true}) {
return linuxerr.EACCES
}
@@ -483,7 +457,7 @@ func (s *Set) GetVal(num int32, creds *auth.Credentials) (int16, error) {
defer s.mu.Unlock()
// "The calling process must have read permission on the semaphore set."
- if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ if !s.obj.CheckPermissions(creds, fs.PermMask{Read: true}) {
return 0, linuxerr.EACCES
}
@@ -500,7 +474,7 @@ func (s *Set) GetValAll(creds *auth.Credentials) ([]uint16, error) {
defer s.mu.Unlock()
// "The calling process must have read permission on the semaphore set."
- if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ if !s.obj.CheckPermissions(creds, fs.PermMask{Read: true}) {
return nil, linuxerr.EACCES
}
@@ -517,7 +491,7 @@ func (s *Set) GetPID(num int32, creds *auth.Credentials) (int32, error) {
defer s.mu.Unlock()
// "The calling process must have read permission on the semaphore set."
- if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ if !s.obj.CheckPermissions(creds, fs.PermMask{Read: true}) {
return 0, linuxerr.EACCES
}
@@ -533,7 +507,7 @@ func (s *Set) countWaiters(num int32, creds *auth.Credentials, pred func(w *wait
defer s.mu.Unlock()
// The calling process must have read permission on the semaphore set.
- if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ if !s.obj.CheckPermissions(creds, fs.PermMask{Read: true}) {
return 0, linuxerr.EACCES
}
@@ -589,7 +563,7 @@ func (s *Set) ExecuteOps(ctx context.Context, ops []linux.Sembuf, creds *auth.Cr
}
}
- if !s.checkPerms(creds, fs.PermMask{Read: readOnly, Write: !readOnly}) {
+ if !s.obj.CheckPermissions(creds, fs.PermMask{Read: readOnly, Write: !readOnly}) {
return nil, 0, linuxerr.EACCES
}
@@ -675,38 +649,10 @@ func (s *Set) AbortWait(num int32, ch chan struct{}) {
// Waiter may not be found in case it raced with wakeWaiters().
}
-func (s *Set) checkCredentials(creds *auth.Credentials) bool {
- return s.owner.UID == creds.EffectiveKUID ||
- s.owner.GID == creds.EffectiveKGID ||
- s.creator.UID == creds.EffectiveKUID ||
- s.creator.GID == creds.EffectiveKGID
-}
-
-func (s *Set) checkCapability(creds *auth.Credentials) bool {
- return creds.HasCapabilityIn(linux.CAP_IPC_OWNER, s.registry.userNS) && creds.UserNamespace.MapFromKUID(s.owner.UID).Ok()
-}
-
-func (s *Set) checkPerms(creds *auth.Credentials, reqPerms fs.PermMask) bool {
- // Are we owner, or in group, or other?
- p := s.perms.Other
- if s.owner.UID == creds.EffectiveKUID {
- p = s.perms.User
- } else if creds.InGroup(s.owner.GID) {
- p = s.perms.Group
- }
-
- // Are permissions satisfied without capability checks?
- if p.SupersetOf(reqPerms) {
- return true
- }
-
- return s.checkCapability(creds)
-}
-
-// destroy destroys the set.
+// Destroy implements ipc.Mechanism.Destroy.
//
// Preconditions: Caller must hold 's.mu'.
-func (s *Set) destroy() {
+func (s *Set) Destroy() {
// Notify all waiters. They will fail on the next attempt to execute
// operations and return error.
s.dead = true
diff --git a/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go b/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go
index f90fbff34..7ea96b30d 100644
--- a/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go
+++ b/pkg/sentry/kernel/semaphore/semaphore_state_autogen.go
@@ -12,9 +12,7 @@ func (r *Registry) StateTypeName() string {
func (r *Registry) StateFields() []string {
return []string{
- "userNS",
- "semaphores",
- "lastIDUsed",
+ "reg",
"indexes",
}
}
@@ -24,20 +22,16 @@ func (r *Registry) beforeSave() {}
// +checklocksignore
func (r *Registry) StateSave(stateSinkObject state.Sink) {
r.beforeSave()
- stateSinkObject.Save(0, &r.userNS)
- stateSinkObject.Save(1, &r.semaphores)
- stateSinkObject.Save(2, &r.lastIDUsed)
- stateSinkObject.Save(3, &r.indexes)
+ stateSinkObject.Save(0, &r.reg)
+ stateSinkObject.Save(1, &r.indexes)
}
func (r *Registry) afterLoad() {}
// +checklocksignore
func (r *Registry) StateLoad(stateSourceObject state.Source) {
- stateSourceObject.Load(0, &r.userNS)
- stateSourceObject.Load(1, &r.semaphores)
- stateSourceObject.Load(2, &r.lastIDUsed)
- stateSourceObject.Load(3, &r.indexes)
+ stateSourceObject.Load(0, &r.reg)
+ stateSourceObject.Load(1, &r.indexes)
}
func (s *Set) StateTypeName() string {
@@ -47,11 +41,7 @@ func (s *Set) StateTypeName() string {
func (s *Set) StateFields() []string {
return []string{
"registry",
- "ID",
- "key",
- "creator",
- "owner",
- "perms",
+ "obj",
"opTime",
"changeTime",
"sems",
@@ -65,15 +55,11 @@ func (s *Set) beforeSave() {}
func (s *Set) StateSave(stateSinkObject state.Sink) {
s.beforeSave()
stateSinkObject.Save(0, &s.registry)
- stateSinkObject.Save(1, &s.ID)
- stateSinkObject.Save(2, &s.key)
- stateSinkObject.Save(3, &s.creator)
- stateSinkObject.Save(4, &s.owner)
- stateSinkObject.Save(5, &s.perms)
- stateSinkObject.Save(6, &s.opTime)
- stateSinkObject.Save(7, &s.changeTime)
- stateSinkObject.Save(8, &s.sems)
- stateSinkObject.Save(9, &s.dead)
+ stateSinkObject.Save(1, &s.obj)
+ stateSinkObject.Save(2, &s.opTime)
+ stateSinkObject.Save(3, &s.changeTime)
+ stateSinkObject.Save(4, &s.sems)
+ stateSinkObject.Save(5, &s.dead)
}
func (s *Set) afterLoad() {}
@@ -81,15 +67,11 @@ func (s *Set) afterLoad() {}
// +checklocksignore
func (s *Set) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(0, &s.registry)
- stateSourceObject.Load(1, &s.ID)
- stateSourceObject.Load(2, &s.key)
- stateSourceObject.Load(3, &s.creator)
- stateSourceObject.Load(4, &s.owner)
- stateSourceObject.Load(5, &s.perms)
- stateSourceObject.Load(6, &s.opTime)
- stateSourceObject.Load(7, &s.changeTime)
- stateSourceObject.Load(8, &s.sems)
- stateSourceObject.Load(9, &s.dead)
+ stateSourceObject.Load(1, &s.obj)
+ stateSourceObject.Load(2, &s.opTime)
+ stateSourceObject.Load(3, &s.changeTime)
+ stateSourceObject.Load(4, &s.sems)
+ stateSourceObject.Load(5, &s.dead)
}
func (s *sem) StateTypeName() string {
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index f7ac4c2b2..2abf467d7 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -43,6 +43,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -51,12 +52,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-// Key represents a shm segment key. Analogous to a file name.
-type Key int32
-
-// ID represents the opaque handle for a shm segment. Analogous to an fd.
-type ID int32
-
// Registry tracks all shared memory segments in an IPC namespace. The registry
// provides the mechanisms for creating and finding segments, and reporting
// global shm parameters.
@@ -69,50 +64,51 @@ type Registry struct {
// mu protects all fields below.
mu sync.Mutex `state:"nosave"`
- // shms maps segment ids to segments.
+ // reg defines basic fields and operations needed for all SysV registries.
//
- // shms holds all referenced segments, which are removed on the last
+ // Withing reg, there are two maps, Objects and KeysToIDs.
+ //
+ // reg.objects holds all referenced segments, which are removed on the last
// DecRef. Thus, it cannot itself hold a reference on the Shm.
//
// Since removal only occurs after the last (unlocked) DecRef, there
// exists a short window during which a Shm still exists in Shm, but is
// unreferenced. Users must use TryIncRef to determine if the Shm is
// still valid.
- shms map[ID]*Shm
-
- // keysToShms maps segment keys to segments.
//
- // Shms in keysToShms are guaranteed to be referenced, as they are
+ // keysToIDs maps segment keys to IDs.
+ //
+ // Shms in keysToIDs are guaranteed to be referenced, as they are
// removed by disassociateKey before the last DecRef.
- keysToShms map[Key]*Shm
+ reg *ipc.Registry
// Sum of the sizes of all existing segments rounded up to page size, in
// units of page size.
totalPages uint64
-
- // ID assigned to the last created segment. Used to quickly find the next
- // unused ID.
- lastIDUsed ID
}
// NewRegistry creates a new shm registry.
func NewRegistry(userNS *auth.UserNamespace) *Registry {
return &Registry{
- userNS: userNS,
- shms: make(map[ID]*Shm),
- keysToShms: make(map[Key]*Shm),
+ userNS: userNS,
+ reg: ipc.NewRegistry(userNS),
}
}
// FindByID looks up a segment given an ID.
//
// FindByID returns a reference on Shm.
-func (r *Registry) FindByID(id ID) *Shm {
+func (r *Registry) FindByID(id ipc.ID) *Shm {
r.mu.Lock()
defer r.mu.Unlock()
- s := r.shms[id]
+ mech := r.reg.FindByID(id)
+ if mech == nil {
+ return nil
+ }
+ s := mech.(*Shm)
+
// Take a reference on s. If TryIncRef fails, s has reached the last
- // DecRef, but hasn't quite been removed from r.shms yet.
+ // DecRef, but hasn't quite been removed from r.reg.objects yet.
if s != nil && s.TryIncRef() {
return s
}
@@ -129,9 +125,9 @@ func (r *Registry) dissociateKey(s *Shm) {
defer r.mu.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
- if s.key != linux.IPC_PRIVATE {
- delete(r.keysToShms, s.key)
- s.key = linux.IPC_PRIVATE
+ if s.obj.Key != linux.IPC_PRIVATE {
+ r.reg.DissociateKey(s.obj.Key)
+ s.obj.Key = linux.IPC_PRIVATE
}
}
@@ -139,7 +135,7 @@ func (r *Registry) dissociateKey(s *Shm) {
// analogous to open(2).
//
// FindOrCreate returns a reference on Shm.
-func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) {
+func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key ipc.Key, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) {
if (create || private) && (size < linux.SHMMIN || size > linux.SHMMAX) {
// "A new segment was to be created and size is less than SHMMIN or
// greater than SHMMAX." - man shmget(2)
@@ -152,49 +148,29 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
r.mu.Lock()
defer r.mu.Unlock()
- if len(r.shms) >= linux.SHMMNI {
+ if r.reg.ObjectCount() >= linux.SHMMNI {
// "All possible shared memory IDs have been taken (SHMMNI) ..."
// - man shmget(2)
return nil, syserror.ENOSPC
}
if !private {
- // Look up an existing segment.
- if shm := r.keysToShms[key]; shm != nil {
- shm.mu.Lock()
- defer shm.mu.Unlock()
-
- // Check that caller can access the segment.
- if !shm.checkPermissions(ctx, fs.PermsFromMode(mode)) {
- // "The user does not have permission to access the shared
- // memory segment, and does not have the CAP_IPC_OWNER
- // capability in the user namespace that governs its IPC
- // namespace." - man shmget(2)
- return nil, linuxerr.EACCES
- }
+ shm, err := r.reg.Find(ctx, key, mode, create, exclusive)
+ if err != nil {
+ return nil, err
+ }
+ // Validate shm-specific parameters.
+ if shm != nil {
+ shm := shm.(*Shm)
if size > shm.size {
// "A segment for the given key exists, but size is greater than
// the size of that segment." - man shmget(2)
return nil, linuxerr.EINVAL
}
-
- if create && exclusive {
- // "IPC_CREAT and IPC_EXCL were specified in shmflg, but a
- // shared memory segment already exists for key."
- // - man shmget(2)
- return nil, linuxerr.EEXIST
- }
-
shm.IncRef()
return shm, nil
}
-
- if !create {
- // "No segment exists for the given key, and IPC_CREAT was not
- // specified." - man shmget(2)
- return nil, syserror.ENOENT
- }
}
var sizeAligned uint64
@@ -212,9 +188,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
}
// Need to create a new segment.
- creator := fs.FileOwnerFromContext(ctx)
- perms := fs.FilePermsFromMode(mode)
- s, err := r.newShm(ctx, pid, key, creator, perms, size)
+ s, err := r.newShmLocked(ctx, pid, key, fs.FileOwnerFromContext(ctx), fs.FilePermsFromMode(mode), size)
if err != nil {
return nil, err
}
@@ -224,10 +198,10 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
return s, nil
}
-// newShm creates a new segment in the registry.
+// newShmLocked creates a new segment in the registry.
//
// Precondition: Caller must hold r.mu.
-func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.FileOwner, perms fs.FilePermissions, size uint64) (*Shm, error) {
+func (r *Registry) newShmLocked(ctx context.Context, pid int32, key ipc.Key, creator fs.FileOwner, perms fs.FilePermissions, size uint64) (*Shm, error) {
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
if mfp == nil {
panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider))
@@ -242,40 +216,21 @@ func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.Fi
shm := &Shm{
mfp: mfp,
registry: r,
- creator: creator,
size: size,
effectiveSize: effectiveSize,
+ obj: ipc.NewObject(r.reg.UserNS, ipc.Key(key), creator, creator, perms),
fr: fr,
- key: key,
- perms: perms,
- owner: creator,
creatorPID: pid,
changeTime: ktime.NowFromContext(ctx),
}
shm.InitRefs()
- // Find the next available ID.
- for id := r.lastIDUsed + 1; id != r.lastIDUsed; id++ {
- // Handle wrap around.
- if id < 0 {
- id = 0
- continue
- }
- if r.shms[id] == nil {
- r.lastIDUsed = id
-
- shm.ID = id
- r.shms[id] = shm
- r.keysToShms[key] = shm
-
- r.totalPages += effectiveSize / hostarch.PageSize
-
- return shm, nil
- }
+ if err := r.reg.Register(shm); err != nil {
+ return nil, err
}
+ r.totalPages += effectiveSize / hostarch.PageSize
- log.Warningf("Shm ids exhuasted, they may be leaking")
- return nil, syserror.ENOSPC
+ return shm, nil
}
// IPCInfo reports global parameters for sysv shared memory segments on this
@@ -297,7 +252,7 @@ func (r *Registry) ShmInfo() *linux.ShmInfo {
defer r.mu.Unlock()
return &linux.ShmInfo{
- UsedIDs: int32(r.lastIDUsed),
+ UsedIDs: int32(r.reg.LastIDUsed()),
ShmTot: r.totalPages,
ShmRss: r.totalPages, // We could probably get a better estimate from memory accounting.
ShmSwp: 0, // No reclaim at the moment.
@@ -314,11 +269,11 @@ func (r *Registry) remove(s *Shm) {
s.mu.Lock()
defer s.mu.Unlock()
- if s.key != linux.IPC_PRIVATE {
+ if s.obj.Key != linux.IPC_PRIVATE {
panic(fmt.Sprintf("Attempted to remove %s from the registry whose key is still associated", s.debugLocked()))
}
- delete(r.shms, s.ID)
+ r.reg.DissociateID(s.obj.ID)
r.totalPages -= s.effectiveSize / hostarch.PageSize
}
@@ -330,13 +285,16 @@ func (r *Registry) Release(ctx context.Context) {
// the IPC namespace containing it has no more references.
toRelease := make([]*Shm, 0)
r.mu.Lock()
- for _, s := range r.keysToShms {
- s.mu.Lock()
- if !s.pendingDestruction {
- toRelease = append(toRelease, s)
- }
- s.mu.Unlock()
- }
+ r.reg.ForAllObjects(
+ func(o ipc.Mechanism) {
+ s := o.(*Shm)
+ s.mu.Lock()
+ if !s.pendingDestruction {
+ toRelease = append(toRelease, s)
+ }
+ s.mu.Unlock()
+ },
+ )
r.mu.Unlock()
for _, s := range toRelease {
@@ -374,12 +332,6 @@ type Shm struct {
// registry points to the shm registry containing this segment. Immutable.
registry *Registry
- // ID is the kernel identifier for this segment. Immutable.
- ID ID
-
- // creator is the user that created the segment. Immutable.
- creator fs.FileOwner
-
// size is the requested size of the segment at creation, in
// bytes. Immutable.
size uint64
@@ -397,14 +349,8 @@ type Shm struct {
// mu protects all fields below.
mu sync.Mutex `state:"nosave"`
- // key is the public identifier for this segment.
- key Key
-
- // perms is the access permissions for the segment.
- perms fs.FilePermissions
+ obj *ipc.Object
- // owner of this segment.
- owner fs.FileOwner
// attachTime is updated on every successful shmat.
attachTime ktime.Time
// detachTime is updated on every successful shmdt.
@@ -426,17 +372,44 @@ type Shm struct {
pendingDestruction bool
}
+// ID returns object's ID.
+func (s *Shm) ID() ipc.ID {
+ return s.obj.ID
+}
+
+// Object implements ipc.Mechanism.Object.
+func (s *Shm) Object() *ipc.Object {
+ return s.obj
+}
+
+// Destroy implements ipc.Mechanism.Destroy. No work is performed on shm.Destroy
+// because a different removal mechanism is used in shm. See Shm.MarkDestroyed.
+func (s *Shm) Destroy() {
+}
+
+// Lock implements ipc.Mechanism.Lock.
+func (s *Shm) Lock() {
+ s.mu.Lock()
+}
+
+// Unlock implements ipc.mechanism.Unlock.
+//
+// +checklocksignore
+func (s *Shm) Unlock() {
+ s.mu.Unlock()
+}
+
// Precondition: Caller must hold s.mu.
func (s *Shm) debugLocked() string {
return fmt.Sprintf("Shm{id: %d, key: %d, size: %d bytes, refs: %d, destroyed: %v}",
- s.ID, s.key, s.size, s.ReadRefs(), s.pendingDestruction)
+ s.obj.ID, s.obj.Key, s.size, s.ReadRefs(), s.pendingDestruction)
}
// MappedName implements memmap.MappingIdentity.MappedName.
func (s *Shm) MappedName(ctx context.Context) string {
s.mu.Lock()
defer s.mu.Unlock()
- return fmt.Sprintf("SYSV%08d", s.key)
+ return fmt.Sprintf("SYSV%08d", s.obj.Key)
}
// DeviceID implements memmap.MappingIdentity.DeviceID.
@@ -448,7 +421,7 @@ func (s *Shm) DeviceID() uint64 {
func (s *Shm) InodeID() uint64 {
// "shmid gets reported as "inode#" in /proc/pid/maps. proc-ps tools use
// this. Changing this will break them." -- Linux, ipc/shm.c:newseg()
- return uint64(s.ID)
+ return uint64(s.obj.ID)
}
// DecRef drops a reference on s.
@@ -551,7 +524,8 @@ func (s *Shm) ConfigureAttach(ctx context.Context, addr hostarch.Addr, opts Atta
return memmap.MMapOpts{}, syserror.EIDRM
}
- if !s.checkPermissions(ctx, fs.PermMask{
+ creds := auth.CredentialsFromContext(ctx)
+ if !s.obj.CheckPermissions(creds, fs.PermMask{
Read: true,
Write: !opts.Readonly,
Execute: opts.Execute,
@@ -591,7 +565,8 @@ func (s *Shm) IPCStat(ctx context.Context) (*linux.ShmidDS, error) {
// "The caller must have read permission on the shared memory segment."
// - man shmctl(2)
- if !s.checkPermissions(ctx, fs.PermMask{Read: true}) {
+ creds := auth.CredentialsFromContext(ctx)
+ if !s.obj.CheckPermissions(creds, fs.PermMask{Read: true}) {
// "IPC_STAT or SHM_STAT is requested and shm_perm.mode does not allow
// read access for shmid, and the calling process does not have the
// CAP_IPC_OWNER capability in the user namespace that governs its IPC
@@ -603,7 +578,6 @@ func (s *Shm) IPCStat(ctx context.Context) (*linux.ShmidDS, error) {
if s.pendingDestruction {
mode |= linux.SHM_DEST
}
- creds := auth.CredentialsFromContext(ctx)
// Use the reference count as a rudimentary count of the number of
// attaches. We exclude:
@@ -620,12 +594,12 @@ func (s *Shm) IPCStat(ctx context.Context) (*linux.ShmidDS, error) {
ds := &linux.ShmidDS{
ShmPerm: linux.IPCPerm{
- Key: uint32(s.key),
- UID: uint32(creds.UserNamespace.MapFromKUID(s.owner.UID)),
- GID: uint32(creds.UserNamespace.MapFromKGID(s.owner.GID)),
- CUID: uint32(creds.UserNamespace.MapFromKUID(s.creator.UID)),
- CGID: uint32(creds.UserNamespace.MapFromKGID(s.creator.GID)),
- Mode: mode | uint16(s.perms.LinuxMode()),
+ Key: uint32(s.obj.Key),
+ UID: uint32(creds.UserNamespace.MapFromKUID(s.obj.Owner.UID)),
+ GID: uint32(creds.UserNamespace.MapFromKGID(s.obj.Owner.GID)),
+ CUID: uint32(creds.UserNamespace.MapFromKUID(s.obj.Creator.UID)),
+ CGID: uint32(creds.UserNamespace.MapFromKGID(s.obj.Creator.GID)),
+ Mode: mode | uint16(s.obj.Perms.LinuxMode()),
Seq: 0, // IPC sequences not supported.
},
ShmSegsz: s.size,
@@ -645,11 +619,11 @@ func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error {
s.mu.Lock()
defer s.mu.Unlock()
- if !s.checkOwnership(ctx) {
+ creds := auth.CredentialsFromContext(ctx)
+ if !s.obj.CheckOwnership(creds) {
return linuxerr.EPERM
}
- creds := auth.CredentialsFromContext(ctx)
uid := creds.UserNamespace.MapToKUID(auth.UID(ds.ShmPerm.UID))
gid := creds.UserNamespace.MapToKGID(auth.GID(ds.ShmPerm.GID))
if !uid.Ok() || !gid.Ok() {
@@ -659,10 +633,10 @@ func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error {
// User may only modify the lower 9 bits of the mode. All the other bits are
// always 0 for the underlying inode.
mode := linux.FileMode(ds.ShmPerm.Mode & 0x1ff)
- s.perms = fs.FilePermsFromMode(mode)
+ s.obj.Perms = fs.FilePermsFromMode(mode)
- s.owner.UID = uid
- s.owner.GID = gid
+ s.obj.Owner.UID = uid
+ s.obj.Owner.GID = gid
s.changeTime = ktime.NowFromContext(ctx)
return nil
@@ -691,40 +665,3 @@ func (s *Shm) MarkDestroyed(ctx context.Context) {
s.DecRef(ctx)
return
}
-
-// checkOwnership verifies whether a segment may be accessed by ctx as an
-// owner. See ipc/util.c:ipcctl_pre_down_nolock() in Linux.
-//
-// Precondition: Caller must hold s.mu.
-func (s *Shm) checkOwnership(ctx context.Context) bool {
- creds := auth.CredentialsFromContext(ctx)
- if s.owner.UID == creds.EffectiveKUID || s.creator.UID == creds.EffectiveKUID {
- return true
- }
-
- // Tasks with CAP_SYS_ADMIN may bypass ownership checks. Strangely, Linux
- // doesn't use CAP_IPC_OWNER for this despite CAP_IPC_OWNER being documented
- // for use to "override IPC ownership checks".
- return creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, s.registry.userNS)
-}
-
-// checkPermissions verifies whether a segment is accessible by ctx for access
-// described by req. See ipc/util.c:ipcperms() in Linux.
-//
-// Precondition: Caller must hold s.mu.
-func (s *Shm) checkPermissions(ctx context.Context, req fs.PermMask) bool {
- creds := auth.CredentialsFromContext(ctx)
-
- p := s.perms.Other
- if s.owner.UID == creds.EffectiveKUID {
- p = s.perms.User
- } else if creds.InGroup(s.owner.GID) {
- p = s.perms.Group
- }
- if p.SupersetOf(req) {
- return true
- }
-
- // Tasks with CAP_IPC_OWNER may bypass permission checks.
- return creds.HasCapabilityIn(linux.CAP_IPC_OWNER, s.registry.userNS)
-}
diff --git a/pkg/sentry/kernel/shm/shm_state_autogen.go b/pkg/sentry/kernel/shm/shm_state_autogen.go
index 30ceeaa03..9dde9122d 100644
--- a/pkg/sentry/kernel/shm/shm_state_autogen.go
+++ b/pkg/sentry/kernel/shm/shm_state_autogen.go
@@ -13,10 +13,8 @@ func (r *Registry) StateTypeName() string {
func (r *Registry) StateFields() []string {
return []string{
"userNS",
- "shms",
- "keysToShms",
+ "reg",
"totalPages",
- "lastIDUsed",
}
}
@@ -26,10 +24,8 @@ func (r *Registry) beforeSave() {}
func (r *Registry) StateSave(stateSinkObject state.Sink) {
r.beforeSave()
stateSinkObject.Save(0, &r.userNS)
- stateSinkObject.Save(1, &r.shms)
- stateSinkObject.Save(2, &r.keysToShms)
- stateSinkObject.Save(3, &r.totalPages)
- stateSinkObject.Save(4, &r.lastIDUsed)
+ stateSinkObject.Save(1, &r.reg)
+ stateSinkObject.Save(2, &r.totalPages)
}
func (r *Registry) afterLoad() {}
@@ -37,10 +33,8 @@ func (r *Registry) afterLoad() {}
// +checklocksignore
func (r *Registry) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(0, &r.userNS)
- stateSourceObject.Load(1, &r.shms)
- stateSourceObject.Load(2, &r.keysToShms)
- stateSourceObject.Load(3, &r.totalPages)
- stateSourceObject.Load(4, &r.lastIDUsed)
+ stateSourceObject.Load(1, &r.reg)
+ stateSourceObject.Load(2, &r.totalPages)
}
func (s *Shm) StateTypeName() string {
@@ -52,14 +46,10 @@ func (s *Shm) StateFields() []string {
"ShmRefs",
"mfp",
"registry",
- "ID",
- "creator",
"size",
"effectiveSize",
"fr",
- "key",
- "perms",
- "owner",
+ "obj",
"attachTime",
"detachTime",
"changeTime",
@@ -77,20 +67,16 @@ func (s *Shm) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(0, &s.ShmRefs)
stateSinkObject.Save(1, &s.mfp)
stateSinkObject.Save(2, &s.registry)
- stateSinkObject.Save(3, &s.ID)
- stateSinkObject.Save(4, &s.creator)
- stateSinkObject.Save(5, &s.size)
- stateSinkObject.Save(6, &s.effectiveSize)
- stateSinkObject.Save(7, &s.fr)
- stateSinkObject.Save(8, &s.key)
- stateSinkObject.Save(9, &s.perms)
- stateSinkObject.Save(10, &s.owner)
- stateSinkObject.Save(11, &s.attachTime)
- stateSinkObject.Save(12, &s.detachTime)
- stateSinkObject.Save(13, &s.changeTime)
- stateSinkObject.Save(14, &s.creatorPID)
- stateSinkObject.Save(15, &s.lastAttachDetachPID)
- stateSinkObject.Save(16, &s.pendingDestruction)
+ stateSinkObject.Save(3, &s.size)
+ stateSinkObject.Save(4, &s.effectiveSize)
+ stateSinkObject.Save(5, &s.fr)
+ stateSinkObject.Save(6, &s.obj)
+ stateSinkObject.Save(7, &s.attachTime)
+ stateSinkObject.Save(8, &s.detachTime)
+ stateSinkObject.Save(9, &s.changeTime)
+ stateSinkObject.Save(10, &s.creatorPID)
+ stateSinkObject.Save(11, &s.lastAttachDetachPID)
+ stateSinkObject.Save(12, &s.pendingDestruction)
}
func (s *Shm) afterLoad() {}
@@ -100,20 +86,16 @@ func (s *Shm) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(0, &s.ShmRefs)
stateSourceObject.Load(1, &s.mfp)
stateSourceObject.Load(2, &s.registry)
- stateSourceObject.Load(3, &s.ID)
- stateSourceObject.Load(4, &s.creator)
- stateSourceObject.Load(5, &s.size)
- stateSourceObject.Load(6, &s.effectiveSize)
- stateSourceObject.Load(7, &s.fr)
- stateSourceObject.Load(8, &s.key)
- stateSourceObject.Load(9, &s.perms)
- stateSourceObject.Load(10, &s.owner)
- stateSourceObject.Load(11, &s.attachTime)
- stateSourceObject.Load(12, &s.detachTime)
- stateSourceObject.Load(13, &s.changeTime)
- stateSourceObject.Load(14, &s.creatorPID)
- stateSourceObject.Load(15, &s.lastAttachDetachPID)
- stateSourceObject.Load(16, &s.pendingDestruction)
+ stateSourceObject.Load(3, &s.size)
+ stateSourceObject.Load(4, &s.effectiveSize)
+ stateSourceObject.Load(5, &s.fr)
+ stateSourceObject.Load(6, &s.obj)
+ stateSourceObject.Load(7, &s.attachTime)
+ stateSourceObject.Load(8, &s.detachTime)
+ stateSourceObject.Load(9, &s.changeTime)
+ stateSourceObject.Load(10, &s.creatorPID)
+ stateSourceObject.Load(11, &s.lastAttachDetachPID)
+ stateSourceObject.Load(12, &s.pendingDestruction)
}
func (r *ShmRefs) StateTypeName() string {
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index f1cb5a2c8..6f44d767b 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -121,10 +121,10 @@ var AMD64 = &kernel.SyscallTable{
65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
66: syscalls.Supported("semctl", Semctl),
67: syscalls.Supported("shmdt", Shmdt),
- 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 70: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 71: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 68: syscalls.Supported("msgget", Msgget),
+ 69: syscalls.ErrorWithEvent("msgsnd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 70: syscalls.ErrorWithEvent("msgrcv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 71: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}),
72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil),
73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil),
@@ -616,10 +616,10 @@ var ARM64 = &kernel.SyscallTable{
183: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
184: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 186: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 187: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 186: syscalls.Supported("msgget", Msgget),
+ 187: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}),
+ 188: syscalls.ErrorWithEvent("msgrcv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 189: syscalls.ErrorWithEvent("msgsnd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
190: syscalls.Supported("semget", Semget),
191: syscalls.Supported("semctl", Semctl),
192: syscalls.Supported("semtimedop", Semtimedop),
diff --git a/pkg/sentry/syscalls/linux/sys_msgqueue.go b/pkg/sentry/syscalls/linux/sys_msgqueue.go
new file mode 100644
index 000000000..3476e218d
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_msgqueue.go
@@ -0,0 +1,57 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
+)
+
+// Msgget implements msgget(2).
+func Msgget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ key := ipc.Key(args[0].Int())
+ flag := args[1].Int()
+
+ private := key == linux.IPC_PRIVATE
+ create := flag&linux.IPC_CREAT == linux.IPC_CREAT
+ exclusive := flag&linux.IPC_EXCL == linux.IPC_EXCL
+ mode := linux.FileMode(flag & 0777)
+
+ r := t.IPCNamespace().MsgqueueRegistry()
+ queue, err := r.FindOrCreate(t, key, mode, private, create, exclusive)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(queue.ID()), nil, nil
+}
+
+// Msgctl implements msgctl(2).
+func Msgctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := ipc.ID(args[0].Int())
+ cmd := args[1].Int()
+
+ creds := auth.CredentialsFromContext(t)
+
+ switch cmd {
+ case linux.IPC_RMID:
+ return 0, nil, t.IPCNamespace().MsgqueueRegistry().Remove(id, creds)
+ default:
+ return 0, nil, linuxerr.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index 30919eb2f..f61cc466c 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -26,13 +26,14 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
)
const opsMax = 500 // SEMOPM
// Semget handles: semget(key_t key, int nsems, int semflg)
func Semget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- key := args[0].Int()
+ key := ipc.Key(args[0].Int())
nsems := args[1].Int()
flag := args[2].Int()
@@ -46,7 +47,7 @@ func Semget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
- return uintptr(set.ID), nil, nil
+ return uintptr(set.ID()), nil, nil
}
// Semtimedop handles: semop(int semid, struct sembuf *sops, size_t nsops, const struct timespec *timeout)
@@ -56,7 +57,7 @@ func Semtimedop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return Semop(t, args)
}
- id := args[0].Int()
+ id := ipc.ID(args[0].Int())
sembufAddr := args[1].Pointer()
nsops := args[2].SizeT()
timespecAddr := args[3].Pointer()
@@ -91,7 +92,7 @@ func Semtimedop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// Semop handles: semop(int semid, struct sembuf *sops, size_t nsops)
func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- id := args[0].Int()
+ id := ipc.ID(args[0].Int())
sembufAddr := args[1].Pointer()
nsops := args[2].SizeT()
@@ -109,7 +110,7 @@ func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, semTimedOp(t, id, ops, false, time.Second)
}
-func semTimedOp(t *kernel.Task, id int32, ops []linux.Sembuf, haveTimeout bool, timeout time.Duration) error {
+func semTimedOp(t *kernel.Task, id ipc.ID, ops []linux.Sembuf, haveTimeout bool, timeout time.Duration) error {
set := t.IPCNamespace().SemaphoreRegistry().FindByID(id)
if set == nil {
@@ -131,7 +132,7 @@ func semTimedOp(t *kernel.Task, id int32, ops []linux.Sembuf, haveTimeout bool,
// Semctl handles: semctl(int semid, int semnum, int cmd, ...)
func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- id := args[0].Int()
+ id := ipc.ID(args[0].Int())
num := args[1].Int()
cmd := args[2].Int()
@@ -210,7 +211,7 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
case linux.SEM_STAT:
arg := args[3].Pointer()
// id is an index in SEM_STAT.
- semid, ds, err := semStat(t, id)
+ semid, ds, err := semStat(t, int32(id))
if err != nil {
return 0, nil, err
}
@@ -222,7 +223,7 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
case linux.SEM_STAT_ANY:
arg := args[3].Pointer()
// id is an index in SEM_STAT.
- semid, ds, err := semStatAny(t, id)
+ semid, ds, err := semStatAny(t, int32(id))
if err != nil {
return 0, nil, err
}
@@ -236,13 +237,13 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
}
-func remove(t *kernel.Task, id int32) error {
+func remove(t *kernel.Task, id ipc.ID) error {
r := t.IPCNamespace().SemaphoreRegistry()
creds := auth.CredentialsFromContext(t)
- return r.RemoveID(id, creds)
+ return r.Remove(id, creds)
}
-func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FilePermissions) error {
+func ipcSet(t *kernel.Task, id ipc.ID, uid auth.UID, gid auth.GID, perms fs.FilePermissions) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
@@ -262,7 +263,7 @@ func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FileP
return set.Change(t, creds, owner, perms)
}
-func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) {
+func ipcStat(t *kernel.Task, id ipc.ID) (*linux.SemidDS, error) {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
@@ -283,7 +284,7 @@ func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
if err != nil {
return 0, ds, err
}
- return set.ID, ds, nil
+ return int32(set.ID()), ds, nil
}
func semStatAny(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
@@ -296,10 +297,10 @@ func semStatAny(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
if err != nil {
return 0, ds, err
}
- return set.ID, ds, nil
+ return int32(set.ID()), ds, nil
}
-func setVal(t *kernel.Task, id int32, num int32, val int16) error {
+func setVal(t *kernel.Task, id ipc.ID, num int32, val int16) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
@@ -310,7 +311,7 @@ func setVal(t *kernel.Task, id int32, num int32, val int16) error {
return set.SetVal(t, num, val, creds, int32(pid))
}
-func setValAll(t *kernel.Task, id int32, array hostarch.Addr) error {
+func setValAll(t *kernel.Task, id ipc.ID, array hostarch.Addr) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
@@ -325,7 +326,7 @@ func setValAll(t *kernel.Task, id int32, array hostarch.Addr) error {
return set.SetValAll(t, vals, creds, int32(pid))
}
-func getVal(t *kernel.Task, id int32, num int32) (int16, error) {
+func getVal(t *kernel.Task, id ipc.ID, num int32) (int16, error) {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
@@ -335,7 +336,7 @@ func getVal(t *kernel.Task, id int32, num int32) (int16, error) {
return set.GetVal(num, creds)
}
-func getValAll(t *kernel.Task, id int32, array hostarch.Addr) error {
+func getValAll(t *kernel.Task, id ipc.ID, array hostarch.Addr) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
@@ -350,7 +351,7 @@ func getValAll(t *kernel.Task, id int32, array hostarch.Addr) error {
return err
}
-func getPID(t *kernel.Task, id int32, num int32) (int32, error) {
+func getPID(t *kernel.Task, id ipc.ID, num int32) (int32, error) {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
@@ -369,7 +370,7 @@ func getPID(t *kernel.Task, id int32, num int32) (int32, error) {
return int32(tg.ID()), nil
}
-func getZCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
+func getZCnt(t *kernel.Task, id ipc.ID, num int32) (uint16, error) {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
@@ -379,7 +380,7 @@ func getZCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
return set.CountZeroWaiters(num, creds)
}
-func getNCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
+func getNCnt(t *kernel.Task, id ipc.ID, num int32) (uint16, error) {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go
index 3e3a952ce..840540506 100644
--- a/pkg/sentry/syscalls/linux/sys_shm.go
+++ b/pkg/sentry/syscalls/linux/sys_shm.go
@@ -19,12 +19,13 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
"gvisor.dev/gvisor/pkg/sentry/kernel/shm"
)
// Shmget implements shmget(2).
func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- key := shm.Key(args[0].Int())
+ key := ipc.Key(args[0].Int())
size := uint64(args[1].SizeT())
flag := args[2].Int()
@@ -40,13 +41,13 @@ func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, err
}
defer segment.DecRef(t)
- return uintptr(segment.ID), nil, nil
+ return uintptr(segment.ID()), nil, nil
}
// findSegment retrives a shm segment by the given id.
//
// findSegment returns a reference on Shm.
-func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) {
+func findSegment(t *kernel.Task, id ipc.ID) (*shm.Shm, error) {
r := t.IPCNamespace().ShmRegistry()
segment := r.FindByID(id)
if segment == nil {
@@ -58,7 +59,7 @@ func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) {
// Shmat implements shmat(2).
func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- id := shm.ID(args[0].Int())
+ id := ipc.ID(args[0].Int())
addr := args[1].Pointer()
flag := args[2].Int()
@@ -89,7 +90,7 @@ func Shmdt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Shmctl implements shmctl(2).
func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- id := shm.ID(args[0].Int())
+ id := ipc.ID(args[0].Int())
cmd := args[1].Int()
buf := args[2].Pointer()