diff options
Diffstat (limited to 'tools')
66 files changed, 4701 insertions, 1378 deletions
diff --git a/tools/bazel.mk b/tools/bazel.mk index 60b50cfb0..1444423e4 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -84,7 +84,7 @@ DOCKER_RUN_OPTIONS += -v "$(shell readlink -m $(GCLOUD_CONFIG)):$(GCLOUD_CONFIG) DOCKER_RUN_OPTIONS += -v "/tmp:/tmp" DOCKER_EXEC_OPTIONS := --user $(UID):$(GID) DOCKER_EXEC_OPTIONS += --interactive -ifeq (true,$(shell test -t 0 && echo true)) +ifeq (true,$(shell test -t 1 && echo true)) DOCKER_EXEC_OPTIONS += --tty endif @@ -181,23 +181,13 @@ endif # build_paths extracts the built binary from the bazel stderr output. # -# This could be alternately done by parsing the bazel build event stream, but -# this is a complex schema, and begs the question: what will build the thing -# that parses the output? Bazel? Do we need a separate bootstrapping build -# command here? Yikes, let's just stick with the ugly shell pipeline. -# # The last line is used to prevent terminal shenanigans. build_paths = \ (set -euo pipefail; \ - $(call wrapper,$(BAZEL) build $(BASE_OPTIONS) $(BAZEL_OPTIONS) $(1)) 2>&1 \ - | tee /dev/fd/2 \ - | sed -n -e '/^Target/,$$p' \ - | sed -n -e '/^ \($(subst /,\/,$(subst $(SPACE),\|,$(BUILD_ROOTS)))\)/p' \ - | sed -e 's/ /\n/g' \ - | awk '{$$1=$$1};1' \ - | strings \ - | xargs -r -n 1 -I {} readlink -f "{}" \ - | xargs -r -n 1 -I {} bash -c 'set -xeuo pipefail; $(2)') + $(call wrapper,$(BAZEL) build $(BASE_OPTIONS) $(BAZEL_OPTIONS) $(1)) && \ + $(call wrapper,$(BAZEL) cquery $(BASE_OPTIONS) $(BAZEL_OPTIONS) $(1) --output=starlark --starlark:file=tools/show_paths.bzl) \ + | xargs -r -n 1 -I {} bash -c 'test -e "{}" || exit 0; readlink -f "{}"' \ + | xargs -r -n 1 -I {} bash -c 'set -euo pipefail; $(2)') clean = $(call header,CLEAN) && $(call wrapper,$(BAZEL) clean) build = $(call header,BUILD $(1)) && $(call build_paths,$(1),echo {}) diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD index 24e6f8a94..5295f4a85 100644 --- a/tools/bazeldefs/BUILD +++ b/tools/bazeldefs/BUILD @@ -46,6 +46,11 @@ genrule( outs = ["version.txt"], cmd = "cat bazel-out/stable-status.txt | grep STABLE_VERSION | cut -d' ' -f2- | sed 's/^[^[:digit:]]*//g' >$@", stamp = True, + tags = [ + "manual", + "nobuilder", + "notap", + ], visibility = ["//:sandbox"], ) diff --git a/tools/bazeldefs/cc.bzl b/tools/bazeldefs/cc.bzl index 2831eac5f..57d33726a 100644 --- a/tools/bazeldefs/cc.bzl +++ b/tools/bazeldefs/cc.bzl @@ -9,6 +9,7 @@ cc_test = _cc_test cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" gtest = "@com_google_googletest//:gtest" gbenchmark = "@com_google_benchmark//:benchmark" +gbenchmark_internal = "@com_google_benchmark//:benchmark" grpcpp = "@com_github_grpc_grpc//:grpc++" vdso_linker_option = "-fuse-ld=gold " diff --git a/tools/bazeldefs/go.bzl b/tools/bazeldefs/go.bzl index da027846b..af3a1c3ee 100644 --- a/tools/bazeldefs/go.bzl +++ b/tools/bazeldefs/go.bzl @@ -6,8 +6,11 @@ load("@io_bazel_rules_go//proto:def.bzl", _go_grpc_library = "go_grpc_library", load("//tools/bazeldefs:defs.bzl", "select_arch", "select_system") gazelle = _gazelle + go_embed_data = _go_embed_data + go_path = _go_path + bazel_worker_proto = "//tools/bazeldefs:worker_protocol_go_proto" def _go_proto_or_grpc_library(go_library_func, name, **kwargs): @@ -15,10 +18,19 @@ def _go_proto_or_grpc_library(go_library_func, name, **kwargs): # If importpath is explicit, pass straight through. go_library_func(name = name, **kwargs) return - deps = [ - dep.replace("_proto", "_go_proto") - for dep in (kwargs.pop("deps", []) or []) - ] + deps = [] + for d in (kwargs.pop("deps", []) or []): + if d == "@com_google_protobuf//:timestamp_proto": + # Special case: this proto has its Go definitions in a different + # repository. + deps.append("@org_golang_google_protobuf//" + + "types/known/timestamppb") + continue + if "//" in d: + repo, path = d.split("//", 1) + deps.append(repo + "//" + path.replace("_proto", "_go_proto")) + else: + deps.append(d.replace("_proto", "_go_proto")) go_library_func( name = name + "_go_proto", importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name + "_go_proto", @@ -130,18 +142,18 @@ def go_context(ctx, goos = None, goarch = None, std = False): elif goarch != go_ctx.sdk.goarch: fail("Internal GOARCH (%s) doesn't match GoSdk GOARCH (%s)." % (goarch, go_ctx.sdk.goarch)) return struct( - go = go_ctx.go, env = go_ctx.env, - nogo_args = [], - stdlib_srcs = go_ctx.sdk.srcs, - runfiles = depset([go_ctx.go] + go_ctx.sdk.srcs + go_ctx.sdk.tools + go_ctx.stdlib.libs), - goos = go_ctx.sdk.goos, + go = go_ctx.go, goarch = go_ctx.sdk.goarch, + goos = go_ctx.sdk.goos, gotags = go_ctx.tags, + nogo_args = [], + runfiles = depset([go_ctx.go] + go_ctx.sdk.srcs + go_ctx.sdk.tools + go_ctx.stdlib.libs), + stdlib_srcs = go_ctx.sdk.srcs, ) def select_goarch(): - return select_arch(arm64 = "arm64", amd64 = "amd64") + return select_arch(amd64 = "amd64", arm64 = "arm64") def select_goos(): return select_system(linux = "linux") diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go index 935154acc..082410697 100644 --- a/tools/bigquery/bigquery.go +++ b/tools/bigquery/bigquery.go @@ -39,13 +39,94 @@ type Suite struct { Timestamp time.Time `bq:"timestamp"` } +func (s *Suite) String() string { + conditions := make([]string, 0, len(s.Conditions)) + for _, c := range s.Conditions { + conditions = append(conditions, c.String()) + } + benchmarks := make([]string, 0, len(s.Benchmarks)) + for _, b := range s.Benchmarks { + benchmarks = append(benchmarks, b.String()) + } + + format := `Suite: +Name: %s +Conditions: %s +Benchmarks: %s +Official: %t +Timestamp: %s +` + + return fmt.Sprintf(format, + s.Name, + strings.Join(conditions, "\n"), + strings.Join(benchmarks, "\n"), + s.Official, + s.Timestamp) +} + // Benchmark represents an individual benchmark in a suite. type Benchmark struct { Name string `bq:"name"` - Condition []*Condition `bq:"condition"` + Condition []*Condition `bq:"cond"` Metric []*Metric `bq:"metric"` } +// String implements the String method for Benchmark +func (bm *Benchmark) String() string { + conditions := make([]string, 0, len(bm.Condition)) + for _, c := range bm.Condition { + conditions = append(conditions, c.String()) + } + metrics := make([]string, 0, len(bm.Metric)) + for _, m := range bm.Metric { + metrics = append(metrics, m.String()) + } + + format := `Condition: +Name: %s +Conditions: %s +Metrics: %s +` + + return fmt.Sprintf(format, + bm.Name, + strings.Join(conditions, "\n"), + strings.Join(metrics, "\n")) +} + +// AddMetric adds a metric to an existing Benchmark. +func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) { + m := &Metric{ + Name: metricName, + Unit: unit, + Sample: sample, + } + bm.Metric = append(bm.Metric, m) +} + +// AddCondition adds a condition to an existing Benchmark. +func (bm *Benchmark) AddCondition(name, value string) { + bm.Condition = append(bm.Condition, &Condition{ + Name: name, + Value: value, + }) +} + +// NewBenchmark initializes a new benchmark. +func NewBenchmark(name string, iters int) *Benchmark { + return &Benchmark{ + Name: name, + Metric: make([]*Metric, 0), + Condition: []*Condition{ + { + Name: "iterations", + Value: strconv.Itoa(iters), + }, + }, + } +} + // Condition represents qualifiers for the benchmark or suite. For example: // Get_Pid/1/real_time would have Benchmark Name "Get_Pid" with "1" // and "real_time" parameters as conditions. Suite conditions include @@ -55,6 +136,10 @@ type Condition struct { Value string `bq:"value"` } +func (c *Condition) String() string { + return fmt.Sprintf("Condition:\nName: %s Value: %s\n", c.Name, c.Value) +} + // Metric holds the actual metric data and unit information for this benchmark. type Metric struct { Name string `bq:"name"` @@ -62,6 +147,10 @@ type Metric struct { Sample float64 `bq:"sample"` } +func (m *Metric) String() string { + return fmt.Sprintf("Metric:\nName: %s Unit: %s Sample: %e\n", m.Name, m.Unit, m.Sample) +} + // InitBigQuery initializes a BigQuery dataset/table in the project. If the dataset/table already exists, it is not duplicated. func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string, opts []option.ClientOption) error { client, err := bq.NewClient(ctx, projectID, opts...) @@ -87,38 +176,6 @@ func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string, opt return nil } -// AddCondition adds a condition to an existing Benchmark. -func (bm *Benchmark) AddCondition(name, value string) { - bm.Condition = append(bm.Condition, &Condition{ - Name: name, - Value: value, - }) -} - -// AddMetric adds a metric to an existing Benchmark. -func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) { - m := &Metric{ - Name: metricName, - Unit: unit, - Sample: sample, - } - bm.Metric = append(bm.Metric, m) -} - -// NewBenchmark initializes a new benchmark. -func NewBenchmark(name string, iters int) *Benchmark { - return &Benchmark{ - Name: name, - Metric: make([]*Metric, 0), - Condition: []*Condition{ - { - Name: "iterations", - Value: strconv.Itoa(iters), - }, - }, - } -} - // NewBenchmarkWithMetric creates a new sending to BigQuery, initialized with a // single iteration and single metric. func NewBenchmarkWithMetric(name, metric, unit string, value float64) *Benchmark { diff --git a/tools/checkescape/BUILD b/tools/checkescape/BUILD index 940538b9e..109b5410c 100644 --- a/tools/checkescape/BUILD +++ b/tools/checkescape/BUILD @@ -8,6 +8,7 @@ go_library( nogo = False, visibility = ["//tools/nogo:__subpackages__"], deps = [ + "//tools/nogo/objdump", "@org_golang_x_tools//go/analysis:go_default_library", "@org_golang_x_tools//go/analysis/passes/buildssa:go_default_library", "@org_golang_x_tools//go/ssa:go_default_library", diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go index c788654a8..ddd1212d7 100644 --- a/tools/checkescape/checkescape.go +++ b/tools/checkescape/checkescape.go @@ -61,21 +61,19 @@ package checkescape import ( "bufio" "bytes" - "flag" "fmt" "go/ast" "go/token" "go/types" "io" "log" - "os" - "os/exec" "path/filepath" "strings" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/buildssa" "golang.org/x/tools/go/ssa" + "gvisor.dev/gvisor/tools/nogo/objdump" ) const ( @@ -92,21 +90,6 @@ const ( exempt = "// escapes" ) -var ( - // Binary is the binary under analysis. - // - // See Reader, below. - binary = flag.String("binary", "", "binary under analysis") - - // Reader is the input stream. - // - // This may be set instead of Binary. - Reader io.Reader - - // objdumpTool is the tool used to dump a binary. - objdumpTool = flag.String("objdump_tool", "", "tool used to dump a binary") -) - // EscapeReason is an escape reason. // // This is a simple enum. @@ -374,31 +357,6 @@ func MergeAll(others []Escapes) (es Escapes) { // Note that the map uses <basename.go>:<line> because that is all that is // provided in the objdump format. Since this is all local, it is sufficient. func loadObjdump() (map[string][]string, error) { - var ( - args []string - stdin io.Reader - ) - if *binary != "" { - args = append(args, *binary) - } else if Reader != nil { - stdin = Reader - } else { - // We have no input stream or binary. - return nil, fmt.Errorf("no binary or reader provided") - } - - // Construct our command. - cmd := exec.Command(*objdumpTool, args...) - cmd.Stdin = stdin - cmd.Stderr = os.Stderr - out, err := cmd.StdoutPipe() - if err != nil { - return nil, err - } - if err := cmd.Start(); err != nil { - return nil, err - } - // Identify calls by address or name. Note that this is also // constructed dynamically below, as we encounted the addresses. // This is because some of the functions (duffzero) may have @@ -431,78 +389,83 @@ func loadObjdump() (map[string][]string, error) { // Build the map. nextFunc := "" // For funcsAllowed. m := make(map[string][]string) - r := bufio.NewReader(out) -NextLine: - for { - line, err := r.ReadString('\n') - if err != nil && err != io.EOF { - return nil, err - } - fields := strings.Fields(line) - - // Is this an "allowed" function definition? - if len(fields) >= 2 && fields[0] == "TEXT" { - nextFunc = strings.TrimSuffix(fields[1], "(SB)") - if _, ok := funcsAllowed[nextFunc]; !ok { - nextFunc = "" // Don't record addresses. - } - } - if nextFunc != "" && len(fields) > 2 { - // Save the given address (in hex form, as it appears). - addrsAllowed[fields[1]] = struct{}{} - } - - // We recognize lines corresponding to actual code (not the - // symbol name or other metadata) and annotate them if they - // correspond to an explicit CALL instruction. We assume that - // the lack of a CALL for a given line is evidence that escape - // analysis has eliminated an allocation. - // - // Lines look like this (including the first space): - // gohacks_unsafe.go:33 0xa39 488b442408 MOVQ 0x8(SP), AX - if len(fields) >= 5 && line[0] == ' ' { - if !strings.Contains(fields[3], "CALL") { - continue + if err := objdump.Load(func(origR io.Reader) error { + r := bufio.NewReader(origR) + NextLine: + for { + line, err := r.ReadString('\n') + if err != nil && err != io.EOF { + return err } - site := fields[0] - target := strings.TrimSuffix(fields[4], "(SB)") + fields := strings.Fields(line) - // Ignore strings containing allowed functions. - if _, ok := funcsAllowed[target]; ok { - continue + // Is this an "allowed" function definition? + if len(fields) >= 2 && fields[0] == "TEXT" { + nextFunc = strings.TrimSuffix(fields[1], "(SB)") + if _, ok := funcsAllowed[nextFunc]; !ok { + nextFunc = "" // Don't record addresses. + } } - if _, ok := addrsAllowed[target]; ok { - continue + if nextFunc != "" && len(fields) > 2 { + // Save the given address (in hex form, as it appears). + addrsAllowed[fields[1]] = struct{}{} } - if len(fields) > 5 { - // This may be a future relocation. Some - // objdump versions describe this differently. - // If it contains any of the functions allowed - // above as a string, we let it go. - softTarget := strings.Join(fields[5:], " ") - for name := range funcsAllowed { - if strings.Contains(softTarget, name) { - continue NextLine + + // We recognize lines corresponding to actual code (not the + // symbol name or other metadata) and annotate them if they + // correspond to an explicit CALL instruction. We assume that + // the lack of a CALL for a given line is evidence that escape + // analysis has eliminated an allocation. + // + // Lines look like this (including the first space): + // gohacks_unsafe.go:33 0xa39 488b442408 MOVQ 0x8(SP), AX + if len(fields) >= 5 && line[0] == ' ' { + if !strings.Contains(fields[3], "CALL") { + continue + } + site := fields[0] + target := strings.TrimSuffix(fields[4], "(SB)") + + // Ignore strings containing allowed functions. + if _, ok := funcsAllowed[target]; ok { + continue + } + if _, ok := addrsAllowed[target]; ok { + continue + } + if len(fields) > 5 { + // This may be a future relocation. Some + // objdump versions describe this differently. + // If it contains any of the functions allowed + // above as a string, we let it go. + softTarget := strings.Join(fields[5:], " ") + for name := range funcsAllowed { + if strings.Contains(softTarget, name) { + continue NextLine + } } } - } - // Does this exist already? - existing, ok := m[site] - if !ok { - existing = make([]string, 0, 1) - } - for _, other := range existing { - if target == other { - continue NextLine + // Does this exist already? + existing, ok := m[site] + if !ok { + existing = make([]string, 0, 1) + } + for _, other := range existing { + if target == other { + continue NextLine + } } + existing = append(existing, target) + m[site] = existing // Update. + } + if err == io.EOF { + break } - existing = append(existing, target) - m[site] = existing // Update. - } - if err == io.EOF { - break } + return nil + }); err != nil { + return nil, err } // Zap any accidental false positives. @@ -518,11 +481,6 @@ NextLine: final[site] = filteredCalls } - // Wait for the dump to finish. - if err := cmd.Wait(); err != nil { - return nil, err - } - return final, nil } diff --git a/tools/checklinkname/BUILD b/tools/checklinkname/BUILD new file mode 100644 index 000000000..0f1b07e24 --- /dev/null +++ b/tools/checklinkname/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "checklinkname", + srcs = [ + "check_linkname.go", + "known.go", + ], + nogo = False, + visibility = ["//tools/nogo:__subpackages__"], + deps = [ + "@org_golang_x_tools//go/analysis:go_default_library", + ], +) diff --git a/tools/checklinkname/README.md b/tools/checklinkname/README.md new file mode 100644 index 000000000..06b3c302d --- /dev/null +++ b/tools/checklinkname/README.md @@ -0,0 +1,54 @@ +# `checklinkname` Analyzer + +`checklinkname` is an analyzer to provide rudimentary type-checking for +`//go:linkname` directives. Since `//go:linkname` only affects linker behavior, +there is no built-in type safety and it is the programmer's responsibility to +ensure the types on either side are compatible. + +`checklinkname` helps with this by checking that uses match expectations, as +defined in this package. + +`known.go` contains the set of known linkname targets. For most functions, we +expect identical types on both sides of the linkname. In a few cases, the types +may be slightly different (e.g., local redefinition of internal type). It is +still the responsibility of the programmer to ensure the signatures in +`known.go` are compatible and safe. + +## Findings + +Here are the most common findings from this package, and how to resolve them. + +### `runtime.foo signature got "BAR" want "BAZ"; stdlib type changed?` + +The definition of `runtime.foo` in the standard library does not match the +expected type in `known.go`. This means that the function signature in the +standard library changed. + +Addressing this will require creating a new linkname directive in a new Go +version build-tagged in any packages using this symbol. Be sure to also check to +ensure use with the new version is safe, as function constraints may have +changed in addition to the signature. + +<!-- TODO(b/165820485): This isn't yet explicitly supported. --> + +`known.go` will also need to be updated to accept the new signature for the new +version of Go. + +### `Cannot find known symbol "runtime.foo"` + +The standard library has removed runtime.foo entirely. Handling is similar to +above, except existing code must transition away from the symbol entirely (note +that is may simply be renamed). + +### `linkname to unknown symbol "mypkg.foo"; add this symbol to checklinkname.knownLinknames type-check against the remote type` + +A package has added a new linkname directive for a symbol not listed in +`known.go`. Address this by adding a new entry for the target symbol. The +`local` field should be the expected type in your package, while `remote` should +be expected type in the remote package (e.g., in the standard library). These +are typically identical, in which case `remote` can be omitted. + +### `usage: //go:linkname localname [linkname]` + +Malformed `//go:linkname` directive. This should be accompanied by a build +failure in the package. diff --git a/tools/checklinkname/check_linkname.go b/tools/checklinkname/check_linkname.go new file mode 100644 index 000000000..5373dd762 --- /dev/null +++ b/tools/checklinkname/check_linkname.go @@ -0,0 +1,229 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package checklinkname ensures that linkname declarations match their source. +package checklinkname + +import ( + "fmt" + "go/ast" + "go/token" + "go/types" + "strings" + + "golang.org/x/tools/go/analysis" +) + +// Analyzer implements the checklinkname analyzer. +var Analyzer = &analysis.Analyzer{ + Name: "checklinkname", + Doc: "verifies that linkname declarations match their source", + Run: run, +} + +// go:linkname can be rather confusing. https://pkg.go.dev/cmd/compile says: +// +// //go:linkname localname [importpath.name] +// +// This special directive does not apply to the Go code that follows it. +// Instead, the //go:linkname directive instructs the compiler to use +// “importpath.name” as the object file symbol name for the variable or +// function declared as “localname” in the source code. If the +// “importpath.name” argument is omitted, the directive uses the symbol's +// default object file symbol name and only has the effect of making the symbol +// accessible to other packages. Because this directive can subvert the type +// system and package modularity, it is only enabled in files that have +// imported "unsafe". +// +// In this package we use the term "local" to refer to the symbol name in the +// same package as the //go:linkname directive, whose name will be changed by +// the linker. We use the term "remote" to refer to the symbol name that we are +// changing to. +// +// In the general case, the local symbol is a function declaration, and the +// remote symbol is a real function in the standard library. + +// linknameSignatures describes a the type signatures of the symbols in a +// //go:linkname directive. +type linknameSignatures struct { + local string + remote string // equivalent to local if "". +} + +func (l *linknameSignatures) Remote() string { + if l.remote == "" { + return l.local + } + return l.remote +} + +// linknameSymbols describes the symbol namess in a single //go:linkname +// directive. +type linknameSymbols struct { + pos token.Pos + local string + remote string +} + +func findLinknames(pass *analysis.Pass, f *ast.File) []linknameSymbols { + var names []linknameSymbols + + for _, cg := range f.Comments { + for _, c := range cg.List { + if len(c.Text) <= 2 || !strings.HasPrefix(c.Text[2:], "go:linkname ") { + continue + } + + f := strings.Fields(c.Text) + if len(f) < 2 || len(f) > 3 { + // Malformed linkname. This is the same error the compiler emits. + pass.Reportf(c.Slash, "usage: //go:linkname localname [linkname]") + } + + if len(f) == 2 { + // "If the “importpath.name” argument is + // omitted, the directive uses the symbol's + // default object file symbol name and only has + // the effect of making the symbol accessible + // to other packages." + // -https://golang.org/cmd/compile + // + // There is no type-checking to be done here. + continue + } + + names = append(names, linknameSymbols{ + pos: c.Slash, + local: f[1], + remote: f[2], + }) + } + } + + return names +} + +func splitSymbol(pkg *types.Package, symbol string) (packagePath, name string) { + // Note that some runtime symbols can have multiple dots. e.g., + // runtime..init_task. + s := strings.SplitN(symbol, ".", 2) + + switch len(s) { + case 1: + // Package name omitted, use current package. + return pkg.Path(), symbol + case 2: + return s[0], s[1] + default: + panic("unreachable") + } +} + +func findObject(pkg *types.Package, symbol string) (types.Object, error) { + packagePath, symbolName := splitSymbol(pkg, symbol) + return findPackageObject(pkg, packagePath, symbolName) +} + +func findPackageObject(pkg *types.Package, packagePath, symbolName string) (types.Object, error) { + if pkg.Path() == packagePath { + o := pkg.Scope().Lookup(symbolName) + if o == nil { + return nil, fmt.Errorf("%q not found in %q (names: %+v)", symbolName, packagePath, pkg.Scope().Names()) + } + return o, nil + } + + for _, p := range pkg.Imports() { + if o, err := findPackageObject(p, packagePath, symbolName); err == nil { + return o, nil + } + } + + return nil, fmt.Errorf("package %q not found", packagePath) +} + +// checkOneLinkname verifies that the type of sym.local matches the type from +// knownLinknames. +func checkOneLinkname(pass *analysis.Pass, f *ast.File, sym linknameSymbols) { + remotePackage, remoteName := splitSymbol(pass.Pkg, sym.remote) + + m, ok := knownLinknames[remotePackage] + if !ok { + pass.Reportf(sym.pos, "linkname to unknown symbol %q; add this symbol to checklinkname.knownLinknames type-check against the remote type", sym.remote) + return + } + + linkname, ok := m[remoteName] + if !ok { + pass.Reportf(sym.pos, "linkname to unknown symbol %q; add this symbol to checklinkname.knownLinknames type-check against the remote type", sym.remote) + return + } + + local, err := findObject(pass.Pkg, sym.local) + if err != nil { + pass.Reportf(sym.pos, "Unable to find symbol %q: %v", sym.local, err) + return + } + + localSig, ok := local.Type().(*types.Signature) + if !ok { + pass.Reportf(local.Pos(), "%q object is not a signature: %+#v", sym.local, local) + return + } + + if linkname.local != localSig.String() { + pass.Reportf(local.Pos(), "%q signature got %q want %q; mismatched types?", sym.local, localSig.String(), linkname.local) + return + } +} + +// checkOneRemote verifies that the type of sym matches wantSig. +func checkOneRemote(pass *analysis.Pass, sym, wantSig string) { + o := pass.Pkg.Scope().Lookup(sym) + if o == nil { + pass.Reportf(pass.Files[0].Package, "Cannot find known symbol %q", sym) + return + } + + sig, ok := o.Type().(*types.Signature) + if !ok { + pass.Reportf(o.Pos(), "%q object is not a signature: %+#v", sym, o) + return + } + + if sig.String() != wantSig { + pass.Reportf(o.Pos(), "%q signature got %q want %q; stdlib type changed?", sym, sig.String(), wantSig) + return + } +} + +func run(pass *analysis.Pass) (interface{}, error) { + // First, check if any remote symbols are in this package. + p, ok := knownLinknames[pass.Pkg.Path()] + if ok { + for sym, l := range p { + checkOneRemote(pass, sym, l.Remote()) + } + } + + // Then check for local //go:linkname directives in this package. + for _, f := range pass.Files { + names := findLinknames(pass, f) + for _, n := range names { + checkOneLinkname(pass, f, n) + } + } + + return nil, nil +} diff --git a/tools/checklinkname/known.go b/tools/checklinkname/known.go new file mode 100644 index 000000000..54e5155fc --- /dev/null +++ b/tools/checklinkname/known.go @@ -0,0 +1,110 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package checklinkname + +// knownLinknames is the set of the symbols for which we can do a rudimentary +// type-check on. +// +// When analyzing the remote package (e.g., runtime), we verify the symbol +// signature matches 'remote'. When analyzing local packages with //go:linkname +// directives, we verify the symbol signature matches 'local'. +// +// Usually these are identical, but may differ slightly if equivalent +// replacement types are used in the local packages, such as a copy of a struct +// or uintptr instead of a pointer type. +// +// NOTE: It is the responsibility of the developer to verify the safety of the +// signatures used here! This analyzer only checks that types match this map; +// it does not verify compatibility of the entries themselves. +// +// //go:linkname directives with no corresponding entry here will trigger a +// finding. +// +// We preform only rudimentary string-based type-checking due to limitations in +// the analysis framework. Ideally, from the local package we'd lookup the +// remote symbol's types.Object and perform robust type-checking. +// Unfortunately, remote symbols are typically loaded from the remote package's +// gcexportdata. Since //go:linkname targets are usually not exported symbols, +// they are no included in gcexportdata and we cannot load their types.Object. +// +// TODO(b/165820485): Add option to specific per-version signatures. +var knownLinknames = map[string]map[string]linknameSignatures{ + "runtime": map[string]linknameSignatures{ + "entersyscall": linknameSignatures{ + local: "func()", + }, + "entersyscallblock": linknameSignatures{ + local: "func()", + }, + "exitsyscall": linknameSignatures{ + local: "func()", + }, + "fastrand": linknameSignatures{ + local: "func() uint32", + }, + "gopark": linknameSignatures{ + // TODO(b/165820485): add verification of waitReason + // size and reason and traceEv values. + local: "func(unlockf func(uintptr, unsafe.Pointer) bool, lock unsafe.Pointer, reason uint8, traceEv byte, traceskip int)", + remote: "func(unlockf func(*runtime.g, unsafe.Pointer) bool, lock unsafe.Pointer, reason runtime.waitReason, traceEv byte, traceskip int)", + }, + "goready": linknameSignatures{ + local: "func(gp uintptr, traceskip int)", + remote: "func(gp *runtime.g, traceskip int)", + }, + "goyield": linknameSignatures{ + local: "func()", + }, + "memmove": linknameSignatures{ + local: "func(to unsafe.Pointer, from unsafe.Pointer, n uintptr)", + }, + "throw": linknameSignatures{ + local: "func(s string)", + }, + }, + "sync": map[string]linknameSignatures{ + "runtime_canSpin": linknameSignatures{ + local: "func(i int) bool", + }, + "runtime_doSpin": linknameSignatures{ + local: "func()", + }, + "runtime_Semacquire": linknameSignatures{ + // The only difference here is the parameter names. We + // can't just change our local use to match remote, as + // the stdlib runtime and sync packages also disagree + // on the name, and the analyzer checks that use as + // well. + local: "func(addr *uint32)", + remote: "func(s *uint32)", + }, + "runtime_Semrelease": linknameSignatures{ + // See above. + local: "func(addr *uint32, handoff bool, skipframes int)", + remote: "func(s *uint32, handoff bool, skipframes int)", + }, + }, + "syscall": map[string]linknameSignatures{ + "runtime_BeforeFork": linknameSignatures{ + local: "func()", + }, + "runtime_AfterFork": linknameSignatures{ + local: "func()", + }, + "runtime_AfterForkInChild": linknameSignatures{ + local: "func()", + }, + }, +} diff --git a/tools/checklinkname/test/BUILD b/tools/checklinkname/test/BUILD new file mode 100644 index 000000000..b29bd84f2 --- /dev/null +++ b/tools/checklinkname/test/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "test", + testonly = 1, + srcs = ["test_unsafe.go"], +) diff --git a/tools/checklinkname/test/test_unsafe.go b/tools/checklinkname/test/test_unsafe.go new file mode 100644 index 000000000..a7504591c --- /dev/null +++ b/tools/checklinkname/test/test_unsafe.go @@ -0,0 +1,34 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package test provides linkname test targets. +package test + +import ( + _ "unsafe" // for go:linkname. +) + +//go:linkname DetachedLinkname runtime.fastrand + +//go:linkname attachedLinkname runtime.entersyscall +func attachedLinkname() + +// AttachedLinkname reexports attachedLinkname because go vet doesn't like an +// exported go:linkname without a comment starting with "// AttachedLinkname". +func AttachedLinkname() { + attachedLinkname() +} + +// DetachedLinkname has a linkname elsewhere in the file. +func DetachedLinkname() uint32 diff --git a/tools/checklocks/BUILD b/tools/checklocks/BUILD index 7d4c63dc7..d23b7cde6 100644 --- a/tools/checklocks/BUILD +++ b/tools/checklocks/BUILD @@ -4,11 +4,16 @@ package(licenses = ["notice"]) go_library( name = "checklocks", - srcs = ["checklocks.go"], + srcs = [ + "analysis.go", + "annotations.go", + "checklocks.go", + "facts.go", + "state.go", + ], nogo = False, visibility = ["//tools/nogo:__subpackages__"], deps = [ - "//pkg/log", "@org_golang_x_tools//go/analysis:go_default_library", "@org_golang_x_tools//go/analysis/passes/buildssa:go_default_library", "@org_golang_x_tools//go/ssa:go_default_library", diff --git a/tools/checklocks/README.md b/tools/checklocks/README.md index dfb0275ab..bd4beb649 100644 --- a/tools/checklocks/README.md +++ b/tools/checklocks/README.md @@ -1,16 +1,29 @@ # CheckLocks Analyzer -<!--* freshness: { owner: 'gvisor-eng' reviewed: '2020-10-05' } *--> +<!--* freshness: { owner: 'gvisor-eng' reviewed: '2021-03-21' } *--> -Checklocks is a nogo analyzer that at compile time uses Go's static analysis -tools to identify and flag cases where a field that is guarded by a mutex in the -same struct is accessed outside of a mutex lock. +Checklocks is an analyzer for lock and atomic constraints. The analyzer relies +on explicit annotations to identify fields that should be checked for access. -The analyzer relies on explicit '// +checklocks:<mutex-name>' kind of -annotations to identify fields that should be checked for access. +## Atomic annotations -Individual struct members may be protected by annotations that indicate how they -must be accessed. These annotations are of the form: +Individual struct members may be noted as requiring atomic access. These +annotations are of the form: + +```go +type foo struct { + // +checkatomic + bar int32 +} +``` + +This will ensure that all accesses to bar are atomic, with the exception of +operations on newly allocated objects. + +## Lock annotations + +Individual struct members may be protected by annotations that indicate locking +requirements for accessing members. These annotations are of the form: ```go type foo struct { @@ -64,30 +77,6 @@ annotations from becoming stale over time as fields are renamed, etc. # Currently not supported -1. The analyzer does not correctly handle deferred functions. e.g The following - code is not correctly checked by the analyzer. The defer call is never - evaluated. As a result if the lock was to be say unlocked twice via deferred - functions it would not be caught by the analyzer. - - Similarly deferred anonymous functions are not evaluated either. - -```go -type A struct { - mu sync.Mutex - - // +checklocks:mu - x int -} - -func abc() { - var a A - a.mu.Lock() - defer a.mu.Unlock() - defer a.mu.Unlock() - a.x = 1 -} -``` - 1. Anonymous functions are not correctly evaluated. The analyzer does not currently support specifying annotations on anonymous functions as a result evaluation of a function that accesses protected fields will fail. @@ -107,10 +96,9 @@ func abc() { f() a.mu.Unlock() } - ``` -# Explicitly Not Supported +### Explicitly Not Supported 1. Checking for embedded mutexes as sync.Locker rather than directly as 'sync.Mutex'. In other words, the checker will not track mutex Lock and @@ -140,3 +128,30 @@ func abc() { checklocks. Only struct members can be used. 2. The checker will not support checking for lock ordering violations. + +## Mixed mode + +Some members may allow read-only atomic access, but be protected against writes +by a mutex. Generally, this imposes the following requirements: + +For a read, one of the following must be true: + +1. A lock held be held. +1. The access is atomic. + +For a write, both of the following must be true: + +1. The lock must be held. +1. The write must be atomic. + +In order to annotate a relevant field, simply apply *both* annotations from +above. For example: + +```go +type foo struct { + mu sync.Mutex + // +checklocks:mu + // +checkatomic + bar int32 +} +``` diff --git a/tools/checklocks/analysis.go b/tools/checklocks/analysis.go new file mode 100644 index 000000000..d3fd797d0 --- /dev/null +++ b/tools/checklocks/analysis.go @@ -0,0 +1,628 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package checklocks + +import ( + "go/token" + "go/types" + "strings" + + "golang.org/x/tools/go/ssa" +) + +func gcd(a, b atomicAlignment) atomicAlignment { + for b != 0 { + a, b = b, a%b + } + return a +} + +// typeAlignment returns the type alignment for the given type. +func (pc *passContext) typeAlignment(pkg *types.Package, obj types.Object) atomicAlignment { + requiredOffset := atomicAlignment(1) + if pc.pass.ImportObjectFact(obj, &requiredOffset) { + return requiredOffset + } + + switch x := obj.Type().Underlying().(type) { + case *types.Struct: + fields := make([]*types.Var, x.NumFields()) + for i := 0; i < x.NumFields(); i++ { + fields[i] = x.Field(i) + } + offsets := pc.pass.TypesSizes.Offsetsof(fields) + for i := 0; i < x.NumFields(); i++ { + // Check the offset, and then assuming that this offset + // aligns with the offset for the broader type. + fieldRequired := pc.typeAlignment(pkg, fields[i]) + if offsets[i]%int64(fieldRequired) != 0 { + // The offset of this field is not compatible. + pc.maybeFail(fields[i].Pos(), "have alignment %d, need %d", offsets[i], fieldRequired) + } + // Ensure the requiredOffset is the LCM of the offset. + requiredOffset *= fieldRequired / gcd(requiredOffset, fieldRequired) + } + case *types.Array: + // Export direct alignment requirements. + if named, ok := x.Elem().(*types.Named); ok { + requiredOffset = pc.typeAlignment(pkg, named.Obj()) + } + default: + // Use the compiler's underlying alignment. + requiredOffset = atomicAlignment(pc.pass.TypesSizes.Alignof(obj.Type().Underlying())) + } + + if pkg == obj.Pkg() { + // Cache as an object fact, to subsequent calls. Note that we + // can only export object facts for the package that we are + // currently analyzing. There may be no exported facts for + // array types or alias types, for example. + pc.pass.ExportObjectFact(obj, &requiredOffset) + } + + return requiredOffset +} + +// checkTypeAlignment checks the alignment of the given type. +// +// This calls typeAlignment, which resolves all types recursively. This method +// should be called for all types individual to ensure full coverage. +func (pc *passContext) checkTypeAlignment(pkg *types.Package, typ *types.Named) { + _ = pc.typeAlignment(pkg, typ.Obj()) +} + +// checkAtomicCall checks for an atomic access. +// +// inst is the instruction analyzed, obj is used only for maybeFail. +// +// If mustBeAtomic is true, then we assert that the instruction *is* an atomic +// fucnction call. If it is false, then we assert that it is *not* an atomic +// dispatch. +// +// If readOnly is true, then only atomic read access are allowed. Note that +// readOnly is only meaningful if mustBeAtomic is set. +func (pc *passContext) checkAtomicCall(inst ssa.Instruction, obj types.Object, mustBeAtomic, readOnly bool) { + switch x := inst.(type) { + case *ssa.Call: + if x.Common().IsInvoke() { + if mustBeAtomic { + // This is an illegal interface dispatch. + pc.maybeFail(inst.Pos(), "dynamic dispatch with atomic-only field") + } + return + } + fn, ok := x.Common().Value.(*ssa.Function) + if !ok { + if mustBeAtomic { + // This is an illegal call to a non-static function. + pc.maybeFail(inst.Pos(), "dispatch to non-static function with atomic-only field") + } + return + } + pkg := fn.Package() + if pkg == nil { + if mustBeAtomic { + // This is a call to some shared wrapper function. + pc.maybeFail(inst.Pos(), "dispatch to shared function or wrapper") + } + return + } + var lff lockFunctionFacts // Check for exemption. + if obj := fn.Object(); obj != nil && pc.pass.ImportObjectFact(obj, &lff) && lff.Ignore { + return + } + if name := pkg.Pkg.Name(); name != "atomic" && name != "atomicbitops" { + if mustBeAtomic { + // This is an illegal call to a non-atomic package function. + pc.maybeFail(inst.Pos(), "dispatch to non-atomic function with atomic-only field") + } + return + } + if !mustBeAtomic { + // We are *not* expecting an atomic dispatch. + if _, ok := pc.forced[pc.positionKey(inst.Pos())]; !ok { + pc.maybeFail(inst.Pos(), "unexpected call to atomic function") + } + } + if !strings.HasPrefix(fn.Name(), "Load") && readOnly { + // We are not allowing any reads in this context. + if _, ok := pc.forced[pc.positionKey(inst.Pos())]; !ok { + pc.maybeFail(inst.Pos(), "unexpected call to atomic write function, is a lock missing?") + } + return + } + default: + if mustBeAtomic { + // This is something else entirely. + if _, ok := pc.forced[pc.positionKey(inst.Pos())]; !ok { + pc.maybeFail(inst.Pos(), "illegal use of atomic-only field by %T instruction", inst) + } + return + } + } +} + +func resolveStruct(typ types.Type) (*types.Struct, bool) { + structType, ok := typ.Underlying().(*types.Struct) + if ok { + return structType, true + } + ptrType, ok := typ.Underlying().(*types.Pointer) + if ok { + return resolveStruct(ptrType.Elem()) + } + return nil, false +} + +func findField(typ types.Type, field int) (types.Object, bool) { + structType, ok := resolveStruct(typ) + if !ok { + return nil, false + } + return structType.Field(field), true +} + +// instructionWithReferrers is a generalization over ssa.Field, ssa.FieldAddr. +type instructionWithReferrers interface { + ssa.Instruction + Referrers() *[]ssa.Instruction +} + +// checkFieldAccess checks the validity of a field access. +// +// This also enforces atomicity constraints for fields that must be accessed +// atomically. The parameter isWrite indicates whether this field is used +// downstream for a write operation. +func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj ssa.Value, field int, ls *lockState, isWrite bool) { + var ( + lff lockFieldFacts + lgf lockGuardFacts + guardsFound int + guardsHeld int + ) + + fieldObj, _ := findField(structObj.Type(), field) + pc.pass.ImportObjectFact(fieldObj, &lff) + pc.pass.ImportObjectFact(fieldObj, &lgf) + + for guardName, fl := range lgf.GuardedBy { + guardsFound++ + r := fl.resolve(structObj) + if _, ok := ls.isHeld(r); ok { + guardsHeld++ + continue + } + if _, ok := pc.forced[pc.positionKey(inst.Pos())]; ok { + // Mark this as locked, since it has been forced. + ls.lockField(r) + guardsHeld++ + continue + } + // Note that we may allow this if the disposition is atomic, + // and we are allowing atomic reads only. This will fall into + // the atomic disposition check below, which asserts that the + // access is atomic. Further, guardsHeld < guardsFound will be + // true for this case, so we require it to be read-only. + if lgf.AtomicDisposition != atomicRequired { + // There is no force key, no atomic access and no lock held. + pc.maybeFail(inst.Pos(), "invalid field access, %s must be locked when accessing %s (locks: %s)", guardName, fieldObj.Name(), ls.String()) + } + } + + // Check the atomic access for this field. + switch lgf.AtomicDisposition { + case atomicRequired: + // Check that this is used safely as an input. + readOnly := guardsHeld < guardsFound + if refs := inst.Referrers(); refs != nil { + for _, otherInst := range *refs { + pc.checkAtomicCall(otherInst, fieldObj, true, readOnly) + } + } + // Check that this is not otherwise written non-atomically, + // even if we do hold all the locks. + if isWrite { + pc.maybeFail(inst.Pos(), "non-atomic write of field %s, writes must still be atomic with locks held (locks: %s)", fieldObj.Name(), ls.String()) + } + case atomicDisallow: + // Check that this is *not* used atomically. + if refs := inst.Referrers(); refs != nil { + for _, otherInst := range *refs { + pc.checkAtomicCall(otherInst, fieldObj, false, false) + } + } + } +} + +func (pc *passContext) checkCall(call callCommon, ls *lockState) { + // See: https://godoc.org/golang.org/x/tools/go/ssa#CallCommon + // + // 1. "call" mode: when Method is nil (!IsInvoke), a CallCommon represents an ordinary + // function call of the value in Value, which may be a *Builtin, a *Function or any + // other value of kind 'func'. + // + // Value may be one of: + // (a) a *Function, indicating a statically dispatched call + // to a package-level function, an anonymous function, or + // a method of a named type. + // + // (b) a *MakeClosure, indicating an immediately applied + // function literal with free variables. + // + // (c) a *Builtin, indicating a statically dispatched call + // to a built-in function. + // + // (d) any other value, indicating a dynamically dispatched + // function call. + switch fn := call.Common().Value.(type) { + case *ssa.Function: + var lff lockFunctionFacts + if fn.Object() != nil { + pc.pass.ImportObjectFact(fn.Object(), &lff) + pc.checkFunctionCall(call, fn, &lff, ls) + } else { + // Anonymous functions have no facts, and cannot be + // annotated. We don't check for violations using the + // function facts, since they cannot exist. Instead, we + // do a fresh analysis using the current lock state. + fnls := ls.fork() + for i, arg := range call.Common().Args { + fnls.store(fn.Params[i], arg) + } + pc.checkFunction(call, fn, &lff, fnls, true /* force */) + } + case *ssa.MakeClosure: + // Note that creating and then invoking closures locally is + // allowed, but analysis of passing closures is done when + // checking individual instructions. + pc.checkClosure(call, fn, ls) + default: + return + } +} + +// postFunctionCallUpdate updates all conditions. +func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunctionFacts, ls *lockState) { + // Release all locks not still held. + for fieldName, fg := range lff.HeldOnEntry { + if _, ok := lff.HeldOnExit[fieldName]; ok { + continue + } + r := fg.resolveCall(call.Common().Args, call.Value()) + if s, ok := ls.unlockField(r); !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + pc.maybeFail(call.Pos(), "attempt to release %s (%s), but not held (locks: %s)", fieldName, s, ls.String()) + } + } + } + + // Update all held locks if acquired. + for fieldName, fg := range lff.HeldOnExit { + if _, ok := lff.HeldOnEntry[fieldName]; ok { + continue + } + // Acquire the lock per the annotation. + r := fg.resolveCall(call.Common().Args, call.Value()) + if s, ok := ls.lockField(r); !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + pc.maybeFail(call.Pos(), "attempt to acquire %s (%s), but already held (locks: %s)", fieldName, s, ls.String()) + } + } + } +} + +// checkFunctionCall checks preconditions for function calls, and tracks the +// lock state by recording relevant calls to sync functions. Note that calls to +// atomic functions are tracked by checkFieldAccess by looking directly at the +// referrers (because ordering doesn't matter there, so we need not scan in +// instruction order). +func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff *lockFunctionFacts, ls *lockState) { + // Check all guards required are held. + for fieldName, fg := range lff.HeldOnEntry { + r := fg.resolveCall(call.Common().Args, call.Value()) + if s, ok := ls.isHeld(r); !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + pc.maybeFail(call.Pos(), "must hold %s (%s) to call %s, but not held (locks: %s)", fieldName, s, fn.Name(), ls.String()) + } else { + // Force the lock to be acquired. + ls.lockField(r) + } + } + } + + // Update all lock state accordingly. + pc.postFunctionCallUpdate(call, lff, ls) + + // Check if it's a method dispatch for something in the sync package. + // See: https://godoc.org/golang.org/x/tools/go/ssa#Function + if fn.Package() != nil && fn.Package().Pkg.Name() == "sync" && fn.Signature.Recv() != nil { + switch fn.Name() { + case "Lock", "RLock": + if s, ok := ls.lockField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + // Double locking a mutex that is already locked. + pc.maybeFail(call.Pos(), "%s already locked (locks: %s)", s, ls.String()) + } + } + case "Unlock", "RUnlock": + if s, ok := ls.unlockField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + // Unlocking something that is already unlocked. + pc.maybeFail(call.Pos(), "%s already unlocked (locks: %s)", s, ls.String()) + } + } + } + } +} + +// checkClosure forks the lock state, and creates a binding for the FreeVars of +// the closure. This allows the analysis to resolve the closure. +func (pc *passContext) checkClosure(call callCommon, fn *ssa.MakeClosure, ls *lockState) { + clls := ls.fork() + clfn := fn.Fn.(*ssa.Function) + for i, fv := range clfn.FreeVars { + clls.store(fv, fn.Bindings[i]) + } + + // Note that this is *not* a call to check function call, which checks + // against the function preconditions. Instead, this does a fresh + // analysis of the function from source code with a different state. + var nolff lockFunctionFacts + pc.checkFunction(call, clfn, &nolff, clls, true /* force */) +} + +// freshAlloc indicates that v has been allocated within the local scope. There +// is no lock checking done on objects that are freshly allocated. +func freshAlloc(v ssa.Value) bool { + switch x := v.(type) { + case *ssa.Alloc: + return true + case *ssa.FieldAddr: + return freshAlloc(x.X) + case *ssa.Field: + return freshAlloc(x.X) + case *ssa.IndexAddr: + return freshAlloc(x.X) + case *ssa.Index: + return freshAlloc(x.X) + case *ssa.Convert: + return freshAlloc(x.X) + case *ssa.ChangeType: + return freshAlloc(x.X) + default: + return false + } +} + +// isWrite indicates that this value is used as the addr field in a store. +// +// Note that this may still be used for a write. The return here is optimistic +// but sufficient for basic analysis. +func isWrite(v ssa.Value) bool { + refs := v.Referrers() + if refs == nil { + return false + } + for _, ref := range *refs { + if s, ok := ref.(*ssa.Store); ok && s.Addr == v { + return true + } + } + return false +} + +// callCommon is an ssa.Value that also implements Common. +type callCommon interface { + Pos() token.Pos + Common() *ssa.CallCommon + Value() *ssa.Call +} + +// checkInstruction checks the legality the single instruction based on the +// current lockState. +func (pc *passContext) checkInstruction(inst ssa.Instruction, ls *lockState) (*ssa.Return, *lockState) { + switch x := inst.(type) { + case *ssa.Store: + // Record that this value is holding this other value. This is + // because at the beginning of each ssa execution, there is a + // series of assignments of parameter values to alloc objects. + // This allows us to trace these back to the original + // parameters as aliases above. + // + // Note that this may overwrite an existing value in the lock + // state, but this is intentional. + ls.store(x.Addr, x.Val) + case *ssa.Field: + if !freshAlloc(x.X) { + pc.checkFieldAccess(x, x.X, x.Field, ls, false) + } + case *ssa.FieldAddr: + if !freshAlloc(x.X) { + pc.checkFieldAccess(x, x.X, x.Field, ls, isWrite(x)) + } + case *ssa.Call: + pc.checkCall(x, ls) + case *ssa.Defer: + ls.pushDefer(x) + case *ssa.RunDefers: + for d := ls.popDefer(); d != nil; d = ls.popDefer() { + pc.checkCall(d, ls) + } + case *ssa.MakeClosure: + refs := x.Referrers() + if refs == nil { + // This is strange, it's not used? Ignore this case, + // since it will probably be optimized away. + return nil, nil + } + hasNonCall := false + for _, ref := range *refs { + switch ref.(type) { + case *ssa.Call, *ssa.Defer: + // Analysis will be done on the call itself + // subsequently, including the lock state at + // the time of the call. + default: + // We need to analyze separately. Per below, + // this means that we'll analyze at closure + // construction time no zero assumptions about + // when it will be called. + hasNonCall = true + } + } + if !hasNonCall { + return nil, nil + } + // Analyze the closure without bindings. This means that we + // assume no lock facts or have any existing lock state. Only + // trivial closures are acceptable in this case. + clfn := x.Fn.(*ssa.Function) + var nolff lockFunctionFacts + pc.checkFunction(nil, clfn, &nolff, nil, false /* force */) + case *ssa.Return: + return x, ls // Valid return state. + } + return nil, nil +} + +// checkBasicBlock traverses the control flow graph starting at a set of given +// block and checks each instruction for allowed operations. +func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock, lff *lockFunctionFacts, parent *lockState, seen map[*ssa.BasicBlock]*lockState) *lockState { + if oldLS, ok := seen[block]; ok && oldLS.isCompatible(parent) { + return nil + } + + // If the lock state is not compatible, then we need to do the + // recursive analysis to ensure that it is still sane. For example, the + // following is guaranteed to generate incompatible locking states: + // + // if foo { + // mu.Lock() + // } + // other stuff ... + // if foo { + // mu.Unlock() + // } + + var ( + rv *ssa.Return + rls *lockState + ) + + // Analyze this block. + seen[block] = parent + ls := parent.fork() + for _, inst := range block.Instrs { + rv, rls = pc.checkInstruction(inst, ls) + if rls != nil { + failed := false + // Validate held locks. + for fieldName, fg := range lff.HeldOnExit { + r := fg.resolveStatic(fn, rv) + if s, ok := rls.isHeld(r); !ok { + if _, ok := pc.forced[pc.positionKey(rv.Pos())]; !ok { + pc.maybeFail(rv.Pos(), "lock %s (%s) not held (locks: %s)", fieldName, s, rls.String()) + failed = true + } else { + // Force the lock to be acquired. + rls.lockField(r) + } + } + } + // Check for other locks, but only if the above didn't trip. + if !failed && rls.count() != len(lff.HeldOnExit) { + pc.maybeFail(rv.Pos(), "return with unexpected locks held (locks: %s)", rls.String()) + } + } + } + + // Analyze all successors. + for _, succ := range block.Succs { + // Collect possible return values, and make sure that the lock + // state aligns with any return value that we may have found + // above. Note that checkBasicBlock will recursively analyze + // the lock state to ensure that Releases and Acquires are + // respected. + if pls := pc.checkBasicBlock(fn, succ, lff, ls, seen); pls != nil { + if rls != nil && !rls.isCompatible(pls) { + if _, ok := pc.forced[pc.positionKey(fn.Pos())]; !ok { + pc.maybeFail(fn.Pos(), "incompatible return states (first: %s, second: %v)", rls.String(), pls.String()) + } + } + rls = pls + } + } + return rls +} + +// checkFunction checks a function invocation, typically starting with nil lockState. +func (pc *passContext) checkFunction(call callCommon, fn *ssa.Function, lff *lockFunctionFacts, parent *lockState, force bool) { + defer func() { + // Mark this function as checked. This is used by the top-level + // loop to ensure that all anonymous functions are scanned, if + // they are not explicitly invoked here. Note that this can + // happen if the anonymous functions are e.g. passed only as + // parameters or used to initialize some structure. + pc.functions[fn] = struct{}{} + }() + if _, ok := pc.functions[fn]; !force && ok { + // This function has already been analyzed at least once. + // That's all we permit for each function, although this may + // cause some anonymous functions to be analyzed in only one + // context. + return + } + + // If no return value is provided, then synthesize one. This is used + // below only to check against the locks preconditions, which may + // include return values. + if call == nil { + call = &ssa.Call{Call: ssa.CallCommon{Value: fn}} + } + + // Initialize ls with any preconditions that require locks to be held + // for the method to be invoked. Note that in the overwhleming majority + // of cases, parent will be nil. However, in the case of closures and + // anonymous functions, we may start with a non-nil lock state. + ls := parent.fork() + for fieldName, fg := range lff.HeldOnEntry { + // The first is the method object itself so we skip that when looking + // for receiver/function parameters. + r := fg.resolveStatic(fn, call.Value()) + if s, ok := ls.lockField(r); !ok { + // This can only happen if the same value is declared + // multiple times, and should be caught by the earlier + // fact scanning. Keep it here as a sanity check. + pc.maybeFail(fn.Pos(), "lock %s (%s) acquired multiple times (locks: %s)", fieldName, s, ls.String()) + } + } + + // Scan the blocks. + seen := make(map[*ssa.BasicBlock]*lockState) + if len(fn.Blocks) > 0 { + pc.checkBasicBlock(fn, fn.Blocks[0], lff, ls, seen) + } + + // Scan the recover block. + if fn.Recover != nil { + pc.checkBasicBlock(fn, fn.Recover, lff, ls, seen) + } + + // Update all lock state accordingly. This will be called only if we + // are doing inline analysis for e.g. an anonymous function. + if call != nil && parent != nil { + pc.postFunctionCallUpdate(call, lff, parent) + } +} diff --git a/tools/checklocks/annotations.go b/tools/checklocks/annotations.go new file mode 100644 index 000000000..371260980 --- /dev/null +++ b/tools/checklocks/annotations.go @@ -0,0 +1,129 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package checklocks + +import ( + "fmt" + + "go/token" + "strconv" + "strings" +) + +const ( + checkLocksAnnotation = "// +checklocks:" + checkLocksAcquires = "// +checklocksacquire:" + checkLocksReleases = "// +checklocksrelease:" + checkLocksIgnore = "// +checklocksignore" + checkLocksForce = "// +checklocksforce" + checkLocksFail = "// +checklocksfail" + checkAtomicAnnotation = "// +checkatomic" +) + +// failData indicates an expected failure. +type failData struct { + pos token.Pos + count int + seen int +} + +// positionKey is a simple position string. +type positionKey string + +// positionKey converts from a token.Pos to a key we can use to track failures +// as the position of the failure annotation is not the same as the position of +// the actual failure (different column/offsets). Hence we ignore these fields +// and only use the file/line numbers to track failures. +func (pc *passContext) positionKey(pos token.Pos) positionKey { + position := pc.pass.Fset.Position(pos) + return positionKey(fmt.Sprintf("%s:%d", position.Filename, position.Line)) +} + +// addFailures adds an expected failure. +func (pc *passContext) addFailures(pos token.Pos, s string) { + count := 1 + if len(s) > 0 && s[0] == ':' { + parsedCount, err := strconv.Atoi(s[1:]) + if err != nil { + pc.pass.Reportf(pos, "unable to parse failure annotation %q: %v", s[1:], err) + return + } + count = parsedCount + } + pc.failures[pc.positionKey(pos)] = &failData{ + pos: pos, + count: count, + } +} + +// addExemption adds an exemption. +func (pc *passContext) addExemption(pos token.Pos) { + pc.exemptions[pc.positionKey(pos)] = struct{}{} +} + +// addForce adds a force annotation. +func (pc *passContext) addForce(pos token.Pos) { + pc.forced[pc.positionKey(pos)] = struct{}{} +} + +// maybeFail checks a potential failure against a specific failure map. +func (pc *passContext) maybeFail(pos token.Pos, fmtStr string, args ...interface{}) { + if fd, ok := pc.failures[pc.positionKey(pos)]; ok { + fd.seen++ + return + } + if _, ok := pc.exemptions[pc.positionKey(pos)]; ok { + return // Ignored, not counted. + } + pc.pass.Reportf(pos, fmtStr, args...) +} + +// checkFailure checks for the expected failure counts. +func (pc *passContext) checkFailures() { + for _, fd := range pc.failures { + if fd.count != fd.seen { + // We are missing expect failures, report as much as possible. + pc.pass.Reportf(fd.pos, "got %d failures, want %d failures", fd.seen, fd.count) + } + } +} + +// extractAnnotations extracts annotations from text. +func (pc *passContext) extractAnnotations(s string, fns map[string]func(p string)) { + for prefix, fn := range fns { + if strings.HasPrefix(s, prefix) { + fn(s[len(prefix):]) + } + } +} + +// extractLineFailures extracts all line-based exceptions. +// +// Note that this applies only to individual line exemptions, and does not +// consider function-wide exemptions, or specific field exemptions, which are +// extracted separately as part of the saved facts for those objects. +func (pc *passContext) extractLineFailures() { + for _, f := range pc.pass.Files { + for _, cg := range f.Comments { + for _, c := range cg.List { + pc.extractAnnotations(c.Text, map[string]func(string){ + checkLocksFail: func(p string) { pc.addFailures(c.Pos(), p) }, + checkLocksIgnore: func(string) { pc.addExemption(c.Pos()) }, + checkLocksForce: func(string) { pc.addForce(c.Pos()) }, + }) + } + } + } +} diff --git a/tools/checklocks/checklocks.go b/tools/checklocks/checklocks.go index 1e877d394..401fb55ec 100644 --- a/tools/checklocks/checklocks.go +++ b/tools/checklocks/checklocks.go @@ -13,32 +13,19 @@ // limitations under the License. // Package checklocks performs lock analysis to identify and flag unprotected -// access to field annotated with a '// +checklocks:<mutex-name>' annotation. +// access to annotated fields. // -// For detailed ussage refer to README.md in the same directory. +// For detailed usage refer to README.md in the same directory. package checklocks import ( - "bytes" - "fmt" "go/ast" "go/token" "go/types" - "reflect" - "regexp" - "strconv" - "strings" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/buildssa" "golang.org/x/tools/go/ssa" - "gvisor.dev/gvisor/pkg/log" -) - -const ( - checkLocksAnnotation = "// +checklocks:" - checkLocksIgnore = "// +checklocksignore" - checkLocksFail = "// +checklocksfail" ) // Analyzer is the main entrypoint. @@ -47,712 +34,123 @@ var Analyzer = &analysis.Analyzer{ Doc: "checks lock preconditions on functions and fields", Run: run, Requires: []*analysis.Analyzer{buildssa.Analyzer}, - FactTypes: []analysis.Fact{(*lockFieldFacts)(nil), (*lockFunctionFacts)(nil)}, -} - -// lockFieldFacts apply on every struct field protected by a lock or that is a -// lock. -type lockFieldFacts struct { - // GuardedBy tracks the names and field numbers that guard this field. - GuardedBy map[string]int - - // IsMutex is true if the field is of type sync.Mutex. - IsMutex bool - - // IsRWMutex is true if the field is of type sync.RWMutex. - IsRWMutex bool - - // FieldNumber is the number of this field in the struct. - FieldNumber int -} - -// AFact implements analysis.Fact.AFact. -func (*lockFieldFacts) AFact() {} - -type functionGuard struct { - // ParameterNumber is the index of the object that contains the guarding mutex. - // This is required during SSA analysis as field names and parameters names are - // not available in SSA. For example, from the example below ParameterNumber would - // be 1 and FieldNumber would correspond to the field number of 'mu' within b's type. - // - // //+checklocks:b.mu - // func (a *A) method(b *B, c *C) { - // ... - // } - ParameterNumber int - - // FieldNumber is the field index of the mutex in the parameter's struct - // type. Refer to example above for more details. - FieldNumber int -} - -// lockFunctionFacts apply on every method. -type lockFunctionFacts struct { - // GuardedBy tracks the names and number of parameter (including receiver) - // lockFuncfields that guard calls to this function. - // The key is the name specified in the checklocks annotation. e.g given - // the following code. - // ``` - // type A struct { - // mu sync.Mutex - // a int - // } - // - // // +checklocks:a.mu - // func xyz(a *A) {..} - // ``` - // - // '`+checklocks:a.mu' will result in an entry in this map as shown below. - // GuardedBy: {"a.mu" => {ParameterNumber: 0, FieldNumber: 0} - GuardedBy map[string]functionGuard -} - -// AFact implements analysis.Fact.AFact. -func (*lockFunctionFacts) AFact() {} - -type positionKey string - -// toPositionKey converts from a token.Position to a key we can use to track -// failures as the position of the failure annotation is not the same as the -// position of the actual failure (different column/offsets). Hence we ignore -// these fields and only use the file/line numbers to track failures. -func toPositionKey(position token.Position) positionKey { - return positionKey(fmt.Sprintf("%s:%d", position.Filename, position.Line)) -} - -type failData struct { - pos token.Pos - count int -} - -func (f failData) String() string { - return fmt.Sprintf("pos: %d, count: %d", f.pos, f.count) + FactTypes: []analysis.Fact{(*atomicAlignment)(nil), (*lockFieldFacts)(nil), (*lockGuardFacts)(nil), (*lockFunctionFacts)(nil)}, } +// passContext is a pass with additional expected failures. type passContext struct { - pass *analysis.Pass - - // exemptions tracks functions that should be exempted from lock checking due - // to '// +checklocksignore' annotation. - exemptions map[types.Object]struct{} - - failures map[positionKey]*failData + pass *analysis.Pass + failures map[positionKey]*failData + exemptions map[positionKey]struct{} + forced map[positionKey]struct{} + functions map[*ssa.Function]struct{} } -var ( - mutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineMutex|Mutex)") - rwMutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineRWMutex|RWMutex)") -) - -func (pc *passContext) extractFieldAnnotations(field *ast.Field, fieldType *types.Var) *lockFieldFacts { - s := fieldType.Type().String() - // We use HasSuffix below because fieldType can be fully qualified with the - // package name eg for the gvisor sync package mutex fields have the type: - // "<package path>/sync/sync.Mutex" - switch { - case mutexRE.Match([]byte(s)): - return &lockFieldFacts{IsMutex: true} - case rwMutexRE.Match([]byte(s)): - return &lockFieldFacts{IsRWMutex: true} - default: - } - if field.Doc == nil { - return nil - } - fieldFacts := &lockFieldFacts{GuardedBy: make(map[string]int)} - for _, l := range field.Doc.List { - if strings.HasPrefix(l.Text, checkLocksAnnotation) { - guardName := strings.TrimPrefix(l.Text, checkLocksAnnotation) - if _, ok := fieldFacts.GuardedBy[guardName]; ok { - pc.pass.Reportf(field.Pos(), "annotation %s specified more than once.", l.Text) - continue - } - fieldFacts.GuardedBy[guardName] = -1 - } - } - - return fieldFacts -} - -func (pc *passContext) findField(v ssa.Value, fieldNumber int) types.Object { - structType, ok := v.Type().Underlying().(*types.Struct) - if !ok { - structType = v.Type().Underlying().(*types.Pointer).Elem().Underlying().(*types.Struct) - } - return structType.Field(fieldNumber) -} - -// findAndExportStructFacts finds any struct fields that are annotated with the -// "// +checklocks:" annotation and exports relevant facts about the fields to -// be used in later analysis. -func (pc *passContext) findAndExportStructFacts(ss *ast.StructType, structType *types.Struct) { - type fieldRef struct { - fieldObj *types.Var - facts *lockFieldFacts - } - mutexes := make(map[string]*fieldRef) - rwMutexes := make(map[string]*fieldRef) - guardedFields := make(map[string]*fieldRef) - for i, field := range ss.Fields.List { - fieldObj := structType.Field(i) - fieldFacts := pc.extractFieldAnnotations(field, fieldObj) - if fieldFacts == nil { - continue - } - fieldFacts.FieldNumber = i - - ref := &fieldRef{fieldObj, fieldFacts} - if fieldFacts.IsMutex { - mutexes[fieldObj.Name()] = ref - } - if fieldFacts.IsRWMutex { - rwMutexes[fieldObj.Name()] = ref - } - if len(fieldFacts.GuardedBy) != 0 { - guardedFields[fieldObj.Name()] = ref - } - } - - // Export facts about all mutexes. - for _, f := range mutexes { - pc.pass.ExportObjectFact(f.fieldObj, f.facts) - } - // Export facts about all rwMutexes. - for _, f := range rwMutexes { - pc.pass.ExportObjectFact(f.fieldObj, f.facts) - } - - // Validate that guarded fields annotations refer to actual mutexes or - // rwMutexes in the struct. - for _, gf := range guardedFields { - for g := range gf.facts.GuardedBy { - if f, ok := mutexes[g]; ok { - gf.facts.GuardedBy[g] = f.facts.FieldNumber - } else if f, ok := rwMutexes[g]; ok { - gf.facts.GuardedBy[g] = f.facts.FieldNumber - } else { - pc.maybeFail(gf.fieldObj.Pos(), false /* isExempted */, "invalid mutex guard, no such mutex %s in struct %s", g, structType.String()) - continue - } - // Export guarded field fact. - pc.pass.ExportObjectFact(gf.fieldObj, gf.facts) - } - } -} - -func (pc *passContext) findAndExportFuncFacts(d *ast.FuncDecl) { - log.Debugf("finding and exporting function facts\n") - // for each function definition, check for +checklocks:mu annotation, which - // means that the function must be called with that lock held. - fnObj := pc.pass.TypesInfo.ObjectOf(d.Name) - funcFacts := lockFunctionFacts{GuardedBy: make(map[string]functionGuard)} - var ( - ignore bool - ignorePos token.Pos - ) - -outerLoop: - for _, l := range d.Doc.List { - if strings.HasPrefix(l.Text, checkLocksIgnore) { - pc.exemptions[fnObj] = struct{}{} - ignore = true - ignorePos = l.Pos() - continue - } - if strings.HasPrefix(l.Text, checkLocksAnnotation) { - guardName := strings.TrimPrefix(l.Text, checkLocksAnnotation) - if _, ok := funcFacts.GuardedBy[guardName]; ok { - pc.pass.Reportf(l.Pos(), "annotation %s specified more than once.", l.Text) - continue - } - - found := false - x := strings.Split(guardName, ".") - if len(x) != 2 { - pc.pass.Reportf(l.Pos(), "checklocks mutex annotation should be of the form 'a.b'") +// forAllTypes applies the given function over all types. +func (pc *passContext) forAllTypes(fn func(ts *ast.TypeSpec)) { + for _, f := range pc.pass.Files { + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok || d.Tok != token.TYPE { continue } - paramName, fieldName := x[0], x[1] - log.Debugf("paramName: %s, fieldName: %s", paramName, fieldName) - var paramList []*ast.Field - if d.Recv != nil { - paramList = append(paramList, d.Recv.List...) - } - if d.Type.Params != nil { - paramList = append(paramList, d.Type.Params.List...) - } - for paramNum, field := range paramList { - log.Debugf("field names: %+v", field.Names) - if len(field.Names) == 0 { - log.Debugf("skipping because parameter is unnamed", paramName) - continue - } - nameExists := false - for _, name := range field.Names { - if name.Name == paramName { - nameExists = true - } - } - if !nameExists { - log.Debugf("skipping because parameter name(s) does not match : %s", paramName) - continue - } - ptrType, ok := pc.pass.TypesInfo.TypeOf(field.Type).Underlying().(*types.Pointer) - if !ok { - // Since mutexes cannot be copied we only care about parameters that - // are pointer types when checking for guards. - pc.pass.Reportf(l.Pos(), "annotation %s incorrectly specified, parameter name does not refer to a pointer type", l.Text) - continue outerLoop - } - - structType, ok := ptrType.Elem().Underlying().(*types.Struct) - if !ok { - pc.pass.Reportf(l.Pos(), "annotation %s incorrectly specified, parameter name does not refer to a pointer to a struct", l.Text) - continue outerLoop - } - - for i := 0; i < structType.NumFields(); i++ { - if structType.Field(i).Name() == fieldName { - var fieldFacts lockFieldFacts - pc.pass.ImportObjectFact(structType.Field(i), &fieldFacts) - if !fieldFacts.IsMutex && !fieldFacts.IsRWMutex { - pc.pass.Reportf(l.Pos(), "field %s of param %s is not a mutex or an rwmutex", paramName, structType.Field(i)) - continue outerLoop - } - funcFacts.GuardedBy[guardName] = functionGuard{ParameterNumber: paramNum, FieldNumber: i} - found = true - continue outerLoop - } - } - if !found { - pc.pass.Reportf(l.Pos(), "annotation refers to a non-existent field %s in %s", guardName, structType) - continue outerLoop - } - } - if !found { - pc.pass.Reportf(l.Pos(), "annotation refers to a non-existent parameter %s", paramName) - } - } - } - - if len(funcFacts.GuardedBy) == 0 { - return - } - if ignore { - pc.pass.Reportf(ignorePos, "//+checklocksignore cannot be specified with other annotations on the function") - } - funcObj, ok := pc.pass.TypesInfo.Defs[d.Name].(*types.Func) - if !ok { - panic(fmt.Sprintf("function type information missing for %+v", d)) - } - log.Debugf("export fact for d: %+v, funcObj: %+v, funcFacts: %+v\n", d, funcObj, funcFacts) - pc.pass.ExportObjectFact(funcObj, &funcFacts) -} - -type mutexState struct { - // lockedMutexes is used to track which mutexes in a given struct are - // currently locked using the field number of the mutex as the key. - lockedMutexes map[int]struct{} -} - -// locksHeld tracks all currently held locks. -type locksHeld struct { - locks map[ssa.Value]mutexState -} - -// Same returns true if the locks held by other and l are the same. -func (l *locksHeld) Same(other *locksHeld) bool { - return reflect.DeepEqual(l.locks, other.locks) -} - -// Copy creates a copy of all the lock state held by l. -func (l *locksHeld) Copy() *locksHeld { - out := &locksHeld{locks: make(map[ssa.Value]mutexState)} - for ssaVal, mState := range l.locks { - newLM := make(map[int]struct{}) - for k, v := range mState.lockedMutexes { - newLM[k] = v - } - out.locks[ssaVal] = mutexState{lockedMutexes: newLM} - } - return out -} - -func isAlias(first, second ssa.Value) bool { - if first == second { - return true - } - switch x := first.(type) { - case *ssa.Field: - if y, ok := second.(*ssa.Field); ok { - return x.Field == y.Field && isAlias(x.X, y.X) - } - case *ssa.FieldAddr: - if y, ok := second.(*ssa.FieldAddr); ok { - return x.Field == y.Field && isAlias(x.X, y.X) - } - case *ssa.Index: - if y, ok := second.(*ssa.Index); ok { - return isAlias(x.Index, y.Index) && isAlias(x.X, y.X) - } - case *ssa.IndexAddr: - if y, ok := second.(*ssa.IndexAddr); ok { - return isAlias(x.Index, y.Index) && isAlias(x.X, y.X) - } - case *ssa.UnOp: - if y, ok := second.(*ssa.UnOp); ok { - return isAlias(x.X, y.X) - } - } - return false -} - -// checkBasicBlocks traverses the control flow graph starting at a set of given -// block and checks each instruction for allowed operations. -// -// funcFact are the exported facts for the enclosing function for these basic -// blocks. -func (pc *passContext) checkBasicBlocks(blocks []*ssa.BasicBlock, recoverBlock *ssa.BasicBlock, fn *ssa.Function, funcFact lockFunctionFacts) { - if len(blocks) == 0 { - return - } - - // mutexes is used to track currently locked sync.Mutexes/sync.RWMutexes for a - // given *struct identified by ssa.Value. - seen := make(map[*ssa.BasicBlock]*locksHeld) - var scan func(block *ssa.BasicBlock, parent *locksHeld) - scan = func(block *ssa.BasicBlock, parent *locksHeld) { - _, isExempted := pc.exemptions[block.Parent().Object()] - if oldLocksHeld, ok := seen[block]; ok { - if oldLocksHeld.Same(parent) { - return - } - pc.maybeFail(block.Instrs[0].Pos(), isExempted, "failure entering a block %+v with different sets of lock held, oldLocks: %+v, parentLocks: %+v", block, oldLocksHeld, parent) - return - } - seen[block] = parent - var lh = parent.Copy() - for _, inst := range block.Instrs { - pc.checkInstruction(inst, isExempted, lh) - } - for _, b := range block.Succs { - scan(b, lh) - } - } - - // Initialize lh with any preconditions that require locks to be held for the - // method to be invoked. - lh := &locksHeld{locks: make(map[ssa.Value]mutexState)} - for _, fg := range funcFact.GuardedBy { - // The first is the method object itself so we skip that when looking - // for receiver/function parameters. - log.Debugf("fn: %s, fn.Operands() == %+v", fn, fn.Operands(nil)) - r := fn.Params[fg.ParameterNumber] - guardObj := findField(r, fg.FieldNumber) - var fieldFacts lockFieldFacts - pc.pass.ImportObjectFact(guardObj, &fieldFacts) - if fieldFacts.IsMutex || fieldFacts.IsRWMutex { - m, ok := lh.locks[r] - if !ok { - m = mutexState{lockedMutexes: make(map[int]struct{})} - lh.locks[r] = m + for _, gs := range d.Specs { + fn(gs.(*ast.TypeSpec)) } - m.lockedMutexes[fieldFacts.FieldNumber] = struct{}{} - } else { - panic(fmt.Sprintf("function: %+v has an invalid guard that is not a mutex: %+v", fn, guardObj)) - } - } - - // Start scanning from the first basic block. - scan(blocks[0], lh) - - // Validate that all blocks were touched. - for _, b := range blocks { - if _, ok := seen[b]; !ok && b != recoverBlock { - panic(fmt.Sprintf("block %+v was not visited during checkBasicBlocks", b)) - } - } -} - -func (pc *passContext) checkInstruction(inst ssa.Instruction, isExempted bool, lh *locksHeld) { - log.Debugf("checking instruction: %s, isExempted: %t", inst, isExempted) - switch x := inst.(type) { - case *ssa.Field: - pc.checkFieldAccess(inst, x.X, x.Field, isExempted, lh) - case *ssa.FieldAddr: - pc.checkFieldAccess(inst, x.X, x.Field, isExempted, lh) - case *ssa.Call: - pc.checkFunctionCall(x, isExempted, lh) - } -} - -func findField(v ssa.Value, field int) types.Object { - structType, ok := v.Type().Underlying().(*types.Struct) - if !ok { - ptrType, ok := v.Type().Underlying().(*types.Pointer) - if !ok { - return nil - } - structType = ptrType.Elem().Underlying().(*types.Struct) - } - return structType.Field(field) -} - -func (pc *passContext) maybeFail(pos token.Pos, isExempted bool, fmtStr string, args ...interface{}) { - posKey := toPositionKey(pc.pass.Fset.Position(pos)) - log.Debugf("maybeFail: pos: %d, positionKey: %s", pos, posKey) - if fData, ok := pc.failures[posKey]; ok { - fData.count-- - if fData.count == 0 { - delete(pc.failures, posKey) } - return - } - if !isExempted { - pc.pass.Reportf(pos, fmt.Sprintf(fmtStr, args...)) } } -func (pc *passContext) checkFieldAccess(inst ssa.Instruction, structObj ssa.Value, field int, isExempted bool, lh *locksHeld) { - var fieldFacts lockFieldFacts - fieldObj := findField(structObj, field) - pc.pass.ImportObjectFact(fieldObj, &fieldFacts) - log.Debugf("fieldObj: %s, fieldFacts: %+v", fieldObj, fieldFacts) - for _, guardFieldNumber := range fieldFacts.GuardedBy { - guardObj := findField(structObj, guardFieldNumber) - var guardfieldFacts lockFieldFacts - pc.pass.ImportObjectFact(guardObj, &guardfieldFacts) - log.Debugf("guardObj: %s, guardFieldFacts: %+v", guardObj, guardfieldFacts) - if guardfieldFacts.IsMutex || guardfieldFacts.IsRWMutex { - log.Debugf("guard is a mutex") - m, ok := lh.locks[structObj] +// forAllFunctions applies the given function over all functions. +func (pc *passContext) forAllFunctions(fn func(fn *ast.FuncDecl)) { + for _, f := range pc.pass.Files { + for _, decl := range f.Decls { + d, ok := decl.(*ast.FuncDecl) if !ok { - pc.maybeFail(inst.Pos(), isExempted, "invalid field access, %s must be locked when accessing %s", guardObj.Name(), fieldObj.Name()) - continue - } - if _, ok := m.lockedMutexes[guardfieldFacts.FieldNumber]; !ok { - pc.maybeFail(inst.Pos(), isExempted, "invalid field access, %s must be locked when accessing %s", guardObj.Name(), fieldObj.Name()) - } - } else { - panic("incorrect guard that is not a mutex or an RWMutex") - } - } -} - -func (pc *passContext) checkFunctionCall(call *ssa.Call, isExempted bool, lh *locksHeld) { - // See: https://godoc.org/golang.org/x/tools/go/ssa#CallCommon - // - // 1. "call" mode: when Method is nil (!IsInvoke), a CallCommon represents an ordinary - // function call of the value in Value, which may be a *Builtin, a *Function or any - // other value of kind 'func'. - // - // Value may be one of: - // (a) a *Function, indicating a statically dispatched call - // to a package-level function, an anonymous function, or - // a method of a named type. - // - // (b) a *MakeClosure, indicating an immediately applied - // function literal with free variables. - // - // (c) a *Builtin, indicating a statically dispatched call - // to a built-in function. - // - // (d) any other value, indicating a dynamically dispatched - // function call. - fn, ok := call.Common().Value.(*ssa.Function) - if !ok { - return - } - if fn.Object() == nil { - return - } - - // Check if the function should be called with any locks held. - var funcFact lockFunctionFacts - pc.pass.ImportObjectFact(fn.Object(), &funcFact) - if len(funcFact.GuardedBy) > 0 { - for _, fg := range funcFact.GuardedBy { - // The first is the method object itself so we skip that when looking - // for receiver/function parameters. - r := (*call.Value().Operands(nil)[fg.ParameterNumber+1]) - guardObj := findField(r, fg.FieldNumber) - if guardObj == nil { continue } - var fieldFacts lockFieldFacts - pc.pass.ImportObjectFact(guardObj, &fieldFacts) - if fieldFacts.IsMutex || fieldFacts.IsRWMutex { - heldMutexes, ok := lh.locks[r] - if !ok { - log.Debugf("fn: %s, funcFact: %+v", fn, funcFact) - pc.maybeFail(call.Pos(), isExempted, "invalid function call %s must be held", guardObj.Name()) - continue - } - if _, ok := heldMutexes.lockedMutexes[fg.FieldNumber]; !ok { - log.Debugf("fn: %s, funcFact: %+v", fn, funcFact) - pc.maybeFail(call.Pos(), isExempted, "invalid function call %s must be held", guardObj.Name()) - } - } else { - panic(fmt.Sprintf("function: %+v has an invalid guard that is not a mutex: %+v", fn, guardObj)) - } - } - } - - // Check if it's a method dispatch for something in the sync package. - // See: https://godoc.org/golang.org/x/tools/go/ssa#Function - if fn.Package() != nil && fn.Package().Pkg.Name() == "sync" && fn.Signature.Recv() != nil { - r, ok := call.Common().Args[0].(*ssa.FieldAddr) - if !ok { - return - } - guardObj := findField(r.X, r.Field) - var fieldFacts lockFieldFacts - pc.pass.ImportObjectFact(guardObj, &fieldFacts) - if fieldFacts.IsMutex || fieldFacts.IsRWMutex { - switch fn.Name() { - case "Lock", "RLock": - obj := r.X - m := mutexState{lockedMutexes: make(map[int]struct{})} - for k, v := range lh.locks { - if isAlias(r.X, k) { - obj = k - m = v - } - } - if _, ok := m.lockedMutexes[r.Field]; ok { - // Double locking a mutex that is already locked. - pc.maybeFail(call.Pos(), isExempted, "trying to a lock %s when already locked", guardObj.Name()) - return - } - m.lockedMutexes[r.Field] = struct{}{} - lh.locks[obj] = m - case "Unlock", "RUnlock": - // Find the associated locker object. - var ( - obj ssa.Value - m mutexState - ) - for k, v := range lh.locks { - if isAlias(r.X, k) { - obj = k - m = v - break - } - } - if _, ok := m.lockedMutexes[r.Field]; !ok { - pc.maybeFail(call.Pos(), isExempted, "trying to unlock a mutex %s that is already unlocked", guardObj.Name()) - return - } - delete(m.lockedMutexes, r.Field) - if len(m.lockedMutexes) == 0 { - delete(lh.locks, obj) - } - case "RLocker", "DowngradeLock", "TryLock", "TryRLock": - // we explicitly ignore this for now. - default: - panic(fmt.Sprintf("unexpected mutex/rwmutex method invoked: %s", fn.Name())) - } + fn(d) } } } +// run is the main entrypoint. func run(pass *analysis.Pass) (interface{}, error) { pc := &passContext{ pass: pass, - exemptions: make(map[types.Object]struct{}), failures: make(map[positionKey]*failData), + exemptions: make(map[positionKey]struct{}), + forced: make(map[positionKey]struct{}), + functions: make(map[*ssa.Function]struct{}), } // Find all line failure annotations. - for _, f := range pass.Files { - for _, cg := range f.Comments { - for _, c := range cg.List { - if strings.Contains(c.Text, checkLocksFail) { - cnt := 1 - if strings.Contains(c.Text, checkLocksFail+":") { - parts := strings.SplitAfter(c.Text, checkLocksFail+":") - parsedCount, err := strconv.Atoi(parts[1]) - if err != nil { - pc.pass.Reportf(c.Pos(), "invalid checklocks annotation : %s", err) - continue - } - cnt = parsedCount - } - position := toPositionKey(pass.Fset.Position(c.Pos())) - pc.failures[position] = &failData{pos: c.Pos(), count: cnt} - } - } - } - } - - // Find all struct declarations and export any relevant facts. - for _, f := range pass.Files { - for _, decl := range f.Decls { - d, ok := decl.(*ast.GenDecl) - // A GenDecl node (generic declaration node) represents an import, - // constant, type or variable declaration. We only care about struct - // declarations so skip any declaration that doesn't declare a new type. - if !ok || d.Tok != token.TYPE { - continue - } + pc.extractLineFailures() - for _, gs := range d.Specs { - ts := gs.(*ast.TypeSpec) - ss, ok := ts.Type.(*ast.StructType) - if !ok { - continue - } - structType := pass.TypesInfo.TypeOf(ts.Name).Underlying().(*types.Struct) - pc.findAndExportStructFacts(ss, structType) - } + // Find all struct declarations and export relevant facts. + pc.forAllTypes(func(ts *ast.TypeSpec) { + if ss, ok := ts.Type.(*ast.StructType); ok { + structType := pc.pass.TypesInfo.TypeOf(ts.Name).Underlying().(*types.Struct) + pc.exportLockFieldFacts(structType, ss) } - } + }) + pc.forAllTypes(func(ts *ast.TypeSpec) { + if ss, ok := ts.Type.(*ast.StructType); ok { + structType := pc.pass.TypesInfo.TypeOf(ts.Name).Underlying().(*types.Struct) + pc.exportLockGuardFacts(structType, ss) + } + }) - // Find all method calls and export any relevant facts. - for _, f := range pass.Files { - for _, decl := range f.Decls { - d, ok := decl.(*ast.FuncDecl) - // Ignore any non function declarations and any functions that do not have - // any comments. - if !ok || d.Doc == nil { - continue - } - pc.findAndExportFuncFacts(d) + // Check all alignments. + pc.forAllTypes(func(ts *ast.TypeSpec) { + typ, ok := pass.TypesInfo.TypeOf(ts.Name).(*types.Named) + if !ok { + return } - } + pc.checkTypeAlignment(pass.Pkg, typ) + }) - // log all known facts and all failures if debug logging is enabled. - allFacts := pass.AllObjectFacts() - for i := range allFacts { - log.Debugf("fact.object: %+v, fact.Fact: %+v", allFacts[i].Object, allFacts[i].Fact) - } - log.Debugf("all expected failures: %+v", pc.failures) + // Find all function declarations and export relevant facts. + pc.forAllFunctions(func(fn *ast.FuncDecl) { + pc.exportFunctionFacts(fn) + }) // Scan all code looking for invalid accesses. state := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) for _, fn := range state.SrcFuncs { - var funcFact lockFunctionFacts - // Anonymous(closures) functions do not have an object() but do show up in - // the SSA. - if obj := fn.Object(); obj != nil { - pc.pass.ImportObjectFact(fn.Object(), &funcFact) + // Import function facts generated above. + // + // Note that anonymous(closures) functions do not have an + // object but do show up in the SSA. They can only be invoked + // by named functions in the package, and they are analyzing + // inline on every call. Thus we skip the analysis here. They + // will be hit on calls, or picked up in the pass below. + if obj := fn.Object(); obj == nil { + continue } + var lff lockFunctionFacts + pc.pass.ImportObjectFact(fn.Object(), &lff) - log.Debugf("checking function: %s", fn) - var b bytes.Buffer - ssa.WriteFunction(&b, fn) - log.Debugf("function SSA: %s", b.String()) - if fn.Recover != nil { - pc.checkBasicBlocks([]*ssa.BasicBlock{fn.Recover}, nil, fn, funcFact) + // Do we ignore this? + if lff.Ignore { + continue } - pc.checkBasicBlocks(fn.Blocks, fn.Recover, fn, funcFact) - } - // Scan for remaining failures we expect. - for _, failure := range pc.failures { - // We are missing expect failures, report as much as possible. - pass.Reportf(failure.pos, "expected %d failures", failure.count) + // Check the basic blocks in the function. + pc.checkFunction(nil, fn, &lff, nil, false /* force */) } + for _, fn := range state.SrcFuncs { + // Ensure all anonymous functions are hit. They are not + // permitted to have any lock preconditions. + if obj := fn.Object(); obj != nil { + continue + } + var nolff lockFunctionFacts + pc.checkFunction(nil, fn, &nolff, nil, false /* force */) + } + + // Check for expected failures. + pc.checkFailures() return nil, nil } diff --git a/tools/checklocks/facts.go b/tools/checklocks/facts.go new file mode 100644 index 000000000..34c9f5ef1 --- /dev/null +++ b/tools/checklocks/facts.go @@ -0,0 +1,624 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package checklocks + +import ( + "fmt" + "go/ast" + "go/token" + "go/types" + "regexp" + "strings" + + "golang.org/x/tools/go/ssa" +) + +// atomicAlignment is saved per type. +// +// This represents the alignment required for the type, which may +// be implied and imposed by other types within the aggregate type. +type atomicAlignment int + +// AFact implements analysis.Fact.AFact. +func (*atomicAlignment) AFact() {} + +// atomicDisposition is saved per field. +// +// This represents how the field must be accessed. It must either +// be non-atomic (default), atomic or ignored. +type atomicDisposition int + +const ( + atomicDisallow atomicDisposition = iota + atomicIgnore + atomicRequired +) + +// fieldList is a simple list of fields, used in two types below. +// +// Note that the integers in this list refer to one of two things: +// - A positive integer refers to a field index in a struct. +// - A negative integer refers to a field index in a struct, where +// that field is a pointer and must be subsequently resolved. +type fieldList []int + +// resolvedValue is an ssa.Value with additional fields. +// +// This can be resolved to a string as part of a lock state. +type resolvedValue struct { + value ssa.Value + valid bool + fieldList []int +} + +// findExtract finds a relevant extract. This must exist within the referrers +// to the call object. If this doesn't then the object which is locked is never +// consumed, and we should consider this a bug. +func findExtract(v ssa.Value, index int) (ssa.Value, bool) { + if refs := v.Referrers(); refs != nil { + for _, inst := range *refs { + if x, ok := inst.(*ssa.Extract); ok && x.Tuple == v && x.Index == index { + return inst.(ssa.Value), true + } + } + } + return nil, false +} + +// resolve resolves the given field list. +func (fl fieldList) resolve(v ssa.Value) (rv resolvedValue) { + return resolvedValue{ + value: v, + fieldList: fl, + valid: true, + } +} + +// valueAsString returns a string representing this value. +// +// This must align with how the string is generated in valueAsString. +func (rv resolvedValue) valueAsString(ls *lockState) string { + typ := rv.value.Type() + s := ls.valueAsString(rv.value) + for i, fieldNumber := range rv.fieldList { + switch { + case fieldNumber > 0: + field, ok := findField(typ, fieldNumber-1) + if !ok { + // This can't be resolved, return for debugging. + return fmt.Sprintf("{%s+%v}", s, rv.fieldList[i:]) + } + s = fmt.Sprintf("&(%s.%s)", s, field.Name()) + typ = field.Type() + case fieldNumber < 1: + field, ok := findField(typ, (-fieldNumber)-1) + if !ok { + // See above. + return fmt.Sprintf("{%s+%v}", s, rv.fieldList[i:]) + } + s = fmt.Sprintf("*(&(%s.%s))", s, field.Name()) + typ = field.Type() + } + } + return s +} + +// lockFieldFacts apply on every struct field. +type lockFieldFacts struct { + // IsMutex is true if the field is of type sync.Mutex. + IsMutex bool + + // IsRWMutex is true if the field is of type sync.RWMutex. + IsRWMutex bool + + // IsPointer indicates if the field is a pointer. + IsPointer bool + + // FieldNumber is the number of this field in the struct. + FieldNumber int +} + +// AFact implements analysis.Fact.AFact. +func (*lockFieldFacts) AFact() {} + +// lockGuardFacts contains guard information. +type lockGuardFacts struct { + // GuardedBy is the set of locks that are guarding this field. The key + // is the original annotation value, and the field list is the object + // traversal path. + GuardedBy map[string]fieldList + + // AtomicDisposition is the disposition for this field. Note that this + // can affect the interpretation of the GuardedBy field above, see the + // relevant comment. + AtomicDisposition atomicDisposition +} + +// AFact implements analysis.Fact.AFact. +func (*lockGuardFacts) AFact() {} + +// functionGuard is used by lockFunctionFacts, below. +type functionGuard struct { + // ParameterNumber is the index of the object that contains the + // guarding mutex. From this parameter, a walk is performed + // subsequently using the resolve method. + // + // Note that is ParameterNumber is beyond the size of parameters, then + // it may return to a return value. This applies only for the Acquires + // relation below. + ParameterNumber int + + // NeedsExtract is used in the case of a return value, and indicates + // that the field must be extracted from a tuple. + NeedsExtract bool + + // FieldList is the traversal path to the object. + FieldList fieldList +} + +// resolveReturn resolves a return value. +// +// Precondition: rv is either an ssa.Value, or an *ssa.Return. +func (fg *functionGuard) resolveReturn(rv interface{}, args int) resolvedValue { + if rv == nil { + // For defers and other objects, this may be nil. This is + // handled in state.go in the actual lock checking logic. + return resolvedValue{ + value: nil, + valid: false, + } + } + index := fg.ParameterNumber - args + // If this is a *ssa.Return object, i.e. we are analyzing the function + // and not the call site, then we can just pull the result directly. + if r, ok := rv.(*ssa.Return); ok { + return fg.FieldList.resolve(r.Results[index]) + } + if fg.NeedsExtract { + // Resolve on the extracted field, this is necessary if the + // type here is not an explicit return. Note that rv must be an + // ssa.Value, since it is not an *ssa.Return. + v, ok := findExtract(rv.(ssa.Value), index) + if !ok { + return resolvedValue{ + value: v, + valid: false, + } + } + return fg.FieldList.resolve(v) + } + if index != 0 { + // This should not happen, NeedsExtract should always be set. + panic("NeedsExtract is false, but return value index is non-zero") + } + // Resolve on the single return. + return fg.FieldList.resolve(rv.(ssa.Value)) +} + +// resolveStatic returns an ssa.Value representing the given field. +// +// Precondition: per resolveReturn. +func (fg *functionGuard) resolveStatic(fn *ssa.Function, rv interface{}) resolvedValue { + if fg.ParameterNumber >= len(fn.Params) { + return fg.resolveReturn(rv, len(fn.Params)) + } + return fg.FieldList.resolve(fn.Params[fg.ParameterNumber]) +} + +// resolveCall returns an ssa.Value representing the given field. +func (fg *functionGuard) resolveCall(args []ssa.Value, rv ssa.Value) resolvedValue { + if fg.ParameterNumber >= len(args) { + return fg.resolveReturn(rv, len(args)) + } + return fg.FieldList.resolve(args[fg.ParameterNumber]) +} + +// lockFunctionFacts apply on every method. +type lockFunctionFacts struct { + // HeldOnEntry tracks the names and number of parameter (including receiver) + // lockFuncfields that guard calls to this function. + // + // The key is the name specified in the checklocks annotation. e.g given + // the following code: + // + // ``` + // type A struct { + // mu sync.Mutex + // a int + // } + // + // // +checklocks:a.mu + // func xyz(a *A) {..} + // ``` + // + // '`+checklocks:a.mu' will result in an entry in this map as shown below. + // HeldOnEntry: {"a.mu" => {ParameterNumber: 0, FieldNumbers: {0}} + // + // Unlikely lockFieldFacts, there is no atomic interpretation. + HeldOnEntry map[string]functionGuard + + // HeldOnExit tracks the locks that are expected to be held on exit. + HeldOnExit map[string]functionGuard + + // Ignore means this function has local analysis ignores. + // + // This is not used outside the local package. + Ignore bool +} + +// AFact implements analysis.Fact.AFact. +func (*lockFunctionFacts) AFact() {} + +// checkGuard validates the guardName. +func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, allowReturn bool) (functionGuard, bool) { + if _, ok := lff.HeldOnEntry[guardName]; ok { + pc.maybeFail(d.Pos(), "annotation %s specified more than once, already required", guardName) + return functionGuard{}, false + } + if _, ok := lff.HeldOnExit[guardName]; ok { + pc.maybeFail(d.Pos(), "annotation %s specified more than once, already acquired", guardName) + return functionGuard{}, false + } + fg, ok := pc.findFunctionGuard(d, guardName, allowReturn) + return fg, ok +} + +// addGuardedBy adds a field to both HeldOnEntry and HeldOnExit. +func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, guardName string) { + if fg, ok := lff.checkGuard(pc, d, guardName, false /* allowReturn */); ok { + if lff.HeldOnEntry == nil { + lff.HeldOnEntry = make(map[string]functionGuard) + } + if lff.HeldOnExit == nil { + lff.HeldOnExit = make(map[string]functionGuard) + } + lff.HeldOnEntry[guardName] = fg + lff.HeldOnExit[guardName] = fg + } +} + +// addAcquires adds a field to HeldOnExit. +func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guardName string) { + if fg, ok := lff.checkGuard(pc, d, guardName, true /* allowReturn */); ok { + if lff.HeldOnExit == nil { + lff.HeldOnExit = make(map[string]functionGuard) + } + lff.HeldOnExit[guardName] = fg + } +} + +// addReleases adds a field to HeldOnEntry. +func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guardName string) { + if fg, ok := lff.checkGuard(pc, d, guardName, false /* allowReturn */); ok { + if lff.HeldOnEntry == nil { + lff.HeldOnEntry = make(map[string]functionGuard) + } + lff.HeldOnEntry[guardName] = fg + } +} + +// fieldListFor returns the fieldList for the given object. +func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool) (int, bool) { + var lff lockFieldFacts + if !pc.pass.ImportObjectFact(fieldObj, &lff) { + // This should not happen: we export facts for all fields. + panic(fmt.Sprintf("no lockFieldFacts available for field %s", fieldName)) + } + // Check that it is indeed a mutex. + if checkMutex && !lff.IsMutex && !lff.IsRWMutex { + pc.maybeFail(pos, "field %s is not a mutex or an rwmutex", fieldName) + return 0, false + } + // Return the resolution path. + if lff.IsPointer { + return -(index + 1), true + } + return (index + 1), true +} + +// resolveOneField resolves a field in a single struct. +func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, fieldName string, checkMutex bool) (fl fieldList, fieldObj types.Object, ok bool) { + // Scan to match the next field. + for i := 0; i < structType.NumFields(); i++ { + fieldObj := structType.Field(i) + if fieldObj.Name() != fieldName { + continue + } + flOne, ok := pc.fieldListFor(pos, fieldObj, i, fieldName, checkMutex) + if !ok { + return nil, nil, false + } + fl = append(fl, flOne) + return fl, fieldObj, true + } + // Is this an embed? + for i := 0; i < structType.NumFields(); i++ { + fieldObj := structType.Field(i) + if !fieldObj.Embedded() { + continue + } + // Is this an embedded struct? + structType, ok := resolveStruct(fieldObj.Type()) + if !ok { + continue + } + // Need to check that there is a resolution path. If there is + // no resolution path that's not a failure: we just continue + // scanning the next embed to find a match. + flEmbed, okEmbed := pc.fieldListFor(pos, fieldObj, i, fieldName, false) + flCont, fieldObjCont, okCont := pc.resolveOneField(pos, structType, fieldName, checkMutex) + if okEmbed && okCont { + fl = append(fl, flEmbed) + fl = append(fl, flCont...) + return fl, fieldObjCont, true + } + } + pc.maybeFail(pos, "field %s does not exist", fieldName) + return nil, nil, false +} + +// resolveField resolves a set of fields given a string, such a 'a.b.c'. +// +// Note that this checks that the final element is a mutex of some kind, and +// will fail appropriately. +func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, parts []string) (fl fieldList, ok bool) { + for partNumber, fieldName := range parts { + flOne, fieldObj, ok := pc.resolveOneField(pos, structType, fieldName, partNumber >= len(parts)-1 /* checkMutex */) + if !ok { + // Error already reported. + return nil, false + } + fl = append(fl, flOne...) + if partNumber < len(parts)-1 { + // Traverse to the next type. + structType, ok = resolveStruct(fieldObj.Type()) + if !ok { + pc.maybeFail(pos, "invalid intermediate field %s", fieldName) + return fl, false + } + } + } + return fl, true +} + +var ( + mutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineMutex|Mutex)") + rwMutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineRWMutex|RWMutex)") +) + +// exportLockFieldFacts finds all struct fields that are mutexes, and ensures +// that they are annotated properly. +// +// This information is consumed subsequently by exportLockGuardFacts, and this +// function must be called first on all structures. +func (pc *passContext) exportLockFieldFacts(structType *types.Struct, ss *ast.StructType) { + for i, field := range ss.Fields.List { + lff := &lockFieldFacts{ + FieldNumber: i, + } + // We use HasSuffix below because fieldType can be fully + // qualified with the package name eg for the gvisor sync + // package mutex fields have the type: + // "<package path>/sync/sync.Mutex" + fieldObj := structType.Field(i) + s := fieldObj.Type().String() + switch { + case mutexRE.MatchString(s): + lff.IsMutex = true + case rwMutexRE.MatchString(s): + lff.IsRWMutex = true + } + // Save whether this is a pointer. + _, lff.IsPointer = fieldObj.Type().Underlying().(*types.Pointer) + // We must always export the lockFieldFacts, since traversal + // can take place along any object in the struct. + pc.pass.ExportObjectFact(fieldObj, lff) + // If this is an anonymous type, then we won't discover it via + // the AST global declarations. We can recurse from here. + if ss, ok := field.Type.(*ast.StructType); ok { + if st, ok := fieldObj.Type().(*types.Struct); ok { + pc.exportLockFieldFacts(st, ss) + } + } + } +} + +// exportLockGuardFacts finds all relevant guard information for structures. +// +// This function requires exportLockFieldFacts be called first on all +// structures. +func (pc *passContext) exportLockGuardFacts(structType *types.Struct, ss *ast.StructType) { + for i, field := range ss.Fields.List { + fieldObj := structType.Field(i) + if field.Doc != nil { + var ( + lff lockFieldFacts + lgf lockGuardFacts + ) + pc.pass.ImportObjectFact(structType.Field(i), &lff) + for _, l := range field.Doc.List { + pc.extractAnnotations(l.Text, map[string]func(string){ + checkAtomicAnnotation: func(string) { + switch lgf.AtomicDisposition { + case atomicRequired: + pc.maybeFail(fieldObj.Pos(), "annotation is redundant, already atomic required") + case atomicIgnore: + pc.maybeFail(fieldObj.Pos(), "annotation is contradictory, already atomic ignored") + } + lgf.AtomicDisposition = atomicRequired + }, + checkLocksIgnore: func(string) { + switch lgf.AtomicDisposition { + case atomicIgnore: + pc.maybeFail(fieldObj.Pos(), "annotation is redundant, already atomic ignored") + case atomicRequired: + pc.maybeFail(fieldObj.Pos(), "annotation is contradictory, already atomic required") + } + lgf.AtomicDisposition = atomicIgnore + }, + checkLocksAnnotation: func(guardName string) { + // Check for a duplicate annotation. + if _, ok := lgf.GuardedBy[guardName]; ok { + pc.maybeFail(fieldObj.Pos(), "annotation %s specified more than once", guardName) + return + } + fl, ok := pc.resolveField(fieldObj.Pos(), structType, strings.Split(guardName, ".")) + if ok { + // If we successfully resolved + // the field, then save it. + if lgf.GuardedBy == nil { + lgf.GuardedBy = make(map[string]fieldList) + } + lgf.GuardedBy[guardName] = fl + } + }, + }) + } + // Save only if there is something meaningful. + if len(lgf.GuardedBy) > 0 || lgf.AtomicDisposition != atomicDisallow { + pc.pass.ExportObjectFact(structType.Field(i), &lgf) + } + } + // See above, for anonymous structure fields. + if ss, ok := field.Type.(*ast.StructType); ok { + if st, ok := fieldObj.Type().(*types.Struct); ok { + pc.exportLockGuardFacts(st, ss) + } + } + } +} + +// countFields gives an accurate field count, according for unnamed arguments +// and return values and the compact identifier format. +func countFields(fl []*ast.Field) (count int) { + for _, field := range fl { + if len(field.Names) == 0 { + count++ + continue + } + count += len(field.Names) + } + return +} + +// matchFieldList attempts to match the given field. +func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName string) (functionGuard, bool) { + parts := strings.Split(guardName, ".") + parameterName := parts[0] + parameterNumber := 0 + for _, field := range fl { + // See countFields, above. + if len(field.Names) == 0 { + parameterNumber++ + continue + } + for _, name := range field.Names { + if name.Name != parameterName { + parameterNumber++ + continue + } + ptrType, ok := pc.pass.TypesInfo.TypeOf(field.Type).Underlying().(*types.Pointer) + if !ok { + // Since mutexes cannot be copied we only care + // about parameters that are pointer types when + // checking for guards. + pc.maybeFail(pos, "parameter name %s does not refer to a pointer type", parameterName) + return functionGuard{}, false + } + structType, ok := ptrType.Elem().Underlying().(*types.Struct) + if !ok { + // Fields can only be in named structures. + pc.maybeFail(pos, "parameter name %s does not refer to a pointer to a struct", parameterName) + return functionGuard{}, false + } + fg := functionGuard{ + ParameterNumber: parameterNumber, + } + fl, ok := pc.resolveField(pos, structType, parts[1:]) + fg.FieldList = fl + return fg, ok // If ok is false, already failed. + } + } + return functionGuard{}, false +} + +// findFunctionGuard identifies the parameter number and field number for a +// particular string of the 'a.b'. +// +// This function will report any errors directly. +func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, allowReturn bool) (functionGuard, bool) { + var ( + parameterList []*ast.Field + returnList []*ast.Field + ) + if d.Recv != nil { + parameterList = append(parameterList, d.Recv.List...) + } + if d.Type.Params != nil { + parameterList = append(parameterList, d.Type.Params.List...) + } + if fg, ok := pc.matchFieldList(d.Pos(), parameterList, guardName); ok { + return fg, ok + } + if allowReturn { + if d.Type.Results != nil { + returnList = append(returnList, d.Type.Results.List...) + } + if fg, ok := pc.matchFieldList(d.Pos(), returnList, guardName); ok { + // Fix this up to apply to the return value, as noted + // in fg.ParameterNumber. For the ssa analysis, we must + // record whether this has multiple results, since + // *ssa.Call indicates: "The Call instruction yields + // the function result if there is exactly one. + // Otherwise it returns a tuple, the components of + // which are accessed via Extract." + fg.ParameterNumber += countFields(parameterList) + fg.NeedsExtract = countFields(returnList) > 1 + return fg, ok + } + } + // We never saw a matching parameter. + pc.maybeFail(d.Pos(), "annotation %s does not have a matching parameter", guardName) + return functionGuard{}, false +} + +// exportFunctionFacts exports relevant function findings. +func (pc *passContext) exportFunctionFacts(d *ast.FuncDecl) { + if d.Doc == nil || d.Doc.List == nil { + return + } + var lff lockFunctionFacts + for _, l := range d.Doc.List { + pc.extractAnnotations(l.Text, map[string]func(string){ + checkLocksIgnore: func(string) { + // Note that this applies to all atomic + // analysis as well. There is no provided way + // to selectively ignore only lock analysis or + // atomic analysis, as we expect this use to be + // extremely rare. + lff.Ignore = true + }, + checkLocksAnnotation: func(guardName string) { lff.addGuardedBy(pc, d, guardName) }, + checkLocksAcquires: func(guardName string) { lff.addAcquires(pc, d, guardName) }, + checkLocksReleases: func(guardName string) { lff.addReleases(pc, d, guardName) }, + }) + } + + // Export the function facts if there is anything to save. + if lff.Ignore || len(lff.HeldOnEntry) > 0 || len(lff.HeldOnExit) > 0 { + funcObj := pc.pass.TypesInfo.Defs[d.Name].(*types.Func) + pc.pass.ExportObjectFact(funcObj, &lff) + } +} diff --git a/tools/checklocks/state.go b/tools/checklocks/state.go new file mode 100644 index 000000000..57061a32e --- /dev/null +++ b/tools/checklocks/state.go @@ -0,0 +1,315 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package checklocks + +import ( + "fmt" + "go/token" + "go/types" + "strings" + "sync/atomic" + + "golang.org/x/tools/go/ssa" +) + +// lockState tracks the locking state and aliases. +type lockState struct { + // lockedMutexes is used to track which mutexes in a given struct are + // currently locked. Note that most of the heavy lifting is done by + // valueAsString below, which maps to specific structure fields, etc. + lockedMutexes []string + + // stored stores values that have been stored in memory, bound to + // FreeVars or passed as Parameterse. + stored map[ssa.Value]ssa.Value + + // used is a temporary map, used only for valueAsString. It prevents + // multiple use of the same memory location. + used map[ssa.Value]struct{} + + // defers are the stack of defers that have been pushed. + defers []*ssa.Defer + + // refs indicates the number of references on this structure. If it's + // greater than one, we will do copy-on-write. + refs *int32 +} + +// newLockState makes a new lockState. +func newLockState() *lockState { + refs := int32(1) // Not shared. + return &lockState{ + lockedMutexes: make([]string, 0), + used: make(map[ssa.Value]struct{}), + stored: make(map[ssa.Value]ssa.Value), + defers: make([]*ssa.Defer, 0), + refs: &refs, + } +} + +// fork forks the locking state. When a lockState is forked, any modifications +// will cause maps to be copied. +func (l *lockState) fork() *lockState { + if l == nil { + return newLockState() + } + atomic.AddInt32(l.refs, 1) + return &lockState{ + lockedMutexes: l.lockedMutexes, + used: make(map[ssa.Value]struct{}), + stored: l.stored, + defers: l.defers, + refs: l.refs, + } +} + +// modify indicates that this state will be modified. +func (l *lockState) modify() { + if atomic.LoadInt32(l.refs) > 1 { + // Copy the lockedMutexes. + lm := make([]string, len(l.lockedMutexes)) + copy(lm, l.lockedMutexes) + l.lockedMutexes = lm + + // Copy the stored values. + s := make(map[ssa.Value]ssa.Value) + for k, v := range l.stored { + s[k] = v + } + l.stored = s + + // Reset the used values. + l.used = make(map[ssa.Value]struct{}) + + // Copy the defers. + ds := make([]*ssa.Defer, len(l.defers)) + copy(ds, l.defers) + l.defers = ds + + // Drop our reference. + atomic.AddInt32(l.refs, -1) + newRefs := int32(1) // Not shared. + l.refs = &newRefs + } +} + +// isHeld indicates whether the field is held is not. +func (l *lockState) isHeld(rv resolvedValue) (string, bool) { + if !rv.valid { + return rv.valueAsString(l), false + } + s := rv.valueAsString(l) + for _, k := range l.lockedMutexes { + if k == s { + return s, true + } + } + return s, false +} + +// lockField locks the given field. +// +// If false is returned, the field was already locked. +func (l *lockState) lockField(rv resolvedValue) (string, bool) { + if !rv.valid { + return rv.valueAsString(l), false + } + s := rv.valueAsString(l) + for _, k := range l.lockedMutexes { + if k == s { + return s, false + } + } + l.modify() + l.lockedMutexes = append(l.lockedMutexes, s) + return s, true +} + +// unlockField unlocks the given field. +// +// If false is returned, the field was not locked. +func (l *lockState) unlockField(rv resolvedValue) (string, bool) { + if !rv.valid { + return rv.valueAsString(l), false + } + s := rv.valueAsString(l) + for i, k := range l.lockedMutexes { + if k == s { + // Copy the last lock in and truncate. + l.modify() + l.lockedMutexes[i] = l.lockedMutexes[len(l.lockedMutexes)-1] + l.lockedMutexes = l.lockedMutexes[:len(l.lockedMutexes)-1] + return s, true + } + } + return s, false +} + +// store records an alias. +func (l *lockState) store(addr ssa.Value, v ssa.Value) { + l.modify() + l.stored[addr] = v +} + +// isSubset indicates other holds all the locks held by l. +func (l *lockState) isSubset(other *lockState) bool { + held := 0 // Number in l, held by other. + for _, k := range l.lockedMutexes { + for _, ok := range other.lockedMutexes { + if k == ok { + held++ + break + } + } + } + return held >= len(l.lockedMutexes) +} + +// count indicates the number of locks held. +func (l *lockState) count() int { + return len(l.lockedMutexes) +} + +// isCompatible returns true if the states are compatible. +func (l *lockState) isCompatible(other *lockState) bool { + return l.isSubset(other) && other.isSubset(l) +} + +// elemType is a type that implements the Elem function. +type elemType interface { + Elem() types.Type +} + +// valueAsString returns a string for a given value. +// +// This decomposes the value into the simplest possible representation in terms +// of parameters, free variables and globals. During resolution, stored values +// may be transferred, as well as bound free variables. +// +// Nil may not be passed here. +func (l *lockState) valueAsString(v ssa.Value) string { + switch x := v.(type) { + case *ssa.Parameter: + // Was this provided as a paramter for a local anonymous + // function invocation? + v, ok := l.stored[x] + if ok { + return l.valueAsString(v) + } + return fmt.Sprintf("{param:%s}", x.Name()) + case *ssa.Global: + return fmt.Sprintf("{global:%s}", x.Name()) + case *ssa.FreeVar: + // Attempt to resolve this, in case we are being invoked in a + // scope where all the variables are bound. + v, ok := l.stored[x] + if ok { + // The FreeVar is typically bound to a location, so we + // check what's been stored there. Note that the second + // may map to the same FreeVar, which we can check. + stored, ok := l.stored[v] + if ok { + return l.valueAsString(stored) + } + } + return fmt.Sprintf("{freevar:%s}", x.Name()) + case *ssa.Convert: + // Just disregard conversion. + return l.valueAsString(x.X) + case *ssa.ChangeType: + // Ditto, disregard. + return l.valueAsString(x.X) + case *ssa.UnOp: + if x.Op != token.MUL { + break + } + // Is this loading a free variable? If yes, then this can be + // resolved in the original isAlias function. + if fv, ok := x.X.(*ssa.FreeVar); ok { + return l.valueAsString(fv) + } + // Should be try to resolve via a memory address? This needs to + // be done since a memory location can hold its own value. + if _, ok := l.used[x.X]; !ok { + // Check if we know what the accessed location holds. + // This is used to disambiguate memory locations. + v, ok := l.stored[x.X] + if ok { + l.used[x.X] = struct{}{} + defer func() { delete(l.used, x.X) }() + return l.valueAsString(v) + } + } + // x.X.Type is pointer. We must construct this type + // dynamically, since the ssa.Value could be synthetic. + return fmt.Sprintf("*(%s)", l.valueAsString(x.X)) + case *ssa.Field: + structType, ok := resolveStruct(x.X.Type()) + if !ok { + // This should not happen. + panic(fmt.Sprintf("structType not available for struct: %#v", x.X)) + } + fieldObj := structType.Field(x.Field) + return fmt.Sprintf("%s.%s", l.valueAsString(x.X), fieldObj.Name()) + case *ssa.FieldAddr: + structType, ok := resolveStruct(x.X.Type()) + if !ok { + // This should not happen. + panic(fmt.Sprintf("structType not available for struct: %#v", x.X)) + } + fieldObj := structType.Field(x.Field) + return fmt.Sprintf("&(%s.%s)", l.valueAsString(x.X), fieldObj.Name()) + case *ssa.Index: + return fmt.Sprintf("%s[%s]", l.valueAsString(x.X), l.valueAsString(x.Index)) + case *ssa.IndexAddr: + return fmt.Sprintf("&(%s[%s])", l.valueAsString(x.X), l.valueAsString(x.Index)) + case *ssa.Lookup: + return fmt.Sprintf("%s[%s]", l.valueAsString(x.X), l.valueAsString(x.Index)) + case *ssa.Extract: + return fmt.Sprintf("%s[%d]", l.valueAsString(x.Tuple), x.Index) + } + + // In the case of any other type (e.g. this may be an alloc, a return + // value, etc.), just return the literal pointer value to the Value. + // This will be unique within the ssa graph, and so if two values are + // equal, they are from the same type. + return fmt.Sprintf("{%T:%p}", v, v) +} + +// String returns the full lock state. +func (l *lockState) String() string { + if l.count() == 0 { + return "no locks held" + } + return strings.Join(l.lockedMutexes, ",") +} + +// pushDefer pushes a defer onto the stack. +func (l *lockState) pushDefer(d *ssa.Defer) { + l.modify() + l.defers = append(l.defers, d) +} + +// popDefer pops a defer from the stack. +func (l *lockState) popDefer() *ssa.Defer { + // Does not technically modify the underlying slice. + count := len(l.defers) + if count == 0 { + return nil + } + d := l.defers[count-1] + l.defers = l.defers[:count-1] + return d +} diff --git a/tools/checklocks/test/BUILD b/tools/checklocks/test/BUILD index b055e71d9..d4d98c256 100644 --- a/tools/checklocks/test/BUILD +++ b/tools/checklocks/test/BUILD @@ -4,5 +4,18 @@ package(licenses = ["notice"]) go_library( name = "test", - srcs = ["test.go"], + srcs = [ + "alignment.go", + "anon.go", + "atomics.go", + "basics.go", + "branches.go", + "closures.go", + "defer.go", + "incompat.go", + "methods.go", + "parameters.go", + "return.go", + "test.go", + ], ) diff --git a/tools/checklocks/test/alignment.go b/tools/checklocks/test/alignment.go new file mode 100644 index 000000000..cd857ff73 --- /dev/null +++ b/tools/checklocks/test/alignment.go @@ -0,0 +1,51 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +type alignedStruct32 struct { + v int32 +} + +type alignedStruct64 struct { + v int64 +} + +type alignedStructGood struct { + v0 alignedStruct32 + v1 alignedStruct32 + v2 alignedStruct64 +} + +type alignedStructGoodArray0 struct { + v0 [3]alignedStruct32 + v1 [3]alignedStruct32 + v2 alignedStruct64 +} + +type alignedStructGoodArray1 [16]alignedStructGood + +type alignedStructBad struct { + v0 alignedStruct32 + v1 alignedStruct64 + v2 alignedStruct32 +} + +type alignedStructBadArray0 struct { + v0 [3]alignedStruct32 + v1 [2]alignedStruct64 + v2 [1]alignedStruct32 +} + +type alignedStructBadArray1 [16]alignedStructBad diff --git a/tools/checklocks/test/anon.go b/tools/checklocks/test/anon.go new file mode 100644 index 000000000..a1f6bddda --- /dev/null +++ b/tools/checklocks/test/anon.go @@ -0,0 +1,35 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import "sync" + +type anonStruct struct { + anon struct { + mu sync.RWMutex + // +checklocks:mu + x int + } +} + +func testAnonAccessValid(tc *anonStruct) { + tc.anon.mu.Lock() + tc.anon.x = 1 + tc.anon.mu.Unlock() +} + +func testAnonAccessInvalid(tc *anonStruct) { + tc.anon.x = 1 // +checklocksfail +} diff --git a/tools/checklocks/test/atomics.go b/tools/checklocks/test/atomics.go new file mode 100644 index 000000000..8e060d8a2 --- /dev/null +++ b/tools/checklocks/test/atomics.go @@ -0,0 +1,91 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "sync" + "sync/atomic" +) + +type atomicStruct struct { + accessedNormally int32 + + // +checkatomic + accessedAtomically int32 + + // +checklocksignore + ignored int32 +} + +func testNormalAccess(tc *atomicStruct, v chan int32, p chan *int32) { + v <- tc.accessedNormally + p <- &tc.accessedNormally +} + +func testAtomicAccess(tc *atomicStruct, v chan int32) { + v <- atomic.LoadInt32(&tc.accessedAtomically) +} + +func testAtomicAccessInvalid(tc *atomicStruct, v chan int32) { + v <- atomic.LoadInt32(&tc.accessedNormally) // +checklocksfail +} + +func testNormalAccessInvalid(tc *atomicStruct, v chan int32, p chan *int32) { + v <- tc.accessedAtomically // +checklocksfail + p <- &tc.accessedAtomically // +checklocksfail +} + +func testIgnored(tc *atomicStruct, v chan int32, p chan *int32) { + v <- atomic.LoadInt32(&tc.ignored) + v <- tc.ignored + p <- &tc.ignored +} + +type atomicMixedStruct struct { + mu sync.Mutex + + // +checkatomic + // +checklocks:mu + accessedMixed int32 +} + +func testAtomicMixedValidRead(tc *atomicMixedStruct, v chan int32) { + v <- atomic.LoadInt32(&tc.accessedMixed) +} + +func testAtomicMixedInvalidRead(tc *atomicMixedStruct, v chan int32, p chan *int32) { + v <- tc.accessedMixed // +checklocksfail + p <- &tc.accessedMixed // +checklocksfail +} + +func testAtomicMixedValidLockedWrite(tc *atomicMixedStruct, v chan int32, p chan *int32) { + tc.mu.Lock() + atomic.StoreInt32(&tc.accessedMixed, 1) + tc.mu.Unlock() +} + +func testAtomicMixedInvalidLockedWrite(tc *atomicMixedStruct, v chan int32, p chan *int32) { + tc.mu.Lock() + tc.accessedMixed = 1 // +checklocksfail:2 + tc.mu.Unlock() +} + +func testAtomicMixedInvalidAtomicWrite(tc *atomicMixedStruct, v chan int32, p chan *int32) { + atomic.StoreInt32(&tc.accessedMixed, 1) // +checklocksfail +} + +func testAtomicMixedInvalidWrite(tc *atomicMixedStruct, v chan int32, p chan *int32) { + tc.accessedMixed = 1 // +checklocksfail:2 +} diff --git a/tools/checklocks/test/basics.go b/tools/checklocks/test/basics.go new file mode 100644 index 000000000..7a773171f --- /dev/null +++ b/tools/checklocks/test/basics.go @@ -0,0 +1,145 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "sync" +) + +func testLockedAccessValid(tc *oneGuardStruct) { + tc.mu.Lock() + tc.guardedField = 1 + tc.mu.Unlock() +} + +func testLockedAccessIgnore(tc *oneGuardStruct) { + tc.mu.Lock() + tc.unguardedField = 1 + tc.mu.Unlock() +} + +func testUnlockedAccessInvalidWrite(tc *oneGuardStruct) { + tc.guardedField = 2 // +checklocksfail +} + +func testUnlockedAccessInvalidRead(tc *oneGuardStruct) { + x := tc.guardedField // +checklocksfail + _ = x +} + +func testUnlockedAccessValid(tc *oneGuardStruct) { + tc.unguardedField = 2 +} + +func testCallValidAccess(tc *oneGuardStruct) { + callValidAccess(tc) +} + +func callValidAccess(tc *oneGuardStruct) { + tc.mu.Lock() + tc.guardedField = 1 + tc.mu.Unlock() +} + +func testCallValueMixup(tc *oneGuardStruct) { + callValueMixup(tc, tc) +} + +func callValueMixup(tc1, tc2 *oneGuardStruct) { + tc1.mu.Lock() + tc2.guardedField = 2 // +checklocksfail + tc1.mu.Unlock() +} + +func testCallPreconditionsInvalid(tc *oneGuardStruct) { + callPreconditions(tc) // +checklocksfail +} + +func testCallPreconditionsValid(tc *oneGuardStruct) { + tc.mu.Lock() + callPreconditions(tc) + tc.mu.Unlock() +} + +// +checklocks:tc.mu +func callPreconditions(tc *oneGuardStruct) { + tc.guardedField = 1 +} + +type nestedFieldsStruct struct { + mu sync.Mutex + + // +checklocks:mu + nestedStruct struct { + nested1 int + nested2 int + } +} + +func testNestedGuardValid(tc *nestedFieldsStruct) { + tc.mu.Lock() + tc.nestedStruct.nested1 = 1 + tc.nestedStruct.nested2 = 2 + tc.mu.Unlock() +} + +func testNestedGuardInvalid(tc *nestedFieldsStruct) { + tc.nestedStruct.nested1 = 1 // +checklocksfail +} + +type rwGuardStruct struct { + rwMu sync.RWMutex + + // +checklocks:rwMu + guardedField int +} + +func testRWValidRead(tc *rwGuardStruct) { + tc.rwMu.Lock() + tc.guardedField = 1 + tc.rwMu.Unlock() +} + +func testRWValidWrite(tc *rwGuardStruct) { + tc.rwMu.RLock() + tc.guardedField = 2 + tc.rwMu.RUnlock() +} + +func testRWInvalidWrite(tc *rwGuardStruct) { + tc.guardedField = 3 // +checklocksfail +} + +func testRWInvalidRead(tc *rwGuardStruct) { + x := tc.guardedField + 3 // +checklocksfail + _ = x +} + +func testTwoLocksDoubleGuardStructValid(tc *twoLocksDoubleGuardStruct) { + tc.mu.Lock() + tc.secondMu.Lock() + tc.doubleGuardedField = 1 + tc.secondMu.Unlock() +} + +func testTwoLocksDoubleGuardStructOnlyOne(tc *twoLocksDoubleGuardStruct) { + tc.mu.Lock() + tc.doubleGuardedField = 2 // +checklocksfail + tc.mu.Unlock() +} + +func testTwoLocksDoubleGuardStructInvalid(tc *twoLocksDoubleGuardStruct) { + tc.doubleGuardedField = 3 // +checklocksfail:2 +} diff --git a/tools/checklocks/test/branches.go b/tools/checklocks/test/branches.go new file mode 100644 index 000000000..81fec29e5 --- /dev/null +++ b/tools/checklocks/test/branches.go @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "math/rand" +) + +func testInconsistentReturn(tc *oneGuardStruct) { // +checklocksfail + if x := rand.Intn(10); x%2 == 1 { + tc.mu.Lock() + } +} + +func testConsistentBranching(tc *oneGuardStruct) { + x := rand.Intn(10) + if x%2 == 1 { + tc.mu.Lock() + } else { + tc.mu.Lock() + } + tc.guardedField = 1 + if x%2 == 1 { + tc.mu.Unlock() + } else { + tc.mu.Unlock() + } +} + +func testInconsistentBranching(tc *oneGuardStruct) { // +checklocksfail:2 + // We traverse the control flow graph in all consistent ways. We cannot + // determine however, that the first if block and second if block will + // evaluate to the same condition. Therefore, there are two consistent + // paths through this code, and two inconsistent paths. Either way, the + // guardedField should be also marked as an invalid access. + x := rand.Intn(10) + if x%2 == 1 { + tc.mu.Lock() + } + tc.guardedField = 1 // +checklocksfail + if x%2 == 1 { + tc.mu.Unlock() // +checklocksforce + } +} diff --git a/tools/checklocks/test/closures.go b/tools/checklocks/test/closures.go new file mode 100644 index 000000000..7da87540a --- /dev/null +++ b/tools/checklocks/test/closures.go @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +func testClosureInvalid(tc *oneGuardStruct) { + // This is expected to fail. + callClosure(func() { + tc.guardedField = 1 // +checklocksfail + }) +} + +func testClosureUnsupported(tc *oneGuardStruct) { + // Locked outside the closure, so may or may not be valid. This cannot + // be handled and we should explicitly fail. This can't be handled + // because of the call through callClosure, below, which means the + // closure will actually be passed as a value somewhere. + tc.mu.Lock() + callClosure(func() { + tc.guardedField = 1 // +checklocksfail + }) + tc.mu.Unlock() +} + +func testClosureValid(tc *oneGuardStruct) { + // All locking happens within the closure. This should not present a + // problem for analysis. + callClosure(func() { + tc.mu.Lock() + tc.guardedField = 1 + tc.mu.Unlock() + }) +} + +func testClosureInline(tc *oneGuardStruct) { + // If the closure is being dispatching inline only, then we should be + // able to analyze this call and give it a thumbs up. + tc.mu.Lock() + func() { + tc.guardedField = 1 + }() + tc.mu.Unlock() +} + +func testAnonymousInvalid(tc *oneGuardStruct) { + // Invalid, as per testClosureInvalid above. + callAnonymous(func(tc *oneGuardStruct) { + tc.guardedField = 1 // +checklocksfail + }, tc) +} + +func testAnonymousUnsupported(tc *oneGuardStruct) { + // Not supportable, as per testClosureUnsupported above. + tc.mu.Lock() + callAnonymous(func(tc *oneGuardStruct) { + tc.guardedField = 1 // +checklocksfail + }, tc) + tc.mu.Unlock() +} + +func testAnonymousValid(tc *oneGuardStruct) { + // Valid, as per testClosureValid above. + callAnonymous(func(tc *oneGuardStruct) { + tc.mu.Lock() + tc.guardedField = 1 + tc.mu.Unlock() + }, tc) +} + +func testAnonymousInline(tc *oneGuardStruct) { + // Unlike the closure case, we are able to dynamically infer the set of + // preconditions for the function dispatch and assert that this is + // a valid call. + tc.mu.Lock() + func(tc *oneGuardStruct) { + tc.guardedField = 1 + }(tc) + tc.mu.Unlock() +} + +//go:noinline +func callClosure(fn func()) { + fn() +} + +//go:noinline +func callAnonymous(fn func(*oneGuardStruct), tc *oneGuardStruct) { + fn(tc) +} diff --git a/tools/checklocks/test/defer.go b/tools/checklocks/test/defer.go new file mode 100644 index 000000000..6e574e5eb --- /dev/null +++ b/tools/checklocks/test/defer.go @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +func testDeferValidUnlock(tc *oneGuardStruct) { + tc.mu.Lock() + tc.guardedField = 1 + defer tc.mu.Unlock() +} + +func testDeferValidAccess(tc *oneGuardStruct) { + tc.mu.Lock() + defer func() { + tc.guardedField = 1 + tc.mu.Unlock() + }() +} + +func testDeferInvalidAccess(tc *oneGuardStruct) { + tc.mu.Lock() + defer func() { + // N.B. Executed after tc.mu.Unlock(). + tc.guardedField = 1 // +checklocksfail + }() + tc.mu.Unlock() +} diff --git a/tools/checklocks/test/incompat.go b/tools/checklocks/test/incompat.go new file mode 100644 index 000000000..b39bc66c1 --- /dev/null +++ b/tools/checklocks/test/incompat.go @@ -0,0 +1,54 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "sync" +) + +// unsupportedLockerStruct verifies that trying to annotate a field that is not a +// sync.Mutex or sync.RWMutex results in a failure. +type unsupportedLockerStruct struct { + mu sync.Locker + + // +checklocks:mu + x int // +checklocksfail +} + +// badFieldsStruct verifies that refering invalid fields fails. +type badFieldsStruct struct { + // +checklocks:mu + x int // +checklocksfail +} + +// redundantStruct verifies that redundant annotations fail. +type redundantStruct struct { + mu sync.Mutex + + // +checklocks:mu + // +checklocks:mu + x int // +checklocksfail +} + +// conflictsStruct verifies that conflicting annotations fail. +type conflictsStruct struct { + // +checkatomicignore + // +checkatomic + x int // +checklocksfail + + // +checkatomic + // +checkatomicignore + y int // +checklocksfail +} diff --git a/tools/checklocks/test/methods.go b/tools/checklocks/test/methods.go new file mode 100644 index 000000000..72e26fca6 --- /dev/null +++ b/tools/checklocks/test/methods.go @@ -0,0 +1,117 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "sync" +) + +type testMethods struct { + mu sync.Mutex + + // +checklocks:mu + guardedField int +} + +func (t *testMethods) methodValid() { + t.mu.Lock() + t.guardedField = 1 + t.mu.Unlock() +} + +func (t *testMethods) methodInvalid() { + t.guardedField = 2 // +checklocksfail +} + +// +checklocks:t.mu +func (t *testMethods) MethodLocked(a, b, c int) { + t.guardedField = 3 +} + +// +checklocksignore +func (t *testMethods) methodIgnore() { + t.guardedField = 2 +} + +func testMethodCallsValid(tc *testMethods) { + tc.methodValid() +} + +func testMethodCallsValidPreconditions(tc *testMethods) { + tc.mu.Lock() + tc.MethodLocked(1, 2, 3) + tc.mu.Unlock() +} + +func testMethodCallsInvalid(tc *testMethods) { + tc.MethodLocked(4, 5, 6) // +checklocksfail +} + +func testMultipleParameters(tc1, tc2, tc3 *testMethods) { + tc1.mu.Lock() + tc1.guardedField = 1 + tc2.guardedField = 2 // +checklocksfail + tc3.guardedField = 3 // +checklocksfail + tc1.mu.Unlock() +} + +type testMethodsWithParameters struct { + mu sync.Mutex + + // +checklocks:mu + guardedField int +} + +type ptrToTestMethodsWithParameters *testMethodsWithParameters + +// +checklocks:t.mu +// +checklocks:a.mu +func (t *testMethodsWithParameters) methodLockedWithParameters(a *testMethodsWithParameters, b *testMethodsWithParameters) { + t.guardedField = a.guardedField + b.guardedField = a.guardedField // +checklocksfail +} + +// +checklocks:t.mu +// +checklocks:a.mu +// +checklocks:b.mu +func (t *testMethodsWithParameters) methodLockedWithPtrType(a *testMethodsWithParameters, b ptrToTestMethodsWithParameters) { + t.guardedField = a.guardedField + b.guardedField = a.guardedField +} + +// +checklocks:a.mu +func standaloneFunctionWithGuard(a *testMethodsWithParameters) { + a.guardedField = 1 + a.mu.Unlock() + a.guardedField = 1 // +checklocksfail +} + +type testMethodsWithEmbedded struct { + mu sync.Mutex + + // +checklocks:mu + guardedField int + p *testMethodsWithParameters +} + +// +checklocks:t.mu +func (t *testMethodsWithEmbedded) DoLocked(a, b *testMethodsWithParameters) { + t.guardedField = 1 + a.mu.Lock() + b.mu.Lock() + t.p.methodLockedWithParameters(a, b) // +checklocksfail + a.mu.Unlock() + b.mu.Unlock() +} diff --git a/tools/checklocks/test/parameters.go b/tools/checklocks/test/parameters.go new file mode 100644 index 000000000..5b9e664b6 --- /dev/null +++ b/tools/checklocks/test/parameters.go @@ -0,0 +1,48 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +func testParameterPassingbyAddrValid(tc *oneGuardStruct) { + tc.mu.Lock() + nestedWithGuardByAddr(&tc.guardedField, &tc.unguardedField) + tc.mu.Unlock() +} + +func testParameterPassingByAddrInalid(tc *oneGuardStruct) { + nestedWithGuardByAddr(&tc.guardedField, &tc.unguardedField) // +checklocksfail +} + +func testParameterPassingByValueValid(tc *oneGuardStruct) { + tc.mu.Lock() + nestedWithGuardByValue(tc.guardedField, tc.unguardedField) + tc.mu.Unlock() +} + +func testParameterPassingByValueInalid(tc *oneGuardStruct) { + nestedWithGuardByValue(tc.guardedField, tc.unguardedField) // +checklocksfail +} + +func nestedWithGuardByAddr(guardedField, unguardedField *int) { + *guardedField = 4 + *unguardedField = 5 +} + +func nestedWithGuardByValue(guardedField, unguardedField int) { + // read the fields to keep SA4009 static analyzer happy. + _ = guardedField + _ = unguardedField + guardedField = 4 + unguardedField = 5 +} diff --git a/tools/checklocks/test/return.go b/tools/checklocks/test/return.go new file mode 100644 index 000000000..47c7b6773 --- /dev/null +++ b/tools/checklocks/test/return.go @@ -0,0 +1,61 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +// +checklocks:tc.mu +func testReturnInvalidGuard() (tc *oneGuardStruct) { // +checklocksfail + return new(oneGuardStruct) +} + +// +checklocksrelease:tc.mu +func testReturnInvalidRelease() (tc *oneGuardStruct) { // +checklocksfail + return new(oneGuardStruct) +} + +// +checklocksacquire:tc.mu +func testReturnInvalidAcquire() (tc *oneGuardStruct) { + return new(oneGuardStruct) // +checklocksfail +} + +// +checklocksacquire:tc.mu +func testReturnValidAcquire() (tc *oneGuardStruct) { + tc = new(oneGuardStruct) + tc.mu.Lock() + return tc +} + +func testReturnAcquireCall() { + tc := testReturnValidAcquire() + tc.guardedField = 1 + tc.mu.Unlock() +} + +// +checklocksacquire:tc.val.mu +// +checklocksacquire:tc.ptr.mu +func testReturnValidNestedAcquire() (tc *nestedGuardStruct) { + tc = new(nestedGuardStruct) + tc.ptr = new(oneGuardStruct) + tc.val.mu.Lock() + tc.ptr.mu.Lock() + return tc +} + +func testReturnNestedAcquireCall() { + tc := testReturnValidNestedAcquire() + tc.val.guardedField = 1 + tc.ptr.guardedField = 1 + tc.val.mu.Unlock() + tc.ptr.mu.Unlock() +} diff --git a/tools/checklocks/test/test.go b/tools/checklocks/test/test.go index 05693c183..cbf6b1635 100644 --- a/tools/checklocks/test/test.go +++ b/tools/checklocks/test/test.go @@ -13,99 +13,24 @@ // limitations under the License. // Package test is a test package. +// +// Tests are all compilation tests in separate files. package test import ( - "math/rand" "sync" ) -type oneGuarded struct { +// oneGuardStruct has one guarded field. +type oneGuardStruct struct { mu sync.Mutex // +checklocks:mu - guardedField int - + guardedField int unguardedField int } -func testAccessOne() { - var tc oneGuarded - // Valid access - tc.mu.Lock() - tc.guardedField = 1 - tc.unguardedField = 1 - tc.mu.Unlock() - - // Valid access as unguarded field is not protected by mu. - tc.unguardedField = 2 - - // Invalid access - tc.guardedField = 2 // +checklocksfail - - // Invalid read of a guarded field. - x := tc.guardedField // +checklocksfail - _ = x -} - -func testFunctionCallsNoParameters() { - // Couple of regular function calls with no parameters. - funcCallWithValidAccess() - funcCallWithInvalidAccess() -} - -func funcCallWithValidAccess() { - var tc2 oneGuarded - // Valid tc2 access - tc2.mu.Lock() - tc2.guardedField = 1 - tc2.mu.Unlock() -} - -func funcCallWithInvalidAccess() { - var tc oneGuarded - var tc2 oneGuarded - // Invalid access, wrong mutex is held. - tc.mu.Lock() - tc2.guardedField = 2 // +checklocksfail - tc.mu.Unlock() -} - -func testParameterPassing() { - var tc oneGuarded - - // Valid call where a guardedField is passed to a function as a parameter. - tc.mu.Lock() - nestedWithGuardByAddr(&tc.guardedField, &tc.unguardedField) - tc.mu.Unlock() - - // Invalid call where a guardedField is passed to a function as a parameter - // without holding locks. - nestedWithGuardByAddr(&tc.guardedField, &tc.unguardedField) // +checklocksfail - - // Valid call where a guardedField is passed to a function as a parameter. - tc.mu.Lock() - nestedWithGuardByValue(tc.guardedField, tc.unguardedField) - tc.mu.Unlock() - - // Invalid call where a guardedField is passed to a function as a parameter - // without holding locks. - nestedWithGuardByValue(tc.guardedField, tc.unguardedField) // +checklocksfail -} - -func nestedWithGuardByAddr(guardedField, unguardedField *int) { - *guardedField = 4 - *unguardedField = 5 -} - -func nestedWithGuardByValue(guardedField, unguardedField int) { - // read the fields to keep SA4009 static analyzer happy. - _ = guardedField - _ = unguardedField - guardedField = 4 - unguardedField = 5 -} - -type twoGuarded struct { +// twoGuardStruct has two guarded fields. +type twoGuardStruct struct { mu sync.Mutex // +checklocks:mu guardedField1 int @@ -113,250 +38,27 @@ type twoGuarded struct { guardedField2 int } -type twoLocks struct { +// twoLocksStruct has two locks and two fields. +type twoLocksStruct struct { mu sync.Mutex secondMu sync.Mutex - // +checklocks:mu guardedField1 int // +checklocks:secondMu guardedField2 int } -type twoLocksDoubleGuard struct { +// twoLocksDoubleGuardStruct has two locks and a single field with two guards. +type twoLocksDoubleGuardStruct struct { mu sync.Mutex secondMu sync.Mutex - // +checklocks:mu // +checklocks:secondMu doubleGuardedField int } -func testTwoLocksDoubleGuard() { - var tc twoLocksDoubleGuard - - // Double guarded field - tc.mu.Lock() - tc.secondMu.Lock() - tc.doubleGuardedField = 1 - tc.secondMu.Unlock() - - // This should fail as we released the secondMu. - tc.doubleGuardedField = 2 // +checklocksfail - tc.mu.Unlock() - - // This should fail as well as now we are not holding any locks. - // - // This line triggers two failures one for each mutex, hence the 2 after - // fail. - tc.doubleGuardedField = 3 // +checklocksfail:2 -} - -type rwGuarded struct { - rwMu sync.RWMutex - - // +checklocks:rwMu - rwGuardedField int -} - -func testRWGuarded() { - var tc rwGuarded - - // Assignment w/ exclusive lock should pass. - tc.rwMu.Lock() - tc.rwGuardedField = 1 - tc.rwMu.Unlock() - - // Assignment w/ RWLock should pass as we don't differentiate between - // Lock/RLock. - tc.rwMu.RLock() - tc.rwGuardedField = 2 - tc.rwMu.RUnlock() - - // Assignment w/o hold Lock() should fail. - tc.rwGuardedField = 3 // +checklocksfail - - // Reading w/o holding lock should fail. - x := tc.rwGuardedField + 3 // +checklocksfail - _ = x -} - -type nestedFields struct { - mu sync.Mutex - - // +checklocks:mu - nestedStruct struct { - nested1 int - nested2 int - } -} - -func testNestedStructGuards() { - var tc nestedFields - // Valid access with mu held. - tc.mu.Lock() - tc.nestedStruct.nested1 = 1 - tc.nestedStruct.nested2 = 2 - tc.mu.Unlock() - - // Invalid access to nested1 wihout holding mu. - tc.nestedStruct.nested1 = 1 // +checklocksfail -} - -type testCaseMethods struct { - mu sync.Mutex - - // +checklocks:mu - guardedField int -} - -func (t *testCaseMethods) Method() { - // Valid access - t.mu.Lock() - t.guardedField = 1 - t.mu.Unlock() - - // invalid access - t.guardedField = 2 // +checklocksfail -} - -// +checklocks:t.mu -func (t *testCaseMethods) MethodLocked(a, b, c int) { - t.guardedField = 3 -} - -// +checklocksignore -func (t *testCaseMethods) IgnoredMethod() { - // Invalid access but should not fail as the function is annotated - // with "// +checklocksignore" - t.guardedField = 2 -} - -func testMethodCalls() { - var tc2 testCaseMethods - - // Valid use, tc2.Method acquires lock. - tc2.Method() - - // Valid access tc2.mu is held before calling tc2.MethodLocked. - tc2.mu.Lock() - tc2.MethodLocked(1, 2, 3) - tc2.mu.Unlock() - - // Invalid access no locks are being held. - tc2.MethodLocked(4, 5, 6) // +checklocksfail -} - -type noMutex struct { - f int - g int -} - -func (n noMutex) method() { - n.f = 1 - n.f = n.g -} - -func testNoMutex() { - var n noMutex - n.method() -} - -func testMultiple() { - var tc1, tc2, tc3 testCaseMethods - - tc1.mu.Lock() - - // Valid access we are holding tc1's lock. - tc1.guardedField = 1 - - // Invalid access we are not holding tc2 or tc3's lock. - tc2.guardedField = 2 // +checklocksfail - tc3.guardedField = 3 // +checklocksfail - tc1.mu.Unlock() -} - -func testConditionalBranchingLocks() { - var tc2 testCaseMethods - x := rand.Intn(10) - if x%2 == 1 { - tc2.mu.Lock() - } - // This is invalid access as tc2.mu is not held if we never entered - // the if block. - tc2.guardedField = 1 // +checklocksfail - - var tc3 testCaseMethods - if x%2 == 1 { - tc3.mu.Lock() - } else { - tc3.mu.Lock() - } - // This is valid as tc3.mu is held in if and else blocks. - tc3.guardedField = 1 -} - -type testMethodWithParams struct { - mu sync.Mutex - - // +checklocks:mu - guardedField int -} - -type ptrToTestMethodWithParams *testMethodWithParams - -// +checklocks:t.mu -// +checklocks:a.mu -func (t *testMethodWithParams) methodLockedWithParams(a *testMethodWithParams, b *testMethodWithParams) { - t.guardedField = a.guardedField - b.guardedField = a.guardedField // +checklocksfail -} - -// +checklocks:t.mu -// +checklocks:a.mu -// +checklocks:b.mu -func (t *testMethodWithParams) methodLockedWithPtrType(a *testMethodWithParams, b ptrToTestMethodWithParams) { - t.guardedField = a.guardedField - b.guardedField = a.guardedField -} - -// +checklocks:a.mu -func standaloneFunctionWithGuard(a *testMethodWithParams) { - a.guardedField = 1 - a.mu.Unlock() - a.guardedField = 1 // +checklocksfail -} - -type testMethodWithEmbedded struct { - mu sync.Mutex - - // +checklocks:mu - guardedField int - p *testMethodWithParams -} - -// +checklocks:t.mu -func (t *testMethodWithEmbedded) DoLocked() { - var a, b testMethodWithParams - t.guardedField = 1 - a.mu.Lock() - b.mu.Lock() - t.p.methodLockedWithParams(&a, &b) // +checklocksfail - a.mu.Unlock() - b.mu.Unlock() -} - -// UnsupportedLockerExample is a test that verifies that trying to annotate a -// field that is not a sync.Mutex/RWMutex results in a failure. -type UnsupportedLockerExample struct { - mu sync.Locker - - // +checklocks:mu - x int // +checklocksfail -} - -func abc() { - var mu sync.Mutex - a := UnsupportedLockerExample{mu: &mu} - a.x = 1 +// nestedGuardStruct nests oneGuardStruct fields. +type nestedGuardStruct struct { + val oneGuardStruct + ptr *oneGuardStruct } diff --git a/tools/constraintutil/BUILD b/tools/constraintutil/BUILD new file mode 100644 index 000000000..004b708c4 --- /dev/null +++ b/tools/constraintutil/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "constraintutil", + srcs = ["constraintutil.go"], + marshal = False, + stateify = False, + visibility = ["//tools:__subpackages__"], +) + +go_test( + name = "constraintutil_test", + size = "small", + srcs = ["constraintutil_test.go"], + library = ":constraintutil", +) diff --git a/tools/constraintutil/constraintutil.go b/tools/constraintutil/constraintutil.go new file mode 100644 index 000000000..fb3fbe5c2 --- /dev/null +++ b/tools/constraintutil/constraintutil.go @@ -0,0 +1,169 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package constraintutil provides utilities for working with Go build +// constraints. +package constraintutil + +import ( + "bufio" + "bytes" + "fmt" + "go/build/constraint" + "io" + "os" + "strings" +) + +// FromReader extracts the build constraint from the Go source or assembly file +// whose contents are read by r. +func FromReader(r io.Reader) (constraint.Expr, error) { + // See go/build.parseFileHeader() for the "official" logic that this is + // derived from. + const ( + slashStar = "/*" + starSlash = "*/" + gobuildPrefix = "//go:build" + ) + s := bufio.NewScanner(r) + var ( + inSlashStar = false // between /* and */ + haveGobuild = false + e constraint.Expr + ) +Lines: + for s.Scan() { + line := bytes.TrimSpace(s.Bytes()) + if !inSlashStar && constraint.IsGoBuild(string(line)) { + if haveGobuild { + return nil, fmt.Errorf("multiple go:build directives") + } + haveGobuild = true + var err error + e, err = constraint.Parse(string(line)) + if err != nil { + return nil, err + } + } + ThisLine: + for len(line) > 0 { + if inSlashStar { + if i := bytes.Index(line, []byte(starSlash)); i >= 0 { + inSlashStar = false + line = bytes.TrimSpace(line[i+len(starSlash):]) + continue ThisLine + } + continue Lines + } + if bytes.HasPrefix(line, []byte("//")) { + continue Lines + } + // Note that if /* appears in the line, but not at the beginning, + // then the line is still non-empty, so skipping this and + // terminating below is correct. + if bytes.HasPrefix(line, []byte(slashStar)) { + inSlashStar = true + line = bytes.TrimSpace(line[len(slashStar):]) + continue ThisLine + } + // A non-empty non-comment line terminates scanning for go:build. + break Lines + } + } + return e, s.Err() +} + +// FromString extracts the build constraint from the Go source or assembly file +// containing the given data. If no build constraint applies to the file, it +// returns nil. +func FromString(str string) (constraint.Expr, error) { + return FromReader(strings.NewReader(str)) +} + +// FromFile extracts the build constraint from the Go source or assembly file +// at the given path. If no build constraint applies to the file, it returns +// nil. +func FromFile(path string) (constraint.Expr, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + return FromReader(f) +} + +// Combine returns a constraint.Expr that evaluates to true iff all expressions +// in es evaluate to true. If es is empty, Combine returns nil. +// +// Preconditions: All constraint.Exprs in es are non-nil. +func Combine(es []constraint.Expr) constraint.Expr { + switch len(es) { + case 0: + return nil + case 1: + return es[0] + default: + a := &constraint.AndExpr{es[0], es[1]} + for i := 2; i < len(es); i++ { + a = &constraint.AndExpr{a, es[i]} + } + return a + } +} + +// CombineFromFiles returns a build constraint expression that evaluates to +// true iff the build constraints from all of the given Go source or assembly +// files evaluate to true. If no build constraints apply to any of the given +// files, it returns nil. +func CombineFromFiles(paths []string) (constraint.Expr, error) { + var es []constraint.Expr + for _, path := range paths { + e, err := FromFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read build constraints from %q: %v", path, err) + } + if e != nil { + es = append(es, e) + } + } + return Combine(es), nil +} + +// Lines returns a string containing build constraint directives for the given +// constraint.Expr, including two trailing newlines, as appropriate for a Go +// source or assembly file. At least a go:build directive will be emitted; if +// the constraint is expressible using +build directives as well, then +build +// directives will also be emitted. +// +// If e is nil, Lines returns the empty string. +func Lines(e constraint.Expr) string { + if e == nil { + return "" + } + + var b strings.Builder + b.WriteString("//go:build ") + b.WriteString(e.String()) + b.WriteByte('\n') + + if pblines, err := constraint.PlusBuildLines(e); err == nil { + for _, line := range pblines { + b.WriteString(line) + b.WriteByte('\n') + } + } + + b.WriteByte('\n') + return b.String() +} diff --git a/tools/constraintutil/constraintutil_test.go b/tools/constraintutil/constraintutil_test.go new file mode 100644 index 000000000..eeabd8dcf --- /dev/null +++ b/tools/constraintutil/constraintutil_test.go @@ -0,0 +1,138 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package constraintutil + +import ( + "go/build/constraint" + "testing" +) + +func TestFileParsing(t *testing.T) { + for _, test := range []struct { + name string + data string + expr string + }{ + { + name: "Empty", + }, + { + name: "NoConstraint", + data: "// copyright header\n\npackage main", + }, + { + name: "ConstraintOnFirstLine", + data: "//go:build amd64\n#include \"textflag.h\"", + expr: "amd64", + }, + { + name: "ConstraintAfterSlashSlashComment", + data: "// copyright header\n\n//go:build linux\n\npackage newlib", + expr: "linux", + }, + { + name: "ConstraintAfterSlashStarComment", + data: "/*\ncopyright header\n*/\n\n//go:build !race\n\npackage oldlib", + expr: "!race", + }, + { + name: "ConstraintInSlashSlashComment", + data: "// blah blah //go:build windows", + }, + { + name: "ConstraintInSlashStarComment", + data: "/*\n//go:build windows\n*/", + }, + { + name: "ConstraintAfterPackageClause", + data: "package oops\n//go:build race", + }, + { + name: "ConstraintAfterCppInclude", + data: "#include \"textflag.h\"\n//go:build arm64", + }, + } { + t.Run(test.name, func(t *testing.T) { + e, err := FromString(test.data) + if err != nil { + t.Fatalf("FromString(%q) failed: %v", test.data, err) + } + if e == nil { + if len(test.expr) != 0 { + t.Errorf("FromString(%q): got no constraint, wanted %q", test.data, test.expr) + } + } else { + got := e.String() + if len(test.expr) == 0 { + t.Errorf("FromString(%q): got %q, wanted no constraint", test.data, got) + } else if got != test.expr { + t.Errorf("FromString(%q): got %q, wanted %q", test.data, got, test.expr) + } + } + }) + } +} + +func TestCombine(t *testing.T) { + for _, test := range []struct { + name string + in []string + out string + }{ + { + name: "0", + }, + { + name: "1", + in: []string{"amd64 || arm64"}, + out: "amd64 || arm64", + }, + { + name: "2", + in: []string{"amd64", "amd64 && linux"}, + out: "amd64 && amd64 && linux", + }, + { + name: "3", + in: []string{"amd64", "amd64 || arm64", "amd64 || riscv64"}, + out: "amd64 && (amd64 || arm64) && (amd64 || riscv64)", + }, + } { + t.Run(test.name, func(t *testing.T) { + inexprs := make([]constraint.Expr, 0, len(test.in)) + for _, estr := range test.in { + line := "//go:build " + estr + e, err := constraint.Parse(line) + if err != nil { + t.Fatalf("constraint.Parse(%q) failed: %v", line, err) + } + inexprs = append(inexprs, e) + } + outexpr := Combine(inexprs) + if outexpr == nil { + if len(test.out) != 0 { + t.Errorf("Combine(%v): got no constraint, wanted %q", test.in, test.out) + } + } else { + got := outexpr.String() + if len(test.out) == 0 { + t.Errorf("Combine(%v): got %q, wanted no constraint", test.in, got) + } else if got != test.out { + t.Errorf("Combine(%v): got %q, wanted %q", test.in, got, test.out) + } + } + }) + } +} diff --git a/tools/defs.bzl b/tools/defs.bzl index 27542a2f5..f4266e1de 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -9,7 +9,7 @@ load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") load("//tools/nogo:defs.bzl", "nogo_test") load("//tools/bazeldefs:defs.bzl", _arch_genrule = "arch_genrule", _build_test = "build_test", _bzl_library = "bzl_library", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _more_shards = "more_shards", _most_shards = "most_shards", _proto_library = "proto_library", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _version = "version") -load("//tools/bazeldefs:cc.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _gbenchmark = "gbenchmark", _grpcpp = "grpcpp", _gtest = "gtest", _vdso_linker_option = "vdso_linker_option") +load("//tools/bazeldefs:cc.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _gbenchmark = "gbenchmark", _gbenchmark_internal = "gbenchmark_internal", _grpcpp = "grpcpp", _gtest = "gtest", _vdso_linker_option = "vdso_linker_option") load("//tools/bazeldefs:go.bzl", _bazel_worker_proto = "bazel_worker_proto", _gazelle = "gazelle", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_rule = "go_rule", _go_test = "go_test", _select_goarch = "select_goarch", _select_goos = "select_goos") load("//tools/bazeldefs:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") @@ -37,6 +37,7 @@ cc_library = _cc_library cc_test = _cc_test cc_toolchain = _cc_toolchain gbenchmark = _gbenchmark +gbenchmark_internal = _gbenchmark_internal gtest = _gtest grpcpp = _grpcpp vdso_linker_option = _vdso_linker_option diff --git a/tools/go_fieldenum/BUILD b/tools/go_fieldenum/BUILD new file mode 100644 index 000000000..2bfdaeb2f --- /dev/null +++ b/tools/go_fieldenum/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "bzl_library", "go_binary") + +licenses(["notice"]) + +go_binary( + name = "fieldenum", + srcs = ["main.go"], + visibility = ["//:sandbox"], +) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/tools/go_fieldenum/defs.bzl b/tools/go_fieldenum/defs.bzl new file mode 100644 index 000000000..0cd2679ca --- /dev/null +++ b/tools/go_fieldenum/defs.bzl @@ -0,0 +1,29 @@ +"""The go_fieldenum target infers Field, Fields, and FieldSet types for each +struct in an input source file marked +fieldenum. +""" + +def _go_fieldenum_impl(ctx): + output = ctx.outputs.out + + args = ["-pkg=%s" % ctx.attr.package, "-out=%s" % output.path] + for src in ctx.attr.srcs: + args += [f.path for f in src.files.to_list()] + + ctx.actions.run( + inputs = ctx.files.srcs, + outputs = [output], + mnemonic = "GoFieldenum", + progress_message = "Generating Go field enumerators %s" % ctx.label, + arguments = args, + executable = ctx.executable._tool, + ) + +go_fieldenum = rule( + implementation = _go_fieldenum_impl, + attrs = { + "srcs": attr.label_list(doc = "input source files", mandatory = True, allow_files = True), + "package": attr.string(doc = "the package for the generated source file", mandatory = True), + "out": attr.output(doc = "output file", mandatory = True), + "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_fieldenum:fieldenum")), + }, +) diff --git a/tools/go_fieldenum/main.go b/tools/go_fieldenum/main.go new file mode 100644 index 000000000..68dfdb3db --- /dev/null +++ b/tools/go_fieldenum/main.go @@ -0,0 +1,310 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary fieldenum emits field bitmasks for all structs in a package marked +// "+fieldenum". +package main + +import ( + "flag" + "fmt" + "go/ast" + "go/parser" + "go/token" + "log" + "os" + "strings" +) + +var ( + outputPkg = flag.String("pkg", "", "output package") + outputFilename = flag.String("out", "-", "output filename") +) + +func main() { + // Parse command line arguments. + flag.Parse() + if len(*outputPkg) == 0 { + log.Fatalf("-pkg must be provided") + } + if len(flag.Args()) == 0 { + log.Fatalf("Input files must be provided") + } + + // Parse input files. + inputFiles := make([]*ast.File, 0, len(flag.Args())) + fset := token.NewFileSet() + for _, filename := range flag.Args() { + f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) + if err != nil { + log.Fatalf("Failed to parse input file %q: %v", filename, err) + } + inputFiles = append(inputFiles, f) + } + + // Determine which types are marked "+fieldenum" and will consequently have + // code generated. + fieldEnumTypes := make(map[string]fieldEnumTypeInfo) + for _, f := range inputFiles { + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok || d.Tok != token.TYPE || d.Doc == nil || len(d.Specs) == 0 { + continue + } + for _, l := range d.Doc.List { + const fieldenumPrefixWithSpace = "// +fieldenum " + if l.Text == "// +fieldenum" || strings.HasPrefix(l.Text, fieldenumPrefixWithSpace) { + spec := d.Specs[0].(*ast.TypeSpec) + name := spec.Name.Name + prefix := name + if len(l.Text) > len(fieldenumPrefixWithSpace) { + prefix = strings.TrimSpace(l.Text[len(fieldenumPrefixWithSpace):]) + } + st, ok := spec.Type.(*ast.StructType) + if !ok { + log.Fatalf("Type %s is marked +fieldenum, but is not a struct", name) + } + fieldEnumTypes[name] = fieldEnumTypeInfo{ + prefix: prefix, + structType: st, + } + break + } + } + } + } + + // Collect information for each type for which code is being generated. + structInfos := make([]structInfo, 0, len(fieldEnumTypes)) + needSyncAtomic := false + for typeName, typeInfo := range fieldEnumTypes { + var si structInfo + si.name = typeName + si.prefix = typeInfo.prefix + for _, field := range typeInfo.structType.Fields.List { + name := structFieldName(field) + // If the field's type is a type that is also marked +fieldenum, + // include a FieldSet for that type in this one's. The field must + // be a struct by value, since if it's a pointer then that struct + // might also point to or include this one (which would make + // FieldSet inclusion circular). It must also be a type defined in + // this package, since otherwise we don't know whether it's marked + // +fieldenum. Thus, field.Type must be an identifier (rather than + // an ast.StarExpr or SelectorExpr). + if tident, ok := field.Type.(*ast.Ident); ok { + if fieldTypeInfo, ok := fieldEnumTypes[tident.Name]; ok { + fsf := fieldSetField{ + fieldName: name, + typePrefix: fieldTypeInfo.prefix, + } + si.reprByFieldSet = append(si.reprByFieldSet, fsf) + si.allFields = append(si.allFields, fsf) + continue + } + } + si.reprByBit = append(si.reprByBit, name) + si.allFields = append(si.allFields, fieldSetField{ + fieldName: name, + }) + // sync/atomic import will be needed for FieldSet.Load(). + needSyncAtomic = true + } + structInfos = append(structInfos, si) + } + + // Build the output file. + var b strings.Builder + fmt.Fprintf(&b, "// Generated by go_fieldenum.\n\n") + fmt.Fprintf(&b, "package %s\n\n", *outputPkg) + if needSyncAtomic { + fmt.Fprintf(&b, "import \"sync/atomic\"\n\n") + } + for _, si := range structInfos { + si.writeTo(&b) + } + + if *outputFilename == "-" { + // Write output to stdout. + fmt.Printf("%s", b.String()) + } else { + // Write output to file. + f, err := os.OpenFile(*outputFilename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644) + if err != nil { + log.Fatalf("Failed to open output file %q: %v", *outputFilename, err) + } + if _, err := f.WriteString(b.String()); err != nil { + log.Fatalf("Failed to write output file %q: %v", *outputFilename, err) + } + f.Close() + } +} + +type fieldEnumTypeInfo struct { + prefix string + structType *ast.StructType +} + +// structInfo contains information about the code generated for a given struct. +type structInfo struct { + // name is the name of the represented struct. + name string + + // prefix is the prefix X applied to the name of each generated type and + // constant, referred to as X in the comments below for convenience. + prefix string + + // reprByBit contains the names of fields in X that should be represented + // by a bit in the bit mask XFieldSet.fields, and by a bool in XFields. + reprByBit []string + + // reprByFieldSet contains fields in X whose type is a named struct (e.g. + // Y) that has a corresponding FieldSet type YFieldSet, and which should + // therefore be represented by including a value of type YFieldSet in + // XFieldSet, and a value of type YFields in XFields. + reprByFieldSet []fieldSetField + + // allFields contains all fields in X in order of declaration. Fields in + // reprByBit have fieldSetField.typePrefix == "". + allFields []fieldSetField +} + +type fieldSetField struct { + fieldName string + typePrefix string +} + +func structFieldName(f *ast.Field) string { + if len(f.Names) != 0 { + return f.Names[0].Name + } + // For embedded struct fields, the field name is the unqualified type name. + texpr := f.Type + for { + switch t := texpr.(type) { + case *ast.StarExpr: + texpr = t.X + case *ast.SelectorExpr: + texpr = t.Sel + case *ast.Ident: + return t.Name + default: + panic(fmt.Sprintf("unexpected %T", texpr)) + } + } +} + +// Workaround for Go defect (map membership test isn't usable in an +// expression). +func fetContains(xs map[string]*ast.StructType, x string) bool { + _, ok := xs[x] + return ok +} + +func (si *structInfo) writeTo(b *strings.Builder) { + fmt.Fprintf(b, "// A %sField represents a field in %s.\n", si.prefix, si.name) + fmt.Fprintf(b, "type %sField uint\n\n", si.prefix) + if len(si.reprByBit) != 0 { + fmt.Fprintf(b, "// %sFieldX represents %s field X.\n", si.prefix, si.name) + fmt.Fprintf(b, "const (\n") + fmt.Fprintf(b, "\t%sField%s %sField = iota\n", si.prefix, si.reprByBit[0], si.prefix) + for _, fieldName := range si.reprByBit[1:] { + fmt.Fprintf(b, "\t%sField%s\n", si.prefix, fieldName) + } + fmt.Fprintf(b, ")\n\n") + } + + fmt.Fprintf(b, "// %sFields represents a set of fields in %s in a literal-friendly form.\n", si.prefix, si.name) + fmt.Fprintf(b, "// The zero value of %sFields represents an empty set.\n", si.prefix) + fmt.Fprintf(b, "type %sFields struct {\n", si.prefix) + for _, fieldSetField := range si.allFields { + if fieldSetField.typePrefix == "" { + fmt.Fprintf(b, "\t%s bool\n", fieldSetField.fieldName) + } else { + fmt.Fprintf(b, "\t%s %sFields\n", fieldSetField.fieldName, fieldSetField.typePrefix) + } + } + fmt.Fprintf(b, "}\n\n") + + fmt.Fprintf(b, "// %sFieldSet represents a set of fields in %s in a compact form.\n", si.prefix, si.name) + fmt.Fprintf(b, "// The zero value of %sFieldSet represents an empty set.\n", si.prefix) + fmt.Fprintf(b, "type %sFieldSet struct {\n", si.prefix) + numBitmaskUint32s := (len(si.reprByBit) + 31) / 32 + for _, fieldSetField := range si.reprByFieldSet { + fmt.Fprintf(b, "\t%s %sFieldSet\n", fieldSetField.fieldName, fieldSetField.typePrefix) + } + if len(si.reprByBit) != 0 { + fmt.Fprintf(b, "\tfields [%d]uint32\n", numBitmaskUint32s) + } + fmt.Fprintf(b, "}\n\n") + + if len(si.reprByBit) != 0 { + fmt.Fprintf(b, "// Contains returns true if f is present in the %sFieldSet.\n", si.prefix) + fmt.Fprintf(b, "func (fs %sFieldSet) Contains(f %sField) bool {\n", si.prefix, si.prefix) + if numBitmaskUint32s == 1 { + fmt.Fprintf(b, "\treturn fs.fields[0] & (uint32(1) << uint(f)) != 0\n") + } else { + fmt.Fprintf(b, "\treturn fs.fields[f/32] & (uint32(1) << (f%%32)) != 0\n") + } + fmt.Fprintf(b, "}\n\n") + + fmt.Fprintf(b, "// Add adds f to the %sFieldSet.\n", si.prefix) + fmt.Fprintf(b, "func (fs *%sFieldSet) Add(f %sField) {\n", si.prefix, si.prefix) + if numBitmaskUint32s == 1 { + fmt.Fprintf(b, "\tfs.fields[0] |= uint32(1) << uint(f)\n") + } else { + fmt.Fprintf(b, "\tfs.fields[f/32] |= uint32(1) << (f%%32)\n") + } + fmt.Fprintf(b, "}\n\n") + + fmt.Fprintf(b, "// Remove removes f from the %sFieldSet.\n", si.prefix) + fmt.Fprintf(b, "func (fs *%sFieldSet) Remove(f %sField) {\n", si.prefix, si.prefix) + if numBitmaskUint32s == 1 { + fmt.Fprintf(b, "\tfs.fields[0] &^= uint32(1) << uint(f)\n") + } else { + fmt.Fprintf(b, "\tfs.fields[f/32] &^= uint32(1) << (f%%32)\n") + } + fmt.Fprintf(b, "}\n\n") + } + + fmt.Fprintf(b, "// Load returns a copy of the %sFieldSet.\n", si.prefix) + fmt.Fprintf(b, "// Load is safe to call concurrently with AddFieldsLoadable, but not Add or Remove.\n") + fmt.Fprintf(b, "func (fs *%sFieldSet) Load() (copied %sFieldSet) {\n", si.prefix, si.prefix) + for _, fieldSetField := range si.reprByFieldSet { + fmt.Fprintf(b, "\tcopied.%s = fs.%s.Load()\n", fieldSetField.fieldName, fieldSetField.fieldName) + } + for i := 0; i < numBitmaskUint32s; i++ { + fmt.Fprintf(b, "\tcopied.fields[%d] = atomic.LoadUint32(&fs.fields[%d])\n", i, i) + } + fmt.Fprintf(b, "\treturn\n") + fmt.Fprintf(b, "}\n\n") + + fmt.Fprintf(b, "// AddFieldsLoadable adds the given fields to the %sFieldSet.\n", si.prefix) + fmt.Fprintf(b, "// AddFieldsLoadable is safe to call concurrently with Load, but not other methods (including other calls to AddFieldsLoadable).\n") + fmt.Fprintf(b, "func (fs *%sFieldSet) AddFieldsLoadable(fields %sFields) {\n", si.prefix, si.prefix) + for _, fieldSetField := range si.reprByFieldSet { + fmt.Fprintf(b, "\tfs.%s.AddFieldsLoadable(fields.%s)\n", fieldSetField.fieldName, fieldSetField.fieldName) + } + for _, fieldName := range si.reprByBit { + fieldConstName := fmt.Sprintf("%sField%s", si.prefix, fieldName) + fmt.Fprintf(b, "\tif fields.%s {\n", fieldName) + if numBitmaskUint32s == 1 { + fmt.Fprintf(b, "\t\tatomic.StoreUint32(&fs.fields[0], fs.fields[0] | (uint32(1) << uint(%s)))\n", fieldConstName) + } else { + fmt.Fprintf(b, "\t\tword, bit := %s/32, %s%%32\n", fieldConstName, fieldConstName) + fmt.Fprintf(b, "\t\tatomic.StoreUint32(&fs.fields[word], fs.fields[word] | (uint32(1) << bit))\n") + } + fmt.Fprintf(b, "\t}\n") + } + fmt.Fprintf(b, "}\n\n") +} diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD index 5e0487e93..211e6b3ed 100644 --- a/tools/go_generics/go_merge/BUILD +++ b/tools/go_generics/go_merge/BUILD @@ -7,6 +7,6 @@ go_binary( srcs = ["main.go"], visibility = ["//:sandbox"], deps = [ - "//tools/tags", + "//tools/constraintutil", ], ) diff --git a/tools/go_generics/go_merge/main.go b/tools/go_generics/go_merge/main.go index 801f2354f..81394ddce 100644 --- a/tools/go_generics/go_merge/main.go +++ b/tools/go_generics/go_merge/main.go @@ -25,9 +25,8 @@ import ( "os" "path/filepath" "strconv" - "strings" - "gvisor.dev/gvisor/tools/tags" + "gvisor.dev/gvisor/tools/constraintutil" ) var ( @@ -131,6 +130,12 @@ func main() { } f.Decls = newDecls + // Infer build constraints for the output file. + bcexpr, err := constraintutil.CombineFromFiles(flag.Args()) + if err != nil { + fatalf("Failed to read build constraints: %v\n", err) + } + // Write the output file. var buf bytes.Buffer if err := format.Node(&buf, fset, f); err != nil { @@ -141,9 +146,7 @@ func main() { fatalf("opening output: %v\n", err) } defer outf.Close() - if t := tags.Aggregate(flag.Args()); len(t) > 0 { - fmt.Fprintf(outf, "%s\n\n", strings.Join(t.Lines(), "\n")) - } + outf.WriteString(constraintutil.Lines(bcexpr)) if _, err := outf.Write(buf.Bytes()); err != nil { fatalf("write: %v\n", err) } diff --git a/tools/go_generics/rules_tests/template_test.go b/tools/go_generics/rules_tests/template_test.go index b2a3446ef..6f4d140da 100644 --- a/tools/go_generics/rules_tests/template_test.go +++ b/tools/go_generics/rules_tests/template_test.go @@ -20,14 +20,16 @@ import ( ) func TestMax(t *testing.T) { - var a int = max(10, 20) + var a int + a = max(10, 20) if a != 20 { t.Errorf("Bad result of max, got %v, want %v", a, 20) } } func TestIntConst(t *testing.T) { - var a int = add(10) + var a int + a = add(10) if a != 30 { t.Errorf("Bad result of add, got %v, want %v", a, 30) } diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD index c2747d94c..aaa203115 100644 --- a/tools/go_marshal/gomarshal/BUILD +++ b/tools/go_marshal/gomarshal/BUILD @@ -18,5 +18,5 @@ go_library( visibility = [ "//:sandbox", ], - deps = ["//tools/tags"], + deps = ["//tools/constraintutil"], ) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 00961c90d..4c23637c0 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -25,7 +25,7 @@ import ( "sort" "strings" - "gvisor.dev/gvisor/tools/tags" + "gvisor.dev/gvisor/tools/constraintutil" ) // List of identifiers we use in generated code that may conflict with a @@ -123,16 +123,18 @@ func (g *Generator) writeHeader() error { var b sourceBuffer b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") - // Emit build tags. - b.emit("// If there are issues with build tag aggregation, see\n") - b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The build tags here\n") - b.emit("// come from the input set of files used to generate this file. This input set\n") - b.emit("// is filtered based on pre-defined file suffixes related to build tags, see \n") - b.emit("// tools/defs.bzl:calculate_sets().\n\n") - - if t := tags.Aggregate(g.inputs); len(t) > 0 { - b.emit(strings.Join(t.Lines(), "\n")) - b.emit("\n\n") + bcexpr, err := constraintutil.CombineFromFiles(g.inputs) + if err != nil { + return err + } + if bcexpr != nil { + // Emit build constraints. + b.emit("// If there are issues with build constraint aggregation, see\n") + b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here\n") + b.emit("// come from the input set of files used to generate this file. This input set\n") + b.emit("// is filtered based on pre-defined file suffixes related to build constraints,\n") + b.emit("// see tools/defs.bzl:calculate_sets().\n\n") + b.emit(constraintutil.Lines(bcexpr)) } // Package header. @@ -553,11 +555,12 @@ func (g *Generator) writeTests(ts []*testGenerator) error { b.reset() b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") - // Emit build tags. - if t := tags.Aggregate(g.inputs); len(t) > 0 { - b.emit(strings.Join(t.Lines(), "\n")) - b.emit("\n\n") + // Emit build constraints. + bcexpr, err := constraintutil.CombineFromFiles(g.inputs) + if err != nil { + return err } + b.emit(constraintutil.Lines(bcexpr)) b.emit("package %s\n\n", g.pkg) if err := b.write(g.outputTest); err != nil { diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index e872560a9..d315be060 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -41,10 +41,10 @@ go_test( srcs = ["marshal_test.go"], deps = [ ":test", + "//pkg/errors/linuxerr", "//pkg/hostarch", "//pkg/marshal", "//pkg/marshal/primitive", - "//pkg/syserror", "//pkg/usermem", "//tools/go_marshal/analysis", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go index 43bafbf96..dec3e84fd 100644 --- a/tools/go_marshal/test/marshal_test.go +++ b/tools/go_marshal/test/marshal_test.go @@ -27,16 +27,16 @@ import ( "unsafe" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" - "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/tools/go_marshal/analysis" "gvisor.dev/gvisor/tools/go_marshal/test" ) -var simulatedErr error = syserror.EFAULT +var simulatedErr error = linuxerr.EFAULT // mockCopyContext implements marshal.CopyContext. type mockCopyContext struct { diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD index 913558b4e..ad66981c7 100644 --- a/tools/go_stateify/BUILD +++ b/tools/go_stateify/BUILD @@ -6,7 +6,7 @@ go_binary( name = "stateify", srcs = ["main.go"], visibility = ["//:sandbox"], - deps = ["//tools/tags"], + deps = ["//tools/constraintutil"], ) bzl_library( diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index 93022f504..3cf00b5dd 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -28,7 +28,7 @@ import ( "strings" "sync" - "gvisor.dev/gvisor/tools/tags" + "gvisor.dev/gvisor/tools/constraintutil" ) var ( @@ -214,10 +214,13 @@ func main() { // Automated warning. fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") - // Emit build tags. - if t := tags.Aggregate(flag.Args()); len(t) > 0 { - fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n")) + // Emit build constraints. + bcexpr, err := constraintutil.CombineFromFiles(flag.Args()) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to infer build constraints: %v", err) + os.Exit(1) } + outputFile.WriteString(constraintutil.Lines(bcexpr)) // Emit the package name. _, pkg := filepath.Split(*fullPkg) @@ -359,7 +362,12 @@ func main() { fmt.Fprintf(outputFile, " stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], recv, name) } emitSaveValue := func(name, typName string) { - fmt.Fprintf(outputFile, " var %sValue %s = %s.save%s()\n", name, typName, recv, camelCased(name)) + // Emit typName to be more robust against code generation bugs, + // but instead of one line make two lines to silence ST1023 + // finding (i.e. avoid nogo finding: "should omit type $typName + // from declaration; it will be inferred from the right-hand side") + fmt.Fprintf(outputFile, " var %sValue %s\n", name, typName) + fmt.Fprintf(outputFile, " %sValue = %s.save%s()\n", name, recv, camelCased(name)) fmt.Fprintf(outputFile, " stateSinkObject.SaveValue(%d, %sValue)\n", fields[name], name) } emitSave := func(name string) { diff --git a/tools/installers/containerd.sh b/tools/installers/containerd.sh index e598bce89..b8da1fe42 100755 --- a/tools/installers/containerd.sh +++ b/tools/installers/containerd.sh @@ -20,6 +20,9 @@ declare -r CONTAINERD_VERSION=${1:-1.3.0} declare -r CONTAINERD_MAJOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $1; }')" declare -r CONTAINERD_MINOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $2; }')" +# We're running Go 1.16, but using pre-module containerd and cri-tools. +export GO111MODULE=off + # Default to an older version for crictl for containerd <= 1.2. if [[ "${CONTAINERD_MAJOR}" -eq 1 ]] && [[ "${CONTAINERD_MINOR}" -le 2 ]]; then declare -r CRITOOLS_VERSION=${CRITOOLS_VERSION:-1.13.0} @@ -29,8 +32,8 @@ fi # Helper for Go packages below. install_helper() { - PACKAGE="${1}" - TAG="${2}" + declare -r PACKAGE="${1}" + declare -r TAG="${2}" # Clone the repository. mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \ @@ -71,8 +74,8 @@ done # Install containerd & cri-tools. declare -rx GOPATH=$(mktemp -d --tmpdir gopathXXXXX) -install_helper github.com/containerd/containerd "v${CONTAINERD_VERSION}" "${GOPATH}" -install_helper github.com/kubernetes-sigs/cri-tools "v${CRITOOLS_VERSION}" "${GOPATH}" +install_helper github.com/containerd/containerd "v${CONTAINERD_VERSION}" +install_helper github.com/kubernetes-sigs/cri-tools "v${CRITOOLS_VERSION}" # Configure containerd-shim. declare -r shim_config_path=/etc/containerd/runsc/config.toml diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD index 6c6f604b5..d72821377 100644 --- a/tools/nogo/BUILD +++ b/tools/nogo/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "bzl_library", "go_library", "select_goarch", "select_goos") +load("//tools:defs.bzl", "bzl_library", "go_library", "go_test", "select_goarch", "select_goos") load("//tools/nogo:defs.bzl", "nogo_objdump_tool", "nogo_stdlib", "nogo_target") package(licenses = ["notice"]) @@ -35,8 +35,10 @@ go_library( visibility = ["//:sandbox"], deps = [ "//tools/checkescape", + "//tools/checklinkname", "//tools/checklocks", "//tools/checkunsafe", + "//tools/nogo/objdump", "//tools/worker", "@co_honnef_go_tools//staticcheck:go_default_library", "@co_honnef_go_tools//stylecheck:go_default_library", @@ -68,9 +70,16 @@ go_library( "@org_golang_x_tools//go/analysis/passes/unsafeptr:go_default_library", "@org_golang_x_tools//go/analysis/passes/unusedresult:go_default_library", "@org_golang_x_tools//go/gcexportdata:go_default_library", + "@org_golang_x_tools//go/types/objectpath:go_default_library", ], ) +go_test( + name = "nogo_test", + srcs = ["config_test.go"], + library = ":nogo", +) + bzl_library( name = "defs_bzl", srcs = ["defs.bzl"], diff --git a/tools/nogo/analyzers.go b/tools/nogo/analyzers.go index 2b3c03fec..db8bbdb8a 100644 --- a/tools/nogo/analyzers.go +++ b/tools/nogo/analyzers.go @@ -47,6 +47,7 @@ import ( "honnef.co/go/tools/stylecheck" "gvisor.dev/gvisor/tools/checkescape" + "gvisor.dev/gvisor/tools/checklinkname" "gvisor.dev/gvisor/tools/checklocks" "gvisor.dev/gvisor/tools/checkunsafe" ) @@ -80,6 +81,7 @@ var AllAnalyzers = []*analysis.Analyzer{ unusedresult.Analyzer, checkescape.Analyzer, checkunsafe.Analyzer, + checklinkname.Analyzer, checklocks.Analyzer, } @@ -115,11 +117,11 @@ func register(all []*analysis.Analyzer) { func init() { // Add all staticcheck analyzers. for _, a := range staticcheck.Analyzers { - AllAnalyzers = append(AllAnalyzers, a) + AllAnalyzers = append(AllAnalyzers, a.Analyzer) } // Add all stylecheck analyzers. for _, a := range stylecheck.Analyzers { - AllAnalyzers = append(AllAnalyzers, a) + AllAnalyzers = append(AllAnalyzers, a.Analyzer) } // Register lists. diff --git a/tools/nogo/build.go b/tools/nogo/build.go index d173cff1f..4067bb480 100644 --- a/tools/nogo/build.go +++ b/tools/nogo/build.go @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build go1.1 +// +build go1.1 + package nogo import ( diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go index 3a6c3fb08..0e7e92965 100644 --- a/tools/nogo/check/main.go +++ b/tools/nogo/check/main.go @@ -62,7 +62,8 @@ func run([]string) int { // Check & load the configuration. if *packageFile != "" && *stdlibFile != "" { - log.Fatalf("unable to perform stdlib and package analysis; provide only one!") + fmt.Fprintf(os.Stderr, "unable to perform stdlib and package analysis; provide only one!") + return 1 } // Run the configuration. @@ -75,18 +76,21 @@ func run([]string) int { c := loadConfig(*packageFile, new(nogo.PackageConfig)).(*nogo.PackageConfig) findings, factData, err = nogo.CheckPackage(c, nogo.AllAnalyzers, nil) } else { - log.Fatalf("please provide at least one of package or stdlib!") + fmt.Fprintf(os.Stderr, "please provide at least one of package or stdlib!") + return 1 } // Check that analysis was successful. if err != nil { - log.Fatalf("error performing analysis: %v", err) + fmt.Fprintf(os.Stderr, "error performing analysis: %v", err) + return 1 } // Save facts. if *factsOutput != "" { if err := ioutil.WriteFile(*factsOutput, factData, 0644); err != nil { - log.Fatalf("error saving findings to %q: %v", *factsOutput, err) + fmt.Fprintf(os.Stderr, "error saving findings to %q: %v", *factsOutput, err) + return 1 } } @@ -94,10 +98,12 @@ func run([]string) int { if *findingsOutput != "" { w, err := os.OpenFile(*findingsOutput, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { - log.Fatalf("error opening output file %q: %v", *findingsOutput, err) + fmt.Fprintf(os.Stderr, "error opening output file %q: %v", *findingsOutput, err) + return 1 } if err := nogo.WriteFindingsTo(w, findings, false /* json */); err != nil { - log.Fatalf("error writing findings to %q: %v", *findingsOutput, err) + fmt.Fprintf(os.Stderr, "error writing findings to %q: %v", *findingsOutput, err) + return 1 } } else { for _, finding := range findings { diff --git a/tools/nogo/config.go b/tools/nogo/config.go index 6436f9d34..ee2533610 100644 --- a/tools/nogo/config.go +++ b/tools/nogo/config.go @@ -186,16 +186,19 @@ func (a AnalyzerConfig) merge(other AnalyzerConfig) { } } -func (a AnalyzerConfig) shouldReport(groupConfig *Group, fullPos, msg string) bool { +// shouldReport returns whether the finding should be reported or suppressed. +// It returns !ok if there is no configuration sufficient to decide one way or +// another. +func (a AnalyzerConfig) shouldReport(groupConfig *Group, fullPos, msg string) (report, ok bool) { gc, ok := a[groupConfig.Name] if !ok { - return groupConfig.Default + return false, false } // Note that if a section appears for a particular group // for a particular analyzer, then it will now be enabled, // and the group default no longer applies. - return gc.shouldReport(fullPos, msg) + return gc.shouldReport(fullPos, msg), true } // Config is a nogo configuration. @@ -298,7 +301,8 @@ func (c *Config) ShouldReport(finding Finding) bool { } // Suppress via global rule? - if !c.Global.shouldReport(groupConfig, fullPos, finding.Message) { + report, ok := c.Global.shouldReport(groupConfig, fullPos, finding.Message) + if ok && !report { return false } @@ -307,5 +311,9 @@ func (c *Config) ShouldReport(finding Finding) bool { if !ok { return groupConfig.Default } - return ac.shouldReport(groupConfig, fullPos, finding.Message) + report, ok = ac.shouldReport(groupConfig, fullPos, finding.Message) + if !ok { + return groupConfig.Default + } + return report } diff --git a/tools/nogo/config_test.go b/tools/nogo/config_test.go new file mode 100644 index 000000000..685cffbec --- /dev/null +++ b/tools/nogo/config_test.go @@ -0,0 +1,301 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.package nogo +package nogo + +import ( + "go/token" + "testing" +) + +// TestShouldReport validates the suppression behavior of Config.ShouldReport. +func TestShouldReport(t *testing.T) { + config := &Config{ + Groups: []Group{ + { + Name: "default-enabled", + Regex: "^default-enabled/", + Default: true, + }, + { + Name: "default-disabled", + Regex: "^default-disabled/", + Default: false, + }, + { + Name: "default-disabled-omitted-from-global", + Regex: "^default-disabled-omitted-from-global/", + Default: false, + }, + }, + Global: AnalyzerConfig{ + "default-enabled": &ItemConfig{ + Exclude: []string{"excluded.go"}, + Suppress: []string{"suppressed"}, + }, + "default-disabled": &ItemConfig{ + Exclude: []string{"excluded.go"}, + Suppress: []string{"suppressed"}, + }, + // Omitting default-disabled-omitted-from-global here + // has no effect on configuration below. + }, + Analyzers: map[AnalyzerName]AnalyzerConfig{ + "analyzer-suppressions": AnalyzerConfig{ + // Suppress some. + "default-enabled": &ItemConfig{ + Exclude: []string{"limited-exclude.go"}, + Suppress: []string{"limited suppress"}, + }, + // Enable all. + "default-disabled": nil, + }, + "enabled-for-default-disabled": AnalyzerConfig{ + "default-disabled": nil, + "default-disabled-omitted-from-global": nil, + }, + }, + } + + if err := config.Compile(); err != nil { + t.Fatalf("Compile(%+v) = %v, want nil", config, err) + } + + cases := []struct { + name string + finding Finding + want bool + }{ + { + name: "enabled", + finding: Finding{ + Category: "foo", + Position: token.Position{ + Filename: "default-enabled/file.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message", + }, + want: true, + }, + { + name: "ungrouped", + finding: Finding{ + Category: "foo", + Position: token.Position{ + Filename: "ungrouped/file.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message", + }, + want: true, + }, + { + name: "suppressed", + finding: Finding{ + Category: "foo", + Position: token.Position{ + Filename: "default-enabled/file.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message suppressed", + }, + want: false, + }, + { + name: "excluded", + finding: Finding{ + Category: "foo", + Position: token.Position{ + Filename: "default-enabled/excluded.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message", + }, + want: false, + }, + { + name: "disabled", + finding: Finding{ + Category: "foo", + Position: token.Position{ + Filename: "default-disabled/file.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message", + }, + want: false, + }, + { + name: "analyzer suppressed", + finding: Finding{ + Category: "analyzer-suppressions", + Position: token.Position{ + Filename: "default-enabled/file.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message limited suppress", + }, + want: false, + }, + { + name: "analyzer suppressed not global", + finding: Finding{ + // Doesn't apply outside of analyzer-suppressions. + Category: "foo", + Position: token.Position{ + Filename: "default-enabled/file.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message limited suppress", + }, + want: true, + }, + { + name: "analyzer suppressed grouped", + finding: Finding{ + Category: "analyzer-suppressions", + Position: token.Position{ + // Doesn't apply outside of default-enabled. + Filename: "default-disabled/file.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message limited suppress", + }, + want: true, + }, + { + name: "analyzer excluded", + finding: Finding{ + Category: "analyzer-suppressions", + Position: token.Position{ + Filename: "default-enabled/limited-exclude.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message", + }, + want: false, + }, + { + name: "analyzer excluded not global", + finding: Finding{ + // Doesn't apply outside of analyzer-suppressions. + Category: "foo", + Position: token.Position{ + Filename: "default-enabled/limited-exclude.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message", + }, + want: true, + }, + { + name: "analyzer excluded grouped", + finding: Finding{ + Category: "analyzer-suppressions", + Position: token.Position{ + // Doesn't apply outside of default-enabled. + Filename: "default-disabled/limited-exclude.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message", + }, + want: true, + }, + { + name: "disabled-omitted", + finding: Finding{ + Category: "foo", + Position: token.Position{ + Filename: "default-disabled-omitted-from-global/file.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message", + }, + want: false, + }, + { + name: "default enabled applies to customized analyzer", + finding: Finding{ + Category: "enabled-for-default-disabled", + Position: token.Position{ + Filename: "default-enabled/file.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message", + }, + want: true, + }, + { + name: "default overridden in customized analyzer", + finding: Finding{ + Category: "enabled-for-default-disabled", + Position: token.Position{ + Filename: "default-disabled/file.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message", + }, + want: true, + }, + { + name: "default overridden in customized analyzer even when omitted from global", + finding: Finding{ + Category: "enabled-for-default-disabled", + Position: token.Position{ + Filename: "default-disabled-omitted-from-global/file.go", + Offset: 0, + Line: 1, + Column: 1, + }, + Message: "message", + }, + want: true, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := config.ShouldReport(tc.finding); got != tc.want { + t.Errorf("ShouldReport(%+v) = %v, want %v", tc.finding, got, tc.want) + } + }) + } +} diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl index ddf5816a6..dc9a8b24e 100644 --- a/tools/nogo/defs.bzl +++ b/tools/nogo/defs.bzl @@ -160,6 +160,11 @@ def _nogo_stdlib_impl(ctx): return [NogoStdlibInfo( facts = facts, raw_findings = raw_findings, + ), DefaultInfo( + # Declare the facts and findings as default outputs. This is not + # strictly required, but ensures that the target still perform analysis + # when built directly rather than just indirectly via a nogo_test. + files = depset([facts, raw_findings]), )] nogo_stdlib = go_rule( @@ -198,6 +203,22 @@ NogoInfo = provider( }, ) +def _select_objfile(files): + """Returns (.a file, .x file, is_archive). + + If no .a file is available, then the first .x file will be returned + instead, and vice versa. If neither are available, then the first provided + file will be returned.""" + a_files = [f for f in files if f.path.endswith(".a")] + x_files = [f for f in files if f.path.endswith(".x")] + if not len(x_files) and not len(a_files): + return (files[0], files[0], False) + if not len(x_files): + x_files = a_files + if not len(a_files): + a_files = x_files + return a_files[0], x_files[0], True + def _nogo_aspect_impl(target, ctx): # If this is a nogo rule itself (and not the shadow of a go_library or # go_binary rule created by such a rule), then we simply return nothing. @@ -232,20 +253,14 @@ def _nogo_aspect_impl(target, ctx): deps = deps + info.deps # Start with all target files and srcs as input. - inputs = target.files.to_list() + srcs + binaries = target.files.to_list() + inputs = binaries + srcs # Generate a shell script that dumps the binary. Annoyingly, this seems # necessary as the context in which a run_shell command runs does not seem # to cleanly allow us redirect stdout to the actual output file. Perhaps # I'm missing something here, but the intermediate script does work. - binaries = target.files.to_list() - objfiles = [f for f in binaries if f.path.endswith(".a")] - if len(objfiles) > 0: - # Prefer the .a files for go_library targets. - target_objfile = objfiles[0] - else: - # Use the raw binary for go_binary and go_test targets. - target_objfile = binaries[0] + target_objfile, target_xfile, has_objfile = _select_objfile(binaries) inputs.append(target_objfile) # Extract the importpath for this package. @@ -274,10 +289,8 @@ def _nogo_aspect_impl(target, ctx): # Configure where to find the binary & fact files. Note that this will # use .x and .a regardless of whether this is a go_binary rule, since # these dependencies must be go_library rules. - x_files = [f.path for f in info.binaries if f.path.endswith(".x")] - if not len(x_files): - x_files = [f.path for f in info.binaries if f.path.endswith(".a")] - import_map[info.importpath] = x_files[0] + _, x_file, _ = _select_objfile(info.binaries) + import_map[info.importpath] = x_file.path fact_map[info.importpath] = info.facts.path # Collect all findings; duplicates are resolved at the end. @@ -287,6 +300,11 @@ def _nogo_aspect_impl(target, ctx): inputs.append(info.facts) inputs += info.binaries + # Add the module itself, for the type sanity check. This applies only to + # the libraries, and not binaries or tests. + if has_objfile: + import_map[importpath] = target_xfile.path + # Add the standard library facts. stdlib_info = ctx.attr._nogo_stdlib[NogoStdlibInfo] stdlib_facts = stdlib_info.facts diff --git a/tools/nogo/filter/main.go b/tools/nogo/filter/main.go index d50336b9b..4a925d03c 100644 --- a/tools/nogo/filter/main.go +++ b/tools/nogo/filter/main.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Binary check is the nogo entrypoint. +// Binary filter is the filters and reports nogo findings. package main import ( diff --git a/tools/nogo/findings.go b/tools/nogo/findings.go index 329a7062e..a73bf1a09 100644 --- a/tools/nogo/findings.go +++ b/tools/nogo/findings.go @@ -109,7 +109,7 @@ func ExtractFindingsFromFile(filename string, asJSON bool) (FindingSet, error) { return ExtractFindingsFrom(r, asJSON) } -// ExtractFindingsFromBytes loads findings from bytes. +// ExtractFindingsFrom loads findings from an io.Reader. func ExtractFindingsFrom(r io.Reader, asJSON bool) (findings FindingSet, err error) { if asJSON { dec := json.NewDecoder(r) diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go index acee7c8bc..d95d7652f 100644 --- a/tools/nogo/nogo.go +++ b/tools/nogo/nogo.go @@ -41,9 +41,10 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/internal/facts" "golang.org/x/tools/go/gcexportdata" + "golang.org/x/tools/go/types/objectpath" // Special case: flags live here and change overall behavior. - "gvisor.dev/gvisor/tools/checkescape" + "gvisor.dev/gvisor/tools/nogo/objdump" "gvisor.dev/gvisor/tools/worker" ) @@ -216,6 +217,11 @@ func (i *importer) Import(path string) (*types.Package, error) { } } + // Check the cache. + if pkg, ok := i.cache[path]; ok && pkg.Complete() { + return pkg, nil + } + // Actually load the data. realPath, ok := i.ImportMap[path] var ( @@ -327,6 +333,9 @@ func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindi // Closure to check a single package. localStdlibFacts := make(stdlibFacts) localStdlibErrs := make(map[string]error) + stdlibCachedFacts.Lookup([]string{""}, func() worker.Sizer { + return localStdlibFacts + }) var checkOne func(pkg string) error // Recursive. checkOne = func(pkg string) error { // Is this already done? @@ -355,11 +364,11 @@ func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindi } // Provide the input. - oldReader := checkescape.Reader - checkescape.Reader = rc // For analysis. + oldReader := objdump.Reader + objdump.Reader = rc // For analysis. defer func() { rc.Close() - checkescape.Reader = oldReader // Restore. + objdump.Reader = oldReader // Restore. }() // Run the analysis. @@ -406,6 +415,56 @@ func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindi return allFindings, buf.Bytes(), nil } +// sanityCheckScope checks that all object in astTypes map to the correct +// objects in binaryTypes. Note that we don't check whether the sets are the +// same, we only care about the fidelity of objects in astTypes. +// +// When an inconsistency is identified, we record it in the astToBinaryMap. +// This allows us to dynamically replace facts and correct for the issue. The +// total number of mismatches is returned. +func sanityCheckScope(astScope *types.Scope, binaryTypes *types.Package, binaryScope *types.Scope, astToBinary map[types.Object]types.Object) error { + for _, x := range astScope.Names() { + fe := astScope.Lookup(x) + path, err := objectpath.For(fe) + if err != nil { + continue // Not an encoded object. + } + se, err := objectpath.Object(binaryTypes, path) + if err != nil { + continue // May be unused, see below. + } + if fe.Id() != se.Id() { + // These types are incompatible. This means that when + // this objectpath is loading from the binaryTypes (for + // dependencies) it will resolve to a fact for that + // type. We don't actually care about this error since + // we do the rewritten, but may as well alert. + log.Printf("WARNING: Object %s is a victim of go/issues/44195.", fe.Id()) + } + se = binaryScope.Lookup(x) + if se == nil { + // The fact may not be exported in the objectdata, if + // it is package internal. This is fine, as nothing out + // of this package can use these symbols. + continue + } + // Save the translation. + astToBinary[fe] = se + } + for i := 0; i < astScope.NumChildren(); i++ { + if err := sanityCheckScope(astScope.Child(i), binaryTypes, binaryScope, astToBinary); err != nil { + return err + } + } + return nil +} + +// sanityCheckTypes checks that two types are sane. The total number of +// mismatches is returned. +func sanityCheckTypes(astTypes, binaryTypes *types.Package, astToBinary map[types.Object]types.Object) error { + return sanityCheckScope(astTypes.Scope(), binaryTypes, binaryTypes.Scope(), astToBinary) +} + // CheckPackage runs all given analyzers. // // The implementation was adapted from [1], which was in turn adpated from [2]. @@ -450,17 +509,46 @@ func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importC Scopes: make(map[ast.Node]*types.Scope), Selections: make(map[*ast.SelectorExpr]*types.Selection), } - types, err := typeConfig.Check(config.ImportPath, imp.fset, syntax, typesInfo) + astTypes, err := typeConfig.Check(config.ImportPath, imp.fset, syntax, typesInfo) if err != nil && imp.lastErr != ErrSkip { return nil, nil, fmt.Errorf("error checking types: %w", err) } - // Load all package facts. - facts, err := facts.Decode(types, config.factLoader) + // Load all facts using the astTypes, although it may need reconciling + // later on. See the fact functions below. + astFacts, err := facts.Decode(astTypes, config.factLoader) if err != nil { return nil, nil, fmt.Errorf("error decoding facts: %w", err) } + // Sanity check all types and record metadata to prevent + // https://github.com/golang/go/issues/44195. + // + // This block loads the binary types, whose encoding will be well + // defined and aligned with any downstream consumers. Below in the fact + // functions for the analysis, we serialize types to both the astFacts + // and the binaryFacts if available. The binaryFacts are the final + // encoded facts in order to ensure compatibility. We keep the + // intermediate astTypes in order to allow exporting and importing + // within the local package under analysis. + var ( + astToBinary = make(map[types.Object]types.Object) + binaryFacts *facts.Set + ) + if _, ok := config.ImportMap[config.ImportPath]; ok { + binaryTypes, err := imp.Import(config.ImportPath) + if err != nil { + return nil, nil, fmt.Errorf("error loading self: %w", err) + } + if err := sanityCheckTypes(astTypes, binaryTypes, astToBinary); err != nil { + return nil, nil, fmt.Errorf("error sanity checking types: %w", err) + } + binaryFacts, err = facts.Decode(binaryTypes, config.factLoader) + if err != nil { + return nil, nil, fmt.Errorf("error decoding facts: %w", err) + } + } + // Register fact types and establish dependencies between analyzers. // The visit closure will execute recursively, and populate results // will all required analysis results. @@ -479,15 +567,15 @@ func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importC } // Run the analysis. - factFilter := make(map[reflect.Type]bool) + localFactsFilter := make(map[reflect.Type]bool) for _, f := range a.FactTypes { - factFilter[reflect.TypeOf(f)] = true + localFactsFilter[reflect.TypeOf(f)] = true } p := &analysis.Pass{ Analyzer: a, Fset: imp.fset, Files: syntax, - Pkg: types, + Pkg: astTypes, TypesInfo: typesInfo, ResultOf: results, // All results. Report: func(d analysis.Diagnostic) { @@ -497,13 +585,29 @@ func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importC Message: d.Message, }) }, - ImportPackageFact: facts.ImportPackageFact, - ExportPackageFact: facts.ExportPackageFact, - ImportObjectFact: facts.ImportObjectFact, - ExportObjectFact: facts.ExportObjectFact, - AllPackageFacts: func() []analysis.PackageFact { return facts.AllPackageFacts(factFilter) }, - AllObjectFacts: func() []analysis.ObjectFact { return facts.AllObjectFacts(factFilter) }, - TypesSizes: typesSizes, + ImportPackageFact: astFacts.ImportPackageFact, + ExportPackageFact: func(fact analysis.Fact) { + astFacts.ExportPackageFact(fact) + if binaryFacts != nil { + binaryFacts.ExportPackageFact(fact) + } + }, + ImportObjectFact: astFacts.ImportObjectFact, + ExportObjectFact: func(obj types.Object, fact analysis.Fact) { + astFacts.ExportObjectFact(obj, fact) + // Note that if no object is recorded in + // astToBinary and binaryFacts != nil, then the + // object doesn't appear in the exported data. + // It was likely an internal object to the + // package, and there is no meaningful + // downstream consumer of the fact. + if binaryObj, ok := astToBinary[obj]; ok && binaryFacts != nil { + binaryFacts.ExportObjectFact(binaryObj, fact) + } + }, + AllPackageFacts: func() []analysis.PackageFact { return astFacts.AllPackageFacts(localFactsFilter) }, + AllObjectFacts: func() []analysis.ObjectFact { return astFacts.AllObjectFacts(localFactsFilter) }, + TypesSizes: typesSizes, } result, err := a.Run(p) if err != nil { @@ -528,8 +632,14 @@ func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importC } } - // Return all findings. - return findings, facts.Encode(), nil + // Return all findings. Note that we have a preference to returning the + // binary facts if available, so that downstream consumers of these + // facts will find the export aligns with the internal type details. + // See the block above with the call to sanityCheckTypes. + if binaryFacts != nil { + return findings, binaryFacts.Encode(), nil + } + return findings, astFacts.Encode(), nil } func init() { diff --git a/tools/tags/BUILD b/tools/nogo/objdump/BUILD index 1c02e2c89..da56efdf7 100644 --- a/tools/tags/BUILD +++ b/tools/nogo/objdump/BUILD @@ -3,9 +3,8 @@ load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( - name = "tags", - srcs = ["tags.go"], - marshal = False, - stateify = False, + name = "objdump", + srcs = ["objdump.go"], + nogo = False, visibility = ["//tools:__subpackages__"], ) diff --git a/tools/nogo/objdump/objdump.go b/tools/nogo/objdump/objdump.go new file mode 100644 index 000000000..48484abf3 --- /dev/null +++ b/tools/nogo/objdump/objdump.go @@ -0,0 +1,96 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package objdump is a wrapper around relevant objdump flags. +package objdump + +import ( + "flag" + "fmt" + "io" + "os" + "os/exec" +) + +var ( + // Binary is the binary under analysis. + // + // See Reader, below. + binary = flag.String("binary", "", "binary under analysis") + + // Reader is the input stream. + // + // This may be set instead of Binary. + Reader io.Reader + + // objdumpTool is the tool used to dump a binary. + objdumpTool = flag.String("objdump_tool", "", "tool used to dump a binary") +) + +// LoadRaw reads the raw object output. +func LoadRaw(fn func(r io.Reader) error) error { + var r io.Reader + if *binary != "" { + f, err := os.Open(*binary) + if err != nil { + return err + } + defer f.Close() + r = f + } else if Reader != nil { + r = Reader + } else { + // We have no input stream. + return fmt.Errorf("no binary or reader provided") + } + return fn(r) +} + +// Load reads the objdump output. +func Load(fn func(r io.Reader) error) error { + var ( + args []string + stdin io.Reader + ) + if *binary != "" { + args = append(args, *binary) + } else if Reader != nil { + stdin = Reader + } else { + // We have no input stream or binary. + return fmt.Errorf("no binary or reader provided") + } + + // Construct our command. + cmd := exec.Command(*objdumpTool, args...) + cmd.Stdin = stdin + cmd.Stderr = os.Stderr + out, err := cmd.StdoutPipe() + if err != nil { + return err + } + if err := cmd.Start(); err != nil { + return err + } + + // Call the user hook. + userErr := fn(out) + + // Wait for the dump to finish. + if err := cmd.Wait(); userErr == nil && err != nil { + return err + } + + return userErr +} diff --git a/tools/parsers/version.go b/tools/parsers/version.go index ab9194b9d..c250f4a2a 100644 --- a/tools/parsers/version.go +++ b/tools/parsers/version.go @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build go1.1 +// +build go1.1 + package main // version is set during linking. diff --git a/tools/show_paths.bzl b/tools/show_paths.bzl new file mode 100644 index 000000000..f0126ac7b --- /dev/null +++ b/tools/show_paths.bzl @@ -0,0 +1,27 @@ +"""Formatter to extract the output files from a target.""" + +def format(target): + provider_map = providers(target) + if not provider_map: + return "" + outputs = dict() + + # Try to resolve in order. + files_to_run = provider_map.get("FilesToRunProvider", None) + default_info = provider_map.get("DefaultInfo", None) + output_group_info = provider_map.get("OutputGroupInfo", None) + if files_to_run and files_to_run.executable: + outputs[files_to_run.executable.path] = True + elif default_info: + for x in default_info.files: + outputs[x.path] = True + elif output_group_info: + for entry in dir(output_group_info): + # Filter out all built-ins and anything that is not a depset. + if entry.startswith("_") or not hasattr(getattr(output_group_info, entry), "to_list"): + continue + for x in getattr(output_group_info, entry).to_list(): + outputs[x.path] = True + + # Return all found files. + return "\n".join(outputs.keys()) diff --git a/tools/tags/tags.go b/tools/tags/tags.go deleted file mode 100644 index f35904e0a..000000000 --- a/tools/tags/tags.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package tags is a utility for parsing build tags. -package tags - -import ( - "fmt" - "io/ioutil" - "strings" -) - -// OrSet is a set of tags on a single line. -// -// Note that tags may include ",", and we don't distinguish this case in the -// logic below. Ideally, this constraints can be split into separate top-level -// build tags in order to resolve any issues. -type OrSet []string - -// Line returns the line for this or. -func (or OrSet) Line() string { - return fmt.Sprintf("// +build %s", strings.Join([]string(or), " ")) -} - -// AndSet is the set of all OrSets. -type AndSet []OrSet - -// Lines returns the lines to be printed. -func (and AndSet) Lines() (ls []string) { - for _, or := range and { - ls = append(ls, or.Line()) - } - return -} - -// Join joins this AndSet with another. -func (and AndSet) Join(other AndSet) AndSet { - return append(and, other...) -} - -// Tags returns the unique set of +build tags. -// -// Derived form the runtime's canBuild. -func Tags(file string) (tags AndSet) { - data, err := ioutil.ReadFile(file) - if err != nil { - return nil - } - // Check file contents for // +build lines. - for _, p := range strings.Split(string(data), "\n") { - p = strings.TrimSpace(p) - if p == "" { - continue - } - if !strings.HasPrefix(p, "//") { - break - } - if !strings.Contains(p, "+build") { - continue - } - fields := strings.Fields(p[2:]) - if len(fields) < 1 || fields[0] != "+build" { - continue - } - tags = append(tags, OrSet(fields[1:])) - } - return tags -} - -// Aggregate aggregates all tags from a set of files. -// -// Note that these may be in conflict, in which case the build will fail. -func Aggregate(files []string) (tags AndSet) { - for _, file := range files { - tags = tags.Join(Tags(file)) - } - return tags -} diff --git a/tools/verity/measure_tool.go b/tools/verity/measure_tool.go index 0d314ae70..4a0bc497a 100644 --- a/tools/verity/measure_tool.go +++ b/tools/verity/measure_tool.go @@ -21,12 +21,14 @@ import ( "io/ioutil" "log" "os" + "strings" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" ) var path = flag.String("path", "", "path to the verity file system.") +var rawpath = flag.String("rawpath", "", "path to the raw file system.") const maxDigestSize = 64 @@ -40,6 +42,14 @@ func main() { if *path == "" { log.Fatalf("no path provided") } + if *rawpath == "" { + log.Fatalf("no rawpath provided") + } + // TODO(b/182315468): Optimize the Merkle tree generate process to + // allow only updating certain files/directories. + if err := clearMerkle(*rawpath); err != nil { + log.Fatalf("Failed to clear merkle files in %s: %v", *rawpath, err) + } if err := enableDir(*path); err != nil { log.Fatalf("Failed to enable file system %s: %v", *path, err) } @@ -49,6 +59,26 @@ func main() { } } +func clearMerkle(path string) error { + files, err := ioutil.ReadDir(path) + if err != nil { + return err + } + + for _, file := range files { + if file.IsDir() { + if err := clearMerkle(path + "/" + file.Name()); err != nil { + return err + } + } else if strings.HasPrefix(file.Name(), ".merkle.verity") { + if err := os.Remove(path + "/" + file.Name()); err != nil { + return err + } + } + } + return nil +} + // enableDir enables verity features on all the files and sub-directories within // path. func enableDir(path string) error { |