summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/task.go21
-rw-r--r--pkg/sentry/kernel/task_clone.go36
-rw-r--r--pkg/sentry/seccheck/BUILD54
-rw-r--r--pkg/sentry/seccheck/clone.go53
-rw-r--r--pkg/sentry/seccheck/seccheck.go136
-rw-r--r--pkg/sentry/seccheck/seccheck_test.go157
-rw-r--r--pkg/sentry/seccheck/task.go39
-rw-r--r--tools/go_fieldenum/BUILD15
-rw-r--r--tools/go_fieldenum/defs.bzl29
-rw-r--r--tools/go_fieldenum/main.go310
11 files changed, 850 insertions, 1 deletions
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 816e60329..c0f13bf52 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -268,6 +268,7 @@ go_library(
"//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
+ "//pkg/sentry/seccheck",
"//pkg/sentry/socket/netlink/port",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/time",
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index 59eeb253d..9a95bf44c 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/seccheck"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
@@ -874,3 +875,23 @@ func (t *Task) ResetKcov() {
t.kcov = nil
}
}
+
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) loadSeccheckInfoLocked(req seccheck.TaskFieldSet, mask *seccheck.TaskFieldSet, info *seccheck.TaskInfo) {
+ if req.Contains(seccheck.TaskFieldThreadID) {
+ info.ThreadID = int32(t.k.tasks.Root.tids[t])
+ mask.Add(seccheck.TaskFieldThreadID)
+ }
+ if req.Contains(seccheck.TaskFieldThreadStartTime) {
+ info.ThreadStartTime = t.startTime
+ mask.Add(seccheck.TaskFieldThreadStartTime)
+ }
+ if req.Contains(seccheck.TaskFieldThreadGroupID) {
+ info.ThreadGroupID = int32(t.k.tasks.Root.tgids[t.tg])
+ mask.Add(seccheck.TaskFieldThreadGroupID)
+ }
+ if req.Contains(seccheck.TaskFieldThreadGroupStartTime) {
+ info.ThreadGroupStartTime = t.tg.leader.startTime
+ mask.Add(seccheck.TaskFieldThreadGroupStartTime)
+ }
+}
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index da4b77ca2..26a981f36 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/seccheck"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -235,7 +236,23 @@ func (t *Task) Clone(args *linux.CloneArgs) (ThreadID, *SyscallControl, error) {
// nt that it must receive before its task goroutine starts running.
tid := nt.k.tasks.Root.IDOfTask(nt)
defer nt.Start(tid)
- t.traceCloneEvent(tid)
+
+ if seccheck.Global.Enabled(seccheck.PointClone) {
+ mask, info := getCloneSeccheckInfo(t, nt, args)
+ if err := seccheck.Global.Clone(t, mask, &info); err != nil {
+ // nt has been visible to the rest of the system since NewTask, so
+ // it may be blocking execve or a group stop, have been notified
+ // for group signal delivery, had children reparented to it, etc.
+ // Thus we can't just drop it on the floor. Instead, instruct the
+ // task goroutine to exit immediately, as quietly as possible.
+ nt.exitTracerNotified = true
+ nt.exitTracerAcked = true
+ nt.exitParentNotified = true
+ nt.exitParentAcked = true
+ nt.runState = (*runExitMain)(nil)
+ return 0, nil, err
+ }
+ }
// "If fork/clone and execve are allowed by @prog, any child processes will
// be constrained to the same filters and system call ABI as the parent." -
@@ -260,6 +277,7 @@ func (t *Task) Clone(args *linux.CloneArgs) (ThreadID, *SyscallControl, error) {
ntid.CopyOut(t, hostarch.Addr(args.ParentTID))
}
+ t.traceCloneEvent(tid)
kind := ptraceCloneKindClone
if args.Flags&linux.CLONE_VFORK != 0 {
kind = ptraceCloneKindVfork
@@ -279,6 +297,22 @@ func (t *Task) Clone(args *linux.CloneArgs) (ThreadID, *SyscallControl, error) {
return ntid, nil, nil
}
+func getCloneSeccheckInfo(t, nt *Task, args *linux.CloneArgs) (seccheck.CloneFieldSet, seccheck.CloneInfo) {
+ req := seccheck.Global.CloneReq()
+ info := seccheck.CloneInfo{
+ Credentials: t.Credentials(),
+ Args: *args,
+ }
+ var mask seccheck.CloneFieldSet
+ mask.Add(seccheck.CloneFieldCredentials)
+ mask.Add(seccheck.CloneFieldArgs)
+ t.k.tasks.mu.RLock()
+ defer t.k.tasks.mu.RUnlock()
+ t.loadSeccheckInfoLocked(req.Invoker, &mask.Invoker, &info.Invoker)
+ nt.loadSeccheckInfoLocked(req.Created, &mask.Created, &info.Created)
+ return mask, info
+}
+
// maybeBeginVforkStop checks if a previously-started vfork child is still
// running and has not yet released its MM, such that its parent t should enter
// a vforkStop.
diff --git a/pkg/sentry/seccheck/BUILD b/pkg/sentry/seccheck/BUILD
new file mode 100644
index 000000000..943fa180d
--- /dev/null
+++ b/pkg/sentry/seccheck/BUILD
@@ -0,0 +1,54 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_fieldenum:defs.bzl", "go_fieldenum")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_fieldenum(
+ name = "seccheck_fieldenum",
+ srcs = [
+ "clone.go",
+ "task.go",
+ ],
+ out = "seccheck_fieldenum.go",
+ package = "seccheck",
+)
+
+go_template_instance(
+ name = "seqatomic_checkerslice",
+ out = "seqatomic_checkerslice_unsafe.go",
+ package = "seccheck",
+ suffix = "CheckerSlice",
+ template = "//pkg/sync/seqatomic:generic_seqatomic",
+ types = {
+ "Value": "[]Checker",
+ },
+)
+
+go_library(
+ name = "seccheck",
+ srcs = [
+ "clone.go",
+ "seccheck.go",
+ "seccheck_fieldenum.go",
+ "seqatomic_checkerslice_unsafe.go",
+ "task.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/gohacks",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "seccheck_test",
+ size = "small",
+ srcs = ["seccheck_test.go"],
+ library = ":seccheck",
+ deps = ["//pkg/context"],
+)
diff --git a/pkg/sentry/seccheck/clone.go b/pkg/sentry/seccheck/clone.go
new file mode 100644
index 000000000..7546fa021
--- /dev/null
+++ b/pkg/sentry/seccheck/clone.go
@@ -0,0 +1,53 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// CloneInfo contains information used by the Clone checkpoint.
+//
+// +fieldenum Clone
+type CloneInfo struct {
+ // Invoker identifies the invoking thread.
+ Invoker TaskInfo
+
+ // Credentials are the invoking thread's credentials.
+ Credentials *auth.Credentials
+
+ // Args contains the arguments to kernel.Task.Clone().
+ Args linux.CloneArgs
+
+ // Created identifies the created thread.
+ Created TaskInfo
+}
+
+// CloneReq returns fields required by the Clone checkpoint.
+func (s *state) CloneReq() CloneFieldSet {
+ return s.cloneReq.Load()
+}
+
+// Clone is called at the Clone checkpoint.
+func (s *state) Clone(ctx context.Context, mask CloneFieldSet, info *CloneInfo) error {
+ for _, c := range s.getCheckers() {
+ if err := c.Clone(ctx, mask, *info); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/seccheck/seccheck.go b/pkg/sentry/seccheck/seccheck.go
new file mode 100644
index 000000000..b6c9d44ce
--- /dev/null
+++ b/pkg/sentry/seccheck/seccheck.go
@@ -0,0 +1,136 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package seccheck defines a structure for dynamically-configured security
+// checks in the sentry.
+package seccheck
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// A Point represents a checkpoint, a point at which a security check occurs.
+type Point uint
+
+// PointX represents the checkpoint X.
+const (
+ PointClone Point = iota
+ // Add new Points above this line.
+ pointLength
+
+ numPointBitmaskUint32s = (int(pointLength)-1)/32 + 1
+)
+
+// A Checker performs security checks at checkpoints.
+//
+// Each Checker method X is called at checkpoint X; if the method may return a
+// non-nil error and does so, it causes the checked operation to fail
+// immediately (without calling subsequent Checkers) and return the error. The
+// info argument contains information relevant to the check. The mask argument
+// indicates what fields in info are valid; the mask should usually be a
+// superset of fields requested by the Checker's corresponding CheckerReq, but
+// may be missing requested fields in some cases (e.g. if the Checker is
+// registered concurrently with invocations of checkpoints).
+type Checker interface {
+ Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error
+}
+
+// CheckerDefaults may be embedded by implementations of Checker to obtain
+// no-op implementations of Checker methods that may be explicitly overridden.
+type CheckerDefaults struct{}
+
+// Clone implements Checker.Clone.
+func (CheckerDefaults) Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ return nil
+}
+
+// CheckerReq indicates what checkpoints a corresponding Checker runs at, and
+// what information it requires at those checkpoints.
+type CheckerReq struct {
+ // Points are the set of checkpoints for which the corresponding Checker
+ // must be called. Note that methods not specified in Points may still be
+ // called; implementations of Checker may embed CheckerDefaults to obtain
+ // no-op implementations of Checker methods.
+ Points []Point
+
+ // All of the following fields indicate what fields in the corresponding
+ // XInfo struct will be requested at the corresponding checkpoint.
+ Clone CloneFields
+}
+
+// Global is the method receiver of all seccheck functions.
+var Global state
+
+// state is the type of global, and is separated out for testing.
+type state struct {
+ // registrationMu serializes all changes to the set of registered Checkers
+ // for all checkpoints.
+ registrationMu sync.Mutex
+
+ // enabledPoints is a bitmask of checkpoints for which at least one Checker
+ // is registered.
+ //
+ // enabledPoints is accessed using atomic memory operations. Mutation of
+ // enabledPoints is serialized by registrationMu.
+ enabledPoints [numPointBitmaskUint32s]uint32
+
+ // registrationSeq supports store-free atomic reads of registeredCheckers.
+ registrationSeq sync.SeqCount
+
+ // checkers is the set of all registered Checkers in order of execution.
+ //
+ // checkers is accessed using instantiations of SeqAtomic functions.
+ // Mutation of checkers is serialized by registrationMu.
+ checkers []Checker
+
+ // All of the following xReq variables indicate what fields in the
+ // corresponding XInfo struct have been requested by any registered
+ // checker, are accessed using atomic memory operations, and are mutated
+ // with registrationMu locked.
+ cloneReq CloneFieldSet
+}
+
+// AppendChecker registers the given Checker to execute at checkpoints. The
+// Checker will execute after all previously-registered Checkers, and only if
+// those Checkers return a nil error.
+func (s *state) AppendChecker(c Checker, req *CheckerReq) {
+ s.registrationMu.Lock()
+ defer s.registrationMu.Unlock()
+ s.cloneReq.AddFieldsLoadable(req.Clone)
+ s.appendCheckerLocked(c)
+ for _, p := range req.Points {
+ word, bit := p/32, p%32
+ atomic.StoreUint32(&s.enabledPoints[word], s.enabledPoints[word]|(uint32(1)<<bit))
+ }
+}
+
+// Enabled returns true if any Checker is registered for the given checkpoint.
+func (s *state) Enabled(p Point) bool {
+ word, bit := p/32, p%32
+ return atomic.LoadUint32(&s.enabledPoints[word])&(uint32(1)<<bit) != 0
+}
+
+func (s *state) getCheckers() []Checker {
+ return SeqAtomicLoadCheckerSlice(&s.registrationSeq, &s.checkers)
+}
+
+// Preconditions: s.registrationMu must be locked.
+func (s *state) appendCheckerLocked(c Checker) {
+ s.registrationSeq.BeginWrite()
+ s.checkers = append(s.checkers, c)
+ s.registrationSeq.EndWrite()
+}
diff --git a/pkg/sentry/seccheck/seccheck_test.go b/pkg/sentry/seccheck/seccheck_test.go
new file mode 100644
index 000000000..687810d18
--- /dev/null
+++ b/pkg/sentry/seccheck/seccheck_test.go
@@ -0,0 +1,157 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ "errors"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+type testChecker struct {
+ CheckerDefaults
+
+ onClone func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error
+}
+
+// Clone implements Checker.Clone.
+func (c *testChecker) Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ if c.onClone == nil {
+ return nil
+ }
+ return c.onClone(ctx, mask, info)
+}
+
+func TestNoChecker(t *testing.T) {
+ var s state
+ if s.Enabled(PointClone) {
+ t.Errorf("Enabled(PointClone): got true, wanted false")
+ }
+}
+
+func TestCheckerNotRegisteredForPoint(t *testing.T) {
+ var s state
+ s.AppendChecker(&testChecker{}, &CheckerReq{})
+ if s.Enabled(PointClone) {
+ t.Errorf("Enabled(PointClone): got true, wanted false")
+ }
+}
+
+func TestCheckerRegistered(t *testing.T) {
+ var s state
+ checkerCalled := false
+ s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ checkerCalled = true
+ return nil
+ }}, &CheckerReq{
+ Points: []Point{PointClone},
+ Clone: CloneFields{
+ Credentials: true,
+ },
+ })
+
+ if !s.Enabled(PointClone) {
+ t.Errorf("Enabled(PointClone): got false, wanted true")
+ }
+ if !s.CloneReq().Contains(CloneFieldCredentials) {
+ t.Errorf("CloneReq().Contains(CloneFieldCredentials): got false, wanted true")
+ }
+ if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != nil {
+ t.Errorf("Clone(): got %v, wanted nil", err)
+ }
+ if !checkerCalled {
+ t.Errorf("Clone() did not call Checker.Clone()")
+ }
+}
+
+func TestMultipleCheckersRegistered(t *testing.T) {
+ var s state
+ checkersCalled := [2]bool{}
+ s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ checkersCalled[0] = true
+ return nil
+ }}, &CheckerReq{
+ Points: []Point{PointClone},
+ Clone: CloneFields{
+ Args: true,
+ },
+ })
+ s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ checkersCalled[1] = true
+ return nil
+ }}, &CheckerReq{
+ Points: []Point{PointClone},
+ Clone: CloneFields{
+ Created: TaskFields{
+ ThreadID: true,
+ },
+ },
+ })
+
+ if !s.Enabled(PointClone) {
+ t.Errorf("Enabled(PointClone): got false, wanted true")
+ }
+ // CloneReq() should return the union of requested fields from all calls to
+ // AppendChecker.
+ req := s.CloneReq()
+ if !req.Contains(CloneFieldArgs) {
+ t.Errorf("req.Contains(CloneFieldArgs): got false, wanted true")
+ }
+ if !req.Created.Contains(TaskFieldThreadID) {
+ t.Errorf("req.Created.Contains(TaskFieldThreadID): got false, wanted true")
+ }
+ if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != nil {
+ t.Errorf("Clone(): got %v, wanted nil", err)
+ }
+ for i := range checkersCalled {
+ if !checkersCalled[i] {
+ t.Errorf("Clone() did not call Checker.Clone() index %d", i)
+ }
+ }
+}
+
+func TestCheckpointReturnsFirstCheckerError(t *testing.T) {
+ errFirstChecker := errors.New("first Checker error")
+ errSecondChecker := errors.New("second Checker error")
+
+ var s state
+ checkersCalled := [2]bool{}
+ s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ checkersCalled[0] = true
+ return errFirstChecker
+ }}, &CheckerReq{
+ Points: []Point{PointClone},
+ })
+ s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ checkersCalled[1] = true
+ return errSecondChecker
+ }}, &CheckerReq{
+ Points: []Point{PointClone},
+ })
+
+ if !s.Enabled(PointClone) {
+ t.Errorf("Enabled(PointClone): got false, wanted true")
+ }
+ if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != errFirstChecker {
+ t.Errorf("Clone(): got %v, wanted %v", err, errFirstChecker)
+ }
+ if !checkersCalled[0] {
+ t.Errorf("Clone() did not call first Checker")
+ }
+ if checkersCalled[1] {
+ t.Errorf("Clone() called second Checker")
+ }
+}
diff --git a/pkg/sentry/seccheck/task.go b/pkg/sentry/seccheck/task.go
new file mode 100644
index 000000000..1dee33203
--- /dev/null
+++ b/pkg/sentry/seccheck/task.go
@@ -0,0 +1,39 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+)
+
+// TaskInfo contains information unambiguously identifying a single thread
+// and/or its containing process.
+//
+// +fieldenum Task
+type TaskInfo struct {
+ // ThreadID is the thread's ID in the root PID namespace.
+ ThreadID int32
+
+ // ThreadStartTime is the thread's CLOCK_REALTIME start time.
+ ThreadStartTime ktime.Time
+
+ // ThreadGroupID is the thread's group leader's ID in the root PID
+ // namespace.
+ ThreadGroupID int32
+
+ // ThreadGroupStartTime is the thread's group leader's CLOCK_REALTIME start
+ // time.
+ ThreadGroupStartTime ktime.Time
+}
diff --git a/tools/go_fieldenum/BUILD b/tools/go_fieldenum/BUILD
new file mode 100644
index 000000000..2bfdaeb2f
--- /dev/null
+++ b/tools/go_fieldenum/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "bzl_library", "go_binary")
+
+licenses(["notice"])
+
+go_binary(
+ name = "fieldenum",
+ srcs = ["main.go"],
+ visibility = ["//:sandbox"],
+)
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/tools/go_fieldenum/defs.bzl b/tools/go_fieldenum/defs.bzl
new file mode 100644
index 000000000..0cd2679ca
--- /dev/null
+++ b/tools/go_fieldenum/defs.bzl
@@ -0,0 +1,29 @@
+"""The go_fieldenum target infers Field, Fields, and FieldSet types for each
+struct in an input source file marked +fieldenum.
+"""
+
+def _go_fieldenum_impl(ctx):
+ output = ctx.outputs.out
+
+ args = ["-pkg=%s" % ctx.attr.package, "-out=%s" % output.path]
+ for src in ctx.attr.srcs:
+ args += [f.path for f in src.files.to_list()]
+
+ ctx.actions.run(
+ inputs = ctx.files.srcs,
+ outputs = [output],
+ mnemonic = "GoFieldenum",
+ progress_message = "Generating Go field enumerators %s" % ctx.label,
+ arguments = args,
+ executable = ctx.executable._tool,
+ )
+
+go_fieldenum = rule(
+ implementation = _go_fieldenum_impl,
+ attrs = {
+ "srcs": attr.label_list(doc = "input source files", mandatory = True, allow_files = True),
+ "package": attr.string(doc = "the package for the generated source file", mandatory = True),
+ "out": attr.output(doc = "output file", mandatory = True),
+ "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_fieldenum:fieldenum")),
+ },
+)
diff --git a/tools/go_fieldenum/main.go b/tools/go_fieldenum/main.go
new file mode 100644
index 000000000..68dfdb3db
--- /dev/null
+++ b/tools/go_fieldenum/main.go
@@ -0,0 +1,310 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary fieldenum emits field bitmasks for all structs in a package marked
+// "+fieldenum".
+package main
+
+import (
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "log"
+ "os"
+ "strings"
+)
+
+var (
+ outputPkg = flag.String("pkg", "", "output package")
+ outputFilename = flag.String("out", "-", "output filename")
+)
+
+func main() {
+ // Parse command line arguments.
+ flag.Parse()
+ if len(*outputPkg) == 0 {
+ log.Fatalf("-pkg must be provided")
+ }
+ if len(flag.Args()) == 0 {
+ log.Fatalf("Input files must be provided")
+ }
+
+ // Parse input files.
+ inputFiles := make([]*ast.File, 0, len(flag.Args()))
+ fset := token.NewFileSet()
+ for _, filename := range flag.Args() {
+ f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
+ if err != nil {
+ log.Fatalf("Failed to parse input file %q: %v", filename, err)
+ }
+ inputFiles = append(inputFiles, f)
+ }
+
+ // Determine which types are marked "+fieldenum" and will consequently have
+ // code generated.
+ fieldEnumTypes := make(map[string]fieldEnumTypeInfo)
+ for _, f := range inputFiles {
+ for _, decl := range f.Decls {
+ d, ok := decl.(*ast.GenDecl)
+ if !ok || d.Tok != token.TYPE || d.Doc == nil || len(d.Specs) == 0 {
+ continue
+ }
+ for _, l := range d.Doc.List {
+ const fieldenumPrefixWithSpace = "// +fieldenum "
+ if l.Text == "// +fieldenum" || strings.HasPrefix(l.Text, fieldenumPrefixWithSpace) {
+ spec := d.Specs[0].(*ast.TypeSpec)
+ name := spec.Name.Name
+ prefix := name
+ if len(l.Text) > len(fieldenumPrefixWithSpace) {
+ prefix = strings.TrimSpace(l.Text[len(fieldenumPrefixWithSpace):])
+ }
+ st, ok := spec.Type.(*ast.StructType)
+ if !ok {
+ log.Fatalf("Type %s is marked +fieldenum, but is not a struct", name)
+ }
+ fieldEnumTypes[name] = fieldEnumTypeInfo{
+ prefix: prefix,
+ structType: st,
+ }
+ break
+ }
+ }
+ }
+ }
+
+ // Collect information for each type for which code is being generated.
+ structInfos := make([]structInfo, 0, len(fieldEnumTypes))
+ needSyncAtomic := false
+ for typeName, typeInfo := range fieldEnumTypes {
+ var si structInfo
+ si.name = typeName
+ si.prefix = typeInfo.prefix
+ for _, field := range typeInfo.structType.Fields.List {
+ name := structFieldName(field)
+ // If the field's type is a type that is also marked +fieldenum,
+ // include a FieldSet for that type in this one's. The field must
+ // be a struct by value, since if it's a pointer then that struct
+ // might also point to or include this one (which would make
+ // FieldSet inclusion circular). It must also be a type defined in
+ // this package, since otherwise we don't know whether it's marked
+ // +fieldenum. Thus, field.Type must be an identifier (rather than
+ // an ast.StarExpr or SelectorExpr).
+ if tident, ok := field.Type.(*ast.Ident); ok {
+ if fieldTypeInfo, ok := fieldEnumTypes[tident.Name]; ok {
+ fsf := fieldSetField{
+ fieldName: name,
+ typePrefix: fieldTypeInfo.prefix,
+ }
+ si.reprByFieldSet = append(si.reprByFieldSet, fsf)
+ si.allFields = append(si.allFields, fsf)
+ continue
+ }
+ }
+ si.reprByBit = append(si.reprByBit, name)
+ si.allFields = append(si.allFields, fieldSetField{
+ fieldName: name,
+ })
+ // sync/atomic import will be needed for FieldSet.Load().
+ needSyncAtomic = true
+ }
+ structInfos = append(structInfos, si)
+ }
+
+ // Build the output file.
+ var b strings.Builder
+ fmt.Fprintf(&b, "// Generated by go_fieldenum.\n\n")
+ fmt.Fprintf(&b, "package %s\n\n", *outputPkg)
+ if needSyncAtomic {
+ fmt.Fprintf(&b, "import \"sync/atomic\"\n\n")
+ }
+ for _, si := range structInfos {
+ si.writeTo(&b)
+ }
+
+ if *outputFilename == "-" {
+ // Write output to stdout.
+ fmt.Printf("%s", b.String())
+ } else {
+ // Write output to file.
+ f, err := os.OpenFile(*outputFilename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
+ if err != nil {
+ log.Fatalf("Failed to open output file %q: %v", *outputFilename, err)
+ }
+ if _, err := f.WriteString(b.String()); err != nil {
+ log.Fatalf("Failed to write output file %q: %v", *outputFilename, err)
+ }
+ f.Close()
+ }
+}
+
+type fieldEnumTypeInfo struct {
+ prefix string
+ structType *ast.StructType
+}
+
+// structInfo contains information about the code generated for a given struct.
+type structInfo struct {
+ // name is the name of the represented struct.
+ name string
+
+ // prefix is the prefix X applied to the name of each generated type and
+ // constant, referred to as X in the comments below for convenience.
+ prefix string
+
+ // reprByBit contains the names of fields in X that should be represented
+ // by a bit in the bit mask XFieldSet.fields, and by a bool in XFields.
+ reprByBit []string
+
+ // reprByFieldSet contains fields in X whose type is a named struct (e.g.
+ // Y) that has a corresponding FieldSet type YFieldSet, and which should
+ // therefore be represented by including a value of type YFieldSet in
+ // XFieldSet, and a value of type YFields in XFields.
+ reprByFieldSet []fieldSetField
+
+ // allFields contains all fields in X in order of declaration. Fields in
+ // reprByBit have fieldSetField.typePrefix == "".
+ allFields []fieldSetField
+}
+
+type fieldSetField struct {
+ fieldName string
+ typePrefix string
+}
+
+func structFieldName(f *ast.Field) string {
+ if len(f.Names) != 0 {
+ return f.Names[0].Name
+ }
+ // For embedded struct fields, the field name is the unqualified type name.
+ texpr := f.Type
+ for {
+ switch t := texpr.(type) {
+ case *ast.StarExpr:
+ texpr = t.X
+ case *ast.SelectorExpr:
+ texpr = t.Sel
+ case *ast.Ident:
+ return t.Name
+ default:
+ panic(fmt.Sprintf("unexpected %T", texpr))
+ }
+ }
+}
+
+// Workaround for Go defect (map membership test isn't usable in an
+// expression).
+func fetContains(xs map[string]*ast.StructType, x string) bool {
+ _, ok := xs[x]
+ return ok
+}
+
+func (si *structInfo) writeTo(b *strings.Builder) {
+ fmt.Fprintf(b, "// A %sField represents a field in %s.\n", si.prefix, si.name)
+ fmt.Fprintf(b, "type %sField uint\n\n", si.prefix)
+ if len(si.reprByBit) != 0 {
+ fmt.Fprintf(b, "// %sFieldX represents %s field X.\n", si.prefix, si.name)
+ fmt.Fprintf(b, "const (\n")
+ fmt.Fprintf(b, "\t%sField%s %sField = iota\n", si.prefix, si.reprByBit[0], si.prefix)
+ for _, fieldName := range si.reprByBit[1:] {
+ fmt.Fprintf(b, "\t%sField%s\n", si.prefix, fieldName)
+ }
+ fmt.Fprintf(b, ")\n\n")
+ }
+
+ fmt.Fprintf(b, "// %sFields represents a set of fields in %s in a literal-friendly form.\n", si.prefix, si.name)
+ fmt.Fprintf(b, "// The zero value of %sFields represents an empty set.\n", si.prefix)
+ fmt.Fprintf(b, "type %sFields struct {\n", si.prefix)
+ for _, fieldSetField := range si.allFields {
+ if fieldSetField.typePrefix == "" {
+ fmt.Fprintf(b, "\t%s bool\n", fieldSetField.fieldName)
+ } else {
+ fmt.Fprintf(b, "\t%s %sFields\n", fieldSetField.fieldName, fieldSetField.typePrefix)
+ }
+ }
+ fmt.Fprintf(b, "}\n\n")
+
+ fmt.Fprintf(b, "// %sFieldSet represents a set of fields in %s in a compact form.\n", si.prefix, si.name)
+ fmt.Fprintf(b, "// The zero value of %sFieldSet represents an empty set.\n", si.prefix)
+ fmt.Fprintf(b, "type %sFieldSet struct {\n", si.prefix)
+ numBitmaskUint32s := (len(si.reprByBit) + 31) / 32
+ for _, fieldSetField := range si.reprByFieldSet {
+ fmt.Fprintf(b, "\t%s %sFieldSet\n", fieldSetField.fieldName, fieldSetField.typePrefix)
+ }
+ if len(si.reprByBit) != 0 {
+ fmt.Fprintf(b, "\tfields [%d]uint32\n", numBitmaskUint32s)
+ }
+ fmt.Fprintf(b, "}\n\n")
+
+ if len(si.reprByBit) != 0 {
+ fmt.Fprintf(b, "// Contains returns true if f is present in the %sFieldSet.\n", si.prefix)
+ fmt.Fprintf(b, "func (fs %sFieldSet) Contains(f %sField) bool {\n", si.prefix, si.prefix)
+ if numBitmaskUint32s == 1 {
+ fmt.Fprintf(b, "\treturn fs.fields[0] & (uint32(1) << uint(f)) != 0\n")
+ } else {
+ fmt.Fprintf(b, "\treturn fs.fields[f/32] & (uint32(1) << (f%%32)) != 0\n")
+ }
+ fmt.Fprintf(b, "}\n\n")
+
+ fmt.Fprintf(b, "// Add adds f to the %sFieldSet.\n", si.prefix)
+ fmt.Fprintf(b, "func (fs *%sFieldSet) Add(f %sField) {\n", si.prefix, si.prefix)
+ if numBitmaskUint32s == 1 {
+ fmt.Fprintf(b, "\tfs.fields[0] |= uint32(1) << uint(f)\n")
+ } else {
+ fmt.Fprintf(b, "\tfs.fields[f/32] |= uint32(1) << (f%%32)\n")
+ }
+ fmt.Fprintf(b, "}\n\n")
+
+ fmt.Fprintf(b, "// Remove removes f from the %sFieldSet.\n", si.prefix)
+ fmt.Fprintf(b, "func (fs *%sFieldSet) Remove(f %sField) {\n", si.prefix, si.prefix)
+ if numBitmaskUint32s == 1 {
+ fmt.Fprintf(b, "\tfs.fields[0] &^= uint32(1) << uint(f)\n")
+ } else {
+ fmt.Fprintf(b, "\tfs.fields[f/32] &^= uint32(1) << (f%%32)\n")
+ }
+ fmt.Fprintf(b, "}\n\n")
+ }
+
+ fmt.Fprintf(b, "// Load returns a copy of the %sFieldSet.\n", si.prefix)
+ fmt.Fprintf(b, "// Load is safe to call concurrently with AddFieldsLoadable, but not Add or Remove.\n")
+ fmt.Fprintf(b, "func (fs *%sFieldSet) Load() (copied %sFieldSet) {\n", si.prefix, si.prefix)
+ for _, fieldSetField := range si.reprByFieldSet {
+ fmt.Fprintf(b, "\tcopied.%s = fs.%s.Load()\n", fieldSetField.fieldName, fieldSetField.fieldName)
+ }
+ for i := 0; i < numBitmaskUint32s; i++ {
+ fmt.Fprintf(b, "\tcopied.fields[%d] = atomic.LoadUint32(&fs.fields[%d])\n", i, i)
+ }
+ fmt.Fprintf(b, "\treturn\n")
+ fmt.Fprintf(b, "}\n\n")
+
+ fmt.Fprintf(b, "// AddFieldsLoadable adds the given fields to the %sFieldSet.\n", si.prefix)
+ fmt.Fprintf(b, "// AddFieldsLoadable is safe to call concurrently with Load, but not other methods (including other calls to AddFieldsLoadable).\n")
+ fmt.Fprintf(b, "func (fs *%sFieldSet) AddFieldsLoadable(fields %sFields) {\n", si.prefix, si.prefix)
+ for _, fieldSetField := range si.reprByFieldSet {
+ fmt.Fprintf(b, "\tfs.%s.AddFieldsLoadable(fields.%s)\n", fieldSetField.fieldName, fieldSetField.fieldName)
+ }
+ for _, fieldName := range si.reprByBit {
+ fieldConstName := fmt.Sprintf("%sField%s", si.prefix, fieldName)
+ fmt.Fprintf(b, "\tif fields.%s {\n", fieldName)
+ if numBitmaskUint32s == 1 {
+ fmt.Fprintf(b, "\t\tatomic.StoreUint32(&fs.fields[0], fs.fields[0] | (uint32(1) << uint(%s)))\n", fieldConstName)
+ } else {
+ fmt.Fprintf(b, "\t\tword, bit := %s/32, %s%%32\n", fieldConstName, fieldConstName)
+ fmt.Fprintf(b, "\t\tatomic.StoreUint32(&fs.fields[word], fs.fields[word] | (uint32(1) << bit))\n")
+ }
+ fmt.Fprintf(b, "\t}\n")
+ }
+ fmt.Fprintf(b, "}\n\n")
+}