diff options
Diffstat (limited to 'pkg/sentry/kernel')
-rw-r--r-- | pkg/sentry/kernel/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/kernel/fd_table.go | 196 | ||||
-rw-r--r-- | pkg/sentry/kernel/fd_table_unsafe.go | 11 | ||||
-rw-r--r-- | pkg/sentry/kernel/msgqueue/msgqueue.go | 278 |
4 files changed, 372 insertions, 114 deletions
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index c613f4932..e4e0dc04f 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -220,6 +220,7 @@ go_library( "//pkg/abi/linux", "//pkg/abi/linux/errno", "//pkg/amutex", + "//pkg/bitmap", "//pkg/bits", "//pkg/bpf", "//pkg/cleanup", diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 8786a70b5..eff556a0c 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -18,10 +18,10 @@ import ( "fmt" "math" "strings" - "sync/atomic" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/bitmap" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -84,13 +84,8 @@ type FDTable struct { // mu protects below. mu sync.Mutex `state:"nosave"` - // next is start position to find fd. - next int32 - - // used contains the number of non-nil entries. It must be accessed - // atomically. It may be read atomically without holding mu (but not - // written). - used int32 + // fdBitmap shows which fds are already in use. + fdBitmap bitmap.Bitmap `state:"nosave"` // descriptorTable holds descriptors. descriptorTable `state:".(map[int32]descriptor)"` @@ -98,6 +93,8 @@ type FDTable struct { func (f *FDTable) saveDescriptorTable() map[int32]descriptor { m := make(map[int32]descriptor) + f.mu.Lock() + defer f.mu.Unlock() f.forEach(context.Background(), func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { m[fd] = descriptor{ file: file, @@ -111,12 +108,16 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor { func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { ctx := context.Background() f.initNoLeakCheck() // Initialize table. - f.used = 0 + f.fdBitmap = bitmap.New(uint32(math.MaxUint16)) for fd, d := range m { + if fd < 0 { + panic(fmt.Sprintf("FD is not supposed to be negative. FD: %d", fd)) + } + if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil { panic("VFS1 or VFS2 files set") } - + f.fdBitmap.Add(uint32(fd)) // Note that we do _not_ need to acquire a extra table reference here. The // table reference will already be accounted for in the file, so we drop the // reference taken by set above. @@ -189,8 +190,10 @@ func (f *FDTable) DecRef(ctx context.Context) { func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) { // retries tracks the number of failed TryIncRef attempts for the same FD. retries := 0 - fd := int32(0) - for { + fds := f.fdBitmap.ToSlice() + // Iterate through the fdBitmap. + for _, ufd := range fds { + fd := int32(ufd) file, fileVFS2, flags, ok := f.getAll(fd) if !ok { break @@ -218,7 +221,6 @@ func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, fileVFS2.DecRef(ctx) } retries = 0 - fd++ } } @@ -226,6 +228,8 @@ func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, func (f *FDTable) String() string { var buf strings.Builder ctx := context.Background() + f.mu.Lock() + defer f.mu.Unlock() f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { switch { case file != nil: @@ -250,10 +254,10 @@ func (f *FDTable) String() string { } // NewFDs allocates new FDs guaranteed to be the lowest number available -// greater than or equal to the fd parameter. All files will share the set +// greater than or equal to the minFD parameter. All files will share the set // flags. Success is guaranteed to be all or none. -func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags FDFlags) (fds []int32, err error) { - if fd < 0 { +func (f *FDTable) NewFDs(ctx context.Context, minFD int32, files []*fs.File, flags FDFlags) (fds []int32, err error) { + if minFD < 0 { // Don't accept negative FDs. return nil, unix.EINVAL } @@ -267,31 +271,48 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags if lim.Cur != limits.Infinity { end = int32(lim.Cur) } - if fd >= end { + if minFD+int32(len(files)) > end { return nil, unix.EMFILE } } f.mu.Lock() - // From f.next to find available fd. - if fd < f.next { - fd = f.next + // max is used as the largest number in fdBitmap + 1. + max := int32(0) + + if !f.fdBitmap.IsEmpty() { + max = int32(f.fdBitmap.Maximum()) + max++ } + // Adjust max in case it is less than minFD. + if max < minFD { + max = minFD + } // Install all entries. - for i := fd; i < end && len(fds) < len(files); i++ { - if d, _, _ := f.get(i); d == nil { - // Set the descriptor. - f.set(ctx, i, files[len(fds)], flags) - fds = append(fds, i) // Record the file descriptor. + for len(fds) < len(files) { + // Try to use free bit in fdBitmap. + // If all bits in fdBitmap are used, expand fd to the max. + fd := f.fdBitmap.FirstZero(uint32(minFD)) + if fd == math.MaxInt32 { + fd = uint32(max) + max++ + } + if fd >= uint32(end) { + break } + f.fdBitmap.Add(fd) + f.set(ctx, int32(fd), files[len(fds)], flags) + fds = append(fds, int32(fd)) + minFD = int32(fd) } // Failure? Unwind existing FDs. if len(fds) < len(files) { for _, i := range fds { f.set(ctx, i, nil, FDFlags{}) + f.fdBitmap.Remove(uint32(i)) } f.mu.Unlock() @@ -305,20 +326,15 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags return nil, unix.EMFILE } - if fd == f.next { - // Update next search start position. - f.next = fds[len(fds)-1] + 1 - } - f.mu.Unlock() return fds, nil } // NewFDsVFS2 allocates new FDs guaranteed to be the lowest number available -// greater than or equal to the fd parameter. All files will share the set +// greater than or equal to the minFD parameter. All files will share the set // flags. Success is guaranteed to be all or none. -func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) { - if fd < 0 { +func (f *FDTable) NewFDsVFS2(ctx context.Context, minFD int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) { + if minFD < 0 { // Don't accept negative FDs. return nil, unix.EINVAL } @@ -332,31 +348,47 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes if lim.Cur != limits.Infinity { end = int32(lim.Cur) } - if fd >= end { + if minFD >= end { return nil, unix.EMFILE } } f.mu.Lock() - // From f.next to find available fd. - if fd < f.next { - fd = f.next + // max is used as the largest number in fdBitmap + 1. + max := int32(0) + + if !f.fdBitmap.IsEmpty() { + max = int32(f.fdBitmap.Maximum()) + max++ } - // Install all entries. - for i := fd; i < end && len(fds) < len(files); i++ { - if d, _, _ := f.getVFS2(i); d == nil { - // Set the descriptor. - f.setVFS2(ctx, i, files[len(fds)], flags) - fds = append(fds, i) // Record the file descriptor. - } + // Adjust max in case it is less than minFD. + if max < minFD { + max = minFD } + for len(fds) < len(files) { + // Try to use free bit in fdBitmap. + // If all bits in fdBitmap are used, expand fd to the max. + fd := f.fdBitmap.FirstZero(uint32(minFD)) + if fd == math.MaxInt32 { + fd = uint32(max) + max++ + } + if fd >= uint32(end) { + break + } + f.fdBitmap.Add(fd) + f.setVFS2(ctx, int32(fd), files[len(fds)], flags) + fds = append(fds, int32(fd)) + minFD = int32(fd) + } // Failure? Unwind existing FDs. if len(fds) < len(files) { for _, i := range fds { f.setVFS2(ctx, i, nil, FDFlags{}) + f.fdBitmap.Remove(uint32(i)) } f.mu.Unlock() @@ -370,57 +402,19 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes return nil, unix.EMFILE } - if fd == f.next { - // Update next search start position. - f.next = fds[len(fds)-1] + 1 - } - f.mu.Unlock() return fds, nil } -// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for +// NewFDVFS2 allocates a file descriptor greater than or equal to minFD for // the given file description. If it succeeds, it takes a reference on file. -func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) { - if minfd < 0 { - // Don't accept negative FDs. - return -1, unix.EINVAL - } - - // Default limit. - end := int32(math.MaxInt32) - - // Ensure we don't get past the provided limit. - if limitSet := limits.FromContext(ctx); limitSet != nil { - lim := limitSet.Get(limits.NumberOfFiles) - if lim.Cur != limits.Infinity { - end = int32(lim.Cur) - } - if minfd >= end { - return -1, unix.EMFILE - } - } - - f.mu.Lock() - defer f.mu.Unlock() - - // From f.next to find available fd. - fd := minfd - if fd < f.next { - fd = f.next - } - for fd < end { - if d, _, _ := f.getVFS2(fd); d == nil { - f.setVFS2(ctx, fd, file, flags) - if fd == f.next { - // Update next search start position. - f.next = fd + 1 - } - return fd, nil - } - fd++ +func (f *FDTable) NewFDVFS2(ctx context.Context, minFD int32, file *vfs.FileDescription, flags FDFlags) (int32, error) { + files := []*vfs.FileDescription{file} + fileSlice, error := f.NewFDsVFS2(ctx, minFD, files, flags) + if error != nil { + return -1, error } - return -1, unix.EMFILE + return fileSlice[0], nil } // NewFDAt sets the file reference for the given FD. If there is an active @@ -469,6 +463,11 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 defer f.mu.Unlock() df, dfVFS2 := f.setAll(ctx, fd, file, fileVFS2, flags) + // Add fd to fdBitmap. + if file != nil || fileVFS2 != nil { + f.fdBitmap.Add(uint32(fd)) + } + return df, dfVFS2, nil } @@ -573,7 +572,9 @@ func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) { // Precondition: The caller must be running on the task goroutine, or Task.mu // must be locked. func (f *FDTable) GetFDs(ctx context.Context) []int32 { - fds := make([]int32, 0, int(atomic.LoadInt32(&f.used))) + f.mu.Lock() + defer f.mu.Unlock() + fds := make([]int32, 0, int(f.fdBitmap.GetNumOnes())) f.forEach(ctx, func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) { fds = append(fds, fd) }) @@ -583,13 +584,15 @@ func (f *FDTable) GetFDs(ctx context.Context) []int32 { // Fork returns an independent FDTable. func (f *FDTable) Fork(ctx context.Context) *FDTable { clone := f.k.NewFDTable() - + f.mu.Lock() + defer f.mu.Unlock() f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { // The set function here will acquire an appropriate table // reference for the clone. We don't need anything else. if df, dfVFS2 := clone.setAll(ctx, fd, file, fileVFS2, flags); df != nil || dfVFS2 != nil { panic("VFS1 or VFS2 files set") } + clone.fdBitmap.Add(uint32(fd)) }) return clone } @@ -604,11 +607,6 @@ func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDesc f.mu.Lock() - // Update current available position. - if fd < f.next { - f.next = fd - } - orig, orig2, _, _ := f.getAll(fd) // Add reference for caller. @@ -621,6 +619,7 @@ func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDesc if orig != nil || orig2 != nil { orig, orig2 = f.setAll(ctx, fd, nil, nil, FDFlags{}) // Zap entry. + f.fdBitmap.Remove(uint32(fd)) } f.mu.Unlock() @@ -644,16 +643,13 @@ func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDes f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { if cond(file, fileVFS2, flags) { df, dfVFS2 := f.setAll(ctx, fd, nil, nil, FDFlags{}) // Clear from table. + f.fdBitmap.Remove(uint32(fd)) if df != nil { files = append(files, df) } if dfVFS2 != nil { filesVFS2 = append(filesVFS2, dfVFS2) } - // Update current available position. - if fd < f.next { - f.next = fd - } } }) f.mu.Unlock() diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index f17f9c59c..2b3e6ef71 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -15,9 +15,11 @@ package kernel import ( + "math" "sync/atomic" "unsafe" + "gvisor.dev/gvisor/pkg/bitmap" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -44,6 +46,7 @@ func (f *FDTable) initNoLeakCheck() { func (f *FDTable) init() { f.initNoLeakCheck() f.InitRefs() + f.fdBitmap = bitmap.New(uint32(math.MaxUint16)) } // get gets a file entry. @@ -162,14 +165,6 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 } } - // Adjust used. - switch { - case orig == nil && desc != nil: - atomic.AddInt32(&f.used, 1) - case orig != nil && desc == nil: - atomic.AddInt32(&f.used, -1) - } - if orig != nil { switch { case orig.file != nil: diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go index 3ce926950..c111297d7 100644 --- a/pkg/sentry/kernel/msgqueue/msgqueue.go +++ b/pkg/sentry/kernel/msgqueue/msgqueue.go @@ -119,14 +119,21 @@ type Queue struct { type Message struct { msgEntry - // mType is an integer representing the type of the sent message. - mType int64 + // Type is an integer representing the type of the sent message. + Type int64 - // mText is an untyped block of memory. - mText []byte + // Text is an untyped block of memory. + Text []byte - // mSize is the size of mText. - mSize uint64 + // Size is the size of Text. + Size uint64 +} + +// Blocker is used for blocking Queue.Send, and Queue.Receive calls that serves +// as an abstracted version of kernel.Task. kernel.Task is not directly used to +// prevent circular dependencies. +type Blocker interface { + Block(C <-chan struct{}) error } // FindOrCreate creates a new message queue or returns an existing one. See @@ -186,6 +193,265 @@ func (r *Registry) Remove(id ipc.ID, creds *auth.Credentials) error { return nil } +// FindByID returns the queue with the specified ID and an error if the ID +// doesn't exist. +func (r *Registry) FindByID(id ipc.ID) (*Queue, error) { + r.mu.Lock() + defer r.mu.Unlock() + + mech := r.reg.FindByID(id) + if mech == nil { + return nil, linuxerr.EINVAL + } + return mech.(*Queue), nil +} + +// Send appends a message to the message queue, and returns an error if sending +// fails. See msgsnd(2). +func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) (err error) { + // Try to perform a non-blocking send using queue.append. If EWOULDBLOCK + // is returned, start the blocking procedure. Otherwise, return normally. + creds := auth.CredentialsFromContext(ctx) + if err := q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK { + return err + } + + if !wait { + return linuxerr.EAGAIN + } + + e, ch := waiter.NewChannelEntry(nil) + q.senders.EventRegister(&e, waiter.EventOut) + + for { + if err = q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK { + break + } + b.Block(ch) + } + + q.senders.EventUnregister(&e) + return err +} + +// append appends a message to the queue's message list and notifies waiting +// receivers that a message has been inserted. It returns an error if adding +// the message would cause the queue to exceed its maximum capacity, which can +// be used as a signal to block the task. Other errors should be returned as is. +func (q *Queue) append(ctx context.Context, m Message, creds *auth.Credentials, pid int32) error { + if m.Type <= 0 { + return linuxerr.EINVAL + } + + q.mu.Lock() + defer q.mu.Unlock() + + if !q.obj.CheckPermissions(creds, fs.PermMask{Write: true}) { + // The calling process does not have write permission on the message + // queue, and does not have the CAP_IPC_OWNER capability in the user + // namespace that governs its IPC namespace. + return linuxerr.EACCES + } + + // Queue was removed while the process was waiting. + if q.dead { + return linuxerr.EIDRM + } + + // Check if sufficient space is available (the queue isn't full.) From + // the man pages: + // + // "A message queue is considered to be full if either of the following + // conditions is true: + // + // • Adding a new message to the queue would cause the total number + // of bytes in the queue to exceed the queue's maximum size (the + // msg_qbytes field). + // + // • Adding another message to the queue would cause the total + // number of messages in the queue to exceed the queue's maximum + // size (the msg_qbytes field). This check is necessary to + // prevent an unlimited number of zero-length messages being + // placed on the queue. Although such messages contain no data, + // they nevertheless consume (locked) kernel memory." + // + // The msg_qbytes field in our implementation is q.maxBytes. + if m.Size+q.byteCount > q.maxBytes || q.messageCount+1 > q.maxBytes { + return linuxerr.EWOULDBLOCK + } + + // Copy the message into the queue. + q.messages.PushBack(&m) + + q.byteCount += m.Size + q.messageCount++ + q.sendPID = pid + q.sendTime = ktime.NowFromContext(ctx) + + // Notify receivers about the new message. + q.receivers.Notify(waiter.EventIn) + + return nil +} + +// Receive removes a message from the queue and returns it. See msgrcv(2). +func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int64, wait, truncate, except bool, pid int32) (msg *Message, err error) { + if maxSize < 0 || maxSize > maxMessageBytes { + return nil, linuxerr.EINVAL + } + max := uint64(maxSize) + + // Try to perform a non-blocking receive using queue.pop. If EWOULDBLOCK + // is returned, start the blocking procedure. Otherwise, return normally. + creds := auth.CredentialsFromContext(ctx) + if msg, err := q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK { + return msg, err + } + + if !wait { + return nil, linuxerr.ENOMSG + } + + e, ch := waiter.NewChannelEntry(nil) + q.receivers.EventRegister(&e, waiter.EventIn) + + for { + if msg, err = q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK { + break + } + b.Block(ch) + } + q.receivers.EventUnregister(&e) + return msg, err +} + +// pop pops the first message from the queue that matches the given type. It +// returns an error for all the cases specified in msgrcv(2). If the queue is +// empty or no message of the specified type is available, a EWOULDBLOCK error +// is returned, which can then be used as a signal to block the process or fail. +func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, maxSize uint64, truncate, except bool, pid int32) (msg *Message, _ error) { + q.mu.Lock() + defer q.mu.Unlock() + + if !q.obj.CheckPermissions(creds, fs.PermMask{Read: true}) { + // The calling process does not have read permission on the message + // queue, and does not have the CAP_IPC_OWNER capability in the user + // namespace that governs its IPC namespace. + return nil, linuxerr.EACCES + } + + // Queue was removed while the process was waiting. + if q.dead { + return nil, linuxerr.EIDRM + } + + if q.messages.Empty() { + return nil, linuxerr.EWOULDBLOCK + } + + // Get a message from the queue. + switch { + case mType == 0: + msg = q.messages.Front() + case mType > 0: + msg = q.msgOfType(mType, except) + case mType < 0: + msg = q.msgOfTypeLessThan(-1 * mType) + } + + // If no message exists, return a blocking singal. + if msg == nil { + return nil, linuxerr.EWOULDBLOCK + } + + // Check message's size is acceptable. + if maxSize < msg.Size { + if !truncate { + return nil, linuxerr.E2BIG + } + msg.Size = maxSize + msg.Text = msg.Text[:maxSize+1] + } + + q.messages.Remove(msg) + + q.byteCount -= msg.Size + q.messageCount-- + q.receivePID = pid + q.receiveTime = ktime.NowFromContext(ctx) + + // Notify senders about available space. + q.senders.Notify(waiter.EventOut) + + return msg, nil +} + +// Copy copies a message from the queue without deleting it. If no message +// exists, an error is returned. See msgrcv(MSG_COPY). +func (q *Queue) Copy(mType int64) (*Message, error) { + q.mu.Lock() + defer q.mu.Unlock() + + if mType < 0 || q.messages.Empty() { + return nil, linuxerr.ENOMSG + } + + msg := q.msgAtIndex(mType) + if msg == nil { + return nil, linuxerr.ENOMSG + } + return msg, nil +} + +// msgOfType returns the first message with the specified type, nil if no +// message is found. If except is true, the first message of a type not equal +// to mType will be returned. +// +// Precondition: caller must hold q.mu. +func (q *Queue) msgOfType(mType int64, except bool) *Message { + if except { + for msg := q.messages.Front(); msg != nil; msg = msg.Next() { + if msg.Type != mType { + return msg + } + } + return nil + } + + for msg := q.messages.Front(); msg != nil; msg = msg.Next() { + if msg.Type == mType { + return msg + } + } + return nil +} + +// msgOfTypeLessThan return the the first message with the lowest type less +// than or equal to mType, nil if no such message exists. +// +// Precondition: caller must hold q.mu. +func (q *Queue) msgOfTypeLessThan(mType int64) (m *Message) { + min := mType + for msg := q.messages.Front(); msg != nil; msg = msg.Next() { + if msg.Type <= mType && msg.Type < min { + m = msg + min = msg.Type + } + } + return m +} + +// msgAtIndex returns a pointer to a message at given index, nil if non exits. +// +// Precondition: caller must hold q.mu. +func (q *Queue) msgAtIndex(mType int64) *Message { + msg := q.messages.Front() + for ; mType != 0 && msg != nil; mType-- { + msg = msg.Next() + } + return msg +} + // Lock implements ipc.Mechanism.Lock. func (q *Queue) Lock() { q.mu.Lock() |