summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/fsimpl/fuse
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/fsimpl/fuse')
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD76
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go428
-rw-r--r--pkg/sentry/fsimpl/fuse/fuse_state_autogen.go184
-rw-r--r--pkg/sentry/fsimpl/fuse/inode_refs.go118
-rw-r--r--pkg/sentry/fsimpl/fuse/request_list.go193
5 files changed, 495 insertions, 504 deletions
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
deleted file mode 100644
index 53a4f3012..000000000
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ /dev/null
@@ -1,76 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-licenses(["notice"])
-
-go_template_instance(
- name = "request_list",
- out = "request_list.go",
- package = "fuse",
- prefix = "request",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*Request",
- "Linker": "*Request",
- },
-)
-
-go_template_instance(
- name = "inode_refs",
- out = "inode_refs.go",
- package = "fuse",
- prefix = "inode",
- template = "//pkg/refs_vfs2:refs_template",
- types = {
- "T": "inode",
- },
-)
-
-go_library(
- name = "fuse",
- srcs = [
- "connection.go",
- "dev.go",
- "fusefs.go",
- "init.go",
- "inode_refs.go",
- "register.go",
- "request_list.go",
- ],
- visibility = ["//pkg/sentry:internal"],
- deps = [
- "//pkg/abi/linux",
- "//pkg/context",
- "//pkg/log",
- "//pkg/refs",
- "//pkg/sentry/fsimpl/devtmpfs",
- "//pkg/sentry/fsimpl/kernfs",
- "//pkg/sentry/kernel",
- "//pkg/sentry/kernel/auth",
- "//pkg/sentry/vfs",
- "//pkg/sync",
- "//pkg/syserror",
- "//pkg/usermem",
- "//pkg/waiter",
- "//tools/go_marshal/marshal",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
-
-go_test(
- name = "fuse_test",
- size = "small",
- srcs = ["dev_test.go"],
- library = ":fuse",
- deps = [
- "//pkg/abi/linux",
- "//pkg/sentry/fsimpl/testutil",
- "//pkg/sentry/kernel",
- "//pkg/sentry/kernel/auth",
- "//pkg/sentry/vfs",
- "//pkg/syserror",
- "//pkg/usermem",
- "//pkg/waiter",
- "//tools/go_marshal/marshal",
- ],
-)
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
deleted file mode 100644
index 1ffe7ccd2..000000000
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ /dev/null
@@ -1,428 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fuse
-
-import (
- "fmt"
- "io"
- "math/rand"
- "testing"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
-)
-
-// echoTestOpcode is the Opcode used during testing. The server used in tests
-// will simply echo the payload back with the appropriate headers.
-const echoTestOpcode linux.FUSEOpcode = 1000
-
-type testPayload struct {
- data uint32
-}
-
-// TestFUSECommunication tests that the communication layer between the Sentry and the
-// FUSE server daemon works as expected.
-func TestFUSECommunication(t *testing.T) {
- s := setup(t)
- defer s.Destroy()
-
- k := kernel.KernelFromContext(s.Ctx)
- creds := auth.CredentialsFromContext(s.Ctx)
-
- // Create test cases with different number of concurrent clients and servers.
- testCases := []struct {
- Name string
- NumClients int
- NumServers int
- MaxActiveRequests uint64
- }{
- {
- Name: "SingleClientSingleServer",
- NumClients: 1,
- NumServers: 1,
- MaxActiveRequests: maxActiveRequestsDefault,
- },
- {
- Name: "SingleClientMultipleServers",
- NumClients: 1,
- NumServers: 10,
- MaxActiveRequests: maxActiveRequestsDefault,
- },
- {
- Name: "MultipleClientsSingleServer",
- NumClients: 10,
- NumServers: 1,
- MaxActiveRequests: maxActiveRequestsDefault,
- },
- {
- Name: "MultipleClientsMultipleServers",
- NumClients: 10,
- NumServers: 10,
- MaxActiveRequests: maxActiveRequestsDefault,
- },
- {
- Name: "RequestCapacityFull",
- NumClients: 10,
- NumServers: 1,
- MaxActiveRequests: 1,
- },
- {
- Name: "RequestCapacityContinuouslyFull",
- NumClients: 100,
- NumServers: 2,
- MaxActiveRequests: 2,
- },
- }
-
- for _, testCase := range testCases {
- t.Run(testCase.Name, func(t *testing.T) {
- conn, fd, err := newTestConnection(s, k, testCase.MaxActiveRequests)
- if err != nil {
- t.Fatalf("newTestConnection: %v", err)
- }
-
- clientsDone := make([]chan struct{}, testCase.NumClients)
- serversDone := make([]chan struct{}, testCase.NumServers)
- serversKill := make([]chan struct{}, testCase.NumServers)
-
- // FUSE clients.
- for i := 0; i < testCase.NumClients; i++ {
- clientsDone[i] = make(chan struct{})
- go func(i int) {
- fuseClientRun(t, s, k, conn, creds, uint32(i), uint64(i), clientsDone[i])
- }(i)
- }
-
- // FUSE servers.
- for j := 0; j < testCase.NumServers; j++ {
- serversDone[j] = make(chan struct{})
- serversKill[j] = make(chan struct{}, 1) // The kill command shouldn't block.
- go func(j int) {
- fuseServerRun(t, s, k, fd, serversDone[j], serversKill[j])
- }(j)
- }
-
- // Tear down.
- //
- // Make sure all the clients are done.
- for i := 0; i < testCase.NumClients; i++ {
- <-clientsDone[i]
- }
-
- // Kill any server that is potentially waiting.
- for j := 0; j < testCase.NumServers; j++ {
- serversKill[j] <- struct{}{}
- }
-
- // Make sure all the servers are done.
- for j := 0; j < testCase.NumServers; j++ {
- <-serversDone[j]
- }
- })
- }
-}
-
-// CallTest makes a request to the server and blocks the invoking
-// goroutine until a server responds with a response. Doesn't block
-// a kernel.Task. Analogous to Connection.Call but used for testing.
-func CallTest(conn *connection, t *kernel.Task, r *Request, i uint32) (*Response, error) {
- conn.fd.mu.Lock()
-
- // Wait until we're certain that a new request can be processed.
- for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests {
- conn.fd.mu.Unlock()
- select {
- case <-conn.fd.fullQueueCh:
- }
- conn.fd.mu.Lock()
- }
-
- fut, err := conn.callFutureLocked(t, r) // No task given.
- conn.fd.mu.Unlock()
-
- if err != nil {
- return nil, err
- }
-
- // Resolve the response.
- //
- // Block without a task.
- select {
- case <-fut.ch:
- }
-
- // A response is ready. Resolve and return it.
- return fut.getResponse(), nil
-}
-
-// ReadTest is analogous to vfs.FileDescription.Read and reads from the FUSE
-// device. However, it does so by - not blocking the task that is calling - and
-// instead just waits on a channel. The behaviour is essentially the same as
-// DeviceFD.Read except it guarantees that the task is not blocked.
-func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem.IOSequence, killServer chan struct{}) (int64, bool, error) {
- var err error
- var n, total int64
-
- dev := fd.Impl().(*DeviceFD)
-
- // Register for notifications.
- w, ch := waiter.NewChannelEntry(nil)
- dev.EventRegister(&w, waiter.EventIn)
- for {
- // Issue the request and break out if it completes with anything other than
- // "would block".
- n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{})
- total += n
- if err != syserror.ErrWouldBlock {
- break
- }
-
- // Wait for a notification that we should retry.
- // Emulate the blocking for when no requests are available
- select {
- case <-ch:
- case <-killServer:
- // Server killed by the main program.
- return 0, true, nil
- }
- }
-
- dev.EventUnregister(&w)
- return total, false, err
-}
-
-// fuseClientRun emulates all the actions of a normal FUSE request. It creates
-// a header, a payload, calls the server, waits for the response, and processes
-// the response.
-func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *connection, creds *auth.Credentials, pid uint32, inode uint64, clientDone chan struct{}) {
- defer func() { clientDone <- struct{}{} }()
-
- tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- clientTask, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("fuse-client-%v", pid), tc, s.MntNs, s.Root, s.Root)
- if err != nil {
- t.Fatal(err)
- }
- testObj := &testPayload{
- data: rand.Uint32(),
- }
-
- req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
-
- // Queue up a request.
- // Analogous to Call except it doesn't block on the task.
- resp, err := CallTest(conn, clientTask, req, pid)
- if err != nil {
- t.Fatalf("CallTaskNonBlock failed: %v", err)
- }
-
- if err = resp.Error(); err != nil {
- t.Fatalf("Server responded with an error: %v", err)
- }
-
- var respTestPayload testPayload
- if err := resp.UnmarshalPayload(&respTestPayload); err != nil {
- t.Fatalf("Unmarshalling payload error: %v", err)
- }
-
- if resp.hdr.Unique != req.hdr.Unique {
- t.Fatalf("got response for another request. Expected response for req %v but got response for req %v",
- req.hdr.Unique, resp.hdr.Unique)
- }
-
- if respTestPayload.data != testObj.data {
- t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data)
- }
-
-}
-
-// fuseServerRun creates a task and emulates all the actions of a simple FUSE server
-// that simply reads a request and echos the same struct back as a response using the
-// appropriate headers.
-func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.FileDescription, serverDone, killServer chan struct{}) {
- defer func() { serverDone <- struct{}{} }()
-
- // Create the tasks that the server will be using.
- tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- var readPayload testPayload
-
- serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root)
- if err != nil {
- t.Fatal(err)
- }
-
- // Read the request.
- for {
- inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes())
- payloadLen := uint32(readPayload.SizeBytes())
-
- // The raed buffer must meet some certain size criteria.
- buffSize := inHdrLen + payloadLen
- if buffSize < linux.FUSE_MIN_READ_BUFFER {
- buffSize = linux.FUSE_MIN_READ_BUFFER
- }
- inBuf := make([]byte, buffSize)
- inIOseq := usermem.BytesIOSequence(inBuf)
-
- n, serverKilled, err := ReadTest(serverTask, fd, inIOseq, killServer)
- if err != nil {
- t.Fatalf("Read failed :%v", err)
- }
-
- // Server should shut down. No new requests are going to be made.
- if serverKilled {
- break
- }
-
- if n <= 0 {
- t.Fatalf("Read read no bytes")
- }
-
- var readFUSEHeaderIn linux.FUSEHeaderIn
- readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen])
- readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen])
-
- if readFUSEHeaderIn.Opcode != echoTestOpcode {
- t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload)
- }
-
- // Write the response.
- outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
- outBuf := make([]byte, outHdrLen+payloadLen)
- outHeader := linux.FUSEHeaderOut{
- Len: outHdrLen + payloadLen,
- Error: 0,
- Unique: readFUSEHeaderIn.Unique,
- }
-
- // Echo the payload back.
- outHeader.MarshalUnsafe(outBuf[:outHdrLen])
- readPayload.MarshalUnsafe(outBuf[outHdrLen:])
- outIOseq := usermem.BytesIOSequence(outBuf)
-
- n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{})
- if err != nil {
- t.Fatalf("Write failed :%v", err)
- }
- }
-}
-
-func setup(t *testing.T) *testutil.System {
- k, err := testutil.Boot()
- if err != nil {
- t.Fatalf("Error creating kernel: %v", err)
- }
-
- ctx := k.SupervisorContext()
- creds := auth.CredentialsFromContext(ctx)
-
- k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
- AllowUserList: true,
- AllowUserMount: true,
- })
-
- mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
- if err != nil {
- t.Fatalf("NewMountNamespace(): %v", err)
- }
-
- return testutil.NewSystem(ctx, t, k.VFS(), mntns)
-}
-
-// newTestConnection creates a fuse connection that the sentry can communicate with
-// and the FD for the server to communicate with.
-func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) {
- vfsObj := &vfs.VirtualFilesystem{}
- fuseDev := &DeviceFD{}
-
- if err := vfsObj.Init(system.Ctx); err != nil {
- return nil, nil, err
- }
-
- vd := vfsObj.NewAnonVirtualDentry("genCountFD")
- defer vd.DecRef(system.Ctx)
- if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil {
- return nil, nil, err
- }
-
- fsopts := filesystemOptions{
- maxActiveRequests: maxActiveRequests,
- }
- fs, err := NewFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd)
- if err != nil {
- return nil, nil, err
- }
-
- return fs.conn, &fuseDev.vfsfd, nil
-}
-
-// SizeBytes implements marshal.Marshallable.SizeBytes.
-func (t *testPayload) SizeBytes() int {
- return 4
-}
-
-// MarshalBytes implements marshal.Marshallable.MarshalBytes.
-func (t *testPayload) MarshalBytes(dst []byte) {
- usermem.ByteOrder.PutUint32(dst[:4], t.data)
-}
-
-// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
-func (t *testPayload) UnmarshalBytes(src []byte) {
- *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])}
-}
-
-// Packed implements marshal.Marshallable.Packed.
-func (t *testPayload) Packed() bool {
- return true
-}
-
-// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
-func (t *testPayload) MarshalUnsafe(dst []byte) {
- t.MarshalBytes(dst)
-}
-
-// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
-func (t *testPayload) UnmarshalUnsafe(src []byte) {
- t.UnmarshalBytes(src)
-}
-
-// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (t *testPayload) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
- panic("not implemented")
-}
-
-// CopyOut implements marshal.Marshallable.CopyOut.
-func (t *testPayload) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
- panic("not implemented")
-}
-
-// CopyIn implements marshal.Marshallable.CopyIn.
-func (t *testPayload) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- panic("not implemented")
-}
-
-// WriteTo implements io.WriterTo.WriteTo.
-func (t *testPayload) WriteTo(w io.Writer) (int64, error) {
- panic("not implemented")
-}
diff --git a/pkg/sentry/fsimpl/fuse/fuse_state_autogen.go b/pkg/sentry/fsimpl/fuse/fuse_state_autogen.go
new file mode 100644
index 000000000..f72fe342e
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/fuse_state_autogen.go
@@ -0,0 +1,184 @@
+// automatically generated by stateify.
+
+package fuse
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (x *Request) StateTypeName() string {
+ return "pkg/sentry/fsimpl/fuse.Request"
+}
+
+func (x *Request) StateFields() []string {
+ return []string{
+ "requestEntry",
+ "id",
+ "hdr",
+ "data",
+ }
+}
+
+func (x *Request) beforeSave() {}
+
+func (x *Request) StateSave(m state.Sink) {
+ x.beforeSave()
+ m.Save(0, &x.requestEntry)
+ m.Save(1, &x.id)
+ m.Save(2, &x.hdr)
+ m.Save(3, &x.data)
+}
+
+func (x *Request) afterLoad() {}
+
+func (x *Request) StateLoad(m state.Source) {
+ m.Load(0, &x.requestEntry)
+ m.Load(1, &x.id)
+ m.Load(2, &x.hdr)
+ m.Load(3, &x.data)
+}
+
+func (x *Response) StateTypeName() string {
+ return "pkg/sentry/fsimpl/fuse.Response"
+}
+
+func (x *Response) StateFields() []string {
+ return []string{
+ "opcode",
+ "hdr",
+ "data",
+ }
+}
+
+func (x *Response) beforeSave() {}
+
+func (x *Response) StateSave(m state.Sink) {
+ x.beforeSave()
+ m.Save(0, &x.opcode)
+ m.Save(1, &x.hdr)
+ m.Save(2, &x.data)
+}
+
+func (x *Response) afterLoad() {}
+
+func (x *Response) StateLoad(m state.Source) {
+ m.Load(0, &x.opcode)
+ m.Load(1, &x.hdr)
+ m.Load(2, &x.data)
+}
+
+func (x *futureResponse) StateTypeName() string {
+ return "pkg/sentry/fsimpl/fuse.futureResponse"
+}
+
+func (x *futureResponse) StateFields() []string {
+ return []string{
+ "opcode",
+ "ch",
+ "hdr",
+ "data",
+ }
+}
+
+func (x *futureResponse) beforeSave() {}
+
+func (x *futureResponse) StateSave(m state.Sink) {
+ x.beforeSave()
+ m.Save(0, &x.opcode)
+ m.Save(1, &x.ch)
+ m.Save(2, &x.hdr)
+ m.Save(3, &x.data)
+}
+
+func (x *futureResponse) afterLoad() {}
+
+func (x *futureResponse) StateLoad(m state.Source) {
+ m.Load(0, &x.opcode)
+ m.Load(1, &x.ch)
+ m.Load(2, &x.hdr)
+ m.Load(3, &x.data)
+}
+
+func (x *inodeRefs) StateTypeName() string {
+ return "pkg/sentry/fsimpl/fuse.inodeRefs"
+}
+
+func (x *inodeRefs) StateFields() []string {
+ return []string{
+ "refCount",
+ }
+}
+
+func (x *inodeRefs) beforeSave() {}
+
+func (x *inodeRefs) StateSave(m state.Sink) {
+ x.beforeSave()
+ m.Save(0, &x.refCount)
+}
+
+func (x *inodeRefs) afterLoad() {}
+
+func (x *inodeRefs) StateLoad(m state.Source) {
+ m.Load(0, &x.refCount)
+}
+
+func (x *requestList) StateTypeName() string {
+ return "pkg/sentry/fsimpl/fuse.requestList"
+}
+
+func (x *requestList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (x *requestList) beforeSave() {}
+
+func (x *requestList) StateSave(m state.Sink) {
+ x.beforeSave()
+ m.Save(0, &x.head)
+ m.Save(1, &x.tail)
+}
+
+func (x *requestList) afterLoad() {}
+
+func (x *requestList) StateLoad(m state.Source) {
+ m.Load(0, &x.head)
+ m.Load(1, &x.tail)
+}
+
+func (x *requestEntry) StateTypeName() string {
+ return "pkg/sentry/fsimpl/fuse.requestEntry"
+}
+
+func (x *requestEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (x *requestEntry) beforeSave() {}
+
+func (x *requestEntry) StateSave(m state.Sink) {
+ x.beforeSave()
+ m.Save(0, &x.next)
+ m.Save(1, &x.prev)
+}
+
+func (x *requestEntry) afterLoad() {}
+
+func (x *requestEntry) StateLoad(m state.Source) {
+ m.Load(0, &x.next)
+ m.Load(1, &x.prev)
+}
+
+func init() {
+ state.Register((*Request)(nil))
+ state.Register((*Response)(nil))
+ state.Register((*futureResponse)(nil))
+ state.Register((*inodeRefs)(nil))
+ state.Register((*requestList)(nil))
+ state.Register((*requestEntry)(nil))
+}
diff --git a/pkg/sentry/fsimpl/fuse/inode_refs.go b/pkg/sentry/fsimpl/fuse/inode_refs.go
new file mode 100644
index 000000000..6b9456e1d
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/inode_refs.go
@@ -0,0 +1,118 @@
+package fuse
+
+import (
+ "fmt"
+ "runtime"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/log"
+ refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+)
+
+// ownerType is used to customize logging. Note that we use a pointer to T so
+// that we do not copy the entire object when passed as a format parameter.
+var inodeownerType *inode
+
+// Refs implements refs.RefCounter. It keeps a reference count using atomic
+// operations and calls the destructor when the count reaches zero.
+//
+// Note that the number of references is actually refCount + 1 so that a default
+// zero-value Refs object contains one reference.
+//
+// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in
+// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount.
+// This will allow us to add stack trace information to the leak messages
+// without growing the size of Refs.
+//
+// +stateify savable
+type inodeRefs struct {
+ // refCount is composed of two fields:
+ //
+ // [32-bit speculative references]:[32-bit real references]
+ //
+ // Speculative references are used for TryIncRef, to avoid a CompareAndSwap
+ // loop. See IncRef, DecRef and TryIncRef for details of how these fields are
+ // used.
+ refCount int64
+}
+
+func (r *inodeRefs) finalize() {
+ var note string
+ switch refs_vfs1.GetLeakMode() {
+ case refs_vfs1.NoLeakChecking:
+ return
+ case refs_vfs1.UninitializedLeakChecking:
+ note = "(Leak checker uninitialized): "
+ }
+ if n := r.ReadRefs(); n != 0 {
+ log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, inodeownerType, n)
+ }
+}
+
+// EnableLeakCheck checks for reference leaks when Refs gets garbage collected.
+func (r *inodeRefs) EnableLeakCheck() {
+ if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking {
+ runtime.SetFinalizer(r, (*inodeRefs).finalize)
+ }
+}
+
+// ReadRefs returns the current number of references. The returned count is
+// inherently racy and is unsafe to use without external synchronization.
+func (r *inodeRefs) ReadRefs() int64 {
+
+ return atomic.LoadInt64(&r.refCount) + 1
+}
+
+// IncRef implements refs.RefCounter.IncRef.
+//
+//go:nosplit
+func (r *inodeRefs) IncRef() {
+ if v := atomic.AddInt64(&r.refCount, 1); v <= 0 {
+ panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, inodeownerType))
+ }
+}
+
+// TryIncRef implements refs.RefCounter.TryIncRef.
+//
+// To do this safely without a loop, a speculative reference is first acquired
+// on the object. This allows multiple concurrent TryIncRef calls to distinguish
+// other TryIncRef calls from genuine references held.
+//
+//go:nosplit
+func (r *inodeRefs) TryIncRef() bool {
+ const speculativeRef = 1 << 32
+ v := atomic.AddInt64(&r.refCount, speculativeRef)
+ if int32(v) < 0 {
+
+ atomic.AddInt64(&r.refCount, -speculativeRef)
+ return false
+ }
+
+ atomic.AddInt64(&r.refCount, -speculativeRef+1)
+ return true
+}
+
+// DecRef implements refs.RefCounter.DecRef.
+//
+// Note that speculative references are counted here. Since they were added
+// prior to real references reaching zero, they will successfully convert to
+// real references. In other words, we see speculative references only in the
+// following case:
+//
+// A: TryIncRef [speculative increase => sees non-negative references]
+// B: DecRef [real decrease]
+// A: TryIncRef [transform speculative to real]
+//
+//go:nosplit
+func (r *inodeRefs) DecRef(destroy func()) {
+ switch v := atomic.AddInt64(&r.refCount, -1); {
+ case v < -1:
+ panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, inodeownerType))
+
+ case v == -1:
+
+ if destroy != nil {
+ destroy()
+ }
+ }
+}
diff --git a/pkg/sentry/fsimpl/fuse/request_list.go b/pkg/sentry/fsimpl/fuse/request_list.go
new file mode 100644
index 000000000..002262f23
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/request_list.go
@@ -0,0 +1,193 @@
+package fuse
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type requestElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (requestElementMapper) linkerFor(elem *Request) *Request { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type requestList struct {
+ head *Request
+ tail *Request
+}
+
+// Reset resets list l to the empty state.
+func (l *requestList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *requestList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *requestList) Front() *Request {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *requestList) Back() *Request {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+func (l *requestList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (requestElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *requestList) PushFront(e *Request) {
+ linker := requestElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ requestElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *requestList) PushBack(e *Request) {
+ linker := requestElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ requestElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *requestList) PushBackList(m *requestList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ requestElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ requestElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *requestList) InsertAfter(b, e *Request) {
+ bLinker := requestElementMapper{}.linkerFor(b)
+ eLinker := requestElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ requestElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *requestList) InsertBefore(a, e *Request) {
+ aLinker := requestElementMapper{}.linkerFor(a)
+ eLinker := requestElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ requestElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *requestList) Remove(e *Request) {
+ linker := requestElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ requestElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ requestElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type requestEntry struct {
+ next *Request
+ prev *Request
+}
+
+// Next returns the entry that follows e in the list.
+func (e *requestEntry) Next() *Request {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *requestEntry) Prev() *Request {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *requestEntry) SetNext(elem *Request) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *requestEntry) SetPrev(elem *Request) {
+ e.prev = elem
+}