summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/kernel/kernel.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/kernel/kernel.go')
-rw-r--r--pkg/sentry/kernel/kernel.go63
1 files changed, 49 insertions, 14 deletions
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index ebd2ec3df..d9c62ff91 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -220,10 +220,15 @@ type Kernel struct {
// danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
- // sockets is the list of all network sockets the system. Protected by
- // extMu.
+ // sockets is the list of all network sockets in the system.
+ // Protected by extMu.
+ // TODO(gvisor.dev/issue/1624): Only used by VFS1.
sockets socketList
+ // socketsVFS2 records all network sockets in the system. Protected by
+ // extMu.
+ socketsVFS2 map[*vfs.FileDescription]*SocketRecord
+
// nextSocketRecord is the next entry number to use in sockets. Protected
// by extMu.
nextSocketRecord uint64
@@ -414,6 +419,8 @@ func (k *Kernel) Init(args InitKernelArgs) error {
return fmt.Errorf("failed to create sockfs mount: %v", err)
}
k.socketMount = socketMount
+
+ k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord)
}
return nil
@@ -1509,20 +1516,27 @@ func (k *Kernel) SupervisorContext() context.Context {
}
}
-// SocketRecord represents a socket recorded in Kernel.sockets. It implements
-// refs.WeakRefUser for sockets stored in the socket table.
+// SocketRecord represents a socket recorded in Kernel.socketsVFS2.
//
// +stateify savable
type SocketRecord struct {
- socketEntry
k *Kernel
- Sock *refs.WeakRef
- SockVFS2 *vfs.FileDescription
- ID uint64 // Socket table entry number.
+ Sock *refs.WeakRef // TODO(gvisor.dev/issue/1624): Only used by VFS1.
+ SockVFS2 *vfs.FileDescription // Only used by VFS2.
+ ID uint64 // Socket table entry number.
+}
+
+// SocketRecordVFS1 represents a socket recorded in Kernel.sockets. It implements
+// refs.WeakRefUser for sockets stored in the socket table.
+//
+// +stateify savable
+type SocketRecordVFS1 struct {
+ socketEntry
+ SocketRecord
}
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
-func (s *SocketRecord) WeakRefGone(context.Context) {
+func (s *SocketRecordVFS1) WeakRefGone(context.Context) {
s.k.extMu.Lock()
s.k.sockets.Remove(s)
s.k.extMu.Unlock()
@@ -1535,7 +1549,12 @@ func (k *Kernel) RecordSocket(sock *fs.File) {
k.extMu.Lock()
id := k.nextSocketRecord
k.nextSocketRecord++
- s := &SocketRecord{k: k, ID: id}
+ s := &SocketRecordVFS1{
+ SocketRecord: SocketRecord{
+ k: k,
+ ID: id,
+ },
+ }
s.Sock = refs.NewWeakRef(sock, s)
k.sockets.PushBack(s)
k.extMu.Unlock()
@@ -1547,9 +1566,12 @@ func (k *Kernel) RecordSocket(sock *fs.File) {
// Precondition: Caller must hold a reference to sock.
//
// Note that the socket table will not hold a reference on the
-// vfs.FileDescription, because we do not support weak refs on VFS2 files.
+// vfs.FileDescription.
func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) {
k.extMu.Lock()
+ if _, ok := k.socketsVFS2[sock]; ok {
+ panic(fmt.Sprintf("Socket %p added twice", sock))
+ }
id := k.nextSocketRecord
k.nextSocketRecord++
s := &SocketRecord{
@@ -1557,7 +1579,14 @@ func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) {
ID: id,
SockVFS2: sock,
}
- k.sockets.PushBack(s)
+ k.socketsVFS2[sock] = s
+ k.extMu.Unlock()
+}
+
+// DeleteSocketVFS2 removes a VFS2 socket from the system-wide socket table.
+func (k *Kernel) DeleteSocketVFS2(sock *vfs.FileDescription) {
+ k.extMu.Lock()
+ delete(k.socketsVFS2, sock)
k.extMu.Unlock()
}
@@ -1568,8 +1597,14 @@ func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) {
func (k *Kernel) ListSockets() []*SocketRecord {
k.extMu.Lock()
var socks []*SocketRecord
- for s := k.sockets.Front(); s != nil; s = s.Next() {
- socks = append(socks, s)
+ if VFS2Enabled {
+ for _, s := range k.socketsVFS2 {
+ socks = append(socks, s)
+ }
+ } else {
+ for s := k.sockets.Front(); s != nil; s = s.Next() {
+ socks = append(socks, &s.SocketRecord)
+ }
}
k.extMu.Unlock()
return socks