diff options
-rw-r--r-- | pkg/sentry/kernel/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/kernel/task.go | 21 | ||||
-rw-r--r-- | pkg/sentry/kernel/task_clone.go | 36 | ||||
-rw-r--r-- | pkg/sentry/seccheck/BUILD | 54 | ||||
-rw-r--r-- | pkg/sentry/seccheck/clone.go | 53 | ||||
-rw-r--r-- | pkg/sentry/seccheck/seccheck.go | 136 | ||||
-rw-r--r-- | pkg/sentry/seccheck/seccheck_test.go | 157 | ||||
-rw-r--r-- | pkg/sentry/seccheck/task.go | 39 | ||||
-rw-r--r-- | tools/go_fieldenum/BUILD | 15 | ||||
-rw-r--r-- | tools/go_fieldenum/defs.bzl | 29 | ||||
-rw-r--r-- | tools/go_fieldenum/main.go | 310 |
11 files changed, 850 insertions, 1 deletions
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 816e60329..c0f13bf52 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -268,6 +268,7 @@ go_library( "//pkg/sentry/mm", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", + "//pkg/sentry/seccheck", "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/time", diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 59eeb253d..9a95bf44c 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/sched" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/seccheck" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -874,3 +875,23 @@ func (t *Task) ResetKcov() { t.kcov = nil } } + +// Preconditions: The TaskSet mutex must be locked. +func (t *Task) loadSeccheckInfoLocked(req seccheck.TaskFieldSet, mask *seccheck.TaskFieldSet, info *seccheck.TaskInfo) { + if req.Contains(seccheck.TaskFieldThreadID) { + info.ThreadID = int32(t.k.tasks.Root.tids[t]) + mask.Add(seccheck.TaskFieldThreadID) + } + if req.Contains(seccheck.TaskFieldThreadStartTime) { + info.ThreadStartTime = t.startTime + mask.Add(seccheck.TaskFieldThreadStartTime) + } + if req.Contains(seccheck.TaskFieldThreadGroupID) { + info.ThreadGroupID = int32(t.k.tasks.Root.tgids[t.tg]) + mask.Add(seccheck.TaskFieldThreadGroupID) + } + if req.Contains(seccheck.TaskFieldThreadGroupStartTime) { + info.ThreadGroupStartTime = t.tg.leader.startTime + mask.Add(seccheck.TaskFieldThreadGroupStartTime) + } +} diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index da4b77ca2..26a981f36 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/sentry/seccheck" "gvisor.dev/gvisor/pkg/usermem" ) @@ -235,7 +236,23 @@ func (t *Task) Clone(args *linux.CloneArgs) (ThreadID, *SyscallControl, error) { // nt that it must receive before its task goroutine starts running. tid := nt.k.tasks.Root.IDOfTask(nt) defer nt.Start(tid) - t.traceCloneEvent(tid) + + if seccheck.Global.Enabled(seccheck.PointClone) { + mask, info := getCloneSeccheckInfo(t, nt, args) + if err := seccheck.Global.Clone(t, mask, &info); err != nil { + // nt has been visible to the rest of the system since NewTask, so + // it may be blocking execve or a group stop, have been notified + // for group signal delivery, had children reparented to it, etc. + // Thus we can't just drop it on the floor. Instead, instruct the + // task goroutine to exit immediately, as quietly as possible. + nt.exitTracerNotified = true + nt.exitTracerAcked = true + nt.exitParentNotified = true + nt.exitParentAcked = true + nt.runState = (*runExitMain)(nil) + return 0, nil, err + } + } // "If fork/clone and execve are allowed by @prog, any child processes will // be constrained to the same filters and system call ABI as the parent." - @@ -260,6 +277,7 @@ func (t *Task) Clone(args *linux.CloneArgs) (ThreadID, *SyscallControl, error) { ntid.CopyOut(t, hostarch.Addr(args.ParentTID)) } + t.traceCloneEvent(tid) kind := ptraceCloneKindClone if args.Flags&linux.CLONE_VFORK != 0 { kind = ptraceCloneKindVfork @@ -279,6 +297,22 @@ func (t *Task) Clone(args *linux.CloneArgs) (ThreadID, *SyscallControl, error) { return ntid, nil, nil } +func getCloneSeccheckInfo(t, nt *Task, args *linux.CloneArgs) (seccheck.CloneFieldSet, seccheck.CloneInfo) { + req := seccheck.Global.CloneReq() + info := seccheck.CloneInfo{ + Credentials: t.Credentials(), + Args: *args, + } + var mask seccheck.CloneFieldSet + mask.Add(seccheck.CloneFieldCredentials) + mask.Add(seccheck.CloneFieldArgs) + t.k.tasks.mu.RLock() + defer t.k.tasks.mu.RUnlock() + t.loadSeccheckInfoLocked(req.Invoker, &mask.Invoker, &info.Invoker) + nt.loadSeccheckInfoLocked(req.Created, &mask.Created, &info.Created) + return mask, info +} + // maybeBeginVforkStop checks if a previously-started vfork child is still // running and has not yet released its MM, such that its parent t should enter // a vforkStop. diff --git a/pkg/sentry/seccheck/BUILD b/pkg/sentry/seccheck/BUILD new file mode 100644 index 000000000..943fa180d --- /dev/null +++ b/pkg/sentry/seccheck/BUILD @@ -0,0 +1,54 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_fieldenum:defs.bzl", "go_fieldenum") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +licenses(["notice"]) + +go_fieldenum( + name = "seccheck_fieldenum", + srcs = [ + "clone.go", + "task.go", + ], + out = "seccheck_fieldenum.go", + package = "seccheck", +) + +go_template_instance( + name = "seqatomic_checkerslice", + out = "seqatomic_checkerslice_unsafe.go", + package = "seccheck", + suffix = "CheckerSlice", + template = "//pkg/sync/seqatomic:generic_seqatomic", + types = { + "Value": "[]Checker", + }, +) + +go_library( + name = "seccheck", + srcs = [ + "clone.go", + "seccheck.go", + "seccheck_fieldenum.go", + "seqatomic_checkerslice_unsafe.go", + "task.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/gohacks", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/time", + "//pkg/sync", + ], +) + +go_test( + name = "seccheck_test", + size = "small", + srcs = ["seccheck_test.go"], + library = ":seccheck", + deps = ["//pkg/context"], +) diff --git a/pkg/sentry/seccheck/clone.go b/pkg/sentry/seccheck/clone.go new file mode 100644 index 000000000..7546fa021 --- /dev/null +++ b/pkg/sentry/seccheck/clone.go @@ -0,0 +1,53 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// CloneInfo contains information used by the Clone checkpoint. +// +// +fieldenum Clone +type CloneInfo struct { + // Invoker identifies the invoking thread. + Invoker TaskInfo + + // Credentials are the invoking thread's credentials. + Credentials *auth.Credentials + + // Args contains the arguments to kernel.Task.Clone(). + Args linux.CloneArgs + + // Created identifies the created thread. + Created TaskInfo +} + +// CloneReq returns fields required by the Clone checkpoint. +func (s *state) CloneReq() CloneFieldSet { + return s.cloneReq.Load() +} + +// Clone is called at the Clone checkpoint. +func (s *state) Clone(ctx context.Context, mask CloneFieldSet, info *CloneInfo) error { + for _, c := range s.getCheckers() { + if err := c.Clone(ctx, mask, *info); err != nil { + return err + } + } + return nil +} diff --git a/pkg/sentry/seccheck/seccheck.go b/pkg/sentry/seccheck/seccheck.go new file mode 100644 index 000000000..b6c9d44ce --- /dev/null +++ b/pkg/sentry/seccheck/seccheck.go @@ -0,0 +1,136 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package seccheck defines a structure for dynamically-configured security +// checks in the sentry. +package seccheck + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sync" +) + +// A Point represents a checkpoint, a point at which a security check occurs. +type Point uint + +// PointX represents the checkpoint X. +const ( + PointClone Point = iota + // Add new Points above this line. + pointLength + + numPointBitmaskUint32s = (int(pointLength)-1)/32 + 1 +) + +// A Checker performs security checks at checkpoints. +// +// Each Checker method X is called at checkpoint X; if the method may return a +// non-nil error and does so, it causes the checked operation to fail +// immediately (without calling subsequent Checkers) and return the error. The +// info argument contains information relevant to the check. The mask argument +// indicates what fields in info are valid; the mask should usually be a +// superset of fields requested by the Checker's corresponding CheckerReq, but +// may be missing requested fields in some cases (e.g. if the Checker is +// registered concurrently with invocations of checkpoints). +type Checker interface { + Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error +} + +// CheckerDefaults may be embedded by implementations of Checker to obtain +// no-op implementations of Checker methods that may be explicitly overridden. +type CheckerDefaults struct{} + +// Clone implements Checker.Clone. +func (CheckerDefaults) Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + return nil +} + +// CheckerReq indicates what checkpoints a corresponding Checker runs at, and +// what information it requires at those checkpoints. +type CheckerReq struct { + // Points are the set of checkpoints for which the corresponding Checker + // must be called. Note that methods not specified in Points may still be + // called; implementations of Checker may embed CheckerDefaults to obtain + // no-op implementations of Checker methods. + Points []Point + + // All of the following fields indicate what fields in the corresponding + // XInfo struct will be requested at the corresponding checkpoint. + Clone CloneFields +} + +// Global is the method receiver of all seccheck functions. +var Global state + +// state is the type of global, and is separated out for testing. +type state struct { + // registrationMu serializes all changes to the set of registered Checkers + // for all checkpoints. + registrationMu sync.Mutex + + // enabledPoints is a bitmask of checkpoints for which at least one Checker + // is registered. + // + // enabledPoints is accessed using atomic memory operations. Mutation of + // enabledPoints is serialized by registrationMu. + enabledPoints [numPointBitmaskUint32s]uint32 + + // registrationSeq supports store-free atomic reads of registeredCheckers. + registrationSeq sync.SeqCount + + // checkers is the set of all registered Checkers in order of execution. + // + // checkers is accessed using instantiations of SeqAtomic functions. + // Mutation of checkers is serialized by registrationMu. + checkers []Checker + + // All of the following xReq variables indicate what fields in the + // corresponding XInfo struct have been requested by any registered + // checker, are accessed using atomic memory operations, and are mutated + // with registrationMu locked. + cloneReq CloneFieldSet +} + +// AppendChecker registers the given Checker to execute at checkpoints. The +// Checker will execute after all previously-registered Checkers, and only if +// those Checkers return a nil error. +func (s *state) AppendChecker(c Checker, req *CheckerReq) { + s.registrationMu.Lock() + defer s.registrationMu.Unlock() + s.cloneReq.AddFieldsLoadable(req.Clone) + s.appendCheckerLocked(c) + for _, p := range req.Points { + word, bit := p/32, p%32 + atomic.StoreUint32(&s.enabledPoints[word], s.enabledPoints[word]|(uint32(1)<<bit)) + } +} + +// Enabled returns true if any Checker is registered for the given checkpoint. +func (s *state) Enabled(p Point) bool { + word, bit := p/32, p%32 + return atomic.LoadUint32(&s.enabledPoints[word])&(uint32(1)<<bit) != 0 +} + +func (s *state) getCheckers() []Checker { + return SeqAtomicLoadCheckerSlice(&s.registrationSeq, &s.checkers) +} + +// Preconditions: s.registrationMu must be locked. +func (s *state) appendCheckerLocked(c Checker) { + s.registrationSeq.BeginWrite() + s.checkers = append(s.checkers, c) + s.registrationSeq.EndWrite() +} diff --git a/pkg/sentry/seccheck/seccheck_test.go b/pkg/sentry/seccheck/seccheck_test.go new file mode 100644 index 000000000..687810d18 --- /dev/null +++ b/pkg/sentry/seccheck/seccheck_test.go @@ -0,0 +1,157 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + "errors" + "testing" + + "gvisor.dev/gvisor/pkg/context" +) + +type testChecker struct { + CheckerDefaults + + onClone func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error +} + +// Clone implements Checker.Clone. +func (c *testChecker) Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + if c.onClone == nil { + return nil + } + return c.onClone(ctx, mask, info) +} + +func TestNoChecker(t *testing.T) { + var s state + if s.Enabled(PointClone) { + t.Errorf("Enabled(PointClone): got true, wanted false") + } +} + +func TestCheckerNotRegisteredForPoint(t *testing.T) { + var s state + s.AppendChecker(&testChecker{}, &CheckerReq{}) + if s.Enabled(PointClone) { + t.Errorf("Enabled(PointClone): got true, wanted false") + } +} + +func TestCheckerRegistered(t *testing.T) { + var s state + checkerCalled := false + s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + checkerCalled = true + return nil + }}, &CheckerReq{ + Points: []Point{PointClone}, + Clone: CloneFields{ + Credentials: true, + }, + }) + + if !s.Enabled(PointClone) { + t.Errorf("Enabled(PointClone): got false, wanted true") + } + if !s.CloneReq().Contains(CloneFieldCredentials) { + t.Errorf("CloneReq().Contains(CloneFieldCredentials): got false, wanted true") + } + if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != nil { + t.Errorf("Clone(): got %v, wanted nil", err) + } + if !checkerCalled { + t.Errorf("Clone() did not call Checker.Clone()") + } +} + +func TestMultipleCheckersRegistered(t *testing.T) { + var s state + checkersCalled := [2]bool{} + s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + checkersCalled[0] = true + return nil + }}, &CheckerReq{ + Points: []Point{PointClone}, + Clone: CloneFields{ + Args: true, + }, + }) + s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + checkersCalled[1] = true + return nil + }}, &CheckerReq{ + Points: []Point{PointClone}, + Clone: CloneFields{ + Created: TaskFields{ + ThreadID: true, + }, + }, + }) + + if !s.Enabled(PointClone) { + t.Errorf("Enabled(PointClone): got false, wanted true") + } + // CloneReq() should return the union of requested fields from all calls to + // AppendChecker. + req := s.CloneReq() + if !req.Contains(CloneFieldArgs) { + t.Errorf("req.Contains(CloneFieldArgs): got false, wanted true") + } + if !req.Created.Contains(TaskFieldThreadID) { + t.Errorf("req.Created.Contains(TaskFieldThreadID): got false, wanted true") + } + if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != nil { + t.Errorf("Clone(): got %v, wanted nil", err) + } + for i := range checkersCalled { + if !checkersCalled[i] { + t.Errorf("Clone() did not call Checker.Clone() index %d", i) + } + } +} + +func TestCheckpointReturnsFirstCheckerError(t *testing.T) { + errFirstChecker := errors.New("first Checker error") + errSecondChecker := errors.New("second Checker error") + + var s state + checkersCalled := [2]bool{} + s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + checkersCalled[0] = true + return errFirstChecker + }}, &CheckerReq{ + Points: []Point{PointClone}, + }) + s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { + checkersCalled[1] = true + return errSecondChecker + }}, &CheckerReq{ + Points: []Point{PointClone}, + }) + + if !s.Enabled(PointClone) { + t.Errorf("Enabled(PointClone): got false, wanted true") + } + if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != errFirstChecker { + t.Errorf("Clone(): got %v, wanted %v", err, errFirstChecker) + } + if !checkersCalled[0] { + t.Errorf("Clone() did not call first Checker") + } + if checkersCalled[1] { + t.Errorf("Clone() called second Checker") + } +} diff --git a/pkg/sentry/seccheck/task.go b/pkg/sentry/seccheck/task.go new file mode 100644 index 000000000..1dee33203 --- /dev/null +++ b/pkg/sentry/seccheck/task.go @@ -0,0 +1,39 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" +) + +// TaskInfo contains information unambiguously identifying a single thread +// and/or its containing process. +// +// +fieldenum Task +type TaskInfo struct { + // ThreadID is the thread's ID in the root PID namespace. + ThreadID int32 + + // ThreadStartTime is the thread's CLOCK_REALTIME start time. + ThreadStartTime ktime.Time + + // ThreadGroupID is the thread's group leader's ID in the root PID + // namespace. + ThreadGroupID int32 + + // ThreadGroupStartTime is the thread's group leader's CLOCK_REALTIME start + // time. + ThreadGroupStartTime ktime.Time +} diff --git a/tools/go_fieldenum/BUILD b/tools/go_fieldenum/BUILD new file mode 100644 index 000000000..2bfdaeb2f --- /dev/null +++ b/tools/go_fieldenum/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "bzl_library", "go_binary") + +licenses(["notice"]) + +go_binary( + name = "fieldenum", + srcs = ["main.go"], + visibility = ["//:sandbox"], +) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/tools/go_fieldenum/defs.bzl b/tools/go_fieldenum/defs.bzl new file mode 100644 index 000000000..0cd2679ca --- /dev/null +++ b/tools/go_fieldenum/defs.bzl @@ -0,0 +1,29 @@ +"""The go_fieldenum target infers Field, Fields, and FieldSet types for each +struct in an input source file marked +fieldenum. +""" + +def _go_fieldenum_impl(ctx): + output = ctx.outputs.out + + args = ["-pkg=%s" % ctx.attr.package, "-out=%s" % output.path] + for src in ctx.attr.srcs: + args += [f.path for f in src.files.to_list()] + + ctx.actions.run( + inputs = ctx.files.srcs, + outputs = [output], + mnemonic = "GoFieldenum", + progress_message = "Generating Go field enumerators %s" % ctx.label, + arguments = args, + executable = ctx.executable._tool, + ) + +go_fieldenum = rule( + implementation = _go_fieldenum_impl, + attrs = { + "srcs": attr.label_list(doc = "input source files", mandatory = True, allow_files = True), + "package": attr.string(doc = "the package for the generated source file", mandatory = True), + "out": attr.output(doc = "output file", mandatory = True), + "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_fieldenum:fieldenum")), + }, +) diff --git a/tools/go_fieldenum/main.go b/tools/go_fieldenum/main.go new file mode 100644 index 000000000..68dfdb3db --- /dev/null +++ b/tools/go_fieldenum/main.go @@ -0,0 +1,310 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary fieldenum emits field bitmasks for all structs in a package marked +// "+fieldenum". +package main + +import ( + "flag" + "fmt" + "go/ast" + "go/parser" + "go/token" + "log" + "os" + "strings" +) + +var ( + outputPkg = flag.String("pkg", "", "output package") + outputFilename = flag.String("out", "-", "output filename") +) + +func main() { + // Parse command line arguments. + flag.Parse() + if len(*outputPkg) == 0 { + log.Fatalf("-pkg must be provided") + } + if len(flag.Args()) == 0 { + log.Fatalf("Input files must be provided") + } + + // Parse input files. + inputFiles := make([]*ast.File, 0, len(flag.Args())) + fset := token.NewFileSet() + for _, filename := range flag.Args() { + f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) + if err != nil { + log.Fatalf("Failed to parse input file %q: %v", filename, err) + } + inputFiles = append(inputFiles, f) + } + + // Determine which types are marked "+fieldenum" and will consequently have + // code generated. + fieldEnumTypes := make(map[string]fieldEnumTypeInfo) + for _, f := range inputFiles { + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok || d.Tok != token.TYPE || d.Doc == nil || len(d.Specs) == 0 { + continue + } + for _, l := range d.Doc.List { + const fieldenumPrefixWithSpace = "// +fieldenum " + if l.Text == "// +fieldenum" || strings.HasPrefix(l.Text, fieldenumPrefixWithSpace) { + spec := d.Specs[0].(*ast.TypeSpec) + name := spec.Name.Name + prefix := name + if len(l.Text) > len(fieldenumPrefixWithSpace) { + prefix = strings.TrimSpace(l.Text[len(fieldenumPrefixWithSpace):]) + } + st, ok := spec.Type.(*ast.StructType) + if !ok { + log.Fatalf("Type %s is marked +fieldenum, but is not a struct", name) + } + fieldEnumTypes[name] = fieldEnumTypeInfo{ + prefix: prefix, + structType: st, + } + break + } + } + } + } + + // Collect information for each type for which code is being generated. + structInfos := make([]structInfo, 0, len(fieldEnumTypes)) + needSyncAtomic := false + for typeName, typeInfo := range fieldEnumTypes { + var si structInfo + si.name = typeName + si.prefix = typeInfo.prefix + for _, field := range typeInfo.structType.Fields.List { + name := structFieldName(field) + // If the field's type is a type that is also marked +fieldenum, + // include a FieldSet for that type in this one's. The field must + // be a struct by value, since if it's a pointer then that struct + // might also point to or include this one (which would make + // FieldSet inclusion circular). It must also be a type defined in + // this package, since otherwise we don't know whether it's marked + // +fieldenum. Thus, field.Type must be an identifier (rather than + // an ast.StarExpr or SelectorExpr). + if tident, ok := field.Type.(*ast.Ident); ok { + if fieldTypeInfo, ok := fieldEnumTypes[tident.Name]; ok { + fsf := fieldSetField{ + fieldName: name, + typePrefix: fieldTypeInfo.prefix, + } + si.reprByFieldSet = append(si.reprByFieldSet, fsf) + si.allFields = append(si.allFields, fsf) + continue + } + } + si.reprByBit = append(si.reprByBit, name) + si.allFields = append(si.allFields, fieldSetField{ + fieldName: name, + }) + // sync/atomic import will be needed for FieldSet.Load(). + needSyncAtomic = true + } + structInfos = append(structInfos, si) + } + + // Build the output file. + var b strings.Builder + fmt.Fprintf(&b, "// Generated by go_fieldenum.\n\n") + fmt.Fprintf(&b, "package %s\n\n", *outputPkg) + if needSyncAtomic { + fmt.Fprintf(&b, "import \"sync/atomic\"\n\n") + } + for _, si := range structInfos { + si.writeTo(&b) + } + + if *outputFilename == "-" { + // Write output to stdout. + fmt.Printf("%s", b.String()) + } else { + // Write output to file. + f, err := os.OpenFile(*outputFilename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644) + if err != nil { + log.Fatalf("Failed to open output file %q: %v", *outputFilename, err) + } + if _, err := f.WriteString(b.String()); err != nil { + log.Fatalf("Failed to write output file %q: %v", *outputFilename, err) + } + f.Close() + } +} + +type fieldEnumTypeInfo struct { + prefix string + structType *ast.StructType +} + +// structInfo contains information about the code generated for a given struct. +type structInfo struct { + // name is the name of the represented struct. + name string + + // prefix is the prefix X applied to the name of each generated type and + // constant, referred to as X in the comments below for convenience. + prefix string + + // reprByBit contains the names of fields in X that should be represented + // by a bit in the bit mask XFieldSet.fields, and by a bool in XFields. + reprByBit []string + + // reprByFieldSet contains fields in X whose type is a named struct (e.g. + // Y) that has a corresponding FieldSet type YFieldSet, and which should + // therefore be represented by including a value of type YFieldSet in + // XFieldSet, and a value of type YFields in XFields. + reprByFieldSet []fieldSetField + + // allFields contains all fields in X in order of declaration. Fields in + // reprByBit have fieldSetField.typePrefix == "". + allFields []fieldSetField +} + +type fieldSetField struct { + fieldName string + typePrefix string +} + +func structFieldName(f *ast.Field) string { + if len(f.Names) != 0 { + return f.Names[0].Name + } + // For embedded struct fields, the field name is the unqualified type name. + texpr := f.Type + for { + switch t := texpr.(type) { + case *ast.StarExpr: + texpr = t.X + case *ast.SelectorExpr: + texpr = t.Sel + case *ast.Ident: + return t.Name + default: + panic(fmt.Sprintf("unexpected %T", texpr)) + } + } +} + +// Workaround for Go defect (map membership test isn't usable in an +// expression). +func fetContains(xs map[string]*ast.StructType, x string) bool { + _, ok := xs[x] + return ok +} + +func (si *structInfo) writeTo(b *strings.Builder) { + fmt.Fprintf(b, "// A %sField represents a field in %s.\n", si.prefix, si.name) + fmt.Fprintf(b, "type %sField uint\n\n", si.prefix) + if len(si.reprByBit) != 0 { + fmt.Fprintf(b, "// %sFieldX represents %s field X.\n", si.prefix, si.name) + fmt.Fprintf(b, "const (\n") + fmt.Fprintf(b, "\t%sField%s %sField = iota\n", si.prefix, si.reprByBit[0], si.prefix) + for _, fieldName := range si.reprByBit[1:] { + fmt.Fprintf(b, "\t%sField%s\n", si.prefix, fieldName) + } + fmt.Fprintf(b, ")\n\n") + } + + fmt.Fprintf(b, "// %sFields represents a set of fields in %s in a literal-friendly form.\n", si.prefix, si.name) + fmt.Fprintf(b, "// The zero value of %sFields represents an empty set.\n", si.prefix) + fmt.Fprintf(b, "type %sFields struct {\n", si.prefix) + for _, fieldSetField := range si.allFields { + if fieldSetField.typePrefix == "" { + fmt.Fprintf(b, "\t%s bool\n", fieldSetField.fieldName) + } else { + fmt.Fprintf(b, "\t%s %sFields\n", fieldSetField.fieldName, fieldSetField.typePrefix) + } + } + fmt.Fprintf(b, "}\n\n") + + fmt.Fprintf(b, "// %sFieldSet represents a set of fields in %s in a compact form.\n", si.prefix, si.name) + fmt.Fprintf(b, "// The zero value of %sFieldSet represents an empty set.\n", si.prefix) + fmt.Fprintf(b, "type %sFieldSet struct {\n", si.prefix) + numBitmaskUint32s := (len(si.reprByBit) + 31) / 32 + for _, fieldSetField := range si.reprByFieldSet { + fmt.Fprintf(b, "\t%s %sFieldSet\n", fieldSetField.fieldName, fieldSetField.typePrefix) + } + if len(si.reprByBit) != 0 { + fmt.Fprintf(b, "\tfields [%d]uint32\n", numBitmaskUint32s) + } + fmt.Fprintf(b, "}\n\n") + + if len(si.reprByBit) != 0 { + fmt.Fprintf(b, "// Contains returns true if f is present in the %sFieldSet.\n", si.prefix) + fmt.Fprintf(b, "func (fs %sFieldSet) Contains(f %sField) bool {\n", si.prefix, si.prefix) + if numBitmaskUint32s == 1 { + fmt.Fprintf(b, "\treturn fs.fields[0] & (uint32(1) << uint(f)) != 0\n") + } else { + fmt.Fprintf(b, "\treturn fs.fields[f/32] & (uint32(1) << (f%%32)) != 0\n") + } + fmt.Fprintf(b, "}\n\n") + + fmt.Fprintf(b, "// Add adds f to the %sFieldSet.\n", si.prefix) + fmt.Fprintf(b, "func (fs *%sFieldSet) Add(f %sField) {\n", si.prefix, si.prefix) + if numBitmaskUint32s == 1 { + fmt.Fprintf(b, "\tfs.fields[0] |= uint32(1) << uint(f)\n") + } else { + fmt.Fprintf(b, "\tfs.fields[f/32] |= uint32(1) << (f%%32)\n") + } + fmt.Fprintf(b, "}\n\n") + + fmt.Fprintf(b, "// Remove removes f from the %sFieldSet.\n", si.prefix) + fmt.Fprintf(b, "func (fs *%sFieldSet) Remove(f %sField) {\n", si.prefix, si.prefix) + if numBitmaskUint32s == 1 { + fmt.Fprintf(b, "\tfs.fields[0] &^= uint32(1) << uint(f)\n") + } else { + fmt.Fprintf(b, "\tfs.fields[f/32] &^= uint32(1) << (f%%32)\n") + } + fmt.Fprintf(b, "}\n\n") + } + + fmt.Fprintf(b, "// Load returns a copy of the %sFieldSet.\n", si.prefix) + fmt.Fprintf(b, "// Load is safe to call concurrently with AddFieldsLoadable, but not Add or Remove.\n") + fmt.Fprintf(b, "func (fs *%sFieldSet) Load() (copied %sFieldSet) {\n", si.prefix, si.prefix) + for _, fieldSetField := range si.reprByFieldSet { + fmt.Fprintf(b, "\tcopied.%s = fs.%s.Load()\n", fieldSetField.fieldName, fieldSetField.fieldName) + } + for i := 0; i < numBitmaskUint32s; i++ { + fmt.Fprintf(b, "\tcopied.fields[%d] = atomic.LoadUint32(&fs.fields[%d])\n", i, i) + } + fmt.Fprintf(b, "\treturn\n") + fmt.Fprintf(b, "}\n\n") + + fmt.Fprintf(b, "// AddFieldsLoadable adds the given fields to the %sFieldSet.\n", si.prefix) + fmt.Fprintf(b, "// AddFieldsLoadable is safe to call concurrently with Load, but not other methods (including other calls to AddFieldsLoadable).\n") + fmt.Fprintf(b, "func (fs *%sFieldSet) AddFieldsLoadable(fields %sFields) {\n", si.prefix, si.prefix) + for _, fieldSetField := range si.reprByFieldSet { + fmt.Fprintf(b, "\tfs.%s.AddFieldsLoadable(fields.%s)\n", fieldSetField.fieldName, fieldSetField.fieldName) + } + for _, fieldName := range si.reprByBit { + fieldConstName := fmt.Sprintf("%sField%s", si.prefix, fieldName) + fmt.Fprintf(b, "\tif fields.%s {\n", fieldName) + if numBitmaskUint32s == 1 { + fmt.Fprintf(b, "\t\tatomic.StoreUint32(&fs.fields[0], fs.fields[0] | (uint32(1) << uint(%s)))\n", fieldConstName) + } else { + fmt.Fprintf(b, "\t\tword, bit := %s/32, %s%%32\n", fieldConstName, fieldConstName) + fmt.Fprintf(b, "\t\tatomic.StoreUint32(&fs.fields[word], fs.fields[word] | (uint32(1) << bit))\n") + } + fmt.Fprintf(b, "\t}\n") + } + fmt.Fprintf(b, "}\n\n") +} |