diff options
Diffstat (limited to 'pkg/sentry/kernel/kernel.go')
-rw-r--r-- | pkg/sentry/kernel/kernel.go | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index cf8bf3ecd..55a9d3d29 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -590,11 +590,17 @@ func (k *Kernel) UniqueID() uint64 { // CreateProcessArgs holds arguments to kernel.CreateProcess. type CreateProcessArgs struct { - // Filename is the filename to load. + // Filename is the filename to load as the init binary. // - // If this is provided as "", then the file will be guessed via Argv[0]. + // If this is provided as "", File will be checked, then the file will be + // guessed via Argv[0]. Filename string + // File is a passed host FD pointing to a file to load as the init binary. + // + // This is checked if and only if Filename is "". + File *fs.File + // Argvv is a list of arguments. Argv []string @@ -779,8 +785,16 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, defer wd.DecRef() } - if args.Filename == "" { - // Was anything provided? + // Check which file to start from. + switch { + case args.Filename != "": + // If a filename is given, take that. + // Set File to nil so we resolve the path in LoadTaskImage. + args.File = nil + case args.File != nil: + // If File is set, take the File provided directly. + default: + // Otherwise look at Argv and see if the first argument is a valid path. if len(args.Argv) == 0 { return nil, 0, fmt.Errorf("no filename or command provided") } @@ -792,7 +806,9 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, // Create a fresh task context. remainingTraversals = uint(args.MaxSymlinkTraversals) - tc, se := k.LoadTaskImage(ctx, k.mounts, root, wd, &remainingTraversals, args.Filename, args.Argv, args.Envv, k.featureSet) + + tc, se := k.LoadTaskImage(ctx, k.mounts, root, wd, &remainingTraversals, args.Filename, args.File, args.Argv, args.Envv, k.featureSet) + if se != nil { return nil, 0, errors.New(se.String()) } |