diff options
Diffstat (limited to 'pkg/sentry/kernel')
-rw-r--r-- | pkg/sentry/kernel/BUILD | 4 | ||||
-rw-r--r-- | pkg/sentry/kernel/kernel.go | 63 |
2 files changed, 51 insertions, 16 deletions
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index a868ee7fe..083071b5e 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -69,8 +69,8 @@ go_template_instance( prefix = "socket", template = "//pkg/ilist:generic_list", types = { - "Element": "*SocketRecord", - "Linker": "*SocketRecord", + "Element": "*SocketRecordVFS1", + "Linker": "*SocketRecordVFS1", }, ) diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index ebd2ec3df..d9c62ff91 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -220,10 +220,15 @@ type Kernel struct { // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` - // sockets is the list of all network sockets the system. Protected by - // extMu. + // sockets is the list of all network sockets in the system. + // Protected by extMu. + // TODO(gvisor.dev/issue/1624): Only used by VFS1. sockets socketList + // socketsVFS2 records all network sockets in the system. Protected by + // extMu. + socketsVFS2 map[*vfs.FileDescription]*SocketRecord + // nextSocketRecord is the next entry number to use in sockets. Protected // by extMu. nextSocketRecord uint64 @@ -414,6 +419,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { return fmt.Errorf("failed to create sockfs mount: %v", err) } k.socketMount = socketMount + + k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord) } return nil @@ -1509,20 +1516,27 @@ func (k *Kernel) SupervisorContext() context.Context { } } -// SocketRecord represents a socket recorded in Kernel.sockets. It implements -// refs.WeakRefUser for sockets stored in the socket table. +// SocketRecord represents a socket recorded in Kernel.socketsVFS2. // // +stateify savable type SocketRecord struct { - socketEntry k *Kernel - Sock *refs.WeakRef - SockVFS2 *vfs.FileDescription - ID uint64 // Socket table entry number. + Sock *refs.WeakRef // TODO(gvisor.dev/issue/1624): Only used by VFS1. + SockVFS2 *vfs.FileDescription // Only used by VFS2. + ID uint64 // Socket table entry number. +} + +// SocketRecordVFS1 represents a socket recorded in Kernel.sockets. It implements +// refs.WeakRefUser for sockets stored in the socket table. +// +// +stateify savable +type SocketRecordVFS1 struct { + socketEntry + SocketRecord } // WeakRefGone implements refs.WeakRefUser.WeakRefGone. -func (s *SocketRecord) WeakRefGone(context.Context) { +func (s *SocketRecordVFS1) WeakRefGone(context.Context) { s.k.extMu.Lock() s.k.sockets.Remove(s) s.k.extMu.Unlock() @@ -1535,7 +1549,12 @@ func (k *Kernel) RecordSocket(sock *fs.File) { k.extMu.Lock() id := k.nextSocketRecord k.nextSocketRecord++ - s := &SocketRecord{k: k, ID: id} + s := &SocketRecordVFS1{ + SocketRecord: SocketRecord{ + k: k, + ID: id, + }, + } s.Sock = refs.NewWeakRef(sock, s) k.sockets.PushBack(s) k.extMu.Unlock() @@ -1547,9 +1566,12 @@ func (k *Kernel) RecordSocket(sock *fs.File) { // Precondition: Caller must hold a reference to sock. // // Note that the socket table will not hold a reference on the -// vfs.FileDescription, because we do not support weak refs on VFS2 files. +// vfs.FileDescription. func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) { k.extMu.Lock() + if _, ok := k.socketsVFS2[sock]; ok { + panic(fmt.Sprintf("Socket %p added twice", sock)) + } id := k.nextSocketRecord k.nextSocketRecord++ s := &SocketRecord{ @@ -1557,7 +1579,14 @@ func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) { ID: id, SockVFS2: sock, } - k.sockets.PushBack(s) + k.socketsVFS2[sock] = s + k.extMu.Unlock() +} + +// DeleteSocketVFS2 removes a VFS2 socket from the system-wide socket table. +func (k *Kernel) DeleteSocketVFS2(sock *vfs.FileDescription) { + k.extMu.Lock() + delete(k.socketsVFS2, sock) k.extMu.Unlock() } @@ -1568,8 +1597,14 @@ func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) { func (k *Kernel) ListSockets() []*SocketRecord { k.extMu.Lock() var socks []*SocketRecord - for s := k.sockets.Front(); s != nil; s = s.Next() { - socks = append(socks, s) + if VFS2Enabled { + for _, s := range k.socketsVFS2 { + socks = append(socks, s) + } + } else { + for s := k.sockets.Front(); s != nil; s = s.Next() { + socks = append(socks, &s.SocketRecord) + } } k.extMu.Unlock() return socks |