diff options
Diffstat (limited to 'tools')
69 files changed, 4577 insertions, 740 deletions
diff --git a/tools/BUILD b/tools/BUILD index e73a9c885..34b950644 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -1,3 +1 @@ package(licenses = ["notice"]) - -exports_files(["nogo.js"]) diff --git a/tools/bazel.mk b/tools/bazel.mk new file mode 100644 index 000000000..45fbbecca --- /dev/null +++ b/tools/bazel.mk @@ -0,0 +1,106 @@ +#!/usr/bin/make -f + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# See base Makefile. +BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \ + git rev-parse --abbrev-ref HEAD 2>/dev/null) | \ + xargs -n 1 basename 2>/dev/null) + +# Bazel container configuration (see below). +USER ?= gvisor +DOCKER_NAME ?= gvisor-bazel +DOCKER_RUN_OPTIONS ?= --privileged +BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/) +GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/) +DOCKER_SOCKET := /var/run/docker.sock + +# Non-configurable. +UID := $(shell id -u ${USER}) +GID := $(shell id -g ${USER}) +FULL_DOCKER_RUN_OPTIONS := $(DOCKER_RUN_OPTIONS) +FULL_DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)" +FULL_DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)" +FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)" + +## +## Bazel helpers. +## +## This file supports targets that wrap bazel in a running Docker +## container to simplify development. Some options are available to +## control the behavior of this container: +## USER - The in-container user. +## DOCKER_RUN_OPTIONS - Options for the container (default: --privileged, required for tests). +## DOCKER_NAME - The container name (default: gvisor-bazel-HASH). +## BAZEL_CACHE - The bazel cache directory (default: detected). +## GCLOUD_CONFIG - The gcloud config directory (detect: detected). +## DOCKER_SOCKET - The Docker socket (default: detected). +## +bazel-server-start: load-default ## Starts the bazel server. + docker run -d --rm \ + --name $(DOCKER_NAME) \ + --user 0:0 \ + -v "$(CURDIR):$(CURDIR)" \ + --workdir "$(CURDIR)" \ + --tmpfs /tmp:rw,exec \ + --entrypoint "" \ + $(FULL_DOCKER_RUN_OPTIONS) \ + gvisor.dev/images/default \ + sh -c "groupadd --gid $(GID) --non-unique $(USER) && \ + useradd --uid $(UID) --non-unique --no-create-home --gid $(GID) -d $(HOME) $(USER) && \ + bazel version && \ + while :; do sleep 3600; done" + @while :; do if docker logs $(DOCKER_NAME) 2>/dev/null | grep "Build label:" >/dev/null; then break; fi; sleep 1; done +.PHONY: bazel-server-start + +bazel-shutdown: ## Shuts down a running bazel server. + @docker exec --user $(UID):$(GID) $(DOCKER_NAME) bazel shutdown; rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]] +.PHONY: bazel-shutdown + +bazel-alias: ## Emits an alias that can be used within the shell. + @echo "alias bazel='docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) bazel'" +.PHONY: bazel-alias + +bazel-server: ## Ensures that the server exists. Used as an internal target. + @docker exec $(DOCKER_NAME) true || $(MAKE) bazel-server-start +.PHONY: bazel-server + +build_paths = docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) sh -c 'bazel build $(OPTIONS) $(TARGETS) 2>&1 \ + | tee /dev/fd/2 \ + | grep -E "^ bazel-bin/" \ + | awk "{print $$1;}"' \ + | xargs -n 1 -I {} sh -c "$(1)" + +build: bazel-server + @$(call build_paths,echo {}) +.PHONY: build + +copy: bazel-server +ifeq (,$(DESTINATION)) + $(error Destination not provided.) +endif + @$(call build_paths,cp -a {} $(DESTINATION)) + +run: bazel-server + @$(call build_paths,{} $(ARGS)) +.PHONY: run + +sudo: bazel-server + @$(call build_paths,sudo -E {} $(ARGS)) +.PHONY: sudo + +test: bazel-server + @docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) bazel test $(OPTIONS) $(TARGETS) +.PHONY: test diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index 905b16d41..3c22aec24 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -1,34 +1,91 @@ """Bazel implementations of standard rules.""" load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier") -load("@io_bazel_rules_go//go:def.bzl", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_test = "go_test", _go_tool_library = "go_tool_library") -load("@io_bazel_rules_go//proto:def.bzl", _go_proto_library = "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "GoLibrary", _go_binary = "go_binary", _go_context = "go_context", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_test = "go_test") +load("@io_bazel_rules_go//proto:def.bzl", _go_grpc_library = "go_grpc_library", _go_proto_library = "go_proto_library") load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test") load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") -load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image") -load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image") load("@pydeps//:requirements.bzl", _py_requirement = "requirement") +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", _cc_grpc_library = "cc_grpc_library") -container_image = _container_image -cc_binary = _cc_binary cc_library = _cc_library cc_flags_supplier = _cc_flags_supplier cc_proto_library = _cc_proto_library cc_test = _cc_test cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" -go_image = _go_image go_embed_data = _go_embed_data gtest = "@com_google_googletest//:gtest" +grpcpp = "@com_github_grpc_grpc//:grpc++" gbenchmark = "@com_google_benchmark//:benchmark" loopback = "//tools/bazeldefs:loopback" -proto_library = native.proto_library pkg_deb = _pkg_deb pkg_tar = _pkg_tar py_library = native.py_library py_binary = native.py_binary py_test = native.py_test +def proto_library(name, has_services = None, **kwargs): + native.proto_library( + name = name, + **kwargs + ) + +def cc_grpc_library(name, **kwargs): + _cc_grpc_library(name = name, grpc_only = True, **kwargs) + +def _go_proto_or_grpc_library(go_library_func, name, **kwargs): + deps = [ + dep.replace("_proto", "_go_proto") + for dep in (kwargs.pop("deps", []) or []) + ] + go_library_func( + name = name + "_go_proto", + importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name + "_go_proto", + proto = ":" + name + "_proto", + deps = deps, + **kwargs + ) + +def go_proto_library(name, **kwargs): + _go_proto_or_grpc_library(_go_proto_library, name, **kwargs) + +def go_grpc_and_proto_libraries(name, **kwargs): + _go_proto_or_grpc_library(_go_grpc_library, name, **kwargs) + +def cc_binary(name, static = False, **kwargs): + """Run cc_binary. + + Args: + name: name of the target. + static: make a static binary if True + **kwargs: the rest of the args. + """ + if static: + # How to statically link a c++ program that uses threads, like for gRPC: + # https://gcc.gnu.org/legacy-ml/gcc-help/2010-05/msg00029.html + if "linkopts" not in kwargs: + kwargs["linkopts"] = [] + kwargs["linkopts"] += [ + "-static", + "-lstdc++", + "-Wl,--whole-archive", + "-lpthread", + "-Wl,--no-whole-archive", + ] + _cc_binary( + name = name, + **kwargs + ) + def go_binary(name, static = False, pure = False, **kwargs): + """Build a go binary. + + Args: + name: name of the target. + static: build a static binary. + pure: build without cgo. + **kwargs: rest of the arguments are passed to _go_binary. + """ if static: kwargs["static"] = "on" if pure: @@ -38,6 +95,10 @@ def go_binary(name, static = False, pure = False, **kwargs): **kwargs ) +def go_importpath(target): + """Returns the importpath for the target.""" + return target[GoLibrary].importpath + def go_library(name, **kwargs): _go_library( name = name, @@ -45,25 +106,17 @@ def go_library(name, **kwargs): **kwargs ) -def go_tool_library(name, **kwargs): - _go_tool_library( - name = name, - importpath = "gvisor.dev/gvisor/" + native.package_name(), - **kwargs - ) - -def go_proto_library(name, proto, **kwargs): - deps = kwargs.pop("deps", []) - _go_proto_library( - name = name, - importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name, - proto = proto, - deps = [dep.replace("_proto", "_go_proto") for dep in deps], - **kwargs - ) +def go_test(name, pure = False, library = None, **kwargs): + """Build a go test. -def go_test(name, **kwargs): - library = kwargs.pop("library", None) + Args: + name: name of the output binary. + pure: should it be built without cgo. + library: the library to embed. + **kwargs: rest of the arguments to pass to _go_test. + """ + if pure: + kwargs["pure"] = "on" if library: kwargs["embed"] = [library] _go_test( @@ -71,6 +124,34 @@ def go_test(name, **kwargs): **kwargs ) +def go_rule(rule, implementation, **kwargs): + """Wraps a rule definition with Go attributes. + + Args: + rule: rule function (typically rule or aspect). + implementation: implementation function. + **kwargs: other arguments to pass to rule. + + Returns: + The result of invoking the rule. + """ + attrs = kwargs.pop("attrs", []) + attrs["_go_context_data"] = attr.label(default = "@io_bazel_rules_go//:go_context_data") + attrs["_stdlib"] = attr.label(default = "@io_bazel_rules_go//:stdlib") + toolchains = kwargs.get("toolchains", []) + ["@io_bazel_rules_go//go:toolchain"] + return rule(implementation, attrs = attrs, toolchains = toolchains, **kwargs) + +def go_context(ctx): + go_ctx = _go_context(ctx) + return struct( + go = go_ctx.go, + env = go_ctx.env, + runfiles = depset([go_ctx.go] + go_ctx.sdk.tools + go_ctx.stdlib.libs), + goos = go_ctx.sdk.goos, + goarch = go_ctx.sdk.goarch, + tags = go_ctx.tags, + ) + def py_requirement(name, direct = True): return _py_requirement(name) diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl index 92b0b5fc0..132040c20 100644 --- a/tools/bazeldefs/platforms.bzl +++ b/tools/bazeldefs/platforms.bzl @@ -2,15 +2,10 @@ # Platform to associated tags. platforms = { - "ptrace": [ - # TODO(b/120560048): Make the tests run without this tag. - "no-sandbox", - ], + "ptrace": [], "kvm": [ "manual", "local", - # TODO(b/120560048): Make the tests run without this tag. - "no-sandbox", ], } diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD new file mode 100644 index 000000000..5748fb390 --- /dev/null +++ b/tools/bigquery/BUILD @@ -0,0 +1,10 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "bigquery", + testonly = 1, + srcs = ["bigquery.go"], + deps = ["@com_google_cloud_go_bigquery//:go_default_library"], +) diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go new file mode 100644 index 000000000..56f0dc5c9 --- /dev/null +++ b/tools/bigquery/bigquery.go @@ -0,0 +1,121 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package bigquery defines a BigQuery schema for benchmarks. +// +// This package contains a schema for BigQuery and methods for publishing +// benchmark data into tables. +package bigquery + +import ( + "context" + "fmt" + "strings" + "time" + + bq "cloud.google.com/go/bigquery" +) + +// Benchmark is the top level structure of recorded benchmark data. BigQuery +// will infer the schema from this. +type Benchmark struct { + Name string `bq:"name"` + Timestamp time.Time `bq:"timestamp"` + Official bool `bq:"official"` + Metric []*Metric `bq:"metric"` + Metadata *Metadata `bq:"metadata"` +} + +// Metric holds the actual metric data and unit information for this benchmark. +type Metric struct { + Name string `bq:"name"` + Unit string `bq:"unit"` + Sample float64 `bq:"sample"` +} + +// Metadata about this benchmark. +type Metadata struct { + CL string `bq:"changelist"` + IterationID string `bq:"iteration_id"` + PendingCL string `bq:"pending_cl"` + Workflow string `bq:"workflow"` + Platform string `bq:"platform"` + Gofer string `bq:"gofer"` +} + +// InitBigQuery initializes a BigQuery dataset/table in the project. If the dataset/table already exists, it is not duplicated. +func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) error { + client, err := bq.NewClient(ctx, projectID) + if err != nil { + return fmt.Errorf("failed to initialize client on project %s: %v", projectID, err) + } + defer client.Close() + + dataset := client.Dataset(datasetID) + if err := dataset.Create(ctx, nil); err != nil && !checkDuplicateError(err) { + return fmt.Errorf("failed to create dataset: %s: %v", datasetID, err) + } + + table := dataset.Table(tableID) + schema, err := bq.InferSchema(Benchmark{}) + if err != nil { + return fmt.Errorf("failed to infer schema: %v", err) + } + + if err := table.Create(ctx, &bq.TableMetadata{Schema: schema}); err != nil && !checkDuplicateError(err) { + return fmt.Errorf("failed to create table: %s: %v", tableID, err) + } + return nil +} + +// AddMetric adds a metric to an existing Benchmark. +func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) { + m := &Metric{ + Name: metricName, + Unit: unit, + Sample: sample, + } + bm.Metric = append(bm.Metric, m) +} + +// NewBenchmark initializes a new benchmark. +func NewBenchmark(name string, official bool) *Benchmark { + return &Benchmark{ + Name: name, + Timestamp: time.Now().UTC(), + Official: official, + Metric: make([]*Metric, 0), + } +} + +// SendBenchmarks sends the slice of benchmarks to the BigQuery dataset/table. +func SendBenchmarks(ctx context.Context, benchmarks []*Benchmark, projectID, datasetID, tableID string) error { + client, err := bq.NewClient(ctx, projectID) + if err != nil { + return fmt.Errorf("Failed to initialize client on project: %s: %v", projectID, err) + } + defer client.Close() + + uploader := client.Dataset(datasetID).Table(tableID).Uploader() + if err = uploader.Put(ctx, benchmarks); err != nil { + return fmt.Errorf("failed to upload benchmarks to proejct %s, table %s.%s: %v", projectID, datasetID, tableID, err) + } + + return nil +} + +// BigQuery will error "409" for duplicate tables and datasets. +func checkDuplicateError(err error) bool { + return strings.Contains(err.Error(), "googleapi: Error 409: Already Exists") +} diff --git a/tools/checkescape/BUILD b/tools/checkescape/BUILD new file mode 100644 index 000000000..b8c3ddf44 --- /dev/null +++ b/tools/checkescape/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "checkescape", + srcs = ["checkescape.go"], + nogo = False, + visibility = ["//tools/nogo:__subpackages__"], + deps = [ + "//tools/nogo/data", + "@org_golang_x_tools//go/analysis:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/buildssa:go_tool_library", + "@org_golang_x_tools//go/ssa:go_tool_library", + ], +) diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go new file mode 100644 index 000000000..571e9a6e6 --- /dev/null +++ b/tools/checkescape/checkescape.go @@ -0,0 +1,726 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package checkescape allows recursive escape analysis for hot paths. +// +// The analysis tracks multiple types of escapes, in two categories. First, +// 'hard' escapes are explicit allocations. Second, 'soft' escapes are +// interface dispatches or dynamic function dispatches; these don't necessarily +// escape but they *may* escape. The analysis is capable of making assertions +// recursively: soft escapes cannot be analyzed in this way, and therefore +// count as escapes for recursive purposes. +// +// The different types of escapes are as follows, with the category in +// parentheses: +// +// heap: A direct allocation is made on the heap (hard). +// builtin: A call is made to a built-in allocation function (hard). +// stack: A stack split as part of a function preamble (soft). +// interface: A call is made via an interface whicy *may* escape (soft). +// dynamic: A dynamic function is dispatched which *may* escape (soft). +// +// To the use the package, annotate a function-level comment with either the +// line "// +checkescape" or "// +checkescape:OPTION[,OPTION]". In the second +// case, the OPTION field is either a type above, or one of: +// +// local: Escape analysis is limited to local hard escapes only. +// all: All the escapes are included. +// hard: All hard escapes are included. +// +// If the "// +checkescape" annotation is provided, this is equivalent to +// provided the local and hard options. +// +// Some examples of this syntax are: +// +// +checkescape:all - Analyzes for all escapes in this function and all calls. +// +checkescape:local - Analyzes only for default local hard escapes. +// +checkescape:heap - Only analyzes for heap escapes. +// +checkescape:interface,dynamic - Only checks for dynamic calls and interface calls. +// +checkescape - Does the same as +checkescape:local,hard. +// +// Note that all of the above can be inverted by using +mustescape. The +// +checkescape keyword will ensure failure if the class of escape occurs, +// whereas +mustescape will fail if the given class of escape does not occur. +// +// Local exemptions can be made by a comment of the form "// escapes: reason." +// This must appear on the line of the escape and will also apply to callers of +// the function as well (for non-local escape analysis). +package checkescape + +import ( + "bufio" + "bytes" + "fmt" + "go/ast" + "go/token" + "go/types" + "io" + "os" + "path/filepath" + "strconv" + "strings" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/buildssa" + "golang.org/x/tools/go/ssa" + "gvisor.dev/gvisor/tools/nogo/data" +) + +const ( + // magic is the magic annotation. + magic = "// +checkescape" + + // magicParams is the magic annotation with specific parameters. + magicParams = magic + ":" + + // testMagic is the test magic annotation (parameters required). + testMagic = "// +mustescape:" + + // exempt is the exemption annotation. + exempt = "// escapes:" +) + +// escapingBuiltins are builtins known to escape. +// +// These are lowered at an earlier stage of compilation to explicit function +// calls, but are not available for recursive analysis. +var escapingBuiltins = []string{ + "append", + "makemap", + "newobject", + "mallocgc", +} + +// Analyzer defines the entrypoint. +var Analyzer = &analysis.Analyzer{ + Name: "checkescape", + Doc: "surfaces recursive escape analysis results", + Run: run, + Requires: []*analysis.Analyzer{buildssa.Analyzer}, + FactTypes: []analysis.Fact{(*packageEscapeFacts)(nil)}, +} + +// packageEscapeFacts is the set of all functions in a package, and whether or +// not they recursively pass escape analysis. +// +// All the type names for receivers are encoded in the full key. The key +// represents the fully qualified package and type name used at link time. +type packageEscapeFacts struct { + Funcs map[string][]Escape +} + +// AFact implements analysis.Fact.AFact. +func (*packageEscapeFacts) AFact() {} + +// CallSite is a single call site. +// +// These can be chained. +type CallSite struct { + LocalPos token.Pos + Resolved LinePosition +} + +// Escape is a single escape instance. +type Escape struct { + Reason EscapeReason + Detail string + Chain []CallSite +} + +// LinePosition is a low-resolution token.Position. +// +// This is used to match against possible exemptions placed in the source. +type LinePosition struct { + Filename string + Line int +} + +// String implements fmt.Stringer.String. +func (e *LinePosition) String() string { + return fmt.Sprintf("%s:%d", e.Filename, e.Line) +} + +// String implements fmt.Stringer.String. +// +// Note that this string will contain new lines. +func (e *Escape) String() string { + var b bytes.Buffer + fmt.Fprintf(&b, "%s", e.Reason.String()) + for i, cs := range e.Chain { + if i == len(e.Chain)-1 { + fmt.Fprintf(&b, "\n @ %s → %s", cs.Resolved.String(), e.Detail) + } else { + fmt.Fprintf(&b, "\n + %s", cs.Resolved.String()) + } + } + return b.String() +} + +// EscapeReason is an escape reason. +// +// This is a simple enum. +type EscapeReason int + +const ( + interfaceInvoke EscapeReason = iota + unknownPackage + allocation + builtin + dynamicCall + stackSplit + reasonCount // Count for below. +) + +// String returns the string for the EscapeReason. +// +// Note that this also implicitly defines the reverse string -> EscapeReason +// mapping, which is the word before the colon (computed below). +func (e EscapeReason) String() string { + switch e { + case interfaceInvoke: + return "interface: function invocation via interface" + case unknownPackage: + return "unknown: no package information available" + case allocation: + return "heap: call to runtime heap allocation" + case builtin: + return "builtin: call to runtime builtin" + case dynamicCall: + return "dynamic: call via dynamic function" + case stackSplit: + return "stack: stack split on function entry" + default: + panic(fmt.Sprintf("unknown reason: %d", e)) + } +} + +var hardReasons = []EscapeReason{ + allocation, + builtin, +} + +var softReasons = []EscapeReason{ + interfaceInvoke, + unknownPackage, + dynamicCall, + stackSplit, +} + +var allReasons = append(hardReasons, softReasons...) + +var escapeTypes = func() map[string]EscapeReason { + result := make(map[string]EscapeReason) + for _, r := range allReasons { + parts := strings.Split(r.String(), ":") + result[parts[0]] = r // Key before ':'. + } + return result +}() + +// EscapeCount counts escapes. +// +// It is used to avoid accumulating too many escapes for the same reason, for +// the same function. We limit each class to 3 instances (arbitrarily). +type EscapeCount struct { + byReason [reasonCount]uint32 +} + +// maxRecordsPerReason is the number of explicit records. +// +// See EscapeCount (and usage), and Record implementation. +const maxRecordsPerReason = 5 + +// Record records the reason or returns false if it should not be added. +func (ec *EscapeCount) Record(reason EscapeReason) bool { + ec.byReason[reason]++ + if ec.byReason[reason] > maxRecordsPerReason { + return false + } + return true +} + +// loadObjdump reads the objdump output. +// +// This records if there is a call any function for every source line. It is +// used only to remove false positives for escape analysis. The call will be +// elided if escape analysis is able to put the object on the heap exclusively. +func loadObjdump() (map[LinePosition]string, error) { + f, err := os.Open(data.Objdump) + if err != nil { + return nil, err + } + defer f.Close() + + // Build the map. + m := make(map[LinePosition]string) + r := bufio.NewReader(f) + var ( + lastField string + lastPos LinePosition + ) + for { + line, err := r.ReadString('\n') + if err != nil && err != io.EOF { + return nil, err + } + + // We recognize lines corresponding to actual code (not the + // symbol name or other metadata) and annotate them if they + // correspond to an explicit CALL instruction. We assume that + // the lack of a CALL for a given line is evidence that escape + // analysis has eliminated an allocation. + // + // Lines look like this (including the first space): + // gohacks_unsafe.go:33 0xa39 488b442408 MOVQ 0x8(SP), AX + if len(line) > 0 && line[0] == ' ' { + fields := strings.Fields(line) + if !strings.Contains(fields[3], "CALL") { + continue + } + + // Ignore strings containing duffzero, which is just + // used by stack allocations for types that are large + // enough to warrant Duff's device. + if strings.Contains(line, "runtime.duffzero") { + continue + } + + // Ignore the racefuncenter call, which is used for + // race builds. This does not escape. + if strings.Contains(line, "runtime.racefuncenter") { + continue + } + + // Calculate the filename and line. Note that per the + // example above, the filename is not a fully qualified + // base, just the basename (what we require). + if fields[0] != lastField { + parts := strings.SplitN(fields[0], ":", 2) + lineNum, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return nil, err + } + lastPos = LinePosition{ + Filename: parts[0], + Line: int(lineNum), + } + lastField = fields[0] + } + if _, ok := m[lastPos]; ok { + continue // Already marked. + } + + // Save the actual call for the detail. + m[lastPos] = strings.Join(fields[3:], " ") + } + if err == io.EOF { + break + } + } + + return m, nil +} + +// poser is a type that implements Pos. +type poser interface { + Pos() token.Pos +} + +// run performs the analysis. +func run(pass *analysis.Pass) (interface{}, error) { + calls, err := loadObjdump() + if err != nil { + return nil, err + } + pef := packageEscapeFacts{ + Funcs: make(map[string][]Escape), + } + linePosition := func(inst, parent poser) LinePosition { + p := pass.Fset.Position(inst.Pos()) + if (p.Filename == "" || p.Line == 0) && parent != nil { + p = pass.Fset.Position(parent.Pos()) + } + return LinePosition{ + Filename: filepath.Base(p.Filename), + Line: p.Line, + } + } + hasCall := func(inst poser) (string, bool) { + p := linePosition(inst, nil) + s, ok := calls[p] + return s, ok + } + callSite := func(inst ssa.Instruction) CallSite { + return CallSite{ + LocalPos: inst.Pos(), + Resolved: linePosition(inst, inst.Parent()), + } + } + escapes := func(reason EscapeReason, detail string, inst ssa.Instruction, ec *EscapeCount) []Escape { + if !ec.Record(reason) { + return nil // Skip. + } + es := Escape{ + Reason: reason, + Detail: detail, + Chain: []CallSite{callSite(inst)}, + } + return []Escape{es} + } + resolve := func(sub []Escape, inst ssa.Instruction, ec *EscapeCount) (es []Escape) { + for _, e := range sub { + if !ec.Record(e.Reason) { + continue // Skip. + } + es = append(es, Escape{ + Reason: e.Reason, + Detail: e.Detail, + Chain: append([]CallSite{callSite(inst)}, e.Chain...), + }) + } + return es + } + state := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) + + var loadFunc func(*ssa.Function) []Escape // Used below. + + analyzeInstruction := func(inst ssa.Instruction, ec *EscapeCount) []Escape { + switch x := inst.(type) { + case *ssa.Call: + if x.Call.IsInvoke() { + // This is an interface dispatch. There is no + // way to know if this is actually escaping or + // not, since we don't know the underlying + // type. + call, _ := hasCall(inst) + return escapes(interfaceInvoke, call, inst, ec) + } + switch x := x.Call.Value.(type) { + case *ssa.Function: + if x.Pkg == nil { + // Can't resolve the package. + return escapes(unknownPackage, "no package", inst, ec) + } + + // Atomic functions are instrinics. We can + // assume that they don't escape. + if x.Pkg.Pkg.Name() == "atomic" { + return nil + } + + // Is this a local function? If yes, call the + // function to load the local function. The + // local escapes are the escapes found in the + // local function. + if x.Pkg.Pkg == pass.Pkg { + return resolve(loadFunc(x), inst, ec) + } + + // Recursively collect information from + // the other analyzers. + var imp packageEscapeFacts + if !pass.ImportPackageFact(x.Pkg.Pkg, &imp) { + // Unable to import the dependency; we must + // declare these as escaping. + return escapes(unknownPackage, "no analysis", inst, ec) + } + + // The escapes of this instruction are the + // escapes of the called function directly. + return resolve(imp.Funcs[x.RelString(x.Pkg.Pkg)], inst, ec) + case *ssa.Builtin: + // Ignore elided escapes. + if _, has := hasCall(inst); !has { + return nil + } + + // Check if the builtin is escaping. + for _, name := range escapingBuiltins { + if x.Name() == name { + return escapes(builtin, name, inst, ec) + } + } + default: + // All dynamic calls are counted as soft + // escapes. They are similar to interface + // dispatches. We cannot actually look up what + // this refers to using static analysis alone. + call, _ := hasCall(inst) + return escapes(dynamicCall, call, inst, ec) + } + case *ssa.Alloc: + // Ignore non-heap allocations. + if !x.Heap { + return nil + } + + // Ignore elided escapes. + call, has := hasCall(inst) + if !has { + return nil + } + + // This is a real heap allocation. + return escapes(allocation, call, inst, ec) + case *ssa.MakeMap: + return escapes(builtin, "makemap", inst, ec) + case *ssa.MakeSlice: + return escapes(builtin, "makeslice", inst, ec) + case *ssa.MakeClosure: + return escapes(builtin, "makeclosure", inst, ec) + case *ssa.MakeChan: + return escapes(builtin, "makechan", inst, ec) + } + return nil // No escapes. + } + + var analyzeBasicBlock func(*ssa.BasicBlock, *EscapeCount) []Escape // Recursive. + analyzeBasicBlock = func(block *ssa.BasicBlock, ec *EscapeCount) (rval []Escape) { + for _, inst := range block.Instrs { + rval = append(rval, analyzeInstruction(inst, ec)...) + } + return rval // N.B. may be empty. + } + + loadFunc = func(fn *ssa.Function) []Escape { + // Is this already available? + name := fn.RelString(pass.Pkg) + if es, ok := pef.Funcs[name]; ok { + return es + } + + // In the case of a true cycle, we assume that the current + // function itself has no escapes until the rest of the + // analysis is complete. This will trip the above in the case + // of a cycle of any kind. + pef.Funcs[name] = nil + + // Perform the basic analysis. + var ( + es []Escape + ec EscapeCount + ) + if fn.Recover != nil { + es = append(es, analyzeBasicBlock(fn.Recover, &ec)...) + } + for _, block := range fn.Blocks { + es = append(es, analyzeBasicBlock(block, &ec)...) + } + + // Check for a stack split. + if call, has := hasCall(fn); has { + es = append(es, Escape{ + Reason: stackSplit, + Detail: call, + Chain: []CallSite{CallSite{ + LocalPos: fn.Pos(), + Resolved: linePosition(fn, fn.Parent()), + }}, + }) + } + + // Save the result and return. + pef.Funcs[name] = es + return es + } + + // Complete all local functions. + for _, fn := range state.SrcFuncs { + loadFunc(fn) + } + + // Build the exception list. + exemptions := make(map[LinePosition]string) + for _, f := range pass.Files { + for _, cg := range f.Comments { + for _, c := range cg.List { + p := pass.Fset.Position(c.Slash) + if strings.HasPrefix(c.Text, exempt) { + exemptions[LinePosition{ + Filename: filepath.Base(p.Filename), + Line: p.Line, + }] = c.Text[len(exempt):] + } + } + } + } + + // Delete everything matching the excemtions. + // + // This has the implication that exceptions are applied recursively, + // since this now modified set is what will be saved. + for name, escapes := range pef.Funcs { + var newEscapes []Escape + for _, escape := range escapes { + isExempt := false + for line, _ := range exemptions { + // Note that an exemption applies if it is + // marked as an exemption anywhere in the call + // chain. It need not be marked as escapes in + // the function itself, nor in the top-level + // caller. + for _, callSite := range escape.Chain { + if callSite.Resolved == line { + isExempt = true + break + } + } + if isExempt { + break + } + } + if !isExempt { + // Record this escape; not an exception. + newEscapes = append(newEscapes, escape) + } + } + pef.Funcs[name] = newEscapes // Update. + } + + // Export all findings for future packages. + pass.ExportPackageFact(&pef) + + // Scan all functions for violations. + for _, f := range pass.Files { + // Scan all declarations. + for _, decl := range f.Decls { + fdecl, ok := decl.(*ast.FuncDecl) + // Function declaration? + if !ok { + continue + } + // Is there a comment? + if fdecl.Doc == nil { + continue + } + var ( + reasons []EscapeReason + found bool + local bool + testReasons = make(map[EscapeReason]bool) // reason -> local? + ) + // Does the comment contain a +checkescape line? + for _, c := range fdecl.Doc.List { + if !strings.HasPrefix(c.Text, magic) && !strings.HasPrefix(c.Text, testMagic) { + continue + } + if c.Text == magic { + // Default: hard reasons, local only. + reasons = hardReasons + local = true + } else if strings.HasPrefix(c.Text, magicParams) { + // Extract specific reasons. + types := strings.Split(c.Text[len(magicParams):], ",") + found = true // For below. + for i := 0; i < len(types); i++ { + if types[i] == "local" { + // Limit search to local escapes. + local = true + } else if types[i] == "all" { + // Append all reasons. + reasons = append(reasons, allReasons...) + } else if types[i] == "hard" { + // Append all hard reasons. + reasons = append(reasons, hardReasons...) + } else { + r, ok := escapeTypes[types[i]] + if !ok { + // This is not a valid escape reason. + pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i]) + continue + } + reasons = append(reasons, r) + } + } + } else if strings.HasPrefix(c.Text, testMagic) { + types := strings.Split(c.Text[len(testMagic):], ",") + local := false + for i := 0; i < len(types); i++ { + if types[i] == "local" { + local = true + } else { + r, ok := escapeTypes[types[i]] + if !ok { + // This is not a valid escape reason. + pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i]) + continue + } + if v, ok := testReasons[r]; ok && v { + // Already registered as local. + continue + } + testReasons[r] = local + } + } + } + } + if len(reasons) == 0 && found { + // A magic annotation was provided, but no reasons. + pass.Reportf(fdecl.Pos(), "no reasons provided") + continue + } + + // Scan for matches. + fn := pass.TypesInfo.Defs[fdecl.Name].(*types.Func) + name := state.Pkg.Prog.FuncValue(fn).RelString(pass.Pkg) + es, ok := pef.Funcs[name] + if !ok { + pass.Reportf(fdecl.Pos(), "internal error: function %s not found.", name) + continue + } + for _, e := range es { + for _, r := range reasons { + // Is does meet our local requirement? + if local && len(e.Chain) > 1 { + continue + } + // Does this match the reason? Emit + // with a full stack trace that + // explains why this violates our + // constraints. + if e.Reason == r { + pass.Reportf(e.Chain[0].LocalPos, "%s", e.String()) + } + } + } + + // Scan for test (required) matches. + testReasonsFound := make(map[EscapeReason]bool) + for _, e := range es { + // Is this local? + local, ok := testReasons[e.Reason] + wantLocal := len(e.Chain) == 1 + testReasonsFound[e.Reason] = wantLocal + if !ok { + continue + } + if local == wantLocal { + delete(testReasons, e.Reason) + } + } + for reason, local := range testReasons { + // We didn't find the escapes we wanted. + pass.Reportf(fdecl.Pos(), fmt.Sprintf("testescapes not found: reason=%s, local=%t", reason, local)) + } + if len(testReasons) > 0 { + // Dump all reasons found to help in debugging. + for _, e := range es { + pass.Reportf(e.Chain[0].LocalPos, "escape found: %s", e.String()) + } + } + } + } + + return nil, nil +} diff --git a/tools/checkescape/test1/BUILD b/tools/checkescape/test1/BUILD new file mode 100644 index 000000000..783403247 --- /dev/null +++ b/tools/checkescape/test1/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "test1", + srcs = ["test1.go"], + visibility = ["//tools/checkescape/test2:__pkg__"], +) diff --git a/tools/checkescape/test1/test1.go b/tools/checkescape/test1/test1.go new file mode 100644 index 000000000..68d3f72cc --- /dev/null +++ b/tools/checkescape/test1/test1.go @@ -0,0 +1,195 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package test1 is a test package. +package test1 + +import ( + "fmt" + "reflect" +) + +// Interface is a generic interface. +type Interface interface { + Foo() +} + +// Type is a concrete implementation of Interface. +type Type struct { + A uint64 + B uint64 +} + +// Foo implements Interface.Foo. +//go:nosplit +func (t Type) Foo() { + fmt.Printf("%v", t) // Never executed. +} + +// +checkescape:all,hard +//go:nosplit +func InterfaceFunction(i Interface) { + // Do nothing; exported for tests. +} + +// +checkesacape:all,hard +//go:nosplit +func TypeFunction(t *Type) { +} + +// +mustescape:local,builtin +//go:noinline +//go:nosplit +func BuiltinMap(x int) map[string]bool { + return make(map[string]bool) +} + +// +mustescape:builtin +//go:noinline +//go:nosplit +func builtinMapRec(x int) map[string]bool { + return BuiltinMap(x) +} + +// +temustescapestescape:local,builtin +//go:noinline +//go:nosplit +func BuiltinClosure(x int) func() { + return func() { + fmt.Printf("%v", x) + } +} + +// +mustescape:builtin +//go:noinline +//go:nosplit +func builtinClosureRec(x int) func() { + return BuiltinClosure(x) +} + +// +mustescape:local,builtin +//go:noinline +//go:nosplit +func BuiltinMakeSlice(x int) []byte { + return make([]byte, x) +} + +// +mustescape:builtin +//go:noinline +//go:nosplit +func builtinMakeSliceRec(x int) []byte { + return BuiltinMakeSlice(x) +} + +// +mustescape:local,builtin +//go:noinline +//go:nosplit +func BuiltinAppend(x []byte) []byte { + return append(x, 0) +} + +// +mustescape:builtin +//go:noinline +//go:nosplit +func builtinAppendRec() []byte { + return BuiltinAppend(nil) +} + +// +mustescape:local,builtin +//go:noinline +//go:nosplit +func BuiltinChan() chan int { + return make(chan int) +} + +// +mustescape:builtin +//go:noinline +//go:nosplit +func builtinChanRec() chan int { + return BuiltinChan() +} + +// +mustescape:local,heap +//go:noinline +//go:nosplit +func Heap() *Type { + var t Type + return &t +} + +// +mustescape:heap +//go:noinline +//go:nosplit +func heapRec() *Type { + return Heap() +} + +// +mustescape:local,interface +//go:noinline +//go:nosplit +func Dispatch(i Interface) { + i.Foo() +} + +// +mustescape:interface +//go:noinline +//go:nosplit +func dispatchRec(i Interface) { + Dispatch(i) +} + +// +mustescape:local,dynamic +//go:noinline +//go:nosplit +func Dynamic(f func()) { + f() +} + +// +mustescape:dynamic +//go:noinline +//go:nosplit +func dynamicRec(f func()) { + Dynamic(f) +} + +// +mustescape:local,unknown +//go:noinline +//go:nosplit +func Unknown() { + _ = reflect.TypeOf((*Type)(nil)) // Does not actually escape. +} + +// +mustescape:unknown +//go:noinline +//go:nosplit +func unknownRec() { + Unknown() +} + +//go:noinline +//go:nosplit +func internalFunc() { +} + +// +mustescape:local,stack +//go:noinline +func Split() { + internalFunc() +} + +// +mustescape:stack +//go:noinline +func splitRec() { + Split() +} diff --git a/tools/checkescape/test2/BUILD b/tools/checkescape/test2/BUILD new file mode 100644 index 000000000..5a11e4b43 --- /dev/null +++ b/tools/checkescape/test2/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "test2", + srcs = ["test2.go"], + deps = ["//tools/checkescape/test1"], +) diff --git a/tools/checkescape/test2/test2.go b/tools/checkescape/test2/test2.go new file mode 100644 index 000000000..7fce3e3be --- /dev/null +++ b/tools/checkescape/test2/test2.go @@ -0,0 +1,94 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package test2 is a test package that imports test1. +package test2 + +import ( + "gvisor.dev/gvisor/tools/checkescape/test1" +) + +// +checkescape:all +//go:nosplit +func interfaceFunctionCrossPkg() { + var i test1.Interface + test1.InterfaceFunction(i) +} + +// +checkesacape:all +//go:nosplit +func typeFunctionCrossPkg() { + var t test1.Type + test1.TypeFunction(&t) +} + +// +mustescape:builtin +//go:noinline +func builtinMapCrossPkg(x int) map[string]bool { + return test1.BuiltinMap(x) +} + +// +mustescape:builtin +//go:noinline +func builtinClosureCrossPkg(x int) func() { + return test1.BuiltinClosure(x) +} + +// +mustescape:builtin +//go:noinline +func builtinMakeSliceCrossPkg(x int) []byte { + return test1.BuiltinMakeSlice(x) +} + +// +mustescape:builtin +//go:noinline +func builtinAppendCrossPkg() []byte { + return test1.BuiltinAppend(nil) +} + +// +mustescape:builtin +//go:noinline +func builtinChanCrossPkg() chan int { + return test1.BuiltinChan() +} + +// +mustescape:heap +//go:noinline +func heapCrossPkg() *test1.Type { + return test1.Heap() +} + +// +mustescape:interface +//go:noinline +func dispatchCrossPkg(i test1.Interface) { + test1.Dispatch(i) +} + +// +mustescape:dynamic +//go:noinline +func dynamicCrossPkg(f func()) { + test1.Dynamic(f) +} + +// +mustescape:unknown +//go:noinline +func unknownCrossPkg() { + test1.Unknown() +} + +// +mustescape:stack +//go:noinline +func splitCrosssPkt() { + test1.Split() +} diff --git a/tools/checkunsafe/BUILD b/tools/checkunsafe/BUILD index 4f1a31a6d..0c264151b 100644 --- a/tools/checkunsafe/BUILD +++ b/tools/checkunsafe/BUILD @@ -1,11 +1,12 @@ -load("//tools:defs.bzl", "go_tool_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) -go_tool_library( +go_library( name = "checkunsafe", srcs = ["check_unsafe.go"], - visibility = ["//:sandbox"], + nogo = False, + visibility = ["//tools/nogo:__subpackages__"], deps = [ "@org_golang_x_tools//go/analysis:go_tool_library", ], diff --git a/tools/defs.bzl b/tools/defs.bzl index 15a310403..cdaf281f3 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,36 +7,38 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") load("//tools/bazeldefs:tags.bzl", "go_suffixes") +load("//tools/nogo:defs.bzl", "nogo_test") # Delegate directly. cc_binary = _cc_binary +cc_flags_supplier = _cc_flags_supplier +cc_grpc_library = _cc_grpc_library cc_library = _cc_library cc_test = _cc_test cc_toolchain = _cc_toolchain -cc_flags_supplier = _cc_flags_supplier -container_image = _container_image +default_installer = _default_installer +default_net_util = _default_net_util +gbenchmark = _gbenchmark go_embed_data = _go_embed_data -go_image = _go_image go_test = _go_test -go_tool_library = _go_tool_library gtest = _gtest -gbenchmark = _gbenchmark +grpcpp = _grpcpp +loopback = _loopback pkg_deb = _pkg_deb pkg_tar = _pkg_tar -py_library = _py_library py_binary = _py_binary -py_test = _py_test +py_library = _py_library py_requirement = _py_requirement +py_test = _py_test select_arch = _select_arch select_system = _select_system -loopback = _loopback -default_installer = _default_installer -default_net_util = _default_net_util -platforms = _platforms + +# Platform options. default_platform = _default_platform +platforms = _platforms def go_binary(name, **kwargs): """Wraps the standard go_binary. @@ -88,7 +90,7 @@ def go_imports(name, src, out): cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"), ) -def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, **kwargs): +def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, nogo = False, **kwargs): """Wraps the standard go_library and does stateification and marshalling. The recommended way is to use this rule with mostly identical configuration as the native @@ -174,6 +176,11 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F deps = all_deps, **kwargs ) + if nogo: + nogo_test( + name = name + "_nogo", + deps = [":" + name], + ) if marshal: # Ignore importpath for go_test. @@ -190,33 +197,52 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F **kwargs ) -def proto_library(name, srcs, **kwargs): +def proto_library(name, srcs, deps = None, has_services = 0, **kwargs): """Wraps the standard proto_library. - Given a proto_library named "foo", this produces three different targets: + Given a proto_library named "foo", this produces up to five different + targets: - foo_proto: proto_library rule. - foo_go_proto: go_proto_library rule. - foo_cc_proto: cc_proto_library rule. + - foo_go_grpc_proto: go_grpc_library rule. + - foo_cc_grpc_proto: cc_grpc_library rule. Args: + name: the name to which _proto, _go_proto, etc, will be appended. srcs: the proto sources. + deps: for the proto library and the go_proto_library. + has_services: 1 to build gRPC code, otherwise 0. **kwargs: standard proto_library arguments. """ - deps = kwargs.pop("deps", []) _proto_library( name = name + "_proto", srcs = srcs, deps = deps, + has_services = has_services, **kwargs ) - _go_proto_library( - name = name + "_go_proto", - proto = ":" + name + "_proto", - deps = deps, - **kwargs - ) + if has_services: + _go_grpc_and_proto_libraries( + name = name, + deps = deps, + **kwargs + ) + else: + _go_proto_library( + name = name, + deps = deps, + **kwargs + ) _cc_proto_library( name = name + "_cc_proto", deps = [":" + name + "_proto"], **kwargs ) + if has_services: + _cc_grpc_library( + name = name + "_cc_grpc_proto", + srcs = [":" + name + "_proto"], + deps = [":" + name + "_cc_proto"], + **kwargs + ) diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl index c5be52ecd..8c9995fd4 100644 --- a/tools/go_generics/defs.bzl +++ b/tools/go_generics/defs.bzl @@ -105,7 +105,6 @@ def _go_template_instance_impl(ctx): executable = ctx.executable._tool, ) - # TODO: How can we get the dependencies out? return struct( files = depset([output]), ) diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go index 9a9a4f298..cd55cf5cb 100644 --- a/tools/go_marshal/analysis/analysis_unsafe.go +++ b/tools/go_marshal/analysis/analysis_unsafe.go @@ -161,6 +161,10 @@ func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) { if typ.NumField() > 0 && nextXOff != int(typ.Size()) { implicitPad := int(typ.Size()) - nextXOff f := typ.Field(typ.NumField() - 1) // Final field + if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" { + // Final field explicitly marked unaligned. + break + } t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.", typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name) } diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl index d79786a68..323e33882 100644 --- a/tools/go_marshal/defs.bzl +++ b/tools/go_marshal/defs.bzl @@ -53,9 +53,10 @@ go_marshal = rule( # marshal_deps are the dependencies requied by generated code. marshal_deps = [ - "//tools/go_marshal/marshal", + "//pkg/gohacks", "//pkg/safecopy", "//pkg/usermem", + "//tools/go_marshal/marshal", ] # marshal_test_deps are required by test targets. diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 729489de5..177013dbb 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -28,12 +28,6 @@ import ( "gvisor.dev/gvisor/tools/tags" ) -const ( - marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal" - safecopyImport = "gvisor.dev/gvisor/pkg/safecopy" - usermemImport = "gvisor.dev/gvisor/pkg/usermem" -) - // List of identifiers we use in generated code that may conflict with a // similarly-named source identifier. Abort gracefully when we see these to // avoid potentially confusing compilation failures in generated code. @@ -44,8 +38,8 @@ const ( // All recievers are single letters, so we don't allow import aliases to be a // single letter. var badIdents = []string{ - "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len", - "ptr", "src", "srcs", "task", "val", + "addr", "blk", "buf", "dst", "dsts", "count", "err", "hdr", "idx", "inner", + "length", "limit", "ptr", "size", "src", "srcs", "task", "val", // All single-letter identifiers. } @@ -110,9 +104,10 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G g.imports.add("reflect") g.imports.add("runtime") g.imports.add("unsafe") - g.imports.add(marshalImport) - g.imports.add(safecopyImport) - g.imports.add(usermemImport) + g.imports.add("gvisor.dev/gvisor/pkg/gohacks") + g.imports.add("gvisor.dev/gvisor/pkg/safecopy") + g.imports.add("gvisor.dev/gvisor/pkg/usermem") + g.imports.add("gvisor.dev/gvisor/tools/go_marshal/marshal") return &g, nil } @@ -194,10 +189,73 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { return files, fsets, nil } +// sliceAPI carries information about the '+marshal slice' directive. +type sliceAPI struct { + // Comment node in the AST containing the +marshal tag. + comment *ast.Comment + // Identifier fragment to use when naming generated functions for the slice + // API. + ident string + // Whether the generated functions should reference the newtype name, or the + // inner type name. Only meaningful on newtype declarations on primitives. + inner bool +} + +// marshallableType carries information about a type marked with the '+marshal' +// directive. +type marshallableType struct { + spec *ast.TypeSpec + slice *sliceAPI +} + +func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) marshallableType { + mt := marshallableType{ + spec: spec, + slice: nil, + } + + var unhandledTags []string + + for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) { + if strings.HasPrefix(tag, "slice:") { + tokens := strings.Split(tag, ":") + if len(tokens) < 2 || len(tokens) > 3 { + abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag)) + } + if len(tokens[1]) == 0 { + abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'") + } + + sa := &sliceAPI{ + comment: tagLine, + ident: tokens[1], + } + mt.slice = sa + + if len(tokens) == 3 { + if tokens[2] != "inner" { + abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'") + } + sa.inner = true + } + + continue + } + + unhandledTags = append(unhandledTags, tag) + } + + if len(unhandledTags) > 0 { + abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " "))) + } + + return mt +} + // collectMarshallableTypes walks the parsed AST and collects a list of type // declarations for which we need to generate the Marshallable interface. -func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { - var types []*ast.TypeSpec +func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []marshallableType { + var types []marshallableType for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) // Type declaration? @@ -212,9 +270,11 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*a } // Does the comment contain a "+marshal" line? marked := false + var tagLine *ast.Comment for _, c := range gdecl.Doc.List { - if c.Text == "// +marshal" { + if strings.HasPrefix(c.Text, "// +marshal") { marked = true + tagLine = c break } } @@ -229,20 +289,17 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*a switch t.Type.(type) { case *ast.StructType: debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) - types = append(types, t) - continue case *ast.Ident: // Newtype on primitive. debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) - types = append(types, t) - continue case *ast.ArrayType: // Newtype on array. debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) - types = append(types, t) - continue + default: + // A user specifically requested marshalling on this type, but we + // don't support it. + abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) } - // A user specifically requested marshalling on this type, but we - // don't support it. - abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) + types = append(types, newMarshallableType(f, tagLine, t)) + } } return types @@ -269,7 +326,7 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp // Make sure we have an import that doesn't use any local names that // would conflict with identifiers in the generated code. - if len(i.name) == 1 { + if len(i.name) == 1 && i.name != "_" { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) } if _, ok := badIdentsMap[i.name]; ok { @@ -281,19 +338,28 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp } -func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - i := newInterfaceGenerator(t, fset) - switch ty := t.Type.(type) { +func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interfaceGenerator { + i := newInterfaceGenerator(t.spec, fset) + switch ty := t.spec.Type.(type) { case *ast.StructType: - i.validateStruct(t, ty) + i.validateStruct(t.spec, ty) i.emitMarshallableForStruct(ty) + if t.slice != nil { + i.emitMarshallableSliceForStruct(ty, t.slice) + } case *ast.Ident: i.validatePrimitiveNewtype(ty) i.emitMarshallableForPrimitiveNewtype(ty) + if t.slice != nil { + i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice) + } case *ast.ArrayType: - i.validateArrayNewtype(t.Name, ty) + i.validateArrayNewtype(t.spec.Name, ty) // After validate, we can safely call arrayLen. - i.emitMarshallableForArrayNewtype(t.Name, ty.Elt.(*ast.Ident), arrayLen(ty)) + i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident)) + if t.slice != nil { + abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?")) + } default: // This should've been filtered out by collectMarshallabeTypes. panic(fmt.Sprintf("Unexpected type %+v", ty)) @@ -303,9 +369,9 @@ func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interface // generateOneTestSuite generates a test suite for the automatically generated // implementations type t. -func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator { - i := newTestGenerator(t) - i.emitTests() +func (g *Generator) generateOneTestSuite(t marshallableType) *testGenerator { + i := newTestGenerator(t.spec) + i.emitTests(t.slice) return i } @@ -355,7 +421,7 @@ func (g *Generator) Run() error { // the list of imports we need to copy to the generated code. for name, _ := range impl.is { if !g.imports.markUsed(name) { - panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name)) + panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name)) } } ts = append(ts, g.generateOneTestSuite(t)) @@ -413,7 +479,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error { // empty example instead. if len(ts) == 0 { b.reset() - b.emit("func ExampleEmptyTestSuite() {\n") + b.emit("func Example() {\n") b.inIndent(func() { b.emit("// This example is intentionally empty to ensure this file contains at least\n") b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n") diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index 8babf61d2..e3c3dac63 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -15,8 +15,10 @@ package gomarshal import ( + "fmt" "go/ast" "go/token" + "strings" ) // interfaceGenerator generates marshalling interfaces for a single type. @@ -72,7 +74,6 @@ func (g *interfaceGenerator) recordUsedMarshallable(m string) { func (g *interfaceGenerator) recordUsedImport(i string) { g.is[i] = struct{}{} - } func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) { @@ -163,3 +164,113 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) { g.recordPotentiallyNonPackedField(accessor) } } + +// emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a +// byte slice, bypassing escape analysis. The caller is responsible for ensuring +// srcPtr lives until they're done with dstVar, the runtime does not consider +// dstVar dependent on srcPtr due to the escape analysis bypass. +// +// srcPtr must be a pointer. +// +// This function uses internally uses the identifier "hdr", and cannot be used +// in a context where it is already bound. +func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) { + g.recordUsedImport("gohacks") + g.emit("// Construct a slice backed by dst's underlying memory.\n") + g.emit("var %s []byte\n", dstVar) + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar) + g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr) + g.emit("hdr.Len = %s\n", lenExpr) + g.emit("hdr.Cap = %s\n\n", lenExpr) +} + +// emitCastToByteSlice unsafely casts a slice with elements of an abitrary type +// to a byte slice. As part of the cast, the byte slice is made to look +// independent of the src slice by bypassing escape analysis. This means the +// byte slice can be used without causing the source to escape. The caller is +// responsible for ensuring srcPtr lives until they're done with dstVar, as the +// runtime no longer considers dstVar dependent on srcPtr and is free to GC it. +// +// srcPtr must be a pointer. +// +// This function uses internally uses the identifiers "ptr", "val" and "hdr", +// and cannot be used in a context where these identifiers are already bound. +func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) { + g.emitNoEscapeSliceDataPointer(srcPtr, "val") + + g.emit("// Construct a slice backed by dst's underlying memory.\n") + g.emit("var %s []byte\n", dstVar) + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar) + g.emit("hdr.Data = uintptr(val)\n") + g.emit("hdr.Len = %s\n", lenExpr) + g.emit("hdr.Cap = %s\n\n", lenExpr) +} + +// emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an +// unsafe.Pointer, bypassing escape analysis. The caller is responsible for +// ensuring srcPtr lives until they're done with dstVar, as the runtime no +// longer considers dstVar dependent on srcPtr and is free to GC it. +// +// srcPtr must be a pointer. +// +// This function uses internally uses the identifier "ptr" cannot be used in a +// context where this identifier is already bound. +func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) { + g.recordUsedImport("gohacks") + g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr) + g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar) +} + +func (g *interfaceGenerator) emitKeepAlive(ptrVar string) { + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar) + g.emit("// must live until the use above.\n") + g.emit("runtime.KeepAlive(%s)\n", ptrVar) +} + +func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) { + switch x := e.X.(type) { + case *ast.BinaryExpr: + // Recursively expand sub-expression. + g.expandBinaryExpr(b, x) + case *ast.Ident: + fmt.Fprintf(b, "%s", x.Name) + case *ast.BasicLit: + fmt.Fprintf(b, "%s", x.Value) + default: + g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") + } + + fmt.Fprintf(b, "%s", e.Op) + + switch y := e.Y.(type) { + case *ast.BinaryExpr: + // Recursively expand sub-expression. + g.expandBinaryExpr(b, y) + case *ast.Ident: + fmt.Fprintf(b, "%s", y.Name) + case *ast.BasicLit: + fmt.Fprintf(b, "%s", y.Value) + default: + g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") + } +} + +// arrayLenExpr returns a string containing a valid golang expression +// representing the length of array a. The returned expression should be treated +// as a single value, and will be already parenthesized as required. +func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string { + var b strings.Builder + + switch l := a.Len.(type) { + case *ast.Ident: + fmt.Fprintf(&b, "%s", l.Name) + case *ast.BasicLit: + fmt.Fprintf(&b, "%s", l.Value) + case *ast.BinaryExpr: + g.expandBinaryExpr(&b, l) + return fmt.Sprintf("(%s)", b.String()) + default: + g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") + } + return b.String() +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go index da36d9305..72ef03a22 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -27,20 +27,12 @@ func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) } - if _, ok := a.Len.(*ast.BasicLit); !ok { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions")) - } - if _, ok := a.Elt.(*ast.Ident); !ok { g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) } - - if arrayLen(a) <= 0 { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) - } } -func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) { +func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) { g.recordUsedImport("io") g.recordUsedImport("marshal") g.recordUsedImport("reflect") @@ -49,13 +41,16 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, g.recordUsedImport("unsafe") g.recordUsedImport("usermem") + lenExpr := g.arrayLenExpr(a) + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("//go:nosplit\n") g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) g.inIndent(func() { if size, dynamic := g.scalarSize(elt); !dynamic { - g.emit("return %d\n", size*len) + g.emit("return %d * %s\n", size, lenExpr) } else { - g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len) + g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr) } }) g.emit("}\n\n") @@ -63,7 +58,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("for idx := 0; idx < %d; idx++ {\n", len) + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) g.inIndent(func() { g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst") }) @@ -74,7 +69,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("for idx := 0; idx < %d; idx++ {\n", len) + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) g.inIndent(func() { g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src") }) @@ -83,6 +78,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, g.emit("}\n\n") g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) g.inIndent(func() { g.emit("// Array newtypes are always packed.\n") @@ -104,79 +100,46 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, }) g.emit("}\n\n") + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") + g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) }) g.emit("}\n\n") g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") }) g.emit("}\n\n") g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := w.Write(buf)\n") + g.emitKeepAlive(g.r) + g.emit("return int64(length), err\n") }) g.emit("}\n\n") diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go index 159397825..39f654ea8 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -104,6 +104,7 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.recordUsedImport("usermem") g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("//go:nosplit\n") g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) g.inIndent(func() { if size, dynamic := g.scalarSize(nt); !dynamic { @@ -129,6 +130,7 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.emit("}\n\n") g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) g.inIndent(func() { g.emit("// Scalar newtypes are always packed.\n") @@ -150,80 +152,138 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) }) g.emit("}\n\n") + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") + g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) }) g.emit("}\n\n") g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") }) g.emit("}\n\n") g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := w.Write(buf)\n") + g.emitKeepAlive(g.r) + g.emit("return int64(length), err\n") + + }) + g.emit("}\n\n") +} + +func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Ident, slice *sliceAPI) { + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + + eltType := g.typeName() + if slice.inner { + eltType = nt.Name + } + + g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType) + g.emit("//go:nosplit\n") + g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitCastSliceToByteSlice("&dst", "buf", "size * count") + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType) + g.emit("//go:nosplit\n") + g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitCastSliceToByteSlice("&src", "buf", "size * count") + + g.emit("length, err := task.CopyOutBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitNoEscapeSliceDataPointer("&src", "val") + + g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitNoEscapeSliceDataPointer("&dst", "val") + g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") }) g.emit("}\n\n") } diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go index e66a38b2e..9cd3c9579 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -62,8 +62,8 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType // No validation to perform on selector fields. However this // callback must still be provided. }, - array: func(n, _ *ast.Ident, len int) { - g.validateArrayNewtype(n, f.Type.(*ast.ArrayType)) + array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) { + g.validateArrayNewtype(n, a) }, unhandled: func(_ *ast.Ident) { g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) @@ -72,20 +72,24 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType }) } -func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { - // Is g.t a packed struct without consideing field types? - thisPacked := true +func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool { + packed := true forEachStructField(st, func(f *ast.Field) { if f.Tag != nil { if f.Tag.Value == "`marshal:\"unaligned\"`" { - if thisPacked { + if packed { debugfAt(g.f.Position(g.t.Pos()), fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) - thisPacked = false + packed = false } } } }) + return packed +} + +func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { + thisPacked := g.isStructPacked(st) g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) @@ -108,16 +112,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.recordUsedMarshallable(tName) dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) }, - array: func(n, t *ast.Ident, len int) { - if len < 1 { - // Zero-length arrays should've been rejected by validate(). - panic("unreachable") - } + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size * len + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr)) } else { g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr)) } }, }.dispatch) @@ -148,7 +149,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.shift("dst", len) } else { // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here + // an instance of the dynamic type we can reference here // (since the version in this struct is anonymous). Use // a typed nil pointer to call SizeBytes() instead. g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) @@ -158,24 +159,30 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") }, selector: func(n, tX, tSel *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name) + g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) + return + } g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") }, - array: func(n, t *ast.Ident, size int) { + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len*size) + g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name) + if size, dynamic := g.scalarSize(t); !dynamic { + g.emit("dst = dst[%d*(%s):]\n", size, lenExpr) } else { // We can't use shiftDynamic here because we don't have // an instance of the dynamic type we can reference here // (since the version in this struct is anonymous). Use // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) } return } - g.emit("for idx := 0; idx < %d; idx++ {\n", size) + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) g.inIndent(func() { g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") }) @@ -195,11 +202,11 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { if len, dynamic := g.scalarSize(t); !dynamic { g.shift("src", len) } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) + // We don't have an instance of the dynamic type we can + // reference here (since the version in this struct is + // anonymous). Use a typed nil pointer to call + // SizeBytes() instead. + g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name)) g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) } return @@ -207,24 +214,31 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") }, selector: func(n, tX, tSel *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name) + g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name)) + return + } g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") }, - array: func(n, t *ast.Ident, size int) { + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) if n.Name == "_" { - g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len*size) + g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr) + if size, dynamic := g.scalarSize(t); !dynamic { + g.emit("src = src[%d*(%s):]\n", size, lenExpr) } else { // We can't use shiftDynamic here because we don't have // an instance of the dynamic type we can referece here // (since the version in this struct is anonymous). Use // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) } return } - g.emit("for idx := 0; idx < %d; idx++ {\n", size) + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) g.inIndent(func() { g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") }) @@ -235,6 +249,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) g.inIndent(func() { expr, fieldsMaybePacked := g.areFieldsPackedExpression() @@ -302,17 +317,17 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { }) g.emit("}\n\n") - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) - g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("return err\n") + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r) + g.emit("return task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") } if thisPacked { g.recordUsedImport("reflect") @@ -324,48 +339,41 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") } // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") } else { fallback() } }) g.emit("}\n\n") + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("//go:nosplit\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + }) + g.emit("}\n\n") + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("//go:nosplit\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("if err != nil {\n") - g.inIndent(func() { - g.emit("return err\n") - }) - g.emit("}\n") - - g.emit("%s.UnmarshalBytes(buf)\n", g.r) - g.emit("return nil\n") + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n") + g.emit("// partially unmarshalled struct.\n") + g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r) + g.emit("return length, err\n") } if thisPacked { g.recordUsedImport("reflect") @@ -377,25 +385,11 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") } // Fast deserialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") } else { fallback() } @@ -410,8 +404,8 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("n, err := w.Write(buf)\n") - g.emit("return int64(n), err\n") + g.emit("length, err := w.Write(buf)\n") + g.emit("return int64(length), err\n") } if thisPacked { g.recordUsedImport("reflect") @@ -423,25 +417,199 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") } // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := w.Write(buf)\n") + g.emitKeepAlive(g.r) + g.emit("return int64(length), err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") +} + +func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) { + thisPacked := g.isStructPacked(st) + + if slice.inner { + abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident)) + } + + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + + g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName()) + g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(size * count)\n") + g.emit("length, err := task.CopyInBytes(addr, buf)\n\n") + + g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n") + g.emit("limit := length/size\n") + g.emit("for idx := 0; idx < limit; idx++ {\n") + g.inIndent(func() { + g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n\n") + + g.emit("// Handle any final partial object.\n") + g.emit("if length < size*count && length%size != 0 {\n") + g.inIndent(func() { + g.emit("idx := limit\n") + g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n\n") + + g.emit("return length, err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !dst[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emitCastSliceToByteSlice("&dst", "buf", "size * count") + + g.emit("length, err := task.CopyInBytes(addr, buf)\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName()) + g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(size * count)\n") + g.emit("for idx := 0; idx < count; idx++ {\n") + g.inIndent(func() { + g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n") + g.emit("return task.CopyOutBytes(addr, buf)\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !src[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emitCastSliceToByteSlice("&src", "buf", "size * count") + + g.emit("length, err := task.CopyOutBytes(addr, buf)\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("for idx := 0; idx < count; idx++ {\n") + g.inIndent(func() { + g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n") + }) + g.emit("}\n") + g.emit("return size * count, nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !src[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + g.emitNoEscapeSliceDataPointer("&src", "val") + + g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("for idx := 0; idx < count; idx++ {\n") + g.inIndent(func() { + g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n") + g.emit("return size * count, nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !dst[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + g.emitNoEscapeSliceDataPointer("&dst", "val") + + g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") } else { fallback() } diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index fd992e44a..631295373 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -30,6 +30,11 @@ var standardImports = []string{ "gvisor.dev/gvisor/tools/go_marshal/analysis", } +var sliceAPIImports = []string{ + "encoding/binary", + "gvisor.dev/gvisor/pkg/usermem", +} + type testGenerator struct { sourceBuffer @@ -58,6 +63,11 @@ func newTestGenerator(t *ast.TypeSpec) *testGenerator { for _, i := range standardImports { g.imports.add(i).markUsed() } + // These imports are used if a type requests the slice API. Don't + // mark them as used by default. + for _, i := range sliceAPIImports { + g.imports.add(i) + } return g } @@ -132,6 +142,42 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { }) } +func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) { + for _, name := range []string{"binary", "usermem"} { + if !g.imports.markUsed(name) { + panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name)) + } + } + + g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() { + g.emit("var x, y, yUnsafe [8]%s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName()) + g.emit("buf := bytes.NewBuffer(make([]byte, size))\n") + g.emit("buf.Reset()\n") + g.emit("if err := binary.Write(buf, usermem.ByteOrder, x[:]); err != nil {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n") + }) + g.emit("}\n") + g.emit("bufUnsafe := make([]byte, size)\n") + g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident) + + g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident) + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident) + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n\n") + }) +} + func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { g.emit("var x, y, yUnsafe %s\n", g.typeName()) @@ -170,12 +216,16 @@ func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { }) } -func (g *testGenerator) emitTests() { +func (g *testGenerator) emitTests(slice *sliceAPI) { g.emitTestNonZeroSize() g.emitTestSuspectAlignment() g.emitTestMarshalUnmarshalPreservesData() g.emitTestWriteToUnmarshalPreservesData() g.emitTestSizeBytesOnTypedNilPtr() + + if slice != nil { + g.emitTestMarshalUnmarshalSlicePreservesData(slice) + } } func (g *testGenerator) write(out io.Writer) error { diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go index a0936e013..d94314302 100644 --- a/tools/go_marshal/gomarshal/util.go +++ b/tools/go_marshal/gomarshal/util.go @@ -25,7 +25,6 @@ import ( "path" "reflect" "sort" - "strconv" "strings" ) @@ -75,29 +74,10 @@ func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { type fieldDispatcher struct { primitive func(n, t *ast.Ident) selector func(n, tX, tSel *ast.Ident) - array func(n, t *ast.Ident, size int) + array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) unhandled func(n *ast.Ident) } -// Precondition: a must have a literal for the array length. Consts and -// expressions are not allowed as array lengths, and should be rejected by the -// caller. -func arrayLen(a *ast.ArrayType) int { - if a.Len == nil { - // Probably a slice? Must be handled by caller. - panic("Nil array length in array type") - } - lenLit, ok := a.Len.(*ast.BasicLit) - if !ok { - panic("Array has non-literal for length") - } - len, err := strconv.Atoi(lenLit.Value) - if err != nil { - panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err)) - } - return len -} - // Precondition: All dispatch callbacks that will be invoked must be // provided. Embedded fields are not allowed, len(f.Names) >= 1. func (fd fieldDispatcher) dispatch(f *ast.Field) { @@ -123,7 +103,7 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) { case *ast.ArrayType: switch t := v.Elt.(type) { case *ast.Ident: - fd.array(name, t, arrayLen(v)) + fd.array(name, v, t) default: // Should be handled with a better error message during validate. panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) @@ -285,6 +265,11 @@ type importStmt struct { aliased bool // Indicates whether this import was referenced by generated code. used bool + // AST node and file set representing the import statement, if any. These + // are only non-nil if the import statement originates from an input source + // file. + spec *ast.ImportSpec + fset *token.FileSet } func newImport(p string) *importStmt { @@ -310,14 +295,27 @@ func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { name: name, path: p, aliased: spec.Name != nil, + spec: spec, + fset: f, } } +// String implements fmt.Stringer.String. This generates a string for the import +// statement appropriate for writing directly to generated code. func (i *importStmt) String() string { if i.aliased { - return fmt.Sprintf("%s \"%s\"", i.name, i.path) + return fmt.Sprintf("%s %q", i.name, i.path) + } + return fmt.Sprintf("%q", i.path) +} + +// debugString returns a debug string representing an import statement. This +// representation is not valid golang code and is used for debugging output. +func (i *importStmt) debugString() string { + if i.spec != nil && i.fset != nil { + return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i) } - return fmt.Sprintf("\"%s\"", i.path) + return fmt.Sprintf("(go-marshal import): %s", i) } func (i *importStmt) markUsed() { @@ -329,40 +327,78 @@ func (i *importStmt) equivalent(other *importStmt) bool { } // importTable represents a collection of importStmts. +// +// An importTable may contain multiple import statements referencing the same +// local name. All import statements aliasing to the same local name are +// technically ambiguous, as if such an import name is used in the generated +// code, it's not clear which import statement it refers to. We ignore any +// potential collisions until actually writing the import table to the generated +// source file. See importTable.write. +// +// Given the following import statements across all the files comprising a +// package marshalled: +// +// "sync" +// "pkg/sync" +// "pkg/sentry/kernel" +// ktime "pkg/sentry/kernel/time" +// +// An importTable representing them would look like this: +// +// importTable { +// is: map[string][]*importStmt { +// "sync": []*importStmt{ +// importStmt{name:"sync", path:"sync", aliased:false} +// importStmt{name:"sync", path:"pkg/sync", aliased:false} +// }, +// "kernel": []*importStmt{importStmt{ +// name: "kernel", +// path: "pkg/sentry/kernel", +// aliased: false +// }}, +// "ktime": []*importStmt{importStmt{ +// name: "ktime", +// path: "pkg/sentry/kernel/time", +// aliased: true, +// }}, +// } +// } +// +// Note that the local name "sync" is assigned to two different import +// statements. This is possible if the import statements are from different +// source files in the same package. +// +// Since go-marshal generates a single output file per package regardless of the +// number of input files, if "sync" is referenced by any generated code, it's +// unclear which import statement "sync" refers to. While it's theoretically +// possible to resolve this by assigning a unique local alias to each instance +// of the sync package, go-marshal currently aborts when it encounters such an +// ambiguity. +// +// TODO(b/151478251): importTable considers the final component of an import +// path to be the package name, but this is only a convention. The actual +// package name is determined by the package statement in the source files for +// the package. type importTable struct { // Map of imports and whether they should be copied to the output. - is map[string]*importStmt + is map[string][]*importStmt } func newImportTable() *importTable { return &importTable{ - is: make(map[string]*importStmt), + is: make(map[string][]*importStmt), } } -// Merges import statements from other into i. Collisions in import statements -// result in a panic. +// Merges import statements from other into i. func (i *importTable) merge(other *importTable) { - for name, im := range other.is { - if dup, ok := i.is[name]; ok && !dup.equivalent(im) { - panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im)) - } - - i.is[name] = im + for name, ims := range other.is { + i.is[name] = append(i.is[name], ims...) } } func (i *importTable) addStmt(s *importStmt) *importStmt { - if old, ok := i.is[s.name]; ok && !old.equivalent(s) { - // A collision should always be between an import inserted by the - // go-marshal tool and an import from the original source file (assuming - // the original source file was valid). We could theoretically handle - // the collision by assigning a local name to our import. However, this - // would need to be plumbed throughout the generator. Given that - // collisions should be rare, simply panic on collision. - panic(fmt.Sprintf("Import collision: old: %s as %v; new: %v as %v", old.path, old.name, s.path, s.name)) - } - i.is[s.name] = s + i.is[s.name] = append(i.is[s.name], s) return s } @@ -378,16 +414,20 @@ func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *impor // Marks the import named n as used. If no such import is in the table, returns // false. func (i *importTable) markUsed(n string) bool { - if n, ok := i.is[n]; ok { - n.markUsed() + if ns, ok := i.is[n]; ok { + for _, n := range ns { + n.markUsed() + } return true } return false } func (i *importTable) clear() { - for _, i := range i.is { - i.used = false + for _, is := range i.is { + for _, i := range is { + i.used = false + } } } @@ -398,9 +438,42 @@ func (i *importTable) write(out io.Writer) error { } imports := make([]string, 0, len(i.is)) - for _, i := range i.is { - if i.used { - imports = append(imports, i.String()) + for name, is := range i.is { + var lastUsed *importStmt + var ambiguous bool + + for _, i := range is { + if i.used { + if lastUsed != nil { + if !i.equivalent(lastUsed) { + ambiguous = true + } + } + lastUsed = i + } + } + + if ambiguous { + // We have two or more import statements across the different source + // files that share a local name, and at least one of these imports + // are used by the generated code. This ambiguity can't be resolved + // by go-marshal and requires the user intervention. Dump a list of + // the colliding import statements and let the user modify the input + // files as appropriate. + var b strings.Builder + fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name) + fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name) + // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't + // be true. Therefore the slicing below is safe. + for _, i := range is[:len(is)-1] { + fmt.Fprintf(&b, " %v\n", i.debugString()) + } + fmt.Fprintf(&b, " %v", is[len(is)-1].debugString()) + panic(b.String()) + } + + if lastUsed != nil { + imports = append(imports, lastUsed.String()) } } sort.Strings(imports) diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go index f129788e0..cb2166252 100644 --- a/tools/go_marshal/marshal/marshal.go +++ b/tools/go_marshal/marshal/marshal.go @@ -42,7 +42,11 @@ type Task interface { CopyInBytes(addr usermem.Addr, b []byte) (int, error) } -// Marshallable represents a type that can be marshalled to and from memory. +// Marshallable represents operations on a type that can be marshalled to and +// from memory. +// +// go-marshal automatically generates implementations for this interface for +// types marked as '+marshal'. type Marshallable interface { io.WriterTo @@ -54,12 +58,18 @@ type Marshallable interface { // likely make use of the type of these fields). SizeBytes() int - // MarshalBytes serializes a copy of a type to dst. dst must be at least - // SizeBytes() long. + // MarshalBytes serializes a copy of a type to dst. dst may be smaller than + // SizeBytes(), which results in a part of the struct being marshalled. Note + // that this may have unexpected results for non-packed types, as implicit + // padding needs to be taken into account when reasoning about how much of + // the type is serialized. MarshalBytes(dst []byte) - // UnmarshalBytes deserializes a type from src. src must be at least - // SizeBytes() long. + // UnmarshalBytes deserializes a type from src. src may be smaller than + // SizeBytes(), which results in a partially deserialized struct. Note that + // this may have unexpected results for non-packed types, as implicit + // padding needs to be taken into account when reasoning about how much of + // the type is deserialized. UnmarshalBytes(src []byte) // Packed returns true if the marshalled size of the type is the same as the @@ -67,13 +77,20 @@ type Marshallable interface { // starting at unaligned addresses (should always be true by default for ABI // structs, verified by automatically generated tests when using // go_marshal), and has no fields marked `marshal:"unaligned"`. + // + // Packed must return the same result for all possible values of the type + // implementing it. Violating this constraint implies the type doesn't have + // a static memory layout, and will lead to memory corruption. + // Go-marshal-generated code reuses the result of Packed for multiple values + // of the same type. Packed() bool // MarshalUnsafe serializes a type by bulk copying its in-memory // representation to the dst buffer. This is only safe to do when the type // has no implicit padding, see Marshallable.Packed. When Packed would // return false, MarshalUnsafe should fall back to the safer but slower - // MarshalBytes. + // MarshalBytes. dst may be smaller than SizeBytes(), see comment for + // MarshalBytes for implications. MarshalUnsafe(dst []byte) // UnmarshalUnsafe deserializes a type by directly copying to the underlying @@ -82,7 +99,8 @@ type Marshallable interface { // This allows much faster unmarshalling of types which have no implicit // padding, see Marshallable.Packed. When Packed would return false, // UnmarshalUnsafe should fall back to the safer but slower unmarshal - // mechanism implemented in UnmarshalBytes. + // mechanism implemented in UnmarshalBytes. src may be smaller than + // SizeBytes(), see comment for UnmarshalBytes for implications. UnmarshalUnsafe(src []byte) // CopyIn deserializes a Marshallable type from a task's memory. This may @@ -91,12 +109,79 @@ type Marshallable interface { // marshalled does not escape. The implementation should avoid creating // extra copies in memory by directly deserializing to the object's // underlying memory. - CopyIn(task Task, addr usermem.Addr) error + // + // If the copy-in from the task memory is only partially successful, CopyIn + // should still attempt to deserialize as much data as possible. See comment + // for UnmarshalBytes. + CopyIn(task Task, addr usermem.Addr) (int, error) // CopyOut serializes a Marshallable type to a task's memory. This may only // be called from a task goroutine. This is more efficient than calling // MarshalUnsafe on Marshallable.Packed types, as the type being serialized // does not escape. The implementation should avoid creating extra copies in // memory by directly serializing from the object's underlying memory. - CopyOut(task Task, addr usermem.Addr) error + // + // The copy-out to the task memory may be partially successful, in which + // case CopyOut returns how much data was serialized. See comment for + // MarshalBytes for implications. + CopyOut(task Task, addr usermem.Addr) (int, error) + + // CopyOutN is like CopyOut, but explicitly requests a partial + // copy-out. Note that this may yield unexpected results for non-packed + // types and the caller may only want to allow this for packed types. See + // comment on MarshalBytes. + // + // The limit must be less than or equal to SizeBytes(). + CopyOutN(task Task, addr usermem.Addr, limit int) (int, error) } + +// go-marshal generates additional functions for a type based on additional +// clauses to the +marshal directive. They are documented below. +// +// Slice API +// ========= +// +// Adding a "slice" clause to the +marshal directive for structs or newtypes on +// primitives like this: +// +// // +marshal slice:FooSlice +// type Foo struct { ... } +// +// Generates four additional functions for marshalling slices of Foos like this: +// +// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It's +// // more efficient that repeatedly calling calling Foo.MarshalUnsafe over a +// // []Foo in a loop. +// func MarshalUnsafeFooSlice(src []Foo, dst []byte) (int, error) { ... } +// +// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It's +// // more efficient that repeatedly calling calling Foo.UnmarshalUnsafe over a +// // []Foo in a loop. +// func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... } +// +// // CopyFooSliceIn copies in a slice of Foo objects from the task's memory. +// func CopyFooSliceIn(task marshal.Task, addr usermem.Addr, dst []Foo) (int, error) { ... } +// +// // CopyFooSliceIn copies out a slice of Foo objects to the task's memory. +// func CopyFooSliceOut(task marshal.Task, addr usermem.Addr, src []Foo) (int, error) { ... } +// +// The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where +// %s is the first argument to the slice clause. This directive is not supported +// for newtypes on arrays. +// +// The slice clause also takes an optional second argument, which must be the +// value "inner": +// +// // +marshal slice:Int32Slice:inner +// type Int32 int32 +// +// This is only valid on newtypes on primitives, and causes the generated +// functions to accept slices of the inner type instead: +// +// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []int32) (int, error) { ... } +// +// Without "inner", they would instead be: +// +// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []Int32) (int, error) { ... } +// +// This may help avoid a cast depending on how the generated functions are used. diff --git a/tools/go_marshal/primitive/BUILD b/tools/go_marshal/primitive/BUILD new file mode 100644 index 000000000..cc08ba63a --- /dev/null +++ b/tools/go_marshal/primitive/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "primitive", + srcs = [ + "primitive.go", + ], + marshal = True, + visibility = [ + "//:sandbox", + ], + deps = [ + "//pkg/usermem", + "//tools/go_marshal/marshal", + ], +) diff --git a/tools/go_marshal/primitive/primitive.go b/tools/go_marshal/primitive/primitive.go new file mode 100644 index 000000000..ebcf130ae --- /dev/null +++ b/tools/go_marshal/primitive/primitive.go @@ -0,0 +1,175 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package primitive defines marshal.Marshallable implementations for primitive +// types. +package primitive + +import ( + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" +) + +// Int16 is a marshal.Marshallable implementation for int16. +// +// +marshal slice:Int16Slice:inner +type Int16 int16 + +// Uint16 is a marshal.Marshallable implementation for uint16. +// +// +marshal slice:Uint16Slice:inner +type Uint16 uint16 + +// Int32 is a marshal.Marshallable implementation for int32. +// +// +marshal slice:Int32Slice:inner +type Int32 int32 + +// Uint32 is a marshal.Marshallable implementation for uint32. +// +// +marshal slice:Uint32Slice:inner +type Uint32 uint32 + +// Int64 is a marshal.Marshallable implementation for int64. +// +// +marshal slice:Int64Slice:inner +type Int64 int64 + +// Uint64 is a marshal.Marshallable implementation for uint64. +// +// +marshal slice:Uint64Slice:inner +type Uint64 uint64 + +// Below, we define some convenience functions for marshalling primitive types +// using the newtypes above, without requiring superfluous casts. + +// 16-bit integers + +// CopyInt16In is a convenient wrapper for copying in an int16 from the task's +// memory. +func CopyInt16In(task marshal.Task, addr usermem.Addr, dst *int16) (int, error) { + var buf Int16 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = int16(buf) + return n, nil +} + +// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's +// memory. +func CopyInt16Out(task marshal.Task, addr usermem.Addr, src int16) (int, error) { + srcP := Int16(src) + return srcP.CopyOut(task, addr) +} + +// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's +// memory. +func CopyUint16In(task marshal.Task, addr usermem.Addr, dst *uint16) (int, error) { + var buf Uint16 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = uint16(buf) + return n, nil +} + +// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's +// memory. +func CopyUint16Out(task marshal.Task, addr usermem.Addr, src uint16) (int, error) { + srcP := Uint16(src) + return srcP.CopyOut(task, addr) +} + +// 32-bit integers + +// CopyInt32In is a convenient wrapper for copying in an int32 from the task's +// memory. +func CopyInt32In(task marshal.Task, addr usermem.Addr, dst *int32) (int, error) { + var buf Int32 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = int32(buf) + return n, nil +} + +// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's +// memory. +func CopyInt32Out(task marshal.Task, addr usermem.Addr, src int32) (int, error) { + srcP := Int32(src) + return srcP.CopyOut(task, addr) +} + +// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's +// memory. +func CopyUint32In(task marshal.Task, addr usermem.Addr, dst *uint32) (int, error) { + var buf Uint32 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = uint32(buf) + return n, nil +} + +// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's +// memory. +func CopyUint32Out(task marshal.Task, addr usermem.Addr, src uint32) (int, error) { + srcP := Uint32(src) + return srcP.CopyOut(task, addr) +} + +// 64-bit integers + +// CopyInt64In is a convenient wrapper for copying in an int64 from the task's +// memory. +func CopyInt64In(task marshal.Task, addr usermem.Addr, dst *int64) (int, error) { + var buf Int64 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = int64(buf) + return n, nil +} + +// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's +// memory. +func CopyInt64Out(task marshal.Task, addr usermem.Addr, src int64) (int, error) { + srcP := Int64(src) + return srcP.CopyOut(task, addr) +} + +// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's +// memory. +func CopyUint64In(task marshal.Task, addr usermem.Addr, dst *uint64) (int, error) { + var buf Uint64 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = uint64(buf) + return n, nil +} + +// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's +// memory. +func CopyUint64Out(task marshal.Task, addr usermem.Addr, src uint64) (int, error) { + srcP := Uint64(src) + return srcP.CopyOut(task, addr) +} diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index f27c5ce52..2fbcc8a03 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_binary", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") licenses(["notice"]) @@ -25,17 +25,20 @@ go_library( testonly = 1, srcs = ["test.go"], marshal = True, + visibility = ["//tools/go_marshal/test:__subpackages__"], deps = ["//tools/go_marshal/test/external"], ) -go_binary( - name = "escape", - testonly = 1, - srcs = ["escape.go"], - gc_goopts = ["-m"], +go_test( + name = "marshal_test", + size = "small", + srcs = ["marshal_test.go"], deps = [ ":test", + "//pkg/syserror", "//pkg/usermem", + "//tools/go_marshal/analysis", "//tools/go_marshal/marshal", + "@com_github_google_go-cmp//cmp:go_default_library", ], ) diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go index c79defe9e..224d308c7 100644 --- a/tools/go_marshal/test/benchmark_test.go +++ b/tools/go_marshal/test/benchmark_test.go @@ -176,3 +176,45 @@ func BenchmarkGoMarshalUnsafe(b *testing.B) { panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) } } + +func BenchmarkBinarySlice(b *testing.B) { + var s1, s2 [64]test.Stat + analysis.RandomizeValue(&s1) + + size := binary.Size(s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, 0, size) + buf = binary.Marshal(buf, usermem.ByteOrder, &s1) + binary.Unmarshal(buf, usermem.ByteOrder, &s2) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +func BenchmarkGoMarshalUnsafeSlice(b *testing.B) { + var s1, s2 [64]test.Stat + analysis.RandomizeValue(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, (*test.Stat)(nil).SizeBytes()*len(s1)) + test.MarshalUnsafeStatSlice(s1[:], buf) + test.UnmarshalUnsafeStatSlice(s2[:], buf) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} diff --git a/tools/go_marshal/test/escape.go b/tools/go_marshal/test/escape.go deleted file mode 100644 index 184f05ea3..000000000 --- a/tools/go_marshal/test/escape.go +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This binary provides a convienient target for analyzing how the go-marshal -// API causes its various arguments to escape to the heap. To use, build and -// observe the output from the go compiler's escape analysis: -// -// $ bazel build :escape -// ... -// escape.go:67:2: moved to heap: task -// escape.go:77:31: make([]byte, size) escapes to heap -// escape.go:87:31: make([]byte, size) escapes to heap -// escape.go:96:6: moved to heap: stat -// ... -// -// This is not an automated test, but simply a minimal binary for easy analysis. -package main - -import ( - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/test" -) - -// dummyTask implements marshal.Task. -type dummyTask struct { -} - -func (*dummyTask) CopyScratchBuffer(size int) []byte { - return make([]byte, size) -} - -func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { - return len(b), nil -} - -func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { - return len(b), nil -} - -func (task *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) { - buf := task.CopyScratchBuffer(marshallable.SizeBytes()) - marshallable.MarshalBytes(buf) - task.CopyOutBytes(addr, buf) -} - -func (task *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) { - buf := task.CopyScratchBuffer(marshallable.SizeBytes()) - marshallable.MarshalUnsafe(buf) - task.CopyOutBytes(addr, buf) -} - -// Expected escapes: -// - task: passed to marshal.Marshallable.CopyOut as the marshal.Task interface. -func doCopyOut() { - task := dummyTask{} - var stat test.Stat - stat.CopyOut(&task, usermem.Addr(0xf000ba12)) -} - -// Expected escapes: -// - buf: make allocates on the heap. -func doMarshalBytesDirect() { - task := dummyTask{} - var stat test.Stat - buf := task.CopyScratchBuffer(stat.SizeBytes()) - stat.MarshalBytes(buf) - task.CopyOutBytes(usermem.Addr(0xf000ba12), buf) -} - -// Expected escapes: -// - buf: make allocates on the heap. -func doMarshalUnsafeDirect() { - task := dummyTask{} - var stat test.Stat - buf := task.CopyScratchBuffer(stat.SizeBytes()) - stat.MarshalUnsafe(buf) - task.CopyOutBytes(usermem.Addr(0xf000ba12), buf) -} - -// Expected escapes: -// - stat: passed to dummyTask.MarshalBytes as the marshal.Marshallable interface. -func doMarshalBytesViaMarshallable() { - task := dummyTask{} - var stat test.Stat - task.MarshalBytes(usermem.Addr(0xf000ba12), &stat) -} - -// Expected escapes: -// - stat: passed to dummyTask.MarshalUnsafe as the marshal.Marshallable interface. -func doMarshalUnsafeViaMarshallable() { - task := dummyTask{} - var stat test.Stat - task.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat) -} - -func main() { - doCopyOut() - doMarshalBytesDirect() - doMarshalUnsafeDirect() - doMarshalBytesViaMarshallable() - doMarshalUnsafeViaMarshallable() -} diff --git a/tools/go_marshal/test/escape/BUILD b/tools/go_marshal/test/escape/BUILD new file mode 100644 index 000000000..f74e6ffae --- /dev/null +++ b/tools/go_marshal/test/escape/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "escape", + testonly = 1, + srcs = ["escape.go"], + deps = [ + "//pkg/usermem", + "//tools/go_marshal/marshal", + "//tools/go_marshal/test", + ], +) diff --git a/tools/go_marshal/test/escape/escape.go b/tools/go_marshal/test/escape/escape.go new file mode 100644 index 000000000..6a46ddbf8 --- /dev/null +++ b/tools/go_marshal/test/escape/escape.go @@ -0,0 +1,95 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package escape + +import ( + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/test" +) + +// dummyTask implements marshal.Task. +type dummyTask struct { +} + +func (*dummyTask) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { + return len(b), nil +} + +func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { + return len(b), nil +} + +func (t *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) { + buf := t.CopyScratchBuffer(marshallable.SizeBytes()) + marshallable.MarshalBytes(buf) + t.CopyOutBytes(addr, buf) +} + +func (t *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) { + buf := t.CopyScratchBuffer(marshallable.SizeBytes()) + marshallable.MarshalUnsafe(buf) + t.CopyOutBytes(addr, buf) +} + +// +checkescape:all +//go:nosplit +func doCopyIn(t *dummyTask) { + var stat test.Stat + stat.CopyIn(t, usermem.Addr(0xf000ba12)) +} + +// +checkescape:all +//go:nosplit +func doCopyOut(t *dummyTask) { + var stat test.Stat + stat.CopyOut(t, usermem.Addr(0xf000ba12)) +} + +// +mustescape:builtin +// +mustescape:stack +func doMarshalBytesDirect(t *dummyTask) { + var stat test.Stat + buf := t.CopyScratchBuffer(stat.SizeBytes()) + stat.MarshalBytes(buf) + t.CopyOutBytes(usermem.Addr(0xf000ba12), buf) +} + +// +mustescape:builtin +// +mustescape:stack +func doMarshalUnsafeDirect(t *dummyTask) { + var stat test.Stat + buf := t.CopyScratchBuffer(stat.SizeBytes()) + stat.MarshalUnsafe(buf) + t.CopyOutBytes(usermem.Addr(0xf000ba12), buf) +} + +// +mustescape:local,heap +// +mustescape:stack +func doMarshalBytesViaMarshallable(t *dummyTask) { + var stat test.Stat + t.MarshalBytes(usermem.Addr(0xf000ba12), &stat) +} + +// +mustescape:local,heap +// +mustescape:stack +func doMarshalUnsafeViaMarshallable(t *dummyTask) { + var stat test.Stat + t.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat) +} diff --git a/tools/go_marshal/test/external/external.go b/tools/go_marshal/test/external/external.go index 4be3722f3..26fe8e0c8 100644 --- a/tools/go_marshal/test/external/external.go +++ b/tools/go_marshal/test/external/external.go @@ -21,3 +21,11 @@ package external type External struct { j int64 } + +// NotPacked is an unaligned Marshallable type for use in testing. +// +// +marshal +type NotPacked struct { + a int32 + b byte `marshal:"unaligned"` +} diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go new file mode 100644 index 000000000..16829ee45 --- /dev/null +++ b/tools/go_marshal/test/marshal_test.go @@ -0,0 +1,515 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package marshal_test contains manual tests for the marshal interface. These +// are intended to test behaviour not covered by the automatically generated +// tests. +package marshal_test + +import ( + "bytes" + "encoding/binary" + "fmt" + "reflect" + "runtime" + "testing" + "unsafe" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/analysis" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/test" +) + +var simulatedErr error = syserror.EFAULT + +// mockTask implements marshal.Task. +type mockTask struct { + taskMem usermem.BytesIO +} + +// populate fills the task memory with the contents of val. +func (t *mockTask) populate(val interface{}) { + var buf bytes.Buffer + // Use binary.Write so we aren't testing go-marshal against its own + // potentially buggy implementation. + if err := binary.Write(&buf, usermem.ByteOrder, val); err != nil { + panic(err) + } + t.taskMem.Bytes = buf.Bytes() +} + +func (t *mockTask) setLimit(n int) { + if len(t.taskMem.Bytes) < n { + grown := make([]byte, n) + copy(grown, t.taskMem.Bytes) + t.taskMem.Bytes = grown + return + } + t.taskMem.Bytes = t.taskMem.Bytes[:n] +} + +// CopyScratchBuffer implements marshal.Task.CopyScratchBuffer. +func (t *mockTask) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +// CopyOutBytes implements marshal.Task.CopyOutBytes. The implementation +// completely ignores the target address and stores a copy of b in its +// internally buffer, overriding any previous contents. +func (t *mockTask) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) { + return t.taskMem.CopyOut(nil, 0, b, usermem.IOOpts{}) +} + +// CopyInBytes implements marshal.Task.CopyInBytes. The implementation +// completely ignores the source address and always fills b from the begining of +// its internal buffer. +func (t *mockTask) CopyInBytes(_ usermem.Addr, b []byte) (int, error) { + return t.taskMem.CopyIn(nil, 0, b, usermem.IOOpts{}) +} + +// unsafeMemory returns the underlying memory for m. The returned slice is only +// valid for the lifetime for m. The garbage collector isn't aware that the +// returned slice is related to m, the caller must ensure m lives long enough. +func unsafeMemory(m marshal.Marshallable) []byte { + if !m.Packed() { + // We can't return a slice pointing to the underlying memory + // since the layout isn't packed. Allocate a temporary buffer + // and marshal instead. + var buf bytes.Buffer + if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil { + panic(err) + } + return buf.Bytes() + } + + // reflect.ValueOf(m) + // .Elem() // Unwrap interface to inner concrete object + // .Addr() // Pointer value to object + // .Pointer() // Actual address from the pointer value + ptr := reflect.ValueOf(m).Elem().Addr().Pointer() + + size := m.SizeBytes() + + var mem []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem)) + hdr.Data = ptr + hdr.Len = size + hdr.Cap = size + + return mem +} + +// unsafeMemorySlice returns the underlying memory for m. The returned slice is +// only valid for the lifetime for m. The garbage collector isn't aware that the +// returned slice is related to m, the caller must ensure m lives long enough. +// +// Precondition: m must be a slice. +func unsafeMemorySlice(m interface{}, elt marshal.Marshallable) []byte { + kind := reflect.TypeOf(m).Kind() + if kind != reflect.Slice { + panic("unsafeMemorySlice called on non-slice") + } + + if !elt.Packed() { + // We can't return a slice pointing to the underlying memory + // since the layout isn't packed. Allocate a temporary buffer + // and marshal instead. + var buf bytes.Buffer + if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil { + panic(err) + } + return buf.Bytes() + } + + v := reflect.ValueOf(m) + length := v.Len() * elt.SizeBytes() + + var mem []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem)) + hdr.Data = v.Pointer() // This is a pointer to the first elem for slices. + hdr.Len = length + hdr.Cap = length + + return mem +} + +func isZeroes(buf []byte) bool { + for _, b := range buf { + if b != 0 { + return false + } + } + return true +} + +// compareMemory compares the first n bytes of two chuncks of memory represented +// by expected and actual. +func compareMemory(t *testing.T, expected, actual []byte, n int) { + t.Logf("Expected (%d): %v (%d) + (%d) %v\n", len(expected), expected[:n], n, len(expected)-n, expected[n:]) + t.Logf("Actual (%d): %v (%d) + (%d) %v\n", len(actual), actual[:n], n, len(actual)-n, actual[n:]) + + if diff := cmp.Diff(expected[:n], actual[:n]); diff != "" { + t.Errorf("Memory buffers don't match:\n--- expected only\n+++ actual only\n%v", diff) + } +} + +// limitedCopyIn populates task memory with src, then unmarshals task memory to +// dst. The task signals an error at limit bytes during copy-in, which should +// result in a truncated unmarshalling. +func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) { + var task mockTask + task.populate(src) + task.setLimit(limit) + + n, err := dst.CopyIn(&task, usermem.Addr(0)) + if n != limit { + t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) + } + if err != simulatedErr { + t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err) + } + + expectedMem := unsafeMemory(src) + defer runtime.KeepAlive(src) + actualMem := unsafeMemory(dst) + defer runtime.KeepAlive(dst) + + compareMemory(t, expectedMem, actualMem, n) + + // The last n bytes should be zero for actual, since actual was + // zero-initialized, and CopyIn shouldn't have touched those bytes. However + // we can only guarantee we didn't touch anything in the last n bytes if the + // layout is packed. + if dst.Packed() && !isZeroes(actualMem[n:]) { + t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", dst.SizeBytes()-n, actualMem) + } +} + +// limitedCopyOut marshals src to task memory. The task signals an error at +// limit bytes during copy-out, which should result in a truncated marshalling. +func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) { + var task mockTask + task.setLimit(limit) + + n, err := src.CopyOut(&task, usermem.Addr(0)) + if n != limit { + t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n) + } + if err != simulatedErr { + t.Errorf("CopyOut returned unexpected error, expected %v, got %v", simulatedErr, err) + } + + expectedMem := unsafeMemory(src) + defer runtime.KeepAlive(src) + actualMem := task.taskMem.Bytes + + compareMemory(t, expectedMem, actualMem, n) +} + +// copyOutN marshals src to task memory, requesting the marshalling to be +// limited to limit bytes. +func copyOutN(t *testing.T, src marshal.Marshallable, limit int) { + var task mockTask + task.setLimit(limit) + + n, err := src.CopyOutN(&task, usermem.Addr(0), limit) + if err != nil { + t.Errorf("CopyOut returned unexpected error: %v", err) + } + if n != limit { + t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n) + } + + expectedMem := unsafeMemory(src) + defer runtime.KeepAlive(src) + actualMem := task.taskMem.Bytes + + t.Logf("Expected: %v + %v\n", expectedMem[:n], expectedMem[n:]) + t.Logf("Actual : %v + %v\n", actualMem[:n], actualMem[n:]) + + compareMemory(t, expectedMem, actualMem, n) +} + +// TestLimitedMarshalling verifies marshalling/unmarshalling succeeds when the +// underyling copy in/out operations partially succeed. +func TestLimitedMarshalling(t *testing.T) { + types := []reflect.Type{ + // Packed types. + reflect.TypeOf((*test.Type2)(nil)), + reflect.TypeOf((*test.Type3)(nil)), + reflect.TypeOf((*test.Timespec)(nil)), + reflect.TypeOf((*test.Stat)(nil)), + reflect.TypeOf((*test.InetAddr)(nil)), + reflect.TypeOf((*test.SignalSet)(nil)), + reflect.TypeOf((*test.SignalSetAlias)(nil)), + // Non-packed types. + reflect.TypeOf((*test.Type1)(nil)), + reflect.TypeOf((*test.Type4)(nil)), + reflect.TypeOf((*test.Type5)(nil)), + reflect.TypeOf((*test.Type6)(nil)), + reflect.TypeOf((*test.Type7)(nil)), + reflect.TypeOf((*test.Type8)(nil)), + } + + for _, tyPtr := range types { + // Remove one level of pointer-indirection from the type. We get this + // back when we pass the type to reflect.New. + ty := tyPtr.Elem() + + // Partial copy-in. + t.Run(fmt.Sprintf("PartialCopyIn_%v", ty), func(t *testing.T) { + expected := reflect.New(ty).Interface().(marshal.Marshallable) + actual := reflect.New(ty).Interface().(marshal.Marshallable) + analysis.RandomizeValue(expected) + + limitedCopyIn(t, expected, actual, expected.SizeBytes()/2) + }) + + // Partial copy-out. + t.Run(fmt.Sprintf("PartialCopyOut_%v", ty), func(t *testing.T) { + expected := reflect.New(ty).Interface().(marshal.Marshallable) + analysis.RandomizeValue(expected) + + limitedCopyOut(t, expected, expected.SizeBytes()/2) + }) + + // Explicitly request partial copy-out. + t.Run(fmt.Sprintf("PartialCopyOutN_%v", ty), func(t *testing.T) { + expected := reflect.New(ty).Interface().(marshal.Marshallable) + analysis.RandomizeValue(expected) + + copyOutN(t, expected, expected.SizeBytes()/2) + }) + } +} + +// TestLimitedMarshalling verifies marshalling/unmarshalling of slices of +// marshallable types succeed when the underyling copy in/out operations +// partially succeed. +func TestLimitedSliceMarshalling(t *testing.T) { + types := []struct { + arrayPtrType reflect.Type + copySliceIn func(task marshal.Task, addr usermem.Addr, dstSlice interface{}) (int, error) + copySliceOut func(task marshal.Task, addr usermem.Addr, srcSlice interface{}) (int, error) + unsafeMemory func(arrPtr interface{}) []byte + }{ + // Packed types. + { + reflect.TypeOf((*[20]test.Stat)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[20]test.Stat)[:] + return test.CopyStatSliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[20]test.Stat)[:] + return test.CopyStatSliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[20]test.Stat)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + { + reflect.TypeOf((*[1]test.Stat)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[1]test.Stat)[:] + return test.CopyStatSliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[1]test.Stat)[:] + return test.CopyStatSliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[1]test.Stat)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + { + reflect.TypeOf((*[5]test.SignalSetAlias)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[5]test.SignalSetAlias)[:] + return test.CopySignalSetAliasSliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[5]test.SignalSetAlias)[:] + return test.CopySignalSetAliasSliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[5]test.SignalSetAlias)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + // Non-packed types. + { + reflect.TypeOf((*[20]test.Type1)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[20]test.Type1)[:] + return test.CopyType1SliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[20]test.Type1)[:] + return test.CopyType1SliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[20]test.Type1)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + { + reflect.TypeOf((*[1]test.Type1)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[1]test.Type1)[:] + return test.CopyType1SliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[1]test.Type1)[:] + return test.CopyType1SliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[1]test.Type1)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + { + reflect.TypeOf((*[7]test.Type8)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[7]test.Type8)[:] + return test.CopyType8SliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[7]test.Type8)[:] + return test.CopyType8SliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[7]test.Type8)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + } + + for _, tt := range types { + // The body of this loop is generic over the type tt.arrayPtrType, with + // the help of reflection. To aid in readability, comments below show + // the equivalent go code assuming + // tt.arrayPtrType = typeof(*[20]test.Stat). + + // Equivalent: + // var x *[20]test.Stat + // arrayTy := reflect.TypeOf(*x) + arrayTy := tt.arrayPtrType.Elem() + + // Partial copy-in of slices. + t.Run(fmt.Sprintf("PartialCopySliceIn_%v", arrayTy), func(t *testing.T) { + // Equivalent: + // var x [20]test.Stat + // length := len(x) + length := arrayTy.Len() + if length < 1 { + panic("Test type can't be zero-length array") + } + // Equivalent: + // elem := new(test.Stat).(marshal.Marshallable) + elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable) + + // Equivalent: + // var expected, actual interface{} + // expected = new([20]test.Stat) + // actual = new([20]test.Stat) + expected := reflect.New(arrayTy).Interface() + actual := reflect.New(arrayTy).Interface() + + analysis.RandomizeValue(expected) + + limit := (length * elem.SizeBytes()) / 2 + // Also make sure the limit is partially inside one of the elements. + limit += elem.SizeBytes() / 2 + analysis.RandomizeValue(expected) + + var task mockTask + task.populate(expected) + task.setLimit(limit) + + n, err := tt.copySliceIn(&task, usermem.Addr(0), actual) + if n != limit { + t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) + } + if n < length*elem.SizeBytes() && err != simulatedErr { + t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err) + } + + expectedMem := tt.unsafeMemory(expected) + defer runtime.KeepAlive(expected) + actualMem := tt.unsafeMemory(actual) + defer runtime.KeepAlive(actual) + + compareMemory(t, expectedMem, actualMem, n) + + // The last n bytes should be zero for actual, since actual was + // zero-initialized, and CopyIn shouldn't have touched those bytes. However + // we can only guarantee we didn't touch anything in the last n bytes if the + // layout is packed. + if elem.Packed() && !isZeroes(actualMem[n:]) { + t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", (elem.SizeBytes()*length)-n, actualMem) + } + }) + + // Partial copy-out of slices. + t.Run(fmt.Sprintf("PartialCopySliceOut_%v", arrayTy), func(t *testing.T) { + // Equivalent: + // var x [20]test.Stat + // length := len(x) + length := arrayTy.Len() + if length < 1 { + panic("Test type can't be zero-length array") + } + // Equivalent: + // elem := new(test.Stat).(marshal.Marshallable) + elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable) + + // Equivalent: + // var expected, actual interface{} + // expected = new([20]test.Stat) + // actual = new([20]test.Stat) + expected := reflect.New(arrayTy).Interface() + + analysis.RandomizeValue(expected) + + limit := (length * elem.SizeBytes()) / 2 + // Also make sure the limit is partially inside one of the elements. + limit += elem.SizeBytes() / 2 + analysis.RandomizeValue(expected) + + var task mockTask + task.populate(expected) + task.setLimit(limit) + + n, err := tt.copySliceOut(&task, usermem.Addr(0), expected) + if n != limit { + t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) + } + if n < length*elem.SizeBytes() && err != simulatedErr { + t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err) + } + + expectedMem := tt.unsafeMemory(expected) + defer runtime.KeepAlive(expected) + actualMem := task.taskMem.Bytes + + compareMemory(t, expectedMem, actualMem, n) + }) + } +} diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index c829db6da..f75ca1b7f 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -23,7 +23,7 @@ import ( // Type1 is a test data type. // -// +marshal +// +marshal slice:Type1Slice type Type1 struct { a Type2 x, y int64 // Multiple field names. @@ -75,6 +75,34 @@ type Type5 struct { m int64 } +// Type6 is a test data type ends mid-word. +// +// +marshal +type Type6 struct { + a int64 + b int64 + // If c isn't marked unaligned, analysis fails (as it should, since + // the unsafe API corrupts Type7). + c byte `marshal:"unaligned"` +} + +// Type7 is a test data type that contains a child struct that ends +// mid-word. +// +marshal +type Type7 struct { + x Type6 + y int64 +} + +// Type8 is a test data type which contains an external non-packed field. +// +// +marshal slice:Type8Slice +type Type8 struct { + a int64 + np ex.NotPacked + b int64 +} + // Timespec represents struct timespec in <time.h>. // // +marshal @@ -85,7 +113,7 @@ type Timespec struct { // Stat represents struct stat. // -// +marshal +// +marshal slice:StatSlice type Stat struct { Dev uint64 Ino uint64 @@ -111,10 +139,38 @@ type InetAddr [4]byte // SignalSet is an example marshallable newtype on a primitive. // -// +marshal +// +marshal slice:SignalSetSlice:inner type SignalSet uint64 // SignalSetAlias is an example newtype on another marshallable type. // -// +marshal +// +marshal slice:SignalSetAliasSlice type SignalSetAlias SignalSet + +const sizeA = 64 +const sizeB = 8 + +// TestArray is a test data structure on an array with a constant length. +// +// +marshal +type TestArray [sizeA]int32 + +// TestArray2 is a newtype on an array with a simple arithmetic expression of +// constants for the array length. +// +// +marshal +type TestArray2 [sizeA * sizeB]int32 + +// TestArray2 is a newtype on an array with a simple arithmetic expression of +// mixed constants and literals for the array length. +// +// +marshal +type TestArray3 [sizeA*sizeB + 12]int32 + +// Type9 is a test data type containing an array with a non-literal length. +// +// +marshal +type Type9 struct { + x int64 + y [sizeA]int32 +} diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index 3437aa476..309ee9c21 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -206,7 +206,7 @@ func main() { initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name)) } emitZeroCheck := func(name string) { - fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name) + fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { m.Failf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, name, name) } emitLoadValue := func(name, typName string) { fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName) diff --git a/tools/image_build.sh b/tools/image_build.sh deleted file mode 100755 index 9b20a740d..000000000 --- a/tools/image_build.sh +++ /dev/null @@ -1,98 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script is responsible for building a new GCP image that: 1) has nested -# virtualization enabled, and 2) has been completely set up with the -# image_setup.sh script. This script should be idempotent, as we memoize the -# setup script with a hash and check for that name. -# -# The GCP project name should be defined via a gcloud config. - -set -xeo pipefail - -# Parameters. -declare -r ZONE=${ZONE:-us-central1-f} -declare -r USERNAME=${USERNAME:-test} -declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud} -declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts} - -# Random names. -declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z) -declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z) -declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z) - -# Hashes inputs. -declare -r SETUP_BLOB=$(echo ${ZONE} ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && sha256sum "$@") -declare -r SETUP_HASH=$(echo ${SETUP_BLOB} | sha256sum - | cut -d' ' -f1 | cut -c 1-16) -declare -r IMAGE_NAME=${IMAGE_NAME:-image-}${SETUP_HASH} - -# Does the image already exist? Skip the build. -declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)") -if ! [[ -z "${existing}" ]]; then - echo "${existing}" - exit 0 -fi - -# Set the zone for all actions. -gcloud config set compute/zone "${ZONE}" - -# Start a unique instance. Note that this instance will have a unique persistent -# disk as it's boot disk with the same name as the instance. -gcloud compute instances create \ - --quiet \ - --image-project "${IMAGE_PROJECT}" \ - --image-family "${IMAGE_FAMILY}" \ - --boot-disk-size "200GB" \ - "${INSTANCE_NAME}" -function cleanup { - gcloud compute instances delete --quiet "${INSTANCE_NAME}" -} -trap cleanup EXIT - -# Wait for the instance to become available. -declare attempts=0 -while [[ "${attempts}" -lt 30 ]]; do - attempts=$((${attempts}+1)) - if gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- true; then - break - fi -done -if [[ "${attempts}" -ge 30 ]]; then - echo "too many attempts: failed" - exit 1 -fi - -# Run the install scripts provided. -for arg; do - gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}" -done - -# Stop the instance; required before creating an image. -gcloud compute instances stop --quiet "${INSTANCE_NAME}" - -# Create a snapshot of the instance disk. -gcloud compute disks snapshot \ - --quiet \ - --zone="${ZONE}" \ - --snapshot-names="${SNAPSHOT_NAME}" \ - "${INSTANCE_NAME}" - -# Create the disk image. -gcloud compute images create \ - --quiet \ - --source-snapshot="${SNAPSHOT_NAME}" \ - --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \ - "${IMAGE_NAME}" diff --git a/tools/installers/BUILD b/tools/installers/BUILD index d78a265ca..caa7b1983 100644 --- a/tools/installers/BUILD +++ b/tools/installers/BUILD @@ -17,6 +17,14 @@ sh_binary( ) sh_binary( + name = "images", + srcs = ["images.sh"], + data = [ + "//images", + ], +) + +sh_binary( name = "master", srcs = ["master.sh"], ) diff --git a/tools/installers/head.sh b/tools/installers/head.sh index 9de8f138c..7fc566ebd 100755 --- a/tools/installers/head.sh +++ b/tools/installers/head.sh @@ -15,7 +15,7 @@ # limitations under the License. # Install our runtime. -$(dirname $0)/runsc install +$(find . -executable -type f -name runsc) install # Restart docker. service docker restart || true diff --git a/tools/installers/images.sh b/tools/installers/images.sh new file mode 100755 index 000000000..52e750f57 --- /dev/null +++ b/tools/installers/images.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeuo pipefail + +# Find the images directory. +for images in $(find . -type d -name images); do + if [[ -f "${images}"/Makefile ]]; then + make -C "${images}" load-all-images + fi +done diff --git a/tools/make_repository.sh b/tools/make_repository.sh index 27ffbc9f3..c91fd283c 100755 --- a/tools/make_repository.sh +++ b/tools/make_repository.sh @@ -17,14 +17,37 @@ # Parse arguments. We require more than two arguments, which are the private # keyring, the e-mail associated with the signer, and the list of packages. if [ "$#" -le 3 ]; then - echo "usage: $0 <private-key> <signer-email> <component> <root> <packages...>" + echo "usage: $0 <private-key> <signer-email> <root> <packages...>" exit 1 fi declare -r private_key=$(readlink -e "$1"); shift declare -r signer="$1"; shift -declare -r component="$1"; shift declare -r root="$1"; shift +# Ensure that we have the correct packages installed. +function apt_install() { + while true; do + sudo apt-get update && + sudo apt-get install -y "$@" && + true + result="${?}" + case $result in + 0) + break + ;; + 100) + # 100 is the error code that apt-get returns. + ;; + *) + exit $result + ;; + esac + done +} +dpkg-sig --help >/dev/null || apt_install dpkg-sig +apt-ftparchive --help >/dev/null || apt_install apt-utils +xz --help >/dev/null || apt_install xz-utils + # Verbose from this point. set -xeo pipefail @@ -78,7 +101,7 @@ for dir in "${root}"/pool/*/binary-*; do name=$(basename "${dir}") arch=${name##binary-} arches+=("${arch}") - repo_packages="${tmpdir}"/"${component}"/"${name}" + repo_packages="${tmpdir}"/main/"${name}" mkdir -p "${repo_packages}" (cd "${root}" && apt-ftparchive --arch "${arch}" packages pool > "${repo_packages}"/Packages) (cd "${repo_packages}" && cat Packages | gzip > Packages.gz) @@ -91,7 +114,7 @@ APT { FTPArchive { Release { Architectures "${arches[@]}"; - Components "${component}"; + Components "main"; }; }; }; diff --git a/tools/nogo.js b/tools/nogo.js deleted file mode 100644 index fc0a4d1f0..000000000 --- a/tools/nogo.js +++ /dev/null @@ -1,7 +0,0 @@ -{ - "checkunsafe": { - "exclude_files": { - "/external/": "not subject to constraint" - } - } -} diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD new file mode 100644 index 000000000..c21b09511 --- /dev/null +++ b/tools/nogo/BUILD @@ -0,0 +1,49 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "nogo", + srcs = [ + "build.go", + "config.go", + "matchers.go", + "nogo.go", + "register.go", + ], + nogo = False, + visibility = ["//:sandbox"], + deps = [ + "//tools/checkescape", + "//tools/checkunsafe", + "//tools/nogo/data", + "@org_golang_x_tools//go/analysis:go_tool_library", + "@org_golang_x_tools//go/analysis/internal/facts:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/asmdecl:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/assign:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/atomic:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/bools:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/buildtag:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/cgocall:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/composite:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/copylock:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/errorsas:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/httpresponse:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/loopclosure:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/lostcancel:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/nilfunc:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/nilness:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/printf:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/shadow:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/shift:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/stdmethods:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/stringintconv:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/structtag:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/tests:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/unmarshal:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/unreachable:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/unsafeptr:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/unusedresult:go_tool_library", + "@org_golang_x_tools//go/gcexportdata:go_tool_library", + ], +) diff --git a/tools/nogo/README.md b/tools/nogo/README.md new file mode 100644 index 000000000..6e4db18de --- /dev/null +++ b/tools/nogo/README.md @@ -0,0 +1,31 @@ +# Extended "nogo" analysis + +This package provides a build aspect that perform nogo analysis. This will be +automatically injected to all relevant libraries when using the default +`go_binary` and `go_library` rules. + +It exists for several reasons. + +* The default `nogo` provided by bazel is insufficient with respect to the + possibility of binary analysis. This package allows us to analyze the + generated binary in addition to using the standard analyzers. + +* The configuration provided in this package is much richer than the standard + `nogo` JSON blob. Specifically, it allows us to exclude specific structures + from the composite rules (such as the Ranges that are common with the set + types). + +* The bazel version of `nogo` is run directly against the `go_library` and + `go_binary` targets, meaning that any change to the configuration requires a + rebuild from scratch (for some reason included all C++ source files in the + process). Using an aspect is more efficient in this regard. + +* The checks supported by this package are exported as tests, which makes it + easier to reason about and plumb into the build system. + +* For uninteresting reasons, it is impossible to integrate the default `nogo` + analyzer provided by bazel with internal Google tooling. To provide a + consistent experience, this package allows those systems to be unified. + +To use this package, import `nogo_test` from `defs.bzl` and add a single +dependency which is a `go_binary` or `go_library` rule. diff --git a/tools/nogo/build.go b/tools/nogo/build.go new file mode 100644 index 000000000..1c0d08661 --- /dev/null +++ b/tools/nogo/build.go @@ -0,0 +1,36 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nogo + +import ( + "fmt" + "io" + "os" +) + +var ( + // internalPrefix is the internal path prefix. Note that this is not + // special, as paths should be passed relative to the repository root + // and should not have any special prefix applied. + internalPrefix = fmt.Sprintf("^") + + // externalPrefix is external workspace packages. + externalPrefix = "^external/" +) + +// findStdPkg needs to find the bundled standard library packages. +func findStdPkg(path, GOOS, GOARCH string) (io.ReadCloser, error) { + return os.Open(fmt.Sprintf("external/go_sdk/pkg/%s_%s/%s.a", GOOS, GOARCH, path)) +} diff --git a/tools/nogo/check/BUILD b/tools/nogo/check/BUILD new file mode 100644 index 000000000..e2d76cd5c --- /dev/null +++ b/tools/nogo/check/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +# Note that the check binary must be public, since an aspect may be applied +# across lots of different rules in different repositories. +go_binary( + name = "check", + srcs = ["main.go"], + visibility = ["//visibility:public"], + deps = ["//tools/nogo"], +) diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go new file mode 100644 index 000000000..3828edf3a --- /dev/null +++ b/tools/nogo/check/main.go @@ -0,0 +1,24 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary check is the nogo entrypoint. +package main + +import ( + "gvisor.dev/gvisor/tools/nogo" +) + +func main() { + nogo.Main() +} diff --git a/tools/nogo/config.go b/tools/nogo/config.go new file mode 100644 index 000000000..6958fca69 --- /dev/null +++ b/tools/nogo/config.go @@ -0,0 +1,116 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nogo + +import ( + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/asmdecl" + "golang.org/x/tools/go/analysis/passes/assign" + "golang.org/x/tools/go/analysis/passes/atomic" + "golang.org/x/tools/go/analysis/passes/bools" + "golang.org/x/tools/go/analysis/passes/buildtag" + "golang.org/x/tools/go/analysis/passes/cgocall" + "golang.org/x/tools/go/analysis/passes/composite" + "golang.org/x/tools/go/analysis/passes/copylock" + "golang.org/x/tools/go/analysis/passes/errorsas" + "golang.org/x/tools/go/analysis/passes/httpresponse" + "golang.org/x/tools/go/analysis/passes/loopclosure" + "golang.org/x/tools/go/analysis/passes/lostcancel" + "golang.org/x/tools/go/analysis/passes/nilfunc" + "golang.org/x/tools/go/analysis/passes/nilness" + "golang.org/x/tools/go/analysis/passes/printf" + "golang.org/x/tools/go/analysis/passes/shadow" + "golang.org/x/tools/go/analysis/passes/shift" + "golang.org/x/tools/go/analysis/passes/stdmethods" + "golang.org/x/tools/go/analysis/passes/stringintconv" + "golang.org/x/tools/go/analysis/passes/structtag" + "golang.org/x/tools/go/analysis/passes/tests" + "golang.org/x/tools/go/analysis/passes/unmarshal" + "golang.org/x/tools/go/analysis/passes/unreachable" + "golang.org/x/tools/go/analysis/passes/unsafeptr" + "golang.org/x/tools/go/analysis/passes/unusedresult" + + "gvisor.dev/gvisor/tools/checkescape" + "gvisor.dev/gvisor/tools/checkunsafe" +) + +var analyzerConfig = map[*analysis.Analyzer]matcher{ + // Standard analyzers. + asmdecl.Analyzer: alwaysMatches(), + assign.Analyzer: externalExcluded( + ".*gazelle/walk/walk.go", // False positive. + ), + atomic.Analyzer: alwaysMatches(), + bools.Analyzer: alwaysMatches(), + buildtag.Analyzer: alwaysMatches(), + cgocall.Analyzer: alwaysMatches(), + composite.Analyzer: and( + disableMatches(), // Disabled for now. + resultExcluded{ + "Object_", + "Range{", + }, + ), + copylock.Analyzer: internalMatches(), // Common external issues (e.g. protos). + errorsas.Analyzer: alwaysMatches(), + httpresponse.Analyzer: alwaysMatches(), + loopclosure.Analyzer: alwaysMatches(), + lostcancel.Analyzer: internalMatches(), // Common external issues. + nilfunc.Analyzer: alwaysMatches(), + nilness.Analyzer: and( + internalMatches(), // Common "tautological checks". + internalExcluded( + "pkg/sentry/platform/kvm/kvm_test.go", // Intentional. + "tools/bigquery/bigquery.go", // False positive. + ), + ), + printf.Analyzer: alwaysMatches(), + shift.Analyzer: alwaysMatches(), + stdmethods.Analyzer: internalMatches(), // Common external issues (e.g. methods named "Write"). + stringintconv.Analyzer: and( + internalExcluded(), + externalExcluded( + ".*protobuf/.*.go", // Bad conversions. + ".*flate/huffman_bit_writer.go", // Bad conversion. + ), + ), + shadow.Analyzer: disableMatches(), // Disabled for now. + structtag.Analyzer: internalMatches(), // External not subject to rules. + tests.Analyzer: alwaysMatches(), + unmarshal.Analyzer: alwaysMatches(), + unreachable.Analyzer: internalMatches(), + unsafeptr.Analyzer: and( + internalMatches(), + internalExcluded( + ".*_test.go", // Exclude tests. + "pkg/flipcall/.*_unsafe.go", // Special case. + "pkg/gohacks/gohacks_unsafe.go", // Special case. + "pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go", // Special case. + "pkg/sentry/platform/kvm/bluepill_unsafe.go", // Special case. + "pkg/sentry/platform/kvm/machine_unsafe.go", // Special case. + "pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go", // Special case. + "pkg/sentry/platform/safecopy/safecopy_unsafe.go", // Special case. + "pkg/sentry/vfs/mount_unsafe.go", // Special case. + "pkg/sentry/platform/systrap/stub_unsafe.go", // Special case. + "pkg/sentry/platform/systrap/switchto_google_unsafe.go", // Special case. + "pkg/sentry/platform/systrap/sysmsg_thread_unsafe.go", // Special case. + ), + ), + unusedresult.Analyzer: alwaysMatches(), + + // Internal analyzers: external packages not subject. + checkescape.Analyzer: internalMatches(), + checkunsafe.Analyzer: internalMatches(), +} diff --git a/tools/nogo/data/BUILD b/tools/nogo/data/BUILD new file mode 100644 index 000000000..b7564cc44 --- /dev/null +++ b/tools/nogo/data/BUILD @@ -0,0 +1,10 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "data", + srcs = ["data.go"], + nogo = False, + visibility = ["//tools:__subpackages__"], +) diff --git a/tools/nogo/data/data.go b/tools/nogo/data/data.go new file mode 100644 index 000000000..eb84d0d27 --- /dev/null +++ b/tools/nogo/data/data.go @@ -0,0 +1,21 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package data contains shared data for nogo analysis. +// +// This is used to break a dependency cycle. +package data + +// Objdump is the dumped binary under analysis. +var Objdump string diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl new file mode 100644 index 000000000..6560b57c8 --- /dev/null +++ b/tools/nogo/defs.bzl @@ -0,0 +1,172 @@ +"""Nogo rules.""" + +load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule") + +# NogoInfo is the serialized set of package facts for a nogo analysis. +# +# Each go_library rule will generate a corresponding nogo rule, which will run +# with the source files as input. Note however, that the individual nogo rules +# are simply stubs that enter into the shadow dependency tree (the "aspect"). +NogoInfo = provider( + fields = { + "facts": "serialized package facts", + "importpath": "package import path", + "binaries": "package binary files", + }, +) + +def _nogo_aspect_impl(target, ctx): + # If this is a nogo rule itself (and not the shadow of a go_library or + # go_binary rule created by such a rule), then we simply return nothing. + # All work is done in the shadow properties for go rules. For a proto + # library, we simply skip the analysis portion but still need to return a + # valid NogoInfo to reference the generated binary. + if ctx.rule.kind == "go_library": + srcs = ctx.rule.files.srcs + elif ctx.rule.kind == "go_proto_library" or ctx.rule.kind == "go_wrap_cc": + srcs = [] + else: + return [NogoInfo()] + + # Construct the Go environment from the go_context.env dictionary. + env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_context(ctx).env.items()]) + + # Start with all target files and srcs as input. + inputs = target.files.to_list() + srcs + + # Generate a shell script that dumps the binary. Annoyingly, this seems + # necessary as the context in which a run_shell command runs does not seem + # to cleanly allow us redirect stdout to the actual output file. Perhaps + # I'm missing something here, but the intermediate script does work. + binaries = target.files.to_list() + disasm_file = ctx.actions.declare_file(target.label.name + ".out") + dumper = ctx.actions.declare_file("%s-dumper" % ctx.label.name) + ctx.actions.write(dumper, "\n".join([ + "#!/bin/bash", + "%s %s tool objdump %s > %s\n" % ( + env_prefix, + go_context(ctx).go.path, + [f.path for f in binaries if f.path.endswith(".a")][0], + disasm_file.path, + ), + ]), is_executable = True) + ctx.actions.run( + inputs = binaries, + outputs = [disasm_file], + tools = go_context(ctx).runfiles, + mnemonic = "GoObjdump", + progress_message = "Objdump %s" % target.label, + executable = dumper, + ) + inputs.append(disasm_file) + + # Extract the importpath for this package. + importpath = go_importpath(target) + + # The nogo tool requires a configfile serialized in JSON format to do its + # work. This must line up with the nogo.Config fields. + facts = ctx.actions.declare_file(target.label.name + ".facts") + config = struct( + ImportPath = importpath, + GoFiles = [src.path for src in srcs if src.path.endswith(".go")], + NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")], + GOOS = go_context(ctx).goos, + GOARCH = go_context(ctx).goarch, + Tags = go_context(ctx).tags, + FactMap = {}, # Constructed below. + ImportMap = {}, # Constructed below. + FactOutput = facts.path, + Objdump = disasm_file.path, + ) + + # Collect all info from shadow dependencies. + for dep in ctx.rule.attr.deps: + # There will be no file attribute set for all transitive dependencies + # that are not go_library or go_binary rules, such as a proto rules. + # This is handled by the ctx.rule.kind check above. + info = dep[NogoInfo] + if not hasattr(info, "facts"): + continue + + # Configure where to find the binary & fact files. Note that this will + # use .x and .a regardless of whether this is a go_binary rule, since + # these dependencies must be go_library rules. + x_files = [f.path for f in info.binaries if f.path.endswith(".x")] + if not len(x_files): + x_files = [f.path for f in info.binaries if f.path.endswith(".a")] + config.ImportMap[info.importpath] = x_files[0] + config.FactMap[info.importpath] = info.facts.path + + # Ensure the above are available as inputs. + inputs.append(info.facts) + inputs += info.binaries + + # Write the configuration and run the tool. + config_file = ctx.actions.declare_file(target.label.name + ".cfg") + ctx.actions.write(config_file, config.to_json()) + inputs.append(config_file) + + # Run the nogo tool itself. + ctx.actions.run( + inputs = inputs, + outputs = [facts], + tools = go_context(ctx).runfiles, + executable = ctx.files._nogo[0], + mnemonic = "GoStaticAnalysis", + progress_message = "Analyzing %s" % target.label, + arguments = ["-config=%s" % config_file.path], + ) + + # Return the package facts as output. + return [NogoInfo( + facts = facts, + importpath = importpath, + binaries = binaries, + )] + +nogo_aspect = go_rule( + aspect, + implementation = _nogo_aspect_impl, + attr_aspects = ["deps"], + attrs = { + "_nogo": attr.label( + default = "//tools/nogo/check:check", + allow_single_file = True, + ), + }, +) + +def _nogo_test_impl(ctx): + """Check nogo findings.""" + + # Build a runner that checks for the existence of the facts file. Note that + # the actual build will fail in the case of a broken analysis. We things + # this way so that any test applied is effectively pushed down to all + # upstream dependencies through the aspect. + inputs = [] + runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) + runner_content = ["#!/bin/bash"] + for dep in ctx.attr.deps: + info = dep[NogoInfo] + inputs.append(info.facts) + + # Draw a sweet unicode checkmark with the package name (in green). + runner_content.append("echo -e \"\\033[0;32m\\xE2\\x9C\\x94\\033[0;31m\\033[0m %s\"" % info.importpath) + runner_content.append("exit 0\n") + ctx.actions.write(runner, "\n".join(runner_content), is_executable = True) + return [DefaultInfo( + runfiles = ctx.runfiles(files = inputs), + executable = runner, + )] + +_nogo_test = rule( + implementation = _nogo_test_impl, + attrs = { + "deps": attr.label_list(aspects = [nogo_aspect]), + }, + test = True, +) + +def nogo_test(**kwargs): + tags = kwargs.pop("tags", []) + ["nogo"] + _nogo_test(tags = tags, **kwargs) diff --git a/tools/nogo/io_bazel_rules_go-visibility.patch b/tools/nogo/io_bazel_rules_go-visibility.patch new file mode 100644 index 000000000..6b64b2e85 --- /dev/null +++ b/tools/nogo/io_bazel_rules_go-visibility.patch @@ -0,0 +1,25 @@ +diff --git a/third_party/org_golang_x_tools-extras.patch b/third_party/org_golang_x_tools-extras.patch +index 133fbccc..5f0d9a47 100644 +--- a/third_party/org_golang_x_tools-extras.patch ++++ b/third_party/org_golang_x_tools-extras.patch +@@ -32,7 +32,7 @@ diff -urN c/go/analysis/internal/facts/BUILD.bazel d/go/analysis/internal/facts/ + + go_library( + name = "go_default_library", +-@@ -14,6 +14,23 @@ ++@@ -14,6 +14,20 @@ + ], + ) + +@@ -43,10 +43,7 @@ diff -urN c/go/analysis/internal/facts/BUILD.bazel d/go/analysis/internal/facts/ + + "imports.go", + + ], + + importpath = "golang.org/x/tools/go/analysis/internal/facts", +-+ visibility = [ +-+ "//go/analysis:__subpackages__", +-+ "@io_bazel_rules_go//go/tools/builders:__pkg__", +-+ ], +++ visibility = ["//visibility:public"], + + deps = [ + + "//go/analysis:go_tool_library", + + "//go/types/objectpath:go_tool_library", diff --git a/tools/nogo/matchers.go b/tools/nogo/matchers.go new file mode 100644 index 000000000..bc5772303 --- /dev/null +++ b/tools/nogo/matchers.go @@ -0,0 +1,138 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nogo + +import ( + "go/token" + "path/filepath" + "regexp" + "strings" + + "golang.org/x/tools/go/analysis" +) + +type matcher interface { + ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool +} + +// pathRegexps excludes explicit paths. +type pathRegexps struct { + expr []*regexp.Regexp + whitelist bool +} + +// buildRegexps builds a list of regular expressions. +// +// This will panic on error. +func buildRegexps(prefix string, args ...string) []*regexp.Regexp { + result := make([]*regexp.Regexp, 0, len(args)) + for _, arg := range args { + result = append(result, regexp.MustCompile(filepath.Join(prefix, arg))) + } + return result +} + +// ShouldReport implements matcher.ShouldReport. +func (p *pathRegexps) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool { + fullPos := fs.Position(d.Pos).String() + for _, path := range p.expr { + if path.MatchString(fullPos) { + return p.whitelist + } + } + return !p.whitelist +} + +// internalExcluded excludes specific internal paths. +func internalExcluded(paths ...string) *pathRegexps { + return &pathRegexps{ + expr: buildRegexps(internalPrefix, paths...), + whitelist: false, + } +} + +// excludedExcluded excludes specific external paths. +func externalExcluded(paths ...string) *pathRegexps { + return &pathRegexps{ + expr: buildRegexps(externalPrefix, paths...), + whitelist: false, + } +} + +// internalMatches returns a path matcher for internal packages. +func internalMatches() *pathRegexps { + return &pathRegexps{ + expr: buildRegexps(internalPrefix, ".*"), + whitelist: true, + } +} + +// resultExcluded excludes explicit message contents. +type resultExcluded []string + +// ShouldReport implements matcher.ShouldReport. +func (r resultExcluded) ShouldReport(d analysis.Diagnostic, _ *token.FileSet) bool { + for _, str := range r { + if strings.Contains(d.Message, str) { + return false + } + } + return true // Not blacklisted. +} + +// andMatcher is a composite matcher. +type andMatcher struct { + first matcher + second matcher +} + +// ShouldReport implements matcher.ShouldReport. +func (a *andMatcher) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool { + return a.first.ShouldReport(d, fs) && a.second.ShouldReport(d, fs) +} + +// and is a syntactic convension for andMatcher. +func and(first matcher, second matcher) *andMatcher { + return &andMatcher{ + first: first, + second: second, + } +} + +// anyMatcher matches everything. +type anyMatcher struct{} + +// ShouldReport implements matcher.ShouldReport. +func (anyMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool { + return true +} + +// alwaysMatches returns an anyMatcher instance. +func alwaysMatches() anyMatcher { + return anyMatcher{} +} + +// neverMatcher will never match. +type neverMatcher struct{} + +// ShouldReport implements matcher.ShouldReport. +func (neverMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool { + return false +} + +// disableMatches returns a neverMatcher instance. +func disableMatches() neverMatcher { + return neverMatcher{} +} diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go new file mode 100644 index 000000000..203cdf688 --- /dev/null +++ b/tools/nogo/nogo.go @@ -0,0 +1,316 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package nogo implements binary analysis similar to bazel's nogo, +// or the unitchecker package. It exists in order to provide additional +// facilities for analysis, namely plumbing through the output from +// dumping the generated binary (to analyze actual produced code). +package nogo + +import ( + "encoding/json" + "flag" + "fmt" + "go/ast" + "go/build" + "go/parser" + "go/token" + "go/types" + "io" + "io/ioutil" + "log" + "os" + "path/filepath" + "reflect" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/internal/facts" + "golang.org/x/tools/go/gcexportdata" + "gvisor.dev/gvisor/tools/nogo/data" +) + +// pkgConfig is serialized as the configuration. +// +// This contains everything required for the analysis. +type pkgConfig struct { + ImportPath string + GoFiles []string + NonGoFiles []string + Tags []string + GOOS string + GOARCH string + ImportMap map[string]string + FactMap map[string]string + FactOutput string + Objdump string +} + +// loadFacts finds and loads facts per FactMap. +func (c *pkgConfig) loadFacts(path string) ([]byte, error) { + realPath, ok := c.FactMap[path] + if !ok { + return nil, nil // No facts available. + } + + // Read the files file. + data, err := ioutil.ReadFile(realPath) + if err != nil { + return nil, err + } + return data, nil +} + +// shouldInclude indicates whether the file should be included. +// +// NOTE: This does only basic parsing of tags. +func (c *pkgConfig) shouldInclude(path string) (bool, error) { + ctx := build.Default + ctx.GOOS = c.GOOS + ctx.GOARCH = c.GOARCH + ctx.BuildTags = c.Tags + return ctx.MatchFile(filepath.Dir(path), filepath.Base(path)) +} + +// importer is an implementation of go/types.Importer. +// +// This wraps a configuration, which provides the map of package names to +// files, and the facts. Note that this importer implementation will always +// pass when a given package is not available. +type importer struct { + pkgConfig + fset *token.FileSet + cache map[string]*types.Package +} + +// Import implements types.Importer.Import. +func (i *importer) Import(path string) (*types.Package, error) { + if path == "unsafe" { + // Special case: go/types has pre-defined type information for + // unsafe. We ensure that this package is correct, in case any + // analyzers are specifically looking for this. + return types.Unsafe, nil + } + realPath, ok := i.ImportMap[path] + var ( + rc io.ReadCloser + err error + ) + if !ok { + // Not found in the import path. Attempt to find the package + // via the standard library. + rc, err = findStdPkg(path, i.GOOS, i.GOARCH) + } else { + // Open the file. + rc, err = os.Open(realPath) + } + if err != nil { + return nil, err + } + defer rc.Close() + + // Load all exported data. + r, err := gcexportdata.NewReader(rc) + if err != nil { + return nil, err + } + + return gcexportdata.Read(r, i.fset, i.cache, path) +} + +// checkPackage runs all analyzers. +// +// The implementation was adapted from [1], which was in turn adpated from [2]. +// This returns a list of matching analysis issues, or an error if the analysis +// could not be completed. +// +// [1] bazelbuid/rules_go/tools/builders/nogo_main.go +// [2] golang.org/x/tools/go/checker/internal/checker +func checkPackage(config pkgConfig) ([]string, error) { + imp := &importer{ + pkgConfig: config, + fset: token.NewFileSet(), + cache: make(map[string]*types.Package), + } + + // Load all source files. + var syntax []*ast.File + for _, file := range config.GoFiles { + include, err := config.shouldInclude(file) + if err != nil { + return nil, fmt.Errorf("error evaluating file %q: %v", file, err) + } + if !include { + continue + } + s, err := parser.ParseFile(imp.fset, file, nil, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("error parsing file %q: %v", file, err) + } + syntax = append(syntax, s) + } + + // Check type information. + typesSizes := types.SizesFor("gc", config.GOARCH) + typeConfig := types.Config{Importer: imp} + typesInfo := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Uses: make(map[*ast.Ident]types.Object), + Defs: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + types, err := typeConfig.Check(config.ImportPath, imp.fset, syntax, typesInfo) + if err != nil { + return nil, fmt.Errorf("error checking types: %v", err) + } + + // Load all package facts. + facts, err := facts.Decode(types, config.loadFacts) + if err != nil { + return nil, fmt.Errorf("error decoding facts: %v", err) + } + + // Set the binary global for use. + data.Objdump = config.Objdump + + // Register fact types and establish dependencies between analyzers. + // The visit closure will execute recursively, and populate results + // will all required analysis results. + diagnostics := make(map[*analysis.Analyzer][]analysis.Diagnostic) + results := make(map[*analysis.Analyzer]interface{}) + var visit func(*analysis.Analyzer) error // For recursion. + visit = func(a *analysis.Analyzer) error { + if _, ok := results[a]; ok { + return nil + } + + // Run recursively for all dependencies. + for _, req := range a.Requires { + if err := visit(req); err != nil { + return err + } + } + + // Prepare the matcher. + m := analyzerConfig[a] + report := func(d analysis.Diagnostic) { + if m.ShouldReport(d, imp.fset) { + diagnostics[a] = append(diagnostics[a], d) + } + } + + // Run the analysis. + factFilter := make(map[reflect.Type]bool) + for _, f := range a.FactTypes { + factFilter[reflect.TypeOf(f)] = true + } + p := &analysis.Pass{ + Analyzer: a, + Fset: imp.fset, + Files: syntax, + Pkg: types, + TypesInfo: typesInfo, + ResultOf: results, // All results. + Report: report, + ImportPackageFact: facts.ImportPackageFact, + ExportPackageFact: facts.ExportPackageFact, + ImportObjectFact: facts.ImportObjectFact, + ExportObjectFact: facts.ExportObjectFact, + AllPackageFacts: func() []analysis.PackageFact { return facts.AllPackageFacts(factFilter) }, + AllObjectFacts: func() []analysis.ObjectFact { return facts.AllObjectFacts(factFilter) }, + TypesSizes: typesSizes, + } + result, err := a.Run(p) + if err != nil { + return fmt.Errorf("error running analysis %s: %v", a, err) + } + + // Sanity check & save the result. + if got, want := reflect.TypeOf(result), a.ResultType; got != want { + return fmt.Errorf("error: analyzer %s returned a result of type %v, but declared ResultType %v", a, got, want) + } + results[a] = result + return nil // Success. + } + + // Visit all analysis recursively. + for a, _ := range analyzerConfig { + if err := visit(a); err != nil { + return nil, err // Already has context. + } + } + + // Write the output file. + if config.FactOutput != "" { + factData := facts.Encode() + if err := ioutil.WriteFile(config.FactOutput, factData, 0644); err != nil { + return nil, fmt.Errorf("error: unable to open facts output %q: %v", config.FactOutput, err) + } + } + + // Convert all diagnostics to strings. + findings := make([]string, 0, len(diagnostics)) + for a, ds := range diagnostics { + for _, d := range ds { + // Include the anlyzer name for debugability and configuration. + findings = append(findings, fmt.Sprintf("%s: %s: %s", a.Name, imp.fset.Position(d.Pos), d.Message)) + } + } + + // Return all findings. + return findings, nil +} + +var ( + configFile = flag.String("config", "", "configuration file (in JSON format)") +) + +// Main is the entrypoint; it should be called directly from main. +// +// N.B. This package registers it's own flags. +func Main() { + // Parse all flags. + flag.Parse() + + // Load the configuration. + f, err := os.Open(*configFile) + if err != nil { + log.Fatalf("unable to open configuration %q: %v", *configFile, err) + } + defer f.Close() + config := new(pkgConfig) + dec := json.NewDecoder(f) + dec.DisallowUnknownFields() + if err := dec.Decode(config); err != nil { + log.Fatalf("unable to decode configuration: %v", err) + } + + // Process the package. + findings, err := checkPackage(*config) + if err != nil { + log.Fatalf("error checking package: %v", err) + } + + // No findings? + if len(findings) == 0 { + os.Exit(0) + } + + // Print findings and exit with non-zero code. + for _, finding := range findings { + fmt.Fprintf(os.Stdout, "%s\n", finding) + } + os.Exit(1) +} diff --git a/tools/nogo/register.go b/tools/nogo/register.go new file mode 100644 index 000000000..62b499661 --- /dev/null +++ b/tools/nogo/register.go @@ -0,0 +1,64 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nogo + +import ( + "encoding/gob" + "log" + + "golang.org/x/tools/go/analysis" +) + +// analyzers returns all configured analyzers. +func analyzers() (all []*analysis.Analyzer) { + for a, _ := range analyzerConfig { + all = append(all, a) + } + return all +} + +func init() { + // Validate basic configuration. + if err := analysis.Validate(analyzers()); err != nil { + log.Fatalf("unable to validate analyzer: %v", err) + } + + // Register all fact types. + // + // N.B. This needs to be done recursively, because there may be + // analyzers in the Requires list that do not appear explicitly above. + registered := make(map[*analysis.Analyzer]struct{}) + var register func(*analysis.Analyzer) + register = func(a *analysis.Analyzer) { + if _, ok := registered[a]; ok { + return + } + + // Regsiter dependencies. + for _, da := range a.Requires { + register(da) + } + + // Register local facts. + for _, f := range a.FactTypes { + gob.Register(f) + } + + registered[a] = struct{}{} // Done. + } + for _, a := range analyzers() { + register(a) + } +} diff --git a/tools/images/BUILD b/tools/vm/BUILD index fe11f08a3..f7160c627 100644 --- a/tools/images/BUILD +++ b/tools/vm/BUILD @@ -1,19 +1,14 @@ load("//tools:defs.bzl", "cc_binary", "gtest") -load("//tools/images:defs.bzl", "vm_image", "vm_test") +load("//tools/vm:defs.bzl", "vm_image", "vm_test") package( default_visibility = ["//:sandbox"], licenses = ["notice"], ) -genrule( +sh_binary( name = "zone", - outs = ["zone.txt"], - cmd = "gcloud config get-value compute/zone > $@", - tags = [ - "local", - "manual", - ], + srcs = ["zone.sh"], ) sh_binary( @@ -42,27 +37,21 @@ vm_image( family = "ubuntu-1604-lts", project = "ubuntu-os-cloud", scripts = [ - "//tools/images/ubuntu1604", + "//tools/vm/ubuntu1604", ], ) -vm_test( - name = "ubuntu1604_test", - image = ":ubuntu1604", - targets = [":test"], -) - vm_image( name = "ubuntu1804", family = "ubuntu-1804-lts", project = "ubuntu-os-cloud", scripts = [ - "//tools/images/ubuntu1804", + "//tools/vm/ubuntu1804", ], ) vm_test( - name = "ubuntu1804_test", - image = ":ubuntu1804", + name = "vm_test", + shard_count = 2, targets = [":test"], ) diff --git a/tools/vm/README.md b/tools/vm/README.md new file mode 100644 index 000000000..898c95fca --- /dev/null +++ b/tools/vm/README.md @@ -0,0 +1,42 @@ +# VM Images & Tests + +All commands in this directory require the `gcloud` project to be set. + +For example: `gcloud config set project gvisor-kokoro-testing`. + +Images can be generated by using the `vm_image` rule. This rule will generate a +binary target that builds an image in an idempotent way, and can be referenced +from other rules. + +For example: + +``` +vm_image( + name = "ubuntu", + project = "ubuntu-1604-lts", + family = "ubuntu-os-cloud", + scripts = [ + "script.sh", + "other.sh", + ], +) +``` + +These images can be built manually by executing the target. The output on +`stdout` will be the image id (in the current project). + +Images are always named per the hash of all the hermetic input scripts. This +allows images to be memoized quickly and easily. + +The `vm_test` rule can be used to execute a command remotely. This is still +under development however, and will likely change over time. + +For example: + +``` +vm_test( + name = "mycommand", + image = ":ubuntu", + targets = [":test"], +) +``` diff --git a/tools/images/build.sh b/tools/vm/build.sh index be462d556..5d3dc0bbf 100755 --- a/tools/images/build.sh +++ b/tools/vm/build.sh @@ -19,7 +19,7 @@ # image_setup.sh script. This script should be idempotent, as we memoize the # setup script with a hash and check for that name. -set -xeou pipefail +set -eou pipefail # Parameters. declare -r USERNAME=${USERNAME:-test} @@ -34,68 +34,84 @@ declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z) # Hash inputs in order to memoize the produced image. declare -r SETUP_HASH=$( (echo ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && cat "$@") | sha256sum - | cut -d' ' -f1 | cut -c 1-16) -declare -r IMAGE_NAME=${IMAGE_FAMILY:-image-}${SETUP_HASH} +declare -r IMAGE_NAME=${IMAGE_FAMILY:-image}-${SETUP_HASH} # Does the image already exist? Skip the build. -declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)") +declare -r existing=$(set -x; gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)") if ! [[ -z "${existing}" ]]; then echo "${existing}" exit 0 fi +# Standard arguments (applies only on script execution). +declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--") + # gcloud has path errors; is this a result of being a genrule? export PATH=${PATH:-/bin:/usr/bin:/usr/local/bin} # Start a unique instance. Note that this instance will have a unique persistent # disk as it's boot disk with the same name as the instance. -gcloud compute instances create \ +(set -x; gcloud compute instances create \ --quiet \ --image-project "${IMAGE_PROJECT}" \ --image-family "${IMAGE_FAMILY}" \ --boot-disk-size "200GB" \ --zone "${ZONE}" \ - "${INSTANCE_NAME}" >/dev/null + "${INSTANCE_NAME}" >/dev/null) function cleanup { - gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}" + (set -x; gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}") } trap cleanup EXIT # Wait for the instance to become available (up to 5 minutes). +echo -n "Waiting for ${INSTANCE_NAME}" declare timeout=300 declare success=0 +declare internal="" declare -r start=$(date +%s) declare -r end=$((${start}+${timeout})) while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do - if gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- env - true 2>/dev/null; then + echo -n "." + if gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then + success=$((${success}+1)) + elif gcloud compute ssh --internal-ip --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then success=$((${success}+1)) + internal="--internal-ip" fi done + if [[ "${success}" -eq "0" ]]; then echo "connect timed out after ${timeout} seconds." exit 1 +else + echo "done." fi # Run the install scripts provided. for arg; do - gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}" >/dev/null + (set -x; gcloud compute ssh ${internal} \ + --zone "${ZONE}" \ + "${USERNAME}"@"${INSTANCE_NAME}" -- \ + "${SSH_ARGS[@]}" \ + sudo bash - <"${arg}" >/dev/null) done # Stop the instance; required before creating an image. -gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null +(set -x; gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null) # Create a snapshot of the instance disk. -gcloud compute disks snapshot \ +(set -x; gcloud compute disks snapshot \ --quiet \ --zone "${ZONE}" \ --snapshot-names="${SNAPSHOT_NAME}" \ - "${INSTANCE_NAME}" >/dev/null + "${INSTANCE_NAME}" >/dev/null) # Create the disk image. -gcloud compute images create \ +(set -x; gcloud compute images create \ --quiet \ --source-snapshot="${SNAPSHOT_NAME}" \ --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \ - "${IMAGE_NAME}" >/dev/null + "${IMAGE_NAME}" >/dev/null) # Finish up. echo "${IMAGE_NAME}" diff --git a/tools/images/defs.bzl b/tools/vm/defs.bzl index de365d153..61feefcbc 100644 --- a/tools/images/defs.bzl +++ b/tools/vm/defs.bzl @@ -1,96 +1,108 @@ -"""Image configuration. - -Images can be generated by using the vm_image rule. For example, - - vm_image( - name = "ubuntu", - project = "...", - family = "...", - scripts = [ - "script.sh", - "other.sh", - ], - ) - -This will always create an vm_image in the current default gcloud project. The -rule has a text file as its output containing the image name. This will enforce -serialization for all dependent rules. - -Images are always named per the hash of all the hermetic input scripts. This -allows images to be memoized quickly and easily. - -The vm_test rule can be used to execute a command remotely. For example, - - vm_test( - name = "mycommand", - image = ":myimage", - targets = [":test"], - ) -""" +"""Image configuration. See README.md.""" load("//tools:defs.bzl", "default_installer") -def _vm_image_impl(ctx): +# vm_image_builder is a rule that will construct a shell script that actually +# generates a given VM image. Note that this does not _run_ the shell script +# (although it can be run manually). It will be run manually during generation +# of the vm_image target itself. This level of indirection is used so that the +# build system itself only runs the builder once when multiple targets depend +# on it, avoiding a set of races and conflicts. +def _vm_image_builder_impl(ctx): + # Generate a binary that actually builds the image. + builder = ctx.actions.declare_file(ctx.label.name) script_paths = [] for script in ctx.files.scripts: script_paths.append(script.short_path) + builder_content = "\n".join([ + "#!/bin/bash", + "export ZONE=$(%s)" % ctx.files.zone[0].short_path, + "export USERNAME=%s" % ctx.attr.username, + "export IMAGE_PROJECT=%s" % ctx.attr.project, + "export IMAGE_FAMILY=%s" % ctx.attr.family, + "%s %s" % (ctx.files._builder[0].short_path, " ".join(script_paths)), + "", + ]) + ctx.actions.write(builder, builder_content, is_executable = True) - resolved_inputs, argv, runfiles_manifests = ctx.resolve_command( - command = "USERNAME=%s ZONE=$(cat %s) IMAGE_PROJECT=%s IMAGE_FAMILY=%s %s %s > %s" % - ( - ctx.attr.username, - ctx.files.zone[0].path, - ctx.attr.project, - ctx.attr.family, - ctx.executable.builder.path, - " ".join(script_paths), - ctx.outputs.out.path, - ), - tools = [ctx.attr.builder] + ctx.attr.scripts, - ) - - ctx.actions.run_shell( - tools = resolved_inputs, - outputs = [ctx.outputs.out], - progress_message = "Building image...", - execution_requirements = {"local": "true"}, - command = argv, - input_manifests = runfiles_manifests, - ) + # Note that the scripts should only be files, and should not include any + # indirect transitive dependencies. The build script wouldn't work. return [DefaultInfo( - files = depset([ctx.outputs.out]), - runfiles = ctx.runfiles(files = [ctx.outputs.out]), + executable = builder, + runfiles = ctx.runfiles( + files = ctx.files.scripts + ctx.files._builder + ctx.files.zone, + ), )] -_vm_image = rule( +vm_image_builder = rule( attrs = { - "builder": attr.label( + "_builder": attr.label( executable = True, - default = "//tools/images:builder", + default = "//tools/vm:builder", cfg = "host", ), "username": attr.string(default = "$(whoami)"), "zone": attr.label( - default = "//tools/images:zone", + executable = True, + default = "//tools/vm:zone", cfg = "host", ), "family": attr.string(mandatory = True), "project": attr.string(mandatory = True), "scripts": attr.label_list(allow_files = True), }, - outputs = { - "out": "%{name}.txt", + executable = True, + implementation = _vm_image_builder_impl, +) + +# See vm_image_builder above. +def _vm_image_impl(ctx): + # Run the builder to generate our output. + echo = ctx.actions.declare_file(ctx.label.name) + resolved_inputs, argv, runfiles_manifests = ctx.resolve_command( + command = "echo -ne \"#!/bin/bash\\necho $(%s)\\n\" > %s && chmod 0755 %s" % ( + ctx.files.builder[0].path, + echo.path, + echo.path, + ), + tools = [ctx.attr.builder], + ) + ctx.actions.run_shell( + tools = resolved_inputs, + outputs = [echo], + progress_message = "Building image...", + execution_requirements = {"local": "true"}, + command = argv, + input_manifests = runfiles_manifests, + ) + + # Return just the echo command. All of the builder runfiles have been + # resolved and consumed in the generation of the trivial echo script. + return [DefaultInfo(executable = echo)] + +_vm_image_test = rule( + attrs = { + "builder": attr.label( + executable = True, + cfg = "host", + ), }, + test = True, implementation = _vm_image_impl, ) -def vm_image(**kwargs): - _vm_image( +def vm_image(name, **kwargs): + vm_image_builder( + name = name + "_builder", + **kwargs + ) + _vm_image_test( + name = name, + builder = ":" + name + "_builder", tags = [ "local", "manual", ], - **kwargs ) def _vm_test_impl(ctx): @@ -101,9 +113,9 @@ def _vm_test_impl(ctx): # they can be copied over for remote execution. runner_content = "\n".join([ "#!/bin/bash", - "export ZONE=$(cat %s)" % ctx.files.zone[0].short_path, + "export ZONE=$(%s)" % ctx.files.zone[0].short_path, "export USERNAME=%s" % ctx.attr.username, - "export IMAGE=$(cat %s)" % ctx.files.image[0].short_path, + "export IMAGE=$(%s)" % ctx.files.image[0].short_path, "export SUDO=%s" % "true" if ctx.attr.sudo else "false", "%s %s" % ( ctx.executable.executer.short_path, @@ -133,17 +145,19 @@ def _vm_test_impl(ctx): _vm_test = rule( attrs = { "image": attr.label( - mandatory = True, + executable = True, + default = "//tools/vm:ubuntu1804", cfg = "host", ), "executer": attr.label( executable = True, - default = "//tools/images:executer", + default = "//tools/vm:executer", cfg = "host", ), "username": attr.string(default = "$(whoami)"), "zone": attr.label( - default = "//tools/images:zone", + executable = True, + default = "//tools/vm:zone", cfg = "host", ), "sudo": attr.bool(default = True), @@ -159,7 +173,7 @@ _vm_test = rule( ) def vm_test( - installer = "//tools/installers:head", + installers = None, **kwargs): """Runs the given targets as a remote test. @@ -168,8 +182,12 @@ def vm_test( **kwargs: All test arguments. Should include targets and image. """ targets = kwargs.pop("targets", []) - if installer: - targets = [installer] + targets + if installers == None: + installers = [ + "//tools/installers:head", + "//tools/installers:images", + ] + targets = installers + targets if default_installer(): targets = [default_installer()] + targets _vm_test( diff --git a/tools/images/execute.sh b/tools/vm/execute.sh index ba4b1ac0e..1f1f3ce01 100755 --- a/tools/images/execute.sh +++ b/tools/vm/execute.sh @@ -31,6 +31,9 @@ declare -r MACHINE=${MACHINE:-n1-standard-1} declare -r ZONE=${ZONE:-us-central1-f} declare -r SUDO=${SUDO:-false} +# Standard arguments (applies only on script execution). +declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--") + # This script is executed as a test rule, which will reset the value of HOME. # Unfortunately, it is needed to load the gconfig credentials. We will reset # HOME when we actually execute in the remote environment, defined below. @@ -81,7 +84,9 @@ tar czf - --dereference --exclude=.git . | gcloud compute ssh \ --ssh-key-file="${KEYNAME}" \ --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}" -- tar xzf - + "${USERNAME}"@"${INSTANCE_NAME}" -- \ + "${SSH_ARGS[@]}" \ + tar xzf - # Execute the command remotely. for cmd; do @@ -108,6 +113,7 @@ for cmd; do --ssh-key-file="${KEYNAME}" \ --zone "${ZONE}" \ "${USERNAME}"@"${INSTANCE_NAME}" -- \ + "${SSH_ARGS[@]}" \ mkdir -p "/tmp/${REMOTE_TMPDIR}" fi if [[ -v XML_OUTPUT_FILE ]]; then @@ -123,6 +129,7 @@ for cmd; do --ssh-key-file="${KEYNAME}" \ --zone "${ZONE}" \ "${USERNAME}"@"${INSTANCE_NAME}" -- \ + "${SSH_ARGS[@]}" \ "${PREFIX[@]}" "${cmd}" # Collect relevant results. @@ -147,6 +154,7 @@ for cmd; do --ssh-key-file="${KEYNAME}" \ --zone "${ZONE}" \ "${USERNAME}"@"${INSTANCE_NAME}" -- \ + "${SSH_ARGS[@]}" \ rm -rf "/tmp/${REMOTE_TMPDIR}" fi done diff --git a/tools/images/test.cc b/tools/vm/test.cc index 4f31d93c5..c0ceacda1 100644 --- a/tools/images/test.cc +++ b/tools/vm/test.cc @@ -16,8 +16,12 @@ namespace { -TEST(Image, Sanity) { - // Do nothing. +TEST(Image, Sanity0) { + // Do nothing (in shard 0). +} + +TEST(Image, Sanity1) { + // Do nothing (in shard 1). } } // namespace diff --git a/tools/images/ubuntu1604/10_core.sh b/tools/vm/ubuntu1604/10_core.sh index cd518d6ac..cd518d6ac 100755 --- a/tools/images/ubuntu1604/10_core.sh +++ b/tools/vm/ubuntu1604/10_core.sh diff --git a/tools/images/ubuntu1604/20_bazel.sh b/tools/vm/ubuntu1604/20_bazel.sh index bb7afa676..bb7afa676 100755 --- a/tools/images/ubuntu1604/20_bazel.sh +++ b/tools/vm/ubuntu1604/20_bazel.sh diff --git a/tools/images/ubuntu1604/25_docker.sh b/tools/vm/ubuntu1604/25_docker.sh index 11eea2d72..11eea2d72 100755 --- a/tools/images/ubuntu1604/25_docker.sh +++ b/tools/vm/ubuntu1604/25_docker.sh diff --git a/tools/images/ubuntu1604/30_containerd.sh b/tools/vm/ubuntu1604/30_containerd.sh index fb3699c12..fb3699c12 100755 --- a/tools/images/ubuntu1604/30_containerd.sh +++ b/tools/vm/ubuntu1604/30_containerd.sh diff --git a/tools/images/ubuntu1604/40_kokoro.sh b/tools/vm/ubuntu1604/40_kokoro.sh index 06a1e6c48..06a1e6c48 100755 --- a/tools/images/ubuntu1604/40_kokoro.sh +++ b/tools/vm/ubuntu1604/40_kokoro.sh diff --git a/tools/images/ubuntu1604/BUILD b/tools/vm/ubuntu1604/BUILD index ab1df0c4c..ab1df0c4c 100644 --- a/tools/images/ubuntu1604/BUILD +++ b/tools/vm/ubuntu1604/BUILD diff --git a/tools/images/ubuntu1804/BUILD b/tools/vm/ubuntu1804/BUILD index 7aa1ecdf7..0c8856dde 100644 --- a/tools/images/ubuntu1804/BUILD +++ b/tools/vm/ubuntu1804/BUILD @@ -2,6 +2,6 @@ package(licenses = ["notice"]) alias( name = "ubuntu1804", - actual = "//tools/images/ubuntu1604", + actual = "//tools/vm/ubuntu1604", visibility = ["//:sandbox"], ) diff --git a/tools/vm/zone.sh b/tools/vm/zone.sh new file mode 100755 index 000000000..79569fb19 --- /dev/null +++ b/tools/vm/zone.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec gcloud config get-value compute/zone |