diff options
Diffstat (limited to 'tools')
130 files changed, 13091 insertions, 0 deletions
diff --git a/tools/BUILD b/tools/BUILD new file mode 100644 index 000000000..34b950644 --- /dev/null +++ b/tools/BUILD @@ -0,0 +1 @@ +package(licenses = ["notice"]) diff --git a/tools/bazel.mk b/tools/bazel.mk new file mode 100644 index 000000000..9f4a40669 --- /dev/null +++ b/tools/bazel.mk @@ -0,0 +1,124 @@ +#!/usr/bin/make -f + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# See base Makefile. +BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \ + git rev-parse --abbrev-ref HEAD 2>/dev/null) | \ + xargs -n 1 basename 2>/dev/null) + +# Bazel container configuration (see below). +USER ?= gvisor +HASH ?= $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8) +DOCKER_NAME ?= gvisor-bazel-$(HASH) +DOCKER_PRIVILEGED ?= --privileged +BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/) +GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/) +DOCKER_SOCKET := /var/run/docker.sock + +# Non-configurable. +UID := $(shell id -u ${USER}) +GID := $(shell id -g ${USER}) +USERADD_OPTIONS := +FULL_DOCKER_RUN_OPTIONS := $(DOCKER_RUN_OPTIONS) +FULL_DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)" +FULL_DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)" +FULL_DOCKER_RUN_OPTIONS += -v "/tmp:/tmp" +ifneq ($(DOCKER_PRIVILEGED),) +FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)" +DOCKER_GROUP := $(shell stat -c '%g' $(DOCKER_SOCKET)) +ifneq ($(GID),$(DOCKER_GROUP)) +USERADD_OPTIONS += --groups $(DOCKER_GROUP) +GROUPADD_DOCKER += groupadd --gid $(DOCKER_GROUP) --non-unique docker-$(HASH) && +FULL_DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP) +endif +endif +SHELL=/bin/bash -o pipefail + +## +## Bazel helpers. +## +## This file supports targets that wrap bazel in a running Docker +## container to simplify development. Some options are available to +## control the behavior of this container: +## USER - The in-container user. +## DOCKER_RUN_OPTIONS - Options for the container (default: --privileged, required for tests). +## DOCKER_NAME - The container name (default: gvisor-bazel-HASH). +## BAZEL_CACHE - The bazel cache directory (default: detected). +## GCLOUD_CONFIG - The gcloud config directory (detect: detected). +## DOCKER_SOCKET - The Docker socket (default: detected). +## +bazel-server-start: load-default ## Starts the bazel server. + @mkdir -p $(BAZEL_CACHE) + @mkdir -p $(GCLOUD_CONFIG) + docker run -d --rm \ + --init \ + --name $(DOCKER_NAME) \ + --user 0:0 $(DOCKER_GROUP_OPTIONS) \ + -v "$(CURDIR):$(CURDIR)" \ + --workdir "$(CURDIR)" \ + --entrypoint "" \ + $(FULL_DOCKER_RUN_OPTIONS) \ + gvisor.dev/images/default \ + sh -c "groupadd --gid $(GID) --non-unique $(USER) && \ + $(GROUPADD_DOCKER) \ + useradd --uid $(UID) --non-unique --no-create-home --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \ + bazel version && \ + exec tail --pid=\$$(bazel info server_pid) -f /dev/null" + @while :; do if docker logs $(DOCKER_NAME) 2>/dev/null | grep "Build label:" >/dev/null; then break; fi; \ + if ! docker ps | grep $(DOCKER_NAME); then exit 1; else sleep 1; fi; done +.PHONY: bazel-server-start + +bazel-shutdown: ## Shuts down a running bazel server. + @docker exec --user $(UID):$(GID) $(DOCKER_NAME) bazel shutdown; rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]] +.PHONY: bazel-shutdown + +bazel-alias: ## Emits an alias that can be used within the shell. + @echo "alias bazel='docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) bazel'" +.PHONY: bazel-alias + +bazel-server: ## Ensures that the server exists. Used as an internal target. + @docker exec $(DOCKER_NAME) true || $(MAKE) bazel-server-start +.PHONY: bazel-server + +build_cmd = docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) sh -o pipefail -c 'bazel $(STARTUP_OPTIONS) build $(OPTIONS) $(TARGETS)' + +build_paths = $(build_cmd) 2>&1 \ + | tee /proc/self/fd/2 \ + | grep -E "^ bazel-bin/" \ + | awk "{print $$1;}" \ + | xargs -n 1 -I {} sh -c "$(1)" + +build: bazel-server + @$(call build_cmd) +.PHONY: build + +copy: bazel-server +ifeq (,$(DESTINATION)) + $(error Destination not provided.) +endif + @$(call build_paths,cp -a {} $(DESTINATION)) + +run: bazel-server + @$(call build_paths,{} $(ARGS)) +.PHONY: run + +sudo: bazel-server + @$(call build_paths,sudo -E {} $(ARGS)) +.PHONY: sudo + +test: bazel-server + @docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) bazel $(STARTUP_OPTIONS) test $(OPTIONS) $(TARGETS) +.PHONY: test diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD new file mode 100644 index 000000000..f2f80bae1 --- /dev/null +++ b/tools/bazeldefs/BUILD @@ -0,0 +1,51 @@ +load("//tools:defs.bzl", "rbe_platform", "rbe_toolchain") + +package(licenses = ["notice"]) + +# In bazel, no special support is required for loopback networking. This is +# just a dummy data target that does not change the test environment. +genrule( + name = "loopback", + outs = ["loopback.txt"], + cmd = "touch $@", + visibility = ["//:sandbox"], +) + +# We need to define a bazel platform and toolchain to specify dockerPrivileged +# and dockerRunAsRoot options, they are required to run tests on the RBE +# cluster in Kokoro. +rbe_platform( + name = "rbe_ubuntu1604", + constraint_values = [ + "@bazel_tools//platforms:x86_64", + "@bazel_tools//platforms:linux", + "@bazel_tools//tools/cpp:clang", + "@bazel_toolchains//constraints:xenial", + "@bazel_toolchains//constraints/sanitizers:support_msan", + ], + remote_execution_properties = """ + properties: { + name: "container-image" + value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:93f7e127196b9b653d39830c50f8b05d49ef6fd8739a9b5b8ab16e1df5399e50" + } + properties: { + name: "dockerAddCapabilities" + value: "SYS_ADMIN" + } + properties: { + name: "dockerPrivileged" + value: "true" + } + """, +) + +rbe_toolchain( + name = "cc-toolchain-clang-x86_64-default", + exec_compatible_with = [], + tags = [ + "manual", + ], + target_compatible_with = [], + toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl new file mode 100644 index 000000000..620c460de --- /dev/null +++ b/tools/bazeldefs/defs.bzl @@ -0,0 +1,182 @@ +"""Bazel implementations of standard rules.""" + +load("@bazel_gazelle//:def.bzl", _gazelle = "gazelle") +load("@bazel_skylib//rules:build_test.bzl", _build_test = "build_test") +load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier") +load("@io_bazel_rules_go//go:def.bzl", "GoLibrary", _go_binary = "go_binary", _go_context = "go_context", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_path = "go_path", _go_test = "go_test") +load("@io_bazel_rules_go//proto:def.bzl", _go_grpc_library = "go_grpc_library", _go_proto_library = "go_proto_library") +load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test") +load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") +load("@pydeps//:requirements.bzl", _py_requirement = "requirement") +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", _cc_grpc_library = "cc_grpc_library") + +build_test = _build_test +cc_library = _cc_library +cc_flags_supplier = _cc_flags_supplier +cc_proto_library = _cc_proto_library +cc_test = _cc_test +cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" +gazelle = _gazelle +go_embed_data = _go_embed_data +go_path = _go_path +gtest = "@com_google_googletest//:gtest" +grpcpp = "@com_github_grpc_grpc//:grpc++" +gbenchmark = "@com_google_benchmark//:benchmark" +loopback = "//tools/bazeldefs:loopback" +pkg_deb = _pkg_deb +pkg_tar = _pkg_tar +py_library = native.py_library +py_binary = native.py_binary +py_test = native.py_test +rbe_platform = native.platform +rbe_toolchain = native.toolchain +vdso_linker_option = "-fuse-ld=gold " + +def proto_library(name, has_services = None, **kwargs): + native.proto_library( + name = name, + **kwargs + ) + +def cc_grpc_library(name, **kwargs): + _cc_grpc_library(name = name, grpc_only = True, **kwargs) + +def _go_proto_or_grpc_library(go_library_func, name, **kwargs): + deps = [ + dep.replace("_proto", "_go_proto") + for dep in (kwargs.pop("deps", []) or []) + ] + go_library_func( + name = name + "_go_proto", + importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name + "_go_proto", + proto = ":" + name + "_proto", + deps = deps, + **kwargs + ) + +def go_proto_library(name, **kwargs): + _go_proto_or_grpc_library(_go_proto_library, name, **kwargs) + +def go_grpc_and_proto_libraries(name, **kwargs): + _go_proto_or_grpc_library(_go_grpc_library, name, **kwargs) + +def cc_binary(name, static = False, **kwargs): + """Run cc_binary. + + Args: + name: name of the target. + static: make a static binary if True + **kwargs: the rest of the args. + """ + if static: + # How to statically link a c++ program that uses threads, like for gRPC: + # https://gcc.gnu.org/legacy-ml/gcc-help/2010-05/msg00029.html + if "linkopts" not in kwargs: + kwargs["linkopts"] = [] + kwargs["linkopts"] += [ + "-static", + "-lstdc++", + "-Wl,--whole-archive", + "-lpthread", + "-Wl,--no-whole-archive", + ] + _cc_binary( + name = name, + **kwargs + ) + +def go_binary(name, static = False, pure = False, **kwargs): + """Build a go binary. + + Args: + name: name of the target. + static: build a static binary. + pure: build without cgo. + **kwargs: rest of the arguments are passed to _go_binary. + """ + if static: + kwargs["static"] = "on" + if pure: + kwargs["pure"] = "on" + _go_binary( + name = name, + **kwargs + ) + +def go_importpath(target): + """Returns the importpath for the target.""" + return target[GoLibrary].importpath + +def go_library(name, **kwargs): + _go_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name(), + **kwargs + ) + +def go_test(name, pure = False, library = None, **kwargs): + """Build a go test. + + Args: + name: name of the output binary. + pure: should it be built without cgo. + library: the library to embed. + **kwargs: rest of the arguments to pass to _go_test. + """ + if pure: + kwargs["pure"] = "on" + if library: + kwargs["embed"] = [library] + _go_test( + name = name, + **kwargs + ) + +def go_rule(rule, implementation, **kwargs): + """Wraps a rule definition with Go attributes. + + Args: + rule: rule function (typically rule or aspect). + implementation: implementation function. + **kwargs: other arguments to pass to rule. + + Returns: + The result of invoking the rule. + """ + attrs = kwargs.pop("attrs", []) + attrs["_go_context_data"] = attr.label(default = "@io_bazel_rules_go//:go_context_data") + attrs["_stdlib"] = attr.label(default = "@io_bazel_rules_go//:stdlib") + toolchains = kwargs.get("toolchains", []) + ["@io_bazel_rules_go//go:toolchain"] + return rule(implementation, attrs = attrs, toolchains = toolchains, **kwargs) + +def go_context(ctx): + go_ctx = _go_context(ctx) + return struct( + go = go_ctx.go, + env = go_ctx.env, + runfiles = depset([go_ctx.go] + go_ctx.sdk.tools + go_ctx.stdlib.libs), + goos = go_ctx.sdk.goos, + goarch = go_ctx.sdk.goarch, + tags = go_ctx.tags, + ) + +def py_requirement(name, direct = True): + return _py_requirement(name) + +def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs): + values = { + "@bazel_tools//src/conditions:linux_x86_64": amd64, + "@bazel_tools//src/conditions:linux_aarch64": arm64, + } + if default: + values["//conditions:default"] = default + return select(values, **kwargs) + +def select_system(linux = ["__linux__"], **kwargs): + return linux # Only Linux supported. + +def default_installer(): + return None + +def default_net_util(): + return [] # Nothing needed. diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl new file mode 100644 index 000000000..165b22311 --- /dev/null +++ b/tools/bazeldefs/platforms.bzl @@ -0,0 +1,9 @@ +"""List of platforms.""" + +# Platform to associated tags. +platforms = { + "ptrace": [], + "kvm": [], +} + +default_platform = "ptrace" diff --git a/tools/bazeldefs/tags.bzl b/tools/bazeldefs/tags.bzl new file mode 100644 index 000000000..f5d7a7b21 --- /dev/null +++ b/tools/bazeldefs/tags.bzl @@ -0,0 +1,56 @@ +"""List of special Go suffixes.""" + +def explode(tagset, suffixes): + """explode combines tagset and suffixes in all ways. + + Args: + tagset: Original suffixes. + suffixes: Suffixes to combine before and after. + + Returns: + The set of possible combinations. + """ + result = [t for t in tagset] + result += [s for s in suffixes] + for t in tagset: + result += [t + s for s in suffixes] + result += [s + t for s in suffixes] + return result + +archs = [ + "_386", + "_aarch64", + "_amd64", + "_arm", + "_arm64", + "_mips", + "_mips64", + "_mips64le", + "_mipsle", + "_ppc64", + "_ppc64le", + "_riscv64", + "_s390x", + "_sparc64", + "_x86", +] + +oses = [ + "_linux", +] + +generic = [ + "_impl", + "_race", + "_norace", + "_unsafe", + "_opts", +] + +# State explosion? Sure. This is approximately: +# len(archs) * (1 + 2 * len(oses) * (1 + 2 * len(generic)) +# +# This evaluates to 495 at the time of writing. So it's a lot of different +# combinations, but not so much that it will cause issues. We can probably add +# quite a few more variants before this becomes a genuine problem. +go_suffixes = explode(explode(archs, oses), generic) diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD new file mode 100644 index 000000000..5748fb390 --- /dev/null +++ b/tools/bigquery/BUILD @@ -0,0 +1,10 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "bigquery", + testonly = 1, + srcs = ["bigquery.go"], + deps = ["@com_google_cloud_go_bigquery//:go_default_library"], +) diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go new file mode 100644 index 000000000..56f0dc5c9 --- /dev/null +++ b/tools/bigquery/bigquery.go @@ -0,0 +1,121 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package bigquery defines a BigQuery schema for benchmarks. +// +// This package contains a schema for BigQuery and methods for publishing +// benchmark data into tables. +package bigquery + +import ( + "context" + "fmt" + "strings" + "time" + + bq "cloud.google.com/go/bigquery" +) + +// Benchmark is the top level structure of recorded benchmark data. BigQuery +// will infer the schema from this. +type Benchmark struct { + Name string `bq:"name"` + Timestamp time.Time `bq:"timestamp"` + Official bool `bq:"official"` + Metric []*Metric `bq:"metric"` + Metadata *Metadata `bq:"metadata"` +} + +// Metric holds the actual metric data and unit information for this benchmark. +type Metric struct { + Name string `bq:"name"` + Unit string `bq:"unit"` + Sample float64 `bq:"sample"` +} + +// Metadata about this benchmark. +type Metadata struct { + CL string `bq:"changelist"` + IterationID string `bq:"iteration_id"` + PendingCL string `bq:"pending_cl"` + Workflow string `bq:"workflow"` + Platform string `bq:"platform"` + Gofer string `bq:"gofer"` +} + +// InitBigQuery initializes a BigQuery dataset/table in the project. If the dataset/table already exists, it is not duplicated. +func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) error { + client, err := bq.NewClient(ctx, projectID) + if err != nil { + return fmt.Errorf("failed to initialize client on project %s: %v", projectID, err) + } + defer client.Close() + + dataset := client.Dataset(datasetID) + if err := dataset.Create(ctx, nil); err != nil && !checkDuplicateError(err) { + return fmt.Errorf("failed to create dataset: %s: %v", datasetID, err) + } + + table := dataset.Table(tableID) + schema, err := bq.InferSchema(Benchmark{}) + if err != nil { + return fmt.Errorf("failed to infer schema: %v", err) + } + + if err := table.Create(ctx, &bq.TableMetadata{Schema: schema}); err != nil && !checkDuplicateError(err) { + return fmt.Errorf("failed to create table: %s: %v", tableID, err) + } + return nil +} + +// AddMetric adds a metric to an existing Benchmark. +func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) { + m := &Metric{ + Name: metricName, + Unit: unit, + Sample: sample, + } + bm.Metric = append(bm.Metric, m) +} + +// NewBenchmark initializes a new benchmark. +func NewBenchmark(name string, official bool) *Benchmark { + return &Benchmark{ + Name: name, + Timestamp: time.Now().UTC(), + Official: official, + Metric: make([]*Metric, 0), + } +} + +// SendBenchmarks sends the slice of benchmarks to the BigQuery dataset/table. +func SendBenchmarks(ctx context.Context, benchmarks []*Benchmark, projectID, datasetID, tableID string) error { + client, err := bq.NewClient(ctx, projectID) + if err != nil { + return fmt.Errorf("Failed to initialize client on project: %s: %v", projectID, err) + } + defer client.Close() + + uploader := client.Dataset(datasetID).Table(tableID).Uploader() + if err = uploader.Put(ctx, benchmarks); err != nil { + return fmt.Errorf("failed to upload benchmarks to proejct %s, table %s.%s: %v", projectID, datasetID, tableID, err) + } + + return nil +} + +// BigQuery will error "409" for duplicate tables and datasets. +func checkDuplicateError(err error) bool { + return strings.Contains(err.Error(), "googleapi: Error 409: Already Exists") +} diff --git a/tools/checkescape/BUILD b/tools/checkescape/BUILD new file mode 100644 index 000000000..b8c3ddf44 --- /dev/null +++ b/tools/checkescape/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "checkescape", + srcs = ["checkescape.go"], + nogo = False, + visibility = ["//tools/nogo:__subpackages__"], + deps = [ + "//tools/nogo/data", + "@org_golang_x_tools//go/analysis:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/buildssa:go_tool_library", + "@org_golang_x_tools//go/ssa:go_tool_library", + ], +) diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go new file mode 100644 index 000000000..f8def4823 --- /dev/null +++ b/tools/checkescape/checkescape.go @@ -0,0 +1,726 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package checkescape allows recursive escape analysis for hot paths. +// +// The analysis tracks multiple types of escapes, in two categories. First, +// 'hard' escapes are explicit allocations. Second, 'soft' escapes are +// interface dispatches or dynamic function dispatches; these don't necessarily +// escape but they *may* escape. The analysis is capable of making assertions +// recursively: soft escapes cannot be analyzed in this way, and therefore +// count as escapes for recursive purposes. +// +// The different types of escapes are as follows, with the category in +// parentheses: +// +// heap: A direct allocation is made on the heap (hard). +// builtin: A call is made to a built-in allocation function (hard). +// stack: A stack split as part of a function preamble (soft). +// interface: A call is made via an interface whicy *may* escape (soft). +// dynamic: A dynamic function is dispatched which *may* escape (soft). +// +// To the use the package, annotate a function-level comment with either the +// line "// +checkescape" or "// +checkescape:OPTION[,OPTION]". In the second +// case, the OPTION field is either a type above, or one of: +// +// local: Escape analysis is limited to local hard escapes only. +// all: All the escapes are included. +// hard: All hard escapes are included. +// +// If the "// +checkescape" annotation is provided, this is equivalent to +// provided the local and hard options. +// +// Some examples of this syntax are: +// +// +checkescape:all - Analyzes for all escapes in this function and all calls. +// +checkescape:local - Analyzes only for default local hard escapes. +// +checkescape:heap - Only analyzes for heap escapes. +// +checkescape:interface,dynamic - Only checks for dynamic calls and interface calls. +// +checkescape - Does the same as +checkescape:local,hard. +// +// Note that all of the above can be inverted by using +mustescape. The +// +checkescape keyword will ensure failure if the class of escape occurs, +// whereas +mustescape will fail if the given class of escape does not occur. +// +// Local exemptions can be made by a comment of the form "// escapes: reason." +// This must appear on the line of the escape and will also apply to callers of +// the function as well (for non-local escape analysis). +package checkescape + +import ( + "bufio" + "bytes" + "fmt" + "go/ast" + "go/token" + "go/types" + "io" + "os" + "path/filepath" + "strconv" + "strings" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/buildssa" + "golang.org/x/tools/go/ssa" + "gvisor.dev/gvisor/tools/nogo/data" +) + +const ( + // magic is the magic annotation. + magic = "// +checkescape" + + // magicParams is the magic annotation with specific parameters. + magicParams = magic + ":" + + // testMagic is the test magic annotation (parameters required). + testMagic = "// +mustescape:" + + // exempt is the exemption annotation. + exempt = "// escapes" +) + +// escapingBuiltins are builtins known to escape. +// +// These are lowered at an earlier stage of compilation to explicit function +// calls, but are not available for recursive analysis. +var escapingBuiltins = []string{ + "append", + "makemap", + "newobject", + "mallocgc", +} + +// Analyzer defines the entrypoint. +var Analyzer = &analysis.Analyzer{ + Name: "checkescape", + Doc: "surfaces recursive escape analysis results", + Run: run, + Requires: []*analysis.Analyzer{buildssa.Analyzer}, + FactTypes: []analysis.Fact{(*packageEscapeFacts)(nil)}, +} + +// packageEscapeFacts is the set of all functions in a package, and whether or +// not they recursively pass escape analysis. +// +// All the type names for receivers are encoded in the full key. The key +// represents the fully qualified package and type name used at link time. +type packageEscapeFacts struct { + Funcs map[string][]Escape +} + +// AFact implements analysis.Fact.AFact. +func (*packageEscapeFacts) AFact() {} + +// CallSite is a single call site. +// +// These can be chained. +type CallSite struct { + LocalPos token.Pos + Resolved LinePosition +} + +// Escape is a single escape instance. +type Escape struct { + Reason EscapeReason + Detail string + Chain []CallSite +} + +// LinePosition is a low-resolution token.Position. +// +// This is used to match against possible exemptions placed in the source. +type LinePosition struct { + Filename string + Line int +} + +// String implements fmt.Stringer.String. +func (e *LinePosition) String() string { + return fmt.Sprintf("%s:%d", e.Filename, e.Line) +} + +// String implements fmt.Stringer.String. +// +// Note that this string will contain new lines. +func (e *Escape) String() string { + var b bytes.Buffer + fmt.Fprintf(&b, "%s", e.Reason.String()) + for i, cs := range e.Chain { + if i == len(e.Chain)-1 { + fmt.Fprintf(&b, "\n @ %s → %s", cs.Resolved.String(), e.Detail) + } else { + fmt.Fprintf(&b, "\n + %s", cs.Resolved.String()) + } + } + return b.String() +} + +// EscapeReason is an escape reason. +// +// This is a simple enum. +type EscapeReason int + +const ( + interfaceInvoke EscapeReason = iota + unknownPackage + allocation + builtin + dynamicCall + stackSplit + reasonCount // Count for below. +) + +// String returns the string for the EscapeReason. +// +// Note that this also implicitly defines the reverse string -> EscapeReason +// mapping, which is the word before the colon (computed below). +func (e EscapeReason) String() string { + switch e { + case interfaceInvoke: + return "interface: function invocation via interface" + case unknownPackage: + return "unknown: no package information available" + case allocation: + return "heap: call to runtime heap allocation" + case builtin: + return "builtin: call to runtime builtin" + case dynamicCall: + return "dynamic: call via dynamic function" + case stackSplit: + return "stack: stack split on function entry" + default: + panic(fmt.Sprintf("unknown reason: %d", e)) + } +} + +var hardReasons = []EscapeReason{ + allocation, + builtin, +} + +var softReasons = []EscapeReason{ + interfaceInvoke, + unknownPackage, + dynamicCall, + stackSplit, +} + +var allReasons = append(hardReasons, softReasons...) + +var escapeTypes = func() map[string]EscapeReason { + result := make(map[string]EscapeReason) + for _, r := range allReasons { + parts := strings.Split(r.String(), ":") + result[parts[0]] = r // Key before ':'. + } + return result +}() + +// EscapeCount counts escapes. +// +// It is used to avoid accumulating too many escapes for the same reason, for +// the same function. We limit each class to 3 instances (arbitrarily). +type EscapeCount struct { + byReason [reasonCount]uint32 +} + +// maxRecordsPerReason is the number of explicit records. +// +// See EscapeCount (and usage), and Record implementation. +const maxRecordsPerReason = 5 + +// Record records the reason or returns false if it should not be added. +func (ec *EscapeCount) Record(reason EscapeReason) bool { + ec.byReason[reason]++ + if ec.byReason[reason] > maxRecordsPerReason { + return false + } + return true +} + +// loadObjdump reads the objdump output. +// +// This records if there is a call any function for every source line. It is +// used only to remove false positives for escape analysis. The call will be +// elided if escape analysis is able to put the object on the heap exclusively. +func loadObjdump() (map[LinePosition]string, error) { + f, err := os.Open(data.Objdump) + if err != nil { + return nil, err + } + defer f.Close() + + // Build the map. + m := make(map[LinePosition]string) + r := bufio.NewReader(f) + var ( + lastField string + lastPos LinePosition + ) + for { + line, err := r.ReadString('\n') + if err != nil && err != io.EOF { + return nil, err + } + + // We recognize lines corresponding to actual code (not the + // symbol name or other metadata) and annotate them if they + // correspond to an explicit CALL instruction. We assume that + // the lack of a CALL for a given line is evidence that escape + // analysis has eliminated an allocation. + // + // Lines look like this (including the first space): + // gohacks_unsafe.go:33 0xa39 488b442408 MOVQ 0x8(SP), AX + if len(line) > 0 && line[0] == ' ' { + fields := strings.Fields(line) + if !strings.Contains(fields[3], "CALL") { + continue + } + + // Ignore strings containing duffzero, which is just + // used by stack allocations for types that are large + // enough to warrant Duff's device. + if strings.Contains(line, "runtime.duffzero") { + continue + } + + // Ignore the racefuncenter call, which is used for + // race builds. This does not escape. + if strings.Contains(line, "runtime.racefuncenter") { + continue + } + + // Calculate the filename and line. Note that per the + // example above, the filename is not a fully qualified + // base, just the basename (what we require). + if fields[0] != lastField { + parts := strings.SplitN(fields[0], ":", 2) + lineNum, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return nil, err + } + lastPos = LinePosition{ + Filename: parts[0], + Line: int(lineNum), + } + lastField = fields[0] + } + if _, ok := m[lastPos]; ok { + continue // Already marked. + } + + // Save the actual call for the detail. + m[lastPos] = strings.Join(fields[3:], " ") + } + if err == io.EOF { + break + } + } + + return m, nil +} + +// poser is a type that implements Pos. +type poser interface { + Pos() token.Pos +} + +// run performs the analysis. +func run(pass *analysis.Pass) (interface{}, error) { + calls, err := loadObjdump() + if err != nil { + return nil, err + } + pef := packageEscapeFacts{ + Funcs: make(map[string][]Escape), + } + linePosition := func(inst, parent poser) LinePosition { + p := pass.Fset.Position(inst.Pos()) + if (p.Filename == "" || p.Line == 0) && parent != nil { + p = pass.Fset.Position(parent.Pos()) + } + return LinePosition{ + Filename: filepath.Base(p.Filename), + Line: p.Line, + } + } + hasCall := func(inst poser) (string, bool) { + p := linePosition(inst, nil) + s, ok := calls[p] + return s, ok + } + callSite := func(inst ssa.Instruction) CallSite { + return CallSite{ + LocalPos: inst.Pos(), + Resolved: linePosition(inst, inst.Parent()), + } + } + escapes := func(reason EscapeReason, detail string, inst ssa.Instruction, ec *EscapeCount) []Escape { + if !ec.Record(reason) { + return nil // Skip. + } + es := Escape{ + Reason: reason, + Detail: detail, + Chain: []CallSite{callSite(inst)}, + } + return []Escape{es} + } + resolve := func(sub []Escape, inst ssa.Instruction, ec *EscapeCount) (es []Escape) { + for _, e := range sub { + if !ec.Record(e.Reason) { + continue // Skip. + } + es = append(es, Escape{ + Reason: e.Reason, + Detail: e.Detail, + Chain: append([]CallSite{callSite(inst)}, e.Chain...), + }) + } + return es + } + state := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) + + var loadFunc func(*ssa.Function) []Escape // Used below. + + analyzeInstruction := func(inst ssa.Instruction, ec *EscapeCount) []Escape { + switch x := inst.(type) { + case *ssa.Call: + if x.Call.IsInvoke() { + // This is an interface dispatch. There is no + // way to know if this is actually escaping or + // not, since we don't know the underlying + // type. + call, _ := hasCall(inst) + return escapes(interfaceInvoke, call, inst, ec) + } + switch x := x.Call.Value.(type) { + case *ssa.Function: + if x.Pkg == nil { + // Can't resolve the package. + return escapes(unknownPackage, "no package", inst, ec) + } + + // Atomic functions are instrinics. We can + // assume that they don't escape. + if x.Pkg.Pkg.Name() == "atomic" { + return nil + } + + // Is this a local function? If yes, call the + // function to load the local function. The + // local escapes are the escapes found in the + // local function. + if x.Pkg.Pkg == pass.Pkg { + return resolve(loadFunc(x), inst, ec) + } + + // Recursively collect information from + // the other analyzers. + var imp packageEscapeFacts + if !pass.ImportPackageFact(x.Pkg.Pkg, &imp) { + // Unable to import the dependency; we must + // declare these as escaping. + return escapes(unknownPackage, "no analysis", inst, ec) + } + + // The escapes of this instruction are the + // escapes of the called function directly. + return resolve(imp.Funcs[x.RelString(x.Pkg.Pkg)], inst, ec) + case *ssa.Builtin: + // Ignore elided escapes. + if _, has := hasCall(inst); !has { + return nil + } + + // Check if the builtin is escaping. + for _, name := range escapingBuiltins { + if x.Name() == name { + return escapes(builtin, name, inst, ec) + } + } + default: + // All dynamic calls are counted as soft + // escapes. They are similar to interface + // dispatches. We cannot actually look up what + // this refers to using static analysis alone. + call, _ := hasCall(inst) + return escapes(dynamicCall, call, inst, ec) + } + case *ssa.Alloc: + // Ignore non-heap allocations. + if !x.Heap { + return nil + } + + // Ignore elided escapes. + call, has := hasCall(inst) + if !has { + return nil + } + + // This is a real heap allocation. + return escapes(allocation, call, inst, ec) + case *ssa.MakeMap: + return escapes(builtin, "makemap", inst, ec) + case *ssa.MakeSlice: + return escapes(builtin, "makeslice", inst, ec) + case *ssa.MakeClosure: + return escapes(builtin, "makeclosure", inst, ec) + case *ssa.MakeChan: + return escapes(builtin, "makechan", inst, ec) + } + return nil // No escapes. + } + + var analyzeBasicBlock func(*ssa.BasicBlock, *EscapeCount) []Escape // Recursive. + analyzeBasicBlock = func(block *ssa.BasicBlock, ec *EscapeCount) (rval []Escape) { + for _, inst := range block.Instrs { + rval = append(rval, analyzeInstruction(inst, ec)...) + } + return rval // N.B. may be empty. + } + + loadFunc = func(fn *ssa.Function) []Escape { + // Is this already available? + name := fn.RelString(pass.Pkg) + if es, ok := pef.Funcs[name]; ok { + return es + } + + // In the case of a true cycle, we assume that the current + // function itself has no escapes until the rest of the + // analysis is complete. This will trip the above in the case + // of a cycle of any kind. + pef.Funcs[name] = nil + + // Perform the basic analysis. + var ( + es []Escape + ec EscapeCount + ) + if fn.Recover != nil { + es = append(es, analyzeBasicBlock(fn.Recover, &ec)...) + } + for _, block := range fn.Blocks { + es = append(es, analyzeBasicBlock(block, &ec)...) + } + + // Check for a stack split. + if call, has := hasCall(fn); has { + es = append(es, Escape{ + Reason: stackSplit, + Detail: call, + Chain: []CallSite{CallSite{ + LocalPos: fn.Pos(), + Resolved: linePosition(fn, fn.Parent()), + }}, + }) + } + + // Save the result and return. + pef.Funcs[name] = es + return es + } + + // Complete all local functions. + for _, fn := range state.SrcFuncs { + loadFunc(fn) + } + + // Build the exception list. + exemptions := make(map[LinePosition]string) + for _, f := range pass.Files { + for _, cg := range f.Comments { + for _, c := range cg.List { + p := pass.Fset.Position(c.Slash) + if strings.HasPrefix(strings.ToLower(c.Text), exempt) { + exemptions[LinePosition{ + Filename: filepath.Base(p.Filename), + Line: p.Line, + }] = c.Text[len(exempt):] + } + } + } + } + + // Delete everything matching the excemtions. + // + // This has the implication that exceptions are applied recursively, + // since this now modified set is what will be saved. + for name, escapes := range pef.Funcs { + var newEscapes []Escape + for _, escape := range escapes { + isExempt := false + for line, _ := range exemptions { + // Note that an exemption applies if it is + // marked as an exemption anywhere in the call + // chain. It need not be marked as escapes in + // the function itself, nor in the top-level + // caller. + for _, callSite := range escape.Chain { + if callSite.Resolved == line { + isExempt = true + break + } + } + if isExempt { + break + } + } + if !isExempt { + // Record this escape; not an exception. + newEscapes = append(newEscapes, escape) + } + } + pef.Funcs[name] = newEscapes // Update. + } + + // Export all findings for future packages. + pass.ExportPackageFact(&pef) + + // Scan all functions for violations. + for _, f := range pass.Files { + // Scan all declarations. + for _, decl := range f.Decls { + fdecl, ok := decl.(*ast.FuncDecl) + // Function declaration? + if !ok { + continue + } + // Is there a comment? + if fdecl.Doc == nil { + continue + } + var ( + reasons []EscapeReason + found bool + local bool + testReasons = make(map[EscapeReason]bool) // reason -> local? + ) + // Does the comment contain a +checkescape line? + for _, c := range fdecl.Doc.List { + if !strings.HasPrefix(c.Text, magic) && !strings.HasPrefix(c.Text, testMagic) { + continue + } + if c.Text == magic { + // Default: hard reasons, local only. + reasons = hardReasons + local = true + } else if strings.HasPrefix(c.Text, magicParams) { + // Extract specific reasons. + types := strings.Split(c.Text[len(magicParams):], ",") + found = true // For below. + for i := 0; i < len(types); i++ { + if types[i] == "local" { + // Limit search to local escapes. + local = true + } else if types[i] == "all" { + // Append all reasons. + reasons = append(reasons, allReasons...) + } else if types[i] == "hard" { + // Append all hard reasons. + reasons = append(reasons, hardReasons...) + } else { + r, ok := escapeTypes[types[i]] + if !ok { + // This is not a valid escape reason. + pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i]) + continue + } + reasons = append(reasons, r) + } + } + } else if strings.HasPrefix(c.Text, testMagic) { + types := strings.Split(c.Text[len(testMagic):], ",") + local := false + for i := 0; i < len(types); i++ { + if types[i] == "local" { + local = true + } else { + r, ok := escapeTypes[types[i]] + if !ok { + // This is not a valid escape reason. + pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i]) + continue + } + if v, ok := testReasons[r]; ok && v { + // Already registered as local. + continue + } + testReasons[r] = local + } + } + } + } + if len(reasons) == 0 && found { + // A magic annotation was provided, but no reasons. + pass.Reportf(fdecl.Pos(), "no reasons provided") + continue + } + + // Scan for matches. + fn := pass.TypesInfo.Defs[fdecl.Name].(*types.Func) + name := state.Pkg.Prog.FuncValue(fn).RelString(pass.Pkg) + es, ok := pef.Funcs[name] + if !ok { + pass.Reportf(fdecl.Pos(), "internal error: function %s not found.", name) + continue + } + for _, e := range es { + for _, r := range reasons { + // Is does meet our local requirement? + if local && len(e.Chain) > 1 { + continue + } + // Does this match the reason? Emit + // with a full stack trace that + // explains why this violates our + // constraints. + if e.Reason == r { + pass.Reportf(e.Chain[0].LocalPos, "%s", e.String()) + } + } + } + + // Scan for test (required) matches. + testReasonsFound := make(map[EscapeReason]bool) + for _, e := range es { + // Is this local? + local, ok := testReasons[e.Reason] + wantLocal := len(e.Chain) == 1 + testReasonsFound[e.Reason] = wantLocal + if !ok { + continue + } + if local == wantLocal { + delete(testReasons, e.Reason) + } + } + for reason, local := range testReasons { + // We didn't find the escapes we wanted. + pass.Reportf(fdecl.Pos(), fmt.Sprintf("testescapes not found: reason=%s, local=%t", reason, local)) + } + if len(testReasons) > 0 { + // Dump all reasons found to help in debugging. + for _, e := range es { + pass.Reportf(e.Chain[0].LocalPos, "escape found: %s", e.String()) + } + } + } + } + + return nil, nil +} diff --git a/tools/checkescape/test1/BUILD b/tools/checkescape/test1/BUILD new file mode 100644 index 000000000..783403247 --- /dev/null +++ b/tools/checkescape/test1/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "test1", + srcs = ["test1.go"], + visibility = ["//tools/checkescape/test2:__pkg__"], +) diff --git a/tools/checkescape/test1/test1.go b/tools/checkescape/test1/test1.go new file mode 100644 index 000000000..68d3f72cc --- /dev/null +++ b/tools/checkescape/test1/test1.go @@ -0,0 +1,195 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package test1 is a test package. +package test1 + +import ( + "fmt" + "reflect" +) + +// Interface is a generic interface. +type Interface interface { + Foo() +} + +// Type is a concrete implementation of Interface. +type Type struct { + A uint64 + B uint64 +} + +// Foo implements Interface.Foo. +//go:nosplit +func (t Type) Foo() { + fmt.Printf("%v", t) // Never executed. +} + +// +checkescape:all,hard +//go:nosplit +func InterfaceFunction(i Interface) { + // Do nothing; exported for tests. +} + +// +checkesacape:all,hard +//go:nosplit +func TypeFunction(t *Type) { +} + +// +mustescape:local,builtin +//go:noinline +//go:nosplit +func BuiltinMap(x int) map[string]bool { + return make(map[string]bool) +} + +// +mustescape:builtin +//go:noinline +//go:nosplit +func builtinMapRec(x int) map[string]bool { + return BuiltinMap(x) +} + +// +temustescapestescape:local,builtin +//go:noinline +//go:nosplit +func BuiltinClosure(x int) func() { + return func() { + fmt.Printf("%v", x) + } +} + +// +mustescape:builtin +//go:noinline +//go:nosplit +func builtinClosureRec(x int) func() { + return BuiltinClosure(x) +} + +// +mustescape:local,builtin +//go:noinline +//go:nosplit +func BuiltinMakeSlice(x int) []byte { + return make([]byte, x) +} + +// +mustescape:builtin +//go:noinline +//go:nosplit +func builtinMakeSliceRec(x int) []byte { + return BuiltinMakeSlice(x) +} + +// +mustescape:local,builtin +//go:noinline +//go:nosplit +func BuiltinAppend(x []byte) []byte { + return append(x, 0) +} + +// +mustescape:builtin +//go:noinline +//go:nosplit +func builtinAppendRec() []byte { + return BuiltinAppend(nil) +} + +// +mustescape:local,builtin +//go:noinline +//go:nosplit +func BuiltinChan() chan int { + return make(chan int) +} + +// +mustescape:builtin +//go:noinline +//go:nosplit +func builtinChanRec() chan int { + return BuiltinChan() +} + +// +mustescape:local,heap +//go:noinline +//go:nosplit +func Heap() *Type { + var t Type + return &t +} + +// +mustescape:heap +//go:noinline +//go:nosplit +func heapRec() *Type { + return Heap() +} + +// +mustescape:local,interface +//go:noinline +//go:nosplit +func Dispatch(i Interface) { + i.Foo() +} + +// +mustescape:interface +//go:noinline +//go:nosplit +func dispatchRec(i Interface) { + Dispatch(i) +} + +// +mustescape:local,dynamic +//go:noinline +//go:nosplit +func Dynamic(f func()) { + f() +} + +// +mustescape:dynamic +//go:noinline +//go:nosplit +func dynamicRec(f func()) { + Dynamic(f) +} + +// +mustescape:local,unknown +//go:noinline +//go:nosplit +func Unknown() { + _ = reflect.TypeOf((*Type)(nil)) // Does not actually escape. +} + +// +mustescape:unknown +//go:noinline +//go:nosplit +func unknownRec() { + Unknown() +} + +//go:noinline +//go:nosplit +func internalFunc() { +} + +// +mustescape:local,stack +//go:noinline +func Split() { + internalFunc() +} + +// +mustescape:stack +//go:noinline +func splitRec() { + Split() +} diff --git a/tools/checkescape/test2/BUILD b/tools/checkescape/test2/BUILD new file mode 100644 index 000000000..5a11e4b43 --- /dev/null +++ b/tools/checkescape/test2/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "test2", + srcs = ["test2.go"], + deps = ["//tools/checkescape/test1"], +) diff --git a/tools/checkescape/test2/test2.go b/tools/checkescape/test2/test2.go new file mode 100644 index 000000000..7fce3e3be --- /dev/null +++ b/tools/checkescape/test2/test2.go @@ -0,0 +1,94 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package test2 is a test package that imports test1. +package test2 + +import ( + "gvisor.dev/gvisor/tools/checkescape/test1" +) + +// +checkescape:all +//go:nosplit +func interfaceFunctionCrossPkg() { + var i test1.Interface + test1.InterfaceFunction(i) +} + +// +checkesacape:all +//go:nosplit +func typeFunctionCrossPkg() { + var t test1.Type + test1.TypeFunction(&t) +} + +// +mustescape:builtin +//go:noinline +func builtinMapCrossPkg(x int) map[string]bool { + return test1.BuiltinMap(x) +} + +// +mustescape:builtin +//go:noinline +func builtinClosureCrossPkg(x int) func() { + return test1.BuiltinClosure(x) +} + +// +mustescape:builtin +//go:noinline +func builtinMakeSliceCrossPkg(x int) []byte { + return test1.BuiltinMakeSlice(x) +} + +// +mustescape:builtin +//go:noinline +func builtinAppendCrossPkg() []byte { + return test1.BuiltinAppend(nil) +} + +// +mustescape:builtin +//go:noinline +func builtinChanCrossPkg() chan int { + return test1.BuiltinChan() +} + +// +mustescape:heap +//go:noinline +func heapCrossPkg() *test1.Type { + return test1.Heap() +} + +// +mustescape:interface +//go:noinline +func dispatchCrossPkg(i test1.Interface) { + test1.Dispatch(i) +} + +// +mustescape:dynamic +//go:noinline +func dynamicCrossPkg(f func()) { + test1.Dynamic(f) +} + +// +mustescape:unknown +//go:noinline +func unknownCrossPkg() { + test1.Unknown() +} + +// +mustescape:stack +//go:noinline +func splitCrosssPkt() { + test1.Split() +} diff --git a/tools/checkunsafe/BUILD b/tools/checkunsafe/BUILD new file mode 100644 index 000000000..0c264151b --- /dev/null +++ b/tools/checkunsafe/BUILD @@ -0,0 +1,13 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "checkunsafe", + srcs = ["check_unsafe.go"], + nogo = False, + visibility = ["//tools/nogo:__subpackages__"], + deps = [ + "@org_golang_x_tools//go/analysis:go_tool_library", + ], +) diff --git a/tools/checkunsafe/check_unsafe.go b/tools/checkunsafe/check_unsafe.go new file mode 100644 index 000000000..4ccd7cc5a --- /dev/null +++ b/tools/checkunsafe/check_unsafe.go @@ -0,0 +1,56 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package checkunsafe allows unsafe imports only in files named appropriately. +package checkunsafe + +import ( + "fmt" + "path" + "strconv" + "strings" + + "golang.org/x/tools/go/analysis" +) + +// Analyzer defines the entrypoint. +var Analyzer = &analysis.Analyzer{ + Name: "checkunsafe", + Doc: "allows unsafe use only in specified files", + Run: run, +} + +func run(pass *analysis.Pass) (interface{}, error) { + for _, f := range pass.Files { + for _, imp := range f.Imports { + // Is this an unsafe import? + pkg, err := strconv.Unquote(imp.Path.Value) + if err != nil || pkg != "unsafe" { + continue + } + + // Extract the filename. + filename := pass.Fset.File(imp.Pos()).Name() + + // Allow files named _unsafe.go or _test.go to opt out. + if strings.HasSuffix(filename, "_unsafe.go") || strings.HasSuffix(filename, "_test.go") { + continue + } + + // Throw the error. + pass.Reportf(imp.Pos(), fmt.Sprintf("package unsafe imported by %s; must end with _unsafe.go", path.Base(filename))) + } + } + return nil, nil +} diff --git a/tools/defs.bzl b/tools/defs.bzl new file mode 100644 index 000000000..40afcdb79 --- /dev/null +++ b/tools/defs.bzl @@ -0,0 +1,254 @@ +"""Wrappers for common build rules. + +These wrappers apply common BUILD configurations (e.g., proto_library +automagically creating cc_ and go_ proto targets) and act as a single point of +change for Google-internal and bazel-compatible rules. +""" + +load("//tools/go_stateify:defs.bzl", "go_stateify") +load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") +load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _vdso_linker_option = "vdso_linker_option") +load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") +load("//tools/bazeldefs:tags.bzl", "go_suffixes") +load("//tools/nogo:defs.bzl", "nogo_test") + +# Delegate directly. +build_test = _build_test +cc_binary = _cc_binary +cc_flags_supplier = _cc_flags_supplier +cc_grpc_library = _cc_grpc_library +cc_library = _cc_library +cc_test = _cc_test +cc_toolchain = _cc_toolchain +default_installer = _default_installer +default_net_util = _default_net_util +gbenchmark = _gbenchmark +gazelle = _gazelle +go_embed_data = _go_embed_data +go_path = _go_path +go_test = _go_test +gtest = _gtest +grpcpp = _grpcpp +loopback = _loopback +pkg_deb = _pkg_deb +pkg_tar = _pkg_tar +py_binary = _py_binary +py_library = _py_library +py_requirement = _py_requirement +py_test = _py_test +select_arch = _select_arch +select_system = _select_system +rbe_platform = _rbe_platform +rbe_toolchain = _rbe_toolchain +vdso_linker_option = _vdso_linker_option + +# Platform options. +default_platform = _default_platform +platforms = _platforms + +def go_binary(name, **kwargs): + """Wraps the standard go_binary. + + Args: + name: the rule name. + **kwargs: standard go_binary arguments. + """ + _go_binary( + name = name, + **kwargs + ) + +def calculate_sets(srcs): + """Calculates special Go sets for templates. + + Args: + srcs: the full set of Go sources. + + Returns: + A dictionary of the form: + + "": [src1.go, src2.go] + "suffix": [src3suffix.go, src4suffix.go] + + Note that suffix will typically start with '_'. + """ + result = dict() + for file in srcs: + if not file.endswith(".go"): + continue + target = "" + for suffix in go_suffixes: + if file.endswith(suffix + ".go"): + target = suffix + if not target in result: + result[target] = [file] + else: + result[target].append(file) + return result + +def go_imports(name, src, out): + """Simplify a single Go source file by eliminating unused imports.""" + native.genrule( + name = name, + srcs = [src], + outs = [out], + tools = ["@org_golang_x_tools//cmd/goimports:goimports"], + cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"), + ) + +def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, nogo = True, **kwargs): + """Wraps the standard go_library and does stateification and marshalling. + + The recommended way is to use this rule with mostly identical configuration as the native + go_library rule. + + These definitions provide additional flags (stateify, marshal) that can be used + with the generators to automatically supplement the library code. + + load("//tools:defs.bzl", "go_library") + + go_library( + name = "foo", + srcs = ["foo.go"], + ) + + Args: + name: the rule name. + srcs: the library sources. + deps: the library dependencies. + imports: imports required for stateify. + stateify: whether statify is enabled (default: true). + marshal: whether marshal is enabled (default: false). + marshal_debug: whether the gomarshal tools emits debugging output (default: false). + **kwargs: standard go_library arguments. + """ + all_srcs = srcs + all_deps = deps + dirname, _, _ = native.package_name().rpartition("/") + full_pkg = dirname + "/" + name + if stateify: + # Only do stateification for non-state packages without manual autogen. + # First, we need to segregate the input files via the special suffixes, + # and calculate the final output set. + state_sets = calculate_sets(srcs) + for (suffix, src_subset) in state_sets.items(): + go_stateify( + name = name + suffix + "_state_autogen_with_imports", + srcs = src_subset, + imports = imports, + package = full_pkg, + out = name + suffix + "_state_autogen_with_imports.go", + ) + go_imports( + name = name + suffix + "_state_autogen", + src = name + suffix + "_state_autogen_with_imports.go", + out = name + suffix + "_state_autogen.go", + ) + all_srcs = all_srcs + [ + name + suffix + "_state_autogen.go" + for suffix in state_sets.keys() + ] + if "//pkg/state" not in all_deps: + all_deps = all_deps + ["//pkg/state"] + + if marshal: + # See above. + marshal_sets = calculate_sets(srcs) + for (suffix, src_subset) in marshal_sets.items(): + go_marshal( + name = name + suffix + "_abi_autogen", + srcs = src_subset, + debug = select({ + "//tools/go_marshal:marshal_config_verbose": True, + "//conditions:default": marshal_debug, + }), + imports = imports, + package = name, + ) + extra_deps = [ + dep + for dep in marshal_deps + if not dep in all_deps + ] + all_deps = all_deps + extra_deps + all_srcs = all_srcs + [ + name + suffix + "_abi_autogen_unsafe.go" + for suffix in marshal_sets.keys() + ] + + _go_library( + name = name, + srcs = all_srcs, + deps = all_deps, + **kwargs + ) + if nogo: + nogo_test( + name = name + "_nogo", + deps = [":" + name], + ) + + if marshal: + # Ignore importpath for go_test. + kwargs.pop("importpath", None) + + # See above. + marshal_sets = calculate_sets(srcs) + for (suffix, _) in marshal_sets.items(): + _go_test( + name = name + suffix + "_abi_autogen_test", + srcs = [name + suffix + "_abi_autogen_test.go"], + library = ":" + name, + deps = marshal_test_deps, + **kwargs + ) + +def proto_library(name, srcs, deps = None, has_services = 0, **kwargs): + """Wraps the standard proto_library. + + Given a proto_library named "foo", this produces up to five different + targets: + - foo_proto: proto_library rule. + - foo_go_proto: go_proto_library rule. + - foo_cc_proto: cc_proto_library rule. + - foo_go_grpc_proto: go_grpc_library rule. + - foo_cc_grpc_proto: cc_grpc_library rule. + + Args: + name: the name to which _proto, _go_proto, etc, will be appended. + srcs: the proto sources. + deps: for the proto library and the go_proto_library. + has_services: 1 to build gRPC code, otherwise 0. + **kwargs: standard proto_library arguments. + """ + _proto_library( + name = name + "_proto", + srcs = srcs, + deps = deps, + has_services = has_services, + **kwargs + ) + if has_services: + _go_grpc_and_proto_libraries( + name = name, + deps = deps, + **kwargs + ) + else: + _go_proto_library( + name = name, + deps = deps, + **kwargs + ) + _cc_proto_library( + name = name + "_cc_proto", + deps = [":" + name + "_proto"], + **kwargs + ) + if has_services: + _cc_grpc_library( + name = name + "_cc_grpc_proto", + srcs = [":" + name + "_proto"], + deps = [":" + name + "_cc_proto"], + **kwargs + ) diff --git a/tools/go_branch.sh b/tools/go_branch.sh new file mode 100755 index 000000000..093de89b4 --- /dev/null +++ b/tools/go_branch.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeo pipefail + +# Discovery the package name from the go.mod file. +declare -r module=$(cat go.mod | grep -E "^module" | cut -d' ' -f2) +declare -r origpwd=$(pwd) +declare -r othersrc=("go.mod" "go.sum" "AUTHORS" "LICENSE") + +# Check that gopath has been built. +declare -r gopath_dir="$(pwd)/bazel-bin/gopath/src/${module}" +if ! [ -d "${gopath_dir}" ]; then + echo "No gopath directory found; build the :gopath target." >&2 + exit 1 +fi + +# Create a temporary working directory, and ensure that this directory and all +# subdirectories are cleaned up upon exit. +declare -r tmp_dir=$(mktemp -d) +finish() { + cd # Leave tmp_dir. + rm -rf "${tmp_dir}" +} +trap finish EXIT + +# Record the current working commit. +declare -r head=$(git describe --always) + +# We expect to have an existing go branch that we will use as the basis for +# this commit. That branch may be empty, but it must exist. +git fetch --all +declare -r go_branch=$(git show-ref --hash go) + +# Clone the current repository to the temporary directory, and check out the +# current go_branch directory. We move to the new repository for convenience. +declare -r repo_orig="$(pwd)" +declare -r repo_new="${tmp_dir}/repository" +git clone . "${repo_new}" +cd "${repo_new}" + +# Setup the repository and checkout the branch. +git config user.email "gvisor-bot@google.com" +git config user.name "gVisor bot" +git fetch origin "${go_branch}" +git checkout -b go "${go_branch}" + +# Start working on a merge commit that combines the previous history with the +# current history. Note that we don't actually want any changes yet. +# +# N.B. The git behavior changed at some point and the relevant flag was added +# to allow for override, so try the only behavior first then pass the flag. +git merge --no-commit --strategy ours ${head} || \ + git merge --allow-unrelated-histories --no-commit --strategy ours ${head} + +# Sync the entire gopath_dir. +rsync --recursive --verbose --delete --exclude .git -L "${gopath_dir}/" . + +# Add additional files. +for file in "${othersrc[@]}"; do + cp "${origpwd}"/"${file}" . +done + +# Construct a new README.md. +cat > README.md <<EOF +# gVisor + +This branch is a synthetic branch, containing only Go sources, that is +compatible with standard Go tools. See the master branch for authoritative +sources and tests. +EOF + +# There are a few solitary files that can get left behind due to the way bazel +# constructs the gopath target. Note that we don't find all Go files here +# because they may correspond to unused templates, etc. +cp "${repo_orig}"/runsc/*.go runsc/ + +# Normalize all permissions. The way bazel constructs the :gopath tree may leave +# some strange permissions on files. We don't have anything in this tree that +# should be execution, only the Go source files, README.md, and ${othersrc}. +find . -type f -exec chmod 0644 {} \; +find . -type d -exec chmod 0755 {} \; + +# Update the current working set and commit. +git add . && git commit -m "Merge ${head} (automated)" + +# Push the branch back to the original repository. +git remote add orig "${repo_orig}" && git push -f orig go:go diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD new file mode 100644 index 000000000..32a949c93 --- /dev/null +++ b/tools/go_generics/BUILD @@ -0,0 +1,38 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "go_generics", + srcs = [ + "generics.go", + "imports.go", + "remove.go", + ], + visibility = ["//:sandbox"], + deps = ["//tools/go_generics/globals"], +) + +genrule( + name = "go_generics_tests", + srcs = glob(["generics_tests/**"]) + [":go_generics"], + outs = ["go_generics_tests.tgz"], + cmd = "tar -czvhf $@ $(SRCS)", +) + +genrule( + name = "go_generics_test_bundle", + srcs = [ + ":go_generics_tests.tgz", + ":go_generics_unittest.sh", + ], + outs = ["go_generics_test.sh"], + cmd = "cat $(location :go_generics_unittest.sh) $(location :go_generics_tests.tgz) > $@", + executable = True, +) + +sh_test( + name = "go_generics_test", + size = "small", + srcs = ["go_generics_test.sh"], +) diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl new file mode 100644 index 000000000..8c9995fd4 --- /dev/null +++ b/tools/go_generics/defs.bzl @@ -0,0 +1,139 @@ +def _go_template_impl(ctx): + input = ctx.files.srcs + output = ctx.outputs.out + + args = ["-o=%s" % output.path] + [f.path for f in input] + + ctx.actions.run( + inputs = input, + outputs = [output], + mnemonic = "GoGenericsTemplate", + progress_message = "Building Go template %s" % ctx.label, + arguments = args, + executable = ctx.executable._tool, + ) + + return struct( + types = ctx.attr.types, + opt_types = ctx.attr.opt_types, + consts = ctx.attr.consts, + opt_consts = ctx.attr.opt_consts, + deps = ctx.attr.deps, + file = output, + ) + +""" +Generates a Go template from a set of Go files. + +A Go template is similar to a go library, except that it has certain types that +can be replaced before usage. For example, one could define a templatized List +struct, whose elements are of type T, then instantiate that template for +T=segment, where "segment" is the concrete type. + +Args: + name: the name of the template. + srcs: the list of source files that comprise the template. + types: the list of generic types in the template that are required to be specified. + opt_types: the list of generic types in the template that can but aren't required to be specified. + consts: the list of constants in the template that are required to be specified. + opt_consts: the list of constants in the template that can but aren't required to be specified. + deps: the list of dependencies. +""" +go_template = rule( + implementation = _go_template_impl, + attrs = { + "srcs": attr.label_list(mandatory = True, allow_files = True), + "deps": attr.label_list(allow_files = True), + "types": attr.string_list(), + "opt_types": attr.string_list(), + "consts": attr.string_list(), + "opt_consts": attr.string_list(), + "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics/go_merge")), + }, + outputs = { + "out": "%{name}_template.go", + }, +) + +def _go_template_instance_impl(ctx): + template = ctx.attr.template + output = ctx.outputs.out + + # Check that all required types are defined. + for t in template.types: + if t not in ctx.attr.types: + fail("Missing value for type %s in %s" % (t, ctx.attr.template.label)) + + # Check that all defined types are expected by the template. + for t in ctx.attr.types: + if (t not in template.types) and (t not in template.opt_types): + fail("Type %s it not a parameter to %s" % (t, ctx.attr.template.label)) + + # Check that all required consts are defined. + for t in template.consts: + if t not in ctx.attr.consts: + fail("Missing value for constant %s in %s" % (t, ctx.attr.template.label)) + + # Check that all defined consts are expected by the template. + for t in ctx.attr.consts: + if (t not in template.consts) and (t not in template.opt_consts): + fail("Const %s it not a parameter to %s" % (t, ctx.attr.template.label)) + + # Build the argument list. + args = ["-i=%s" % template.file.path, "-o=%s" % output.path] + args += ["-p=%s" % ctx.attr.package] + + if len(ctx.attr.prefix) > 0: + args += ["-prefix=%s" % ctx.attr.prefix] + + if len(ctx.attr.suffix) > 0: + args += ["-suffix=%s" % ctx.attr.suffix] + + args += [("-t=%s=%s" % (p[0], p[1])) for p in ctx.attr.types.items()] + args += [("-c=%s=%s" % (p[0], p[1])) for p in ctx.attr.consts.items()] + args += [("-import=%s=%s" % (p[0], p[1])) for p in ctx.attr.imports.items()] + + if ctx.attr.anon: + args += ["-anon"] + + ctx.actions.run( + inputs = [template.file], + outputs = [output], + mnemonic = "GoGenericsInstance", + progress_message = "Building Go template instance %s" % ctx.label, + arguments = args, + executable = ctx.executable._tool, + ) + + return struct( + files = depset([output]), + ) + +""" +Instantiates a Go template by replacing all generic types with concrete ones. + +Args: + name: the name of the template instance. + template: the label of the template to be instatiated. + prefix: a prefix to be added to globals in the template. + suffix: a suffix to be added to global in the template. + types: the map from generic type names to concrete ones. + consts: the map from constant names to their values. + imports: the map from imports used in types/consts to their import paths. + package: the name of the package the instantiated template will be compiled into. +""" +go_template_instance = rule( + implementation = _go_template_instance_impl, + attrs = { + "template": attr.label(mandatory = True, providers = ["types"]), + "prefix": attr.string(), + "suffix": attr.string(), + "types": attr.string_dict(), + "consts": attr.string_dict(), + "imports": attr.string_dict(), + "anon": attr.bool(mandatory = False, default = False), + "package": attr.string(mandatory = True), + "out": attr.output(mandatory = True), + "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics")), + }, +) diff --git a/tools/go_generics/generics.go b/tools/go_generics/generics.go new file mode 100644 index 000000000..0860ca9db --- /dev/null +++ b/tools/go_generics/generics.go @@ -0,0 +1,286 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// go_generics reads a Go source file and writes a new version of that file with +// a few transformations applied to each. Namely: +// +// 1. Global types can be explicitly renamed with the -t option. For example, +// if -t=A=B is passed in, all references to A will be replaced with +// references to B; a function declaration like: +// +// func f(arg *A) +// +// would be renamed to: +// +// func f(arg *B) +// +// 2. Global type definitions and their method sets will be removed when they're +// being renamed with -t. For example, if -t=A=B is passed in, the following +// definition and methods that existed in the input file wouldn't exist at +// all in the output file: +// +// type A struct{} +// +// func (*A) f() {} +// +// 3. All global types, variables, constants and functions (not methods) are +// prefixed and suffixed based on the option -prefix and -suffix arguments. +// For example, if -suffix=A is passed in, the following globals: +// +// func f() +// type t struct{} +// +// would be renamed to: +// +// func fA() +// type tA struct{} +// +// Some special tags are also modified. For example: +// +// "state:.(t)" +// +// would become: +// +// "state:.(tA)" +// +// 4. The package is renamed to the value via the -p argument. +// 5. Value of constants can be modified with -c argument. +// +// Note that not just the top-level declarations are renamed, all references to +// them are also properly renamed as well, taking into account visibility rules +// and shadowing. For example, if -suffix=A is passed in, the following: +// +// var b = 100 +// +// func f() { +// g(b) +// b := 0 +// g(b) +// } +// +// Would be replaced with: +// +// var bA = 100 +// +// func f() { +// g(bA) +// b := 0 +// g(b) +// } +// +// Note that the second call to g() kept "b" as an argument because it refers to +// the local variable "b". +// +// Note that go_generics can handle anonymous fields with renamed types if +// -anon is passed in, however it does not perform strict checking on parameter +// types that share the same name as the global type and therefore will rename +// them as well. +// +// You can see an example in the tools/go_generics/generics_tests/interface test. +package main + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "io/ioutil" + "os" + "regexp" + "strings" + + "gvisor.dev/gvisor/tools/go_generics/globals" +) + +var ( + input = flag.String("i", "", "input `file`") + output = flag.String("o", "", "output `file`") + suffix = flag.String("suffix", "", "`suffix` to add to each global symbol") + prefix = flag.String("prefix", "", "`prefix` to add to each global symbol") + packageName = flag.String("p", "main", "output package `name`") + printAST = flag.Bool("ast", false, "prints the AST") + processAnon = flag.Bool("anon", false, "process anonymous fields") + types = make(mapValue) + consts = make(mapValue) + imports = make(mapValue) +) + +// mapValue implements flag.Value. We use a mapValue flag instead of a regular +// string flag when we want to allow more than one instance of the flag. For +// example, we allow several "-t A=B" arguments, and will rename them all. +type mapValue map[string]string + +func (m mapValue) String() string { + var b bytes.Buffer + first := true + for k, v := range m { + if !first { + b.WriteRune(',') + } else { + first = false + } + b.WriteString(k) + b.WriteRune('=') + b.WriteString(v) + } + return b.String() +} + +func (m mapValue) Set(s string) error { + sep := strings.Index(s, "=") + if sep == -1 { + return fmt.Errorf("missing '=' from '%s'", s) + } + + m[s[:sep]] = s[sep+1:] + + return nil +} + +// stateTagRegexp matches against the 'typed' state tags. +var stateTagRegexp = regexp.MustCompile(`^(.*[^a-z0-9_])state:"\.\(([^\)]*)\)"(.*)$`) + +var identifierRegexp = regexp.MustCompile(`^(.*[^a-zA-Z_])([a-zA-Z_][a-zA-Z0-9_]*)(.*)$`) + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0]) + flag.PrintDefaults() + } + + flag.Var(types, "t", "rename type A to B when `A=B` is passed in. Multiple such mappings are allowed.") + flag.Var(consts, "c", "reassign constant A to value B when `A=B` is passed in. Multiple such mappings are allowed.") + flag.Var(imports, "import", "specifies the import libraries to use when types are not local. `name=path` specifies that 'name', used in types as name.type, refers to the package living in 'path'.") + flag.Parse() + + if *input == "" || *output == "" { + flag.Usage() + os.Exit(1) + } + + // Parse the input file. + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, *input, nil, parser.ParseComments|parser.DeclarationErrors|parser.SpuriousErrors) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + + // Print the AST if requested. + if *printAST { + ast.Print(fset, f) + } + + cmap := ast.NewCommentMap(fset, f, f.Comments) + + // Update imports based on what's used in types and consts. + maps := []mapValue{types, consts} + importDecl, err := updateImports(maps, imports) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + types = maps[0] + consts = maps[1] + + // Reassign all specified constants. + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok || d.Tok != token.CONST { + continue + } + + for _, gs := range d.Specs { + s := gs.(*ast.ValueSpec) + for i, id := range s.Names { + if n, ok := consts[id.Name]; ok { + s.Values[i] = &ast.BasicLit{Value: n} + } + } + } + } + + // Go through all globals and their uses in the AST and rename the types + // with explicitly provided names, and rename all types, variables, + // consts and functions with the provided prefix and suffix. + globals.Visit(fset, f, func(ident *ast.Ident, kind globals.SymKind) { + if n, ok := types[ident.Name]; ok && kind == globals.KindType { + ident.Name = n + } else { + switch kind { + case globals.KindType, globals.KindVar, globals.KindConst, globals.KindFunction: + if ident.Name != "_" { + ident.Name = *prefix + ident.Name + *suffix + } + case globals.KindTag: + // Modify the state tag appropriately. + if m := stateTagRegexp.FindStringSubmatch(ident.Name); m != nil { + if t := identifierRegexp.FindStringSubmatch(m[2]); t != nil { + typeName := *prefix + t[2] + *suffix + if n, ok := types[t[2]]; ok { + typeName = n + } + ident.Name = m[1] + `state:".(` + t[1] + typeName + t[3] + `)"` + m[3] + } + } + } + } + }, *processAnon) + + // Remove the definition of all types that are being remapped. + set := make(typeSet) + for _, v := range types { + set[v] = struct{}{} + } + removeTypes(set, f) + + // Add the new imports, if any, to the top. + if importDecl != nil { + newDecls := make([]ast.Decl, 0, len(f.Decls)+1) + newDecls = append(newDecls, importDecl) + newDecls = append(newDecls, f.Decls...) + f.Decls = newDecls + } + + // Update comments to remove the ones potentially associated with the + // type T that we removed. + f.Comments = cmap.Filter(f).Comments() + + // If there are file (package) comments, delete them. + if f.Doc != nil { + for i, cg := range f.Comments { + if cg == f.Doc { + f.Comments = append(f.Comments[:i], f.Comments[i+1:]...) + break + } + } + } + + // Write the output file. + f.Name.Name = *packageName + + var buf bytes.Buffer + if err := format.Node(&buf, fset, f); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + + if err := ioutil.WriteFile(*output, buf.Bytes(), 0644); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } +} diff --git a/tools/go_generics/generics_tests/all_stmts/input.go b/tools/go_generics/generics_tests/all_stmts/input.go new file mode 100644 index 000000000..4791d1ff1 --- /dev/null +++ b/tools/go_generics/generics_tests/all_stmts/input.go @@ -0,0 +1,290 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "sync" +) + +type T int + +func h(T) { +} + +type s struct { + a, b int + c []int +} + +func g(T) *s { + return &s{} +} + +func f() (T, []int) { + // Branch. + goto T + goto R + + // Labeled. +T: + _ = T(0) + + // Empty. +R: + ; + + // Assignment with definition. + a, b, c := T(1), T(2), T(3) + _, _, _ = a, b, c + + // Assignment without definition. + g(T(0)).a, g(T(1)).b, c = int(T(1)), int(T(2)), T(3) + _, _, _ = a, b, c + + // Block. + { + var T T + T = 0 + _ = T + } + + // Declarations. + type Type T + const Const T = 10 + var g1 func(T, int, ...T) (int, T) + var v T + var w = T(0) + { + var T struct { + f []T + } + _ = T + } + + // Defer. + defer g1(T(0), 1) + + // Expression. + h(v + w + T(1)) + + // For statements. + for i := T(0); i < T(10); i++ { + var T func(int) T + v := T(0) + _ = v + } + + for { + var T func(int) T + v := T(0) + _ = v + } + + // Go. + go g1(T(0), 1) + + // If statements. + if a != T(1) { + var T func(int) T + v := T(0) + _ = v + } + + if a := T(0); a != T(1) { + var T func(int) T + v := T(0) + _ = v + } + + if a := T(0); a != T(1) { + var T func(int) T + v := T(0) + _ = v + } else if b := T(0); b != T(1) { + var T func(int) T + v := T(0) + _ = v + } else if T := T(0); T != 1 { + T++ + } else { + T-- + } + + if a := T(0); a != T(1) { + var T func(int) T + v := T(0) + _ = v + } else { + var T func(int) T + v := T(0) + _ = v + } + + // Inc/Dec statements. + (*(*T)(nil))++ + (*(*T)(nil))-- + + // Range statements. + for g(T(0)).a, g(T(1)).b = range g(T(10)).c { + var d T + _ = d + } + + for T, b := range g(T(10)).c { + _ = T + _ = b + } + + // Select statement. + { + var fch func(T) chan int + + select { + case <-fch(T(30)): + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + case T := <-fch(T(30)): + T = 0 + _ = T + case g(T(0)).a = <-fch(T(30)): + var T T + T = 0 + _ = T + case fch(T(30)) <- int(T(0)): + var T T + T = 0 + _ = T + } + } + + // Send statements. + { + var ch chan T + var fch func(T) chan int + + ch <- T(0) + fch(T(1)) <- g(T(10)).a + } + + // Switch statements. + { + var a T + var b int + switch { + case a == T(0): + var T T + T = 0 + _ = T + case a < T(0), b < g(T(10)).a: + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + } + } + + switch T(g(T(10)).a) { + case T(0): + var T T + T = 0 + _ = T + case T(1), T(g(T(10)).a): + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + } + + switch b := g(T(10)); T(b.a) + T(10) { + case T(0): + var T T + T = 0 + _ = T + case T(1), T(g(T(10)).a): + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + } + + // Type switch statements. + { + var interfaceFunc func(T) interface{} + + switch interfaceFunc(T(0)).(type) { + case *T, T, int: + var T T + T = 0 + _ = T + case sync.Mutex, **T: + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + } + + switch x := interfaceFunc(T(0)).(type) { + case *T, T, int: + var T T + T = 0 + _ = T + _ = x + case sync.Mutex, **T: + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + } + + switch t := T(0); x := interfaceFunc(T(0) + t).(type) { + case *T, T, int: + var T T + T = 0 + _ = T + _ = x + case sync.Mutex, **T: + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + } + } + + // Return statement. + return T(10), g(T(11)).c +} diff --git a/tools/go_generics/generics_tests/all_stmts/opts.txt b/tools/go_generics/generics_tests/all_stmts/opts.txt new file mode 100644 index 000000000..c9d0e09bf --- /dev/null +++ b/tools/go_generics/generics_tests/all_stmts/opts.txt @@ -0,0 +1 @@ +-t=T=Q diff --git a/tools/go_generics/generics_tests/all_stmts/output/output.go b/tools/go_generics/generics_tests/all_stmts/output/output.go new file mode 100644 index 000000000..a53d84535 --- /dev/null +++ b/tools/go_generics/generics_tests/all_stmts/output/output.go @@ -0,0 +1,288 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "sync" +) + +func h(Q) { +} + +type s struct { + a, b int + c []int +} + +func g(Q) *s { + return &s{} +} + +func f() (Q, []int) { + // Branch. + goto T + goto R + + // Labeled. +T: + _ = Q(0) + + // Empty. +R: + ; + + // Assignment with definition. + a, b, c := Q(1), Q(2), Q(3) + _, _, _ = a, b, c + + // Assignment without definition. + g(Q(0)).a, g(Q(1)).b, c = int(Q(1)), int(Q(2)), Q(3) + _, _, _ = a, b, c + + // Block. + { + var T Q + T = 0 + _ = T + } + + // Declarations. + type Type Q + const Const Q = 10 + var g1 func(Q, int, ...Q) (int, Q) + var v Q + var w = Q(0) + { + var T struct { + f []Q + } + _ = T + } + + // Defer. + defer g1(Q(0), 1) + + // Expression. + h(v + w + Q(1)) + + // For statements. + for i := Q(0); i < Q(10); i++ { + var T func(int) Q + v := T(0) + _ = v + } + + for { + var T func(int) Q + v := T(0) + _ = v + } + + // Go. + go g1(Q(0), 1) + + // If statements. + if a != Q(1) { + var T func(int) Q + v := T(0) + _ = v + } + + if a := Q(0); a != Q(1) { + var T func(int) Q + v := T(0) + _ = v + } + + if a := Q(0); a != Q(1) { + var T func(int) Q + v := T(0) + _ = v + } else if b := Q(0); b != Q(1) { + var T func(int) Q + v := T(0) + _ = v + } else if T := Q(0); T != 1 { + T++ + } else { + T-- + } + + if a := Q(0); a != Q(1) { + var T func(int) Q + v := T(0) + _ = v + } else { + var T func(int) Q + v := T(0) + _ = v + } + + // Inc/Dec statements. + (*(*Q)(nil))++ + (*(*Q)(nil))-- + + // Range statements. + for g(Q(0)).a, g(Q(1)).b = range g(Q(10)).c { + var d Q + _ = d + } + + for T, b := range g(Q(10)).c { + _ = T + _ = b + } + + // Select statement. + { + var fch func(Q) chan int + + select { + case <-fch(Q(30)): + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + case T := <-fch(Q(30)): + T = 0 + _ = T + case g(Q(0)).a = <-fch(Q(30)): + var T Q + T = 0 + _ = T + case fch(Q(30)) <- int(Q(0)): + var T Q + T = 0 + _ = T + } + } + + // Send statements. + { + var ch chan Q + var fch func(Q) chan int + + ch <- Q(0) + fch(Q(1)) <- g(Q(10)).a + } + + // Switch statements. + { + var a Q + var b int + switch { + case a == Q(0): + var T Q + T = 0 + _ = T + case a < Q(0), b < g(Q(10)).a: + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + } + } + + switch Q(g(Q(10)).a) { + case Q(0): + var T Q + T = 0 + _ = T + case Q(1), Q(g(Q(10)).a): + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + } + + switch b := g(Q(10)); Q(b.a) + Q(10) { + case Q(0): + var T Q + T = 0 + _ = T + case Q(1), Q(g(Q(10)).a): + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + } + + // Type switch statements. + { + var interfaceFunc func(Q) interface{} + + switch interfaceFunc(Q(0)).(type) { + case *Q, Q, int: + var T Q + T = 0 + _ = T + case sync.Mutex, **Q: + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + } + + switch x := interfaceFunc(Q(0)).(type) { + case *Q, Q, int: + var T Q + T = 0 + _ = T + _ = x + case sync.Mutex, **Q: + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + } + + switch t := Q(0); x := interfaceFunc(Q(0) + t).(type) { + case *Q, Q, int: + var T Q + T = 0 + _ = T + _ = x + case sync.Mutex, **Q: + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + } + } + + // Return statement. + return Q(10), g(Q(11)).c +} diff --git a/tools/go_generics/generics_tests/all_types/input.go b/tools/go_generics/generics_tests/all_types/input.go new file mode 100644 index 000000000..3575d02ec --- /dev/null +++ b/tools/go_generics/generics_tests/all_types/input.go @@ -0,0 +1,43 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import "./lib" + +type T int + +type newType struct { + a T + b lib.T + c *T + d (T) + e chan T + f <-chan T + g chan<- T + h []T + i [10]T + j map[T]T + k func(T, T) (T, T) + l interface { + f(T) + } + m struct { + T + a T + } +} + +func f(...T) { +} diff --git a/tools/go_generics/generics_tests/all_types/lib/lib.go b/tools/go_generics/generics_tests/all_types/lib/lib.go new file mode 100644 index 000000000..988786496 --- /dev/null +++ b/tools/go_generics/generics_tests/all_types/lib/lib.go @@ -0,0 +1,17 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lib + +type T int32 diff --git a/tools/go_generics/generics_tests/all_types/opts.txt b/tools/go_generics/generics_tests/all_types/opts.txt new file mode 100644 index 000000000..c9d0e09bf --- /dev/null +++ b/tools/go_generics/generics_tests/all_types/opts.txt @@ -0,0 +1 @@ +-t=T=Q diff --git a/tools/go_generics/generics_tests/all_types/output/output.go b/tools/go_generics/generics_tests/all_types/output/output.go new file mode 100644 index 000000000..41fd147a1 --- /dev/null +++ b/tools/go_generics/generics_tests/all_types/output/output.go @@ -0,0 +1,41 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import "./lib" + +type newType struct { + a Q + b lib.T + c *Q + d (Q) + e chan Q + f <-chan Q + g chan<- Q + h []Q + i [10]Q + j map[Q]Q + k func(Q, Q) (Q, Q) + l interface { + f(Q) + } + m struct { + Q + a Q + } +} + +func f(...Q) { +} diff --git a/tools/go_generics/generics_tests/anon/input.go b/tools/go_generics/generics_tests/anon/input.go new file mode 100644 index 000000000..44086d522 --- /dev/null +++ b/tools/go_generics/generics_tests/anon/input.go @@ -0,0 +1,46 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +type T interface { + Apply(T) T +} + +type Foo struct { + T + Bar map[string]T `json:"bar,omitempty"` +} + +type Baz struct { + T someTypeNotT +} + +func (f Foo) GetBar(name string) T { + b, ok := f.Bar[name] + if ok { + b = f.Apply(b) + } else { + b = f.T + } + return b +} + +func foobar() { + a := Baz{} + a.T = 0 // should not be renamed, this is a limitation + + b := otherpkg.UnrelatedType{} + b.T = 0 // should not be renamed, this is a limitation +} diff --git a/tools/go_generics/generics_tests/anon/opts.txt b/tools/go_generics/generics_tests/anon/opts.txt new file mode 100644 index 000000000..a5e9d26de --- /dev/null +++ b/tools/go_generics/generics_tests/anon/opts.txt @@ -0,0 +1 @@ +-t=T=Q -suffix=New -anon diff --git a/tools/go_generics/generics_tests/anon/output/output.go b/tools/go_generics/generics_tests/anon/output/output.go new file mode 100644 index 000000000..160cddf79 --- /dev/null +++ b/tools/go_generics/generics_tests/anon/output/output.go @@ -0,0 +1,42 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +type FooNew struct { + Q + Bar map[string]Q `json:"bar,omitempty"` +} + +type BazNew struct { + T someTypeNotT +} + +func (f FooNew) GetBar(name string) Q { + b, ok := f.Bar[name] + if ok { + b = f.Apply(b) + } else { + b = f.Q + } + return b +} + +func foobarNew() { + a := BazNew{} + a.Q = 0 // should not be renamed, this is a limitation + + b := otherpkg.UnrelatedType{} + b.Q = 0 // should not be renamed, this is a limitation +} diff --git a/tools/go_generics/generics_tests/consts/input.go b/tools/go_generics/generics_tests/consts/input.go new file mode 100644 index 000000000..04b95fcc6 --- /dev/null +++ b/tools/go_generics/generics_tests/consts/input.go @@ -0,0 +1,26 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +const c1 = 10 +const x, y, z = 100, 200, 300 +const v float32 = 1.0 + 2.0 +const s = "abc" +const ( + A = 10 + B, C, D = 10, 20, 30 + S = "abc" + T, U, V string = "abc", "def", "ghi" +) diff --git a/tools/go_generics/generics_tests/consts/opts.txt b/tools/go_generics/generics_tests/consts/opts.txt new file mode 100644 index 000000000..4fb59dce8 --- /dev/null +++ b/tools/go_generics/generics_tests/consts/opts.txt @@ -0,0 +1 @@ +-c=c1=20 -c=z=600 -c=v=3.3 -c=s="def" -c=A=20 -c=C=100 -c=S="def" -c=T="ABC" diff --git a/tools/go_generics/generics_tests/consts/output/output.go b/tools/go_generics/generics_tests/consts/output/output.go new file mode 100644 index 000000000..18d316cc9 --- /dev/null +++ b/tools/go_generics/generics_tests/consts/output/output.go @@ -0,0 +1,26 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +const c1 = 20 +const x, y, z = 100, 200, 600 +const v float32 = 3.3 +const s = "def" +const ( + A = 20 + B, C, D = 10, 100, 30 + S = "def" + T, U, V string = "ABC", "def", "ghi" +) diff --git a/tools/go_generics/generics_tests/imports/input.go b/tools/go_generics/generics_tests/imports/input.go new file mode 100644 index 000000000..0f032c2a1 --- /dev/null +++ b/tools/go_generics/generics_tests/imports/input.go @@ -0,0 +1,24 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +type T int + +var global T + +const ( + m = 0 + n = 0 +) diff --git a/tools/go_generics/generics_tests/imports/opts.txt b/tools/go_generics/generics_tests/imports/opts.txt new file mode 100644 index 000000000..87324be79 --- /dev/null +++ b/tools/go_generics/generics_tests/imports/opts.txt @@ -0,0 +1 @@ +-t=T=sync.Mutex -c=n=math.Uint32 -c=m=math.Uint64 -import=sync=sync -import=math=mymathpath diff --git a/tools/go_generics/generics_tests/imports/output/output.go b/tools/go_generics/generics_tests/imports/output/output.go new file mode 100644 index 000000000..2488ca58c --- /dev/null +++ b/tools/go_generics/generics_tests/imports/output/output.go @@ -0,0 +1,27 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + __generics_imported1 "mymathpath" + __generics_imported0 "sync" +) + +var global __generics_imported0.Mutex + +const ( + m = __generics_imported1.Uint64 + n = __generics_imported1.Uint32 +) diff --git a/tools/go_generics/generics_tests/remove_typedef/input.go b/tools/go_generics/generics_tests/remove_typedef/input.go new file mode 100644 index 000000000..cf632bae7 --- /dev/null +++ b/tools/go_generics/generics_tests/remove_typedef/input.go @@ -0,0 +1,37 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +func f(T) Q { + return Q{} +} + +type T struct{} + +type Q struct{} + +func (*T) f() { +} + +func (T) g() { +} + +func (*Q) f(T) T { + return T{} +} + +func (*Q) g(T) *T { + return nil +} diff --git a/tools/go_generics/generics_tests/remove_typedef/opts.txt b/tools/go_generics/generics_tests/remove_typedef/opts.txt new file mode 100644 index 000000000..9c8ecaada --- /dev/null +++ b/tools/go_generics/generics_tests/remove_typedef/opts.txt @@ -0,0 +1 @@ +-t=T=U diff --git a/tools/go_generics/generics_tests/remove_typedef/output/output.go b/tools/go_generics/generics_tests/remove_typedef/output/output.go new file mode 100644 index 000000000..d44fd8e1c --- /dev/null +++ b/tools/go_generics/generics_tests/remove_typedef/output/output.go @@ -0,0 +1,29 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +func f(U) Q { + return Q{} +} + +type Q struct{} + +func (*Q) f(U) U { + return U{} +} + +func (*Q) g(U) *U { + return nil +} diff --git a/tools/go_generics/generics_tests/simple/input.go b/tools/go_generics/generics_tests/simple/input.go new file mode 100644 index 000000000..2a917f16c --- /dev/null +++ b/tools/go_generics/generics_tests/simple/input.go @@ -0,0 +1,45 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +type T int + +var global T + +func f(_ T, a int) { +} + +func g(a T, b int) { + var c T + _ = c + + d := (*T)(nil) + _ = d +} + +type R struct { + T + a *T +} + +var ( + Z *T = (*T)(nil) +) + +const ( + X T = (T)(0) +) + +type Y T diff --git a/tools/go_generics/generics_tests/simple/opts.txt b/tools/go_generics/generics_tests/simple/opts.txt new file mode 100644 index 000000000..7832ef66f --- /dev/null +++ b/tools/go_generics/generics_tests/simple/opts.txt @@ -0,0 +1 @@ +-t=T=Q -suffix=New diff --git a/tools/go_generics/generics_tests/simple/output/output.go b/tools/go_generics/generics_tests/simple/output/output.go new file mode 100644 index 000000000..6bfa0b25b --- /dev/null +++ b/tools/go_generics/generics_tests/simple/output/output.go @@ -0,0 +1,43 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +var globalNew Q + +func fNew(_ Q, a int) { +} + +func gNew(a Q, b int) { + var c Q + _ = c + + d := (*Q)(nil) + _ = d +} + +type RNew struct { + Q + a *Q +} + +var ( + ZNew *Q = (*Q)(nil) +) + +const ( + XNew Q = (Q)(0) +) + +type YNew Q diff --git a/tools/go_generics/globals/BUILD b/tools/go_generics/globals/BUILD new file mode 100644 index 000000000..38caa3ce7 --- /dev/null +++ b/tools/go_generics/globals/BUILD @@ -0,0 +1,13 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "globals", + srcs = [ + "globals_visitor.go", + "scope.go", + ], + stateify = False, + visibility = ["//tools/go_generics:__pkg__"], +) diff --git a/tools/go_generics/globals/globals_visitor.go b/tools/go_generics/globals/globals_visitor.go new file mode 100644 index 000000000..883f21ebe --- /dev/null +++ b/tools/go_generics/globals/globals_visitor.go @@ -0,0 +1,597 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package globals provides an AST visitor that calls the visit function for all +// global identifiers. +package globals + +import ( + "fmt" + + "go/ast" + "go/token" + "path/filepath" + "strconv" +) + +// globalsVisitor holds the state used while traversing the nodes of a file in +// search of globals. +// +// The visitor does two passes on the global declarations: the first one adds +// all globals to the global scope (since Go allows references to globals that +// haven't been declared yet), and the second one calls f() for the definition +// and uses of globals found in the first pass. +// +// The implementation correctly handles cases when globals are aliased by +// locals; in such cases, f() is not called. +type globalsVisitor struct { + // file is the file whose nodes are being visited. + file *ast.File + + // fset is the file set the file being visited belongs to. + fset *token.FileSet + + // f is the visit function to be called when a global symbol is reached. + f func(*ast.Ident, SymKind) + + // scope is the current scope as nodes are visited. + scope *scope + + // processAnon indicates whether we should process anonymous struct fields. + // It does not perform strict checking on parameter types that share the same name + // as the global type and therefore will rename them as well. + processAnon bool +} + +// unexpected is called when an unexpected node appears in the AST. It dumps +// the location of the associated token and panics because this should only +// happen when there is a bug in the traversal code. +func (v *globalsVisitor) unexpected(p token.Pos) { + panic(fmt.Sprintf("Unable to parse at %v", v.fset.Position(p))) +} + +// pushScope creates a new scope and pushes it to the top of the scope stack. +func (v *globalsVisitor) pushScope() { + v.scope = newScope(v.scope) +} + +// popScope removes the scope created by the last call to pushScope. +func (v *globalsVisitor) popScope() { + v.scope = v.scope.outer +} + +// visitType is called when an expression is known to be a type, for example, +// on the first argument of make(). It visits all children nodes and reports +// any globals. +func (v *globalsVisitor) visitType(ge ast.Expr) { + switch e := ge.(type) { + case *ast.Ident: + if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() { + v.f(e, s.kind) + } + + case *ast.SelectorExpr: + id := GetIdent(e.X) + if id == nil { + v.unexpected(e.X.Pos()) + } + + case *ast.StarExpr: + v.visitType(e.X) + case *ast.ParenExpr: + v.visitType(e.X) + case *ast.ChanType: + v.visitType(e.Value) + case *ast.Ellipsis: + v.visitType(e.Elt) + case *ast.ArrayType: + v.visitExpr(e.Len) + v.visitType(e.Elt) + case *ast.MapType: + v.visitType(e.Key) + v.visitType(e.Value) + case *ast.StructType: + v.visitFields(e.Fields, KindUnknown) + case *ast.FuncType: + v.visitFields(e.Params, KindUnknown) + v.visitFields(e.Results, KindUnknown) + case *ast.InterfaceType: + v.visitFields(e.Methods, KindUnknown) + default: + v.unexpected(ge.Pos()) + } +} + +// visitFields visits all fields, and add symbols if kind isn't KindUnknown. +func (v *globalsVisitor) visitFields(l *ast.FieldList, kind SymKind) { + if l == nil { + return + } + + for _, f := range l.List { + if kind != KindUnknown { + for _, n := range f.Names { + v.scope.add(n.Name, kind, n.Pos()) + } + } + v.visitType(f.Type) + if f.Tag != nil { + tag := ast.NewIdent(f.Tag.Value) + v.f(tag, KindTag) + // Replace the tag if updated. + if tag.Name != f.Tag.Value { + f.Tag.Value = tag.Name + } + } + } +} + +// visitGenDecl is called when a generic declaration is encountered, for example, +// on variable, constant and type declarations. It adds all newly defined +// symbols to the current scope and reports them if the current scope is the +// global one. +func (v *globalsVisitor) visitGenDecl(d *ast.GenDecl) { + switch d.Tok { + case token.IMPORT: + case token.TYPE: + for _, gs := range d.Specs { + s := gs.(*ast.TypeSpec) + v.scope.add(s.Name.Name, KindType, s.Name.Pos()) + if v.scope.isGlobal() { + v.f(s.Name, KindType) + } + v.visitType(s.Type) + } + case token.CONST, token.VAR: + kind := KindConst + if d.Tok == token.VAR { + kind = KindVar + } + + for _, gs := range d.Specs { + s := gs.(*ast.ValueSpec) + if s.Type != nil { + v.visitType(s.Type) + } + + for _, e := range s.Values { + v.visitExpr(e) + } + + for _, n := range s.Names { + if v.scope.isGlobal() { + v.f(n, kind) + } + v.scope.add(n.Name, kind, n.Pos()) + } + } + default: + v.unexpected(d.Pos()) + } +} + +// isViableType determines if the given expression is a viable type expression, +// that is, if it could be interpreted as a type, for example, sync.Mutex, +// myType, func(int)int, as opposed to -1, 2 * 2, a + b, etc. +func (v *globalsVisitor) isViableType(expr ast.Expr) bool { + switch e := expr.(type) { + case *ast.Ident: + // This covers the plain identifier case. When we see it, we + // have to check if it resolves to a type; if the symbol is not + // known, we'll claim it's viable as a type. + s := v.scope.deepLookup(e.Name) + return s == nil || s.kind == KindType + + case *ast.ChanType, *ast.ArrayType, *ast.MapType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.Ellipsis: + // This covers the following cases: + // 1. ChanType: + // chan T + // <-chan T + // chan<- T + // 2. ArrayType: + // [Expr]T + // 3. MapType: + // map[T]U + // 4. StructType: + // struct { Fields } + // 5. FuncType: + // func(Fields)Returns + // 6. Interface: + // interface { Fields } + // 7. Ellipsis: + // ...T + return true + + case *ast.SelectorExpr: + // The only case in which an expression involving a selector can + // be a type is if it has the following form X.T, where X is an + // import, and T is a type exported by X. + // + // There's no way to know whether T is a type because we don't + // parse imports. So we just claim that this is a viable type; + // it doesn't affect the general result because we don't visit + // imported symbols. + id := GetIdent(e.X) + if id == nil { + return false + } + + s := v.scope.deepLookup(id.Name) + return s != nil && s.kind == KindImport + + case *ast.StarExpr: + // This covers the *T case. The expression is a viable type if + // T is. + return v.isViableType(e.X) + + case *ast.ParenExpr: + // This covers the (T) case. The expression is a viable type if + // T is. + return v.isViableType(e.X) + + default: + return false + } +} + +// visitCallExpr visits a "call expression" which can be either a +// function/method call (e.g., f(), pkg.f(), obj.f(), etc.) call or a type +// conversion (e.g., int32(1), (*sync.Mutex)(ptr), etc.). +func (v *globalsVisitor) visitCallExpr(e *ast.CallExpr) { + if v.isViableType(e.Fun) { + v.visitType(e.Fun) + } else { + v.visitExpr(e.Fun) + } + + // If the function being called is new or make, the first argument is + // a type, so it needs to be visited as such. + first := 0 + if id := GetIdent(e.Fun); id != nil && (id.Name == "make" || id.Name == "new") { + if len(e.Args) > 0 { + v.visitType(e.Args[0]) + } + first = 1 + } + + for i := first; i < len(e.Args); i++ { + v.visitExpr(e.Args[i]) + } +} + +// visitExpr visits all nodes of an expression, and reports any globals that it +// finds. +func (v *globalsVisitor) visitExpr(ge ast.Expr) { + switch e := ge.(type) { + case nil: + case *ast.Ident: + if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() { + v.f(e, s.kind) + } + + case *ast.BasicLit: + case *ast.CompositeLit: + v.visitType(e.Type) + for _, ne := range e.Elts { + v.visitExpr(ne) + } + case *ast.FuncLit: + v.pushScope() + v.visitFields(e.Type.Params, KindParameter) + v.visitFields(e.Type.Results, KindResult) + v.visitBlockStmt(e.Body) + v.popScope() + + case *ast.BinaryExpr: + v.visitExpr(e.X) + v.visitExpr(e.Y) + + case *ast.CallExpr: + v.visitCallExpr(e) + + case *ast.IndexExpr: + v.visitExpr(e.X) + v.visitExpr(e.Index) + + case *ast.KeyValueExpr: + v.visitExpr(e.Value) + + case *ast.ParenExpr: + v.visitExpr(e.X) + + case *ast.SelectorExpr: + v.visitExpr(e.X) + if v.processAnon { + v.visitExpr(e.Sel) + } + + case *ast.SliceExpr: + v.visitExpr(e.X) + v.visitExpr(e.Low) + v.visitExpr(e.High) + v.visitExpr(e.Max) + + case *ast.StarExpr: + v.visitExpr(e.X) + + case *ast.TypeAssertExpr: + v.visitExpr(e.X) + if e.Type != nil { + v.visitType(e.Type) + } + + case *ast.UnaryExpr: + v.visitExpr(e.X) + + default: + v.unexpected(ge.Pos()) + } +} + +// GetIdent returns the identifier associated with the given expression by +// removing parentheses if needed. +func GetIdent(expr ast.Expr) *ast.Ident { + switch e := expr.(type) { + case *ast.Ident: + return e + case *ast.ParenExpr: + return GetIdent(e.X) + default: + return nil + } +} + +// visitStmt visits all nodes of a statement, and reports any globals that it +// finds. It also adds to the current scope new symbols defined/declared. +func (v *globalsVisitor) visitStmt(gs ast.Stmt) { + switch s := gs.(type) { + case nil, *ast.BranchStmt, *ast.EmptyStmt: + case *ast.AssignStmt: + for _, e := range s.Rhs { + v.visitExpr(e) + } + + // We visit the LHS after the RHS because the symbols we'll + // potentially add to the table aren't meant to be visible to + // the RHS. + for _, e := range s.Lhs { + if s.Tok == token.DEFINE { + if n := GetIdent(e); n != nil { + v.scope.add(n.Name, KindVar, n.Pos()) + } + } + v.visitExpr(e) + } + + case *ast.BlockStmt: + v.visitBlockStmt(s) + + case *ast.DeclStmt: + v.visitGenDecl(s.Decl.(*ast.GenDecl)) + + case *ast.DeferStmt: + v.visitCallExpr(s.Call) + + case *ast.ExprStmt: + v.visitExpr(s.X) + + case *ast.ForStmt: + v.pushScope() + v.visitStmt(s.Init) + v.visitExpr(s.Cond) + v.visitStmt(s.Post) + v.visitBlockStmt(s.Body) + v.popScope() + + case *ast.GoStmt: + v.visitCallExpr(s.Call) + + case *ast.IfStmt: + v.pushScope() + v.visitStmt(s.Init) + v.visitExpr(s.Cond) + v.visitBlockStmt(s.Body) + v.visitStmt(s.Else) + v.popScope() + + case *ast.IncDecStmt: + v.visitExpr(s.X) + + case *ast.LabeledStmt: + v.visitStmt(s.Stmt) + + case *ast.RangeStmt: + v.pushScope() + v.visitExpr(s.X) + if s.Tok == token.DEFINE { + if n := GetIdent(s.Key); n != nil { + v.scope.add(n.Name, KindVar, n.Pos()) + } + + if n := GetIdent(s.Value); n != nil { + v.scope.add(n.Name, KindVar, n.Pos()) + } + } + v.visitExpr(s.Key) + v.visitExpr(s.Value) + v.visitBlockStmt(s.Body) + v.popScope() + + case *ast.ReturnStmt: + for _, r := range s.Results { + v.visitExpr(r) + } + + case *ast.SelectStmt: + for _, ns := range s.Body.List { + c := ns.(*ast.CommClause) + + v.pushScope() + v.visitStmt(c.Comm) + for _, bs := range c.Body { + v.visitStmt(bs) + } + v.popScope() + } + + case *ast.SendStmt: + v.visitExpr(s.Chan) + v.visitExpr(s.Value) + + case *ast.SwitchStmt: + v.pushScope() + v.visitStmt(s.Init) + v.visitExpr(s.Tag) + for _, ns := range s.Body.List { + c := ns.(*ast.CaseClause) + v.pushScope() + for _, ce := range c.List { + v.visitExpr(ce) + } + for _, bs := range c.Body { + v.visitStmt(bs) + } + v.popScope() + } + v.popScope() + + case *ast.TypeSwitchStmt: + v.pushScope() + v.visitStmt(s.Init) + v.visitStmt(s.Assign) + for _, ns := range s.Body.List { + c := ns.(*ast.CaseClause) + v.pushScope() + for _, ce := range c.List { + v.visitType(ce) + } + for _, bs := range c.Body { + v.visitStmt(bs) + } + v.popScope() + } + v.popScope() + + default: + v.unexpected(gs.Pos()) + } +} + +// visitBlockStmt visits all statements in the block, adding symbols to a newly +// created scope. +func (v *globalsVisitor) visitBlockStmt(s *ast.BlockStmt) { + v.pushScope() + for _, c := range s.List { + v.visitStmt(c) + } + v.popScope() +} + +// visitFuncDecl is called when a function or method declaration is encountered. +// it creates a new scope for the function [optional] receiver, parameters and +// results, and visits all children nodes. +func (v *globalsVisitor) visitFuncDecl(d *ast.FuncDecl) { + // We don't report methods. + if d.Recv == nil { + v.f(d.Name, KindFunction) + } + + v.pushScope() + v.visitFields(d.Recv, KindReceiver) + v.visitFields(d.Type.Params, KindParameter) + v.visitFields(d.Type.Results, KindResult) + if d.Body != nil { + v.visitBlockStmt(d.Body) + } + v.popScope() +} + +// globalsFromDecl is called in the first, and adds symbols to global scope. +func (v *globalsVisitor) globalsFromGenDecl(d *ast.GenDecl) { + switch d.Tok { + case token.IMPORT: + for _, gs := range d.Specs { + s := gs.(*ast.ImportSpec) + if s.Name == nil { + str, _ := strconv.Unquote(s.Path.Value) + v.scope.add(filepath.Base(str), KindImport, s.Path.Pos()) + } else if s.Name.Name != "_" { + v.scope.add(s.Name.Name, KindImport, s.Name.Pos()) + } + } + case token.TYPE: + for _, gs := range d.Specs { + s := gs.(*ast.TypeSpec) + v.scope.add(s.Name.Name, KindType, s.Name.Pos()) + } + case token.CONST, token.VAR: + kind := KindConst + if d.Tok == token.VAR { + kind = KindVar + } + + for _, s := range d.Specs { + for _, n := range s.(*ast.ValueSpec).Names { + v.scope.add(n.Name, kind, n.Pos()) + } + } + default: + v.unexpected(d.Pos()) + } +} + +// visit implements the visiting of globals. It does performs the two passes +// described in the description of the globalsVisitor struct. +func (v *globalsVisitor) visit() { + // Gather all symbols in the global scope. This excludes methods. + v.pushScope() + for _, gd := range v.file.Decls { + switch d := gd.(type) { + case *ast.GenDecl: + v.globalsFromGenDecl(d) + case *ast.FuncDecl: + if d.Recv == nil { + v.scope.add(d.Name.Name, KindFunction, d.Name.Pos()) + } + default: + v.unexpected(gd.Pos()) + } + } + + // Go through the contents of the declarations. + for _, gd := range v.file.Decls { + switch d := gd.(type) { + case *ast.GenDecl: + v.visitGenDecl(d) + case *ast.FuncDecl: + v.visitFuncDecl(d) + } + } +} + +// Visit traverses the provided AST and calls f() for each identifier that +// refers to global names. The global name must be defined in the file itself. +// +// The function f() is allowed to modify the identifier, for example, to rename +// uses of global references. +func Visit(fset *token.FileSet, file *ast.File, f func(*ast.Ident, SymKind), processAnon bool) { + v := globalsVisitor{ + fset: fset, + file: file, + f: f, + processAnon: processAnon, + } + + v.visit() +} diff --git a/tools/go_generics/globals/scope.go b/tools/go_generics/globals/scope.go new file mode 100644 index 000000000..eec93534b --- /dev/null +++ b/tools/go_generics/globals/scope.go @@ -0,0 +1,84 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package globals + +import ( + "go/token" +) + +// SymKind specifies the kind of a global symbol. For example, a variable, const +// function, etc. +type SymKind int + +// Constants for different kinds of symbols. +const ( + KindUnknown SymKind = iota + KindImport + KindType + KindVar + KindConst + KindFunction + KindReceiver + KindParameter + KindResult + KindTag +) + +type symbol struct { + kind SymKind + pos token.Pos + scope *scope +} + +type scope struct { + outer *scope + syms map[string]*symbol +} + +func newScope(outer *scope) *scope { + return &scope{ + outer: outer, + syms: make(map[string]*symbol), + } +} + +func (s *scope) isGlobal() bool { + return s.outer == nil +} + +func (s *scope) lookup(n string) *symbol { + return s.syms[n] +} + +func (s *scope) deepLookup(n string) *symbol { + for x := s; x != nil; x = x.outer { + if sym := x.lookup(n); sym != nil { + return sym + } + } + return nil +} + +func (s *scope) add(name string, kind SymKind, pos token.Pos) { + if s.syms[name] != nil { + return + } + + s.syms[name] = &symbol{ + kind: kind, + pos: pos, + scope: s, + } +} diff --git a/tools/go_generics/go_generics_unittest.sh b/tools/go_generics/go_generics_unittest.sh new file mode 100755 index 000000000..44b22db91 --- /dev/null +++ b/tools/go_generics/go_generics_unittest.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Bash "safe-mode": Treat command failures as fatal (even those that occur in +# pipes), and treat unset variables as errors. +set -eu -o pipefail + +# This file will be generated as a self-extracting shell script in order to +# eliminate the need for any runtime dependencies. The tarball at the end will +# include the go_generics binary, as well as a subdirectory named +# generics_tests. See the BUILD file for more information. +declare -r temp=$(mktemp -d) +function cleanup() { + rm -rf "${temp}" +} +# trap cleanup EXIT + +# Print message in "$1" then exit with status 1. +function die () { + echo "$1" 1>&2 + exit 1 +} + +# This prints the line number of __BUNDLE__ below, that should be the last line +# of this script. After that point, the concatenated archive will be the +# contents. +declare -r tgz=`awk '/^__BUNDLE__/ {print NR + 1; exit 0; }' $0` +tail -n+"${tgz}" $0 | tar -xzv -C "${temp}" + +# The target for the test. +declare -r binary="$(find ${temp} -type f -a -name go_generics)" +declare -r input_dirs="$(find ${temp} -type d -a -name generics_tests)/*" + +# Go through all test cases. +for f in ${input_dirs}; do + base=$(basename "${f}") + + # Run go_generics on the input file. + opts=$(head -n 1 ${f}/opts.txt) + out="${f}/output/generated.go" + expected="${f}/output/output.go" + ${binary} ${opts} "-i=${f}/input.go" "-o=${out}" || die "go_generics failed for test case \"${base}\"" + + # Compare the outputs. + diff ${expected} ${out} + if [ $? -ne 0 ]; then + echo "Expected:" + cat ${expected} + echo "Actual:" + cat ${out} + die "Actual output is different from expected for test \"${base}\"" + fi +done + +echo "PASS" +exit 0 +__BUNDLE__ diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD new file mode 100644 index 000000000..2fd5a200d --- /dev/null +++ b/tools/go_generics/go_merge/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "go_merge", + srcs = ["main.go"], + visibility = ["//:sandbox"], +) diff --git a/tools/go_generics/go_merge/main.go b/tools/go_generics/go_merge/main.go new file mode 100644 index 000000000..f6a331123 --- /dev/null +++ b/tools/go_generics/go_merge/main.go @@ -0,0 +1,139 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "io/ioutil" + "os" + "path/filepath" + "strconv" +) + +var ( + output = flag.String("o", "", "output `file`") +) + +func fatalf(s string, args ...interface{}) { + fmt.Fprintf(os.Stderr, s, args...) + os.Exit(1) +} + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [options] <input1> [<input2> ...]\n", os.Args[0]) + flag.PrintDefaults() + } + + flag.Parse() + if *output == "" || len(flag.Args()) == 0 { + flag.Usage() + os.Exit(1) + } + + // Load all files. + files := make(map[string]*ast.File) + fset := token.NewFileSet() + var name string + for _, fname := range flag.Args() { + f, err := parser.ParseFile(fset, fname, nil, parser.ParseComments|parser.DeclarationErrors|parser.SpuriousErrors) + if err != nil { + fatalf("%v\n", err) + } + + files[fname] = f + if name == "" { + name = f.Name.Name + } else if name != f.Name.Name { + fatalf("Expected '%s' for package name instead of '%s'.\n", name, f.Name.Name) + } + } + + // Merge all files into one. + pkg := &ast.Package{ + Name: name, + Files: files, + } + f := ast.MergePackageFiles(pkg, ast.FilterUnassociatedComments|ast.FilterFuncDuplicates|ast.FilterImportDuplicates) + + // Create a new declaration slice with all imports at the top, merging any + // redundant imports. + imports := make(map[string]*ast.ImportSpec) + var anonImports []*ast.ImportSpec + for _, d := range f.Decls { + if g, ok := d.(*ast.GenDecl); ok && g.Tok == token.IMPORT { + for _, s := range g.Specs { + i := s.(*ast.ImportSpec) + p, _ := strconv.Unquote(i.Path.Value) + var n string + if i.Name == nil { + n = filepath.Base(p) + } else { + n = i.Name.Name + } + if n == "_" { + anonImports = append(anonImports, i) + } else { + if i2, ok := imports[n]; ok { + if first, second := i.Path.Value, i2.Path.Value; first != second { + fatalf("Conflicting paths for import name '%s': '%s' vs. '%s'\n", n, first, second) + } + } else { + imports[n] = i + } + } + } + } + } + newDecls := make([]ast.Decl, 0, len(f.Decls)) + if l := len(imports) + len(anonImports); l > 0 { + // Non-NoPos Lparen is needed for Go to recognize more than one spec in + // ast.GenDecl.Specs. + d := &ast.GenDecl{ + Tok: token.IMPORT, + Lparen: token.NoPos + 1, + Specs: make([]ast.Spec, 0, l), + } + for _, i := range imports { + d.Specs = append(d.Specs, i) + } + for _, i := range anonImports { + d.Specs = append(d.Specs, i) + } + newDecls = append(newDecls, d) + } + for _, d := range f.Decls { + if g, ok := d.(*ast.GenDecl); !ok || g.Tok != token.IMPORT { + newDecls = append(newDecls, d) + } + } + f.Decls = newDecls + + // Write the output file. + var buf bytes.Buffer + if err := format.Node(&buf, fset, f); err != nil { + fatalf("%v\n", err) + } + + if err := ioutil.WriteFile(*output, buf.Bytes(), 0644); err != nil { + fatalf("%v\n", err) + } +} diff --git a/tools/go_generics/imports.go b/tools/go_generics/imports.go new file mode 100644 index 000000000..148dc7216 --- /dev/null +++ b/tools/go_generics/imports.go @@ -0,0 +1,150 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "strconv" + + "gvisor.dev/gvisor/tools/go_generics/globals" +) + +type importedPackage struct { + newName string + path string +} + +// updateImportIdent modifies the given import identifier with the new name +// stored in the used map. If the identifier doesn't exist in the used map yet, +// a new name is generated and inserted into the map. +func updateImportIdent(orig string, imports mapValue, id *ast.Ident, used map[string]*importedPackage) error { + importName := id.Name + + // If the name is already in the table, just use the new name. + m := used[importName] + if m != nil { + id.Name = m.newName + return nil + } + + // Create a new entry in the used map. + path := imports[importName] + if path == "" { + return fmt.Errorf("Unknown path to package '%s', used in '%s'", importName, orig) + } + + m = &importedPackage{ + newName: fmt.Sprintf("__generics_imported%d", len(used)), + path: strconv.Quote(path), + } + used[importName] = m + + id.Name = m.newName + + return nil +} + +// convertExpression creates a new string that is a copy of the input one with +// all imports references renamed to the names in the "used" map. If the +// referenced import isn't in "used" yet, a new one is created based on the path +// in "imports" and stored in "used". For example, if string s is +// "math.MaxUint32-math.MaxUint16+10", it would be converted to +// "x.MaxUint32-x.MathUint16+10", where x is a generated name. +func convertExpression(s string, imports mapValue, used map[string]*importedPackage) (string, error) { + // Parse the expression in the input string. + expr, err := parser.ParseExpr(s) + if err != nil { + return "", fmt.Errorf("Unable to parse \"%s\": %v", s, err) + } + + // Go through the AST and update references. + var retErr error + ast.Inspect(expr, func(n ast.Node) bool { + switch x := n.(type) { + case *ast.SelectorExpr: + if id := globals.GetIdent(x.X); id != nil { + if err := updateImportIdent(s, imports, id, used); err != nil { + retErr = err + } + return false + } + } + return true + }) + if retErr != nil { + return "", retErr + } + + // Convert the modified AST back to a string. + fset := token.NewFileSet() + var buf bytes.Buffer + if err := format.Node(&buf, fset, expr); err != nil { + return "", err + } + + return string(buf.Bytes()), nil +} + +// updateImports replaces all maps in the input slice with copies where the +// mapped values have had all references to imported packages renamed to +// generated names. It also returns an import declaration for all the renamed +// import packages. +// +// For example, if the input maps contains A=math.B and C=math.D, the updated +// maps will instead contain A=__generics_imported0.B and +// C=__generics_imported0.C, and the 'import __generics_imported0 "math"' would +// be returned as the import declaration. +func updateImports(maps []mapValue, imports mapValue) (ast.Decl, error) { + importsUsed := make(map[string]*importedPackage) + + // Update all maps. + for i, m := range maps { + newMap := make(mapValue) + for n, e := range m { + updated, err := convertExpression(e, imports, importsUsed) + if err != nil { + return nil, err + } + + newMap[n] = updated + } + maps[i] = newMap + } + + // Nothing else to do if no imports are used in the expressions. + if len(importsUsed) == 0 { + return nil, nil + } + + // Create spec array for each new import. + specs := make([]ast.Spec, 0, len(importsUsed)) + for _, i := range importsUsed { + specs = append(specs, &ast.ImportSpec{ + Name: &ast.Ident{Name: i.newName}, + Path: &ast.BasicLit{Value: i.path}, + }) + } + + return &ast.GenDecl{ + Tok: token.IMPORT, + Specs: specs, + Lparen: token.NoPos + 1, + }, nil +} diff --git a/tools/go_generics/remove.go b/tools/go_generics/remove.go new file mode 100644 index 000000000..568a6bbd3 --- /dev/null +++ b/tools/go_generics/remove.go @@ -0,0 +1,105 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "go/ast" + "go/token" +) + +type typeSet map[string]struct{} + +// isTypeOrPointerToType determines if the given AST expression represents a +// type or a pointer to a type that exists in the provided type set. +func isTypeOrPointerToType(set typeSet, expr ast.Expr, starCount int) bool { + switch e := expr.(type) { + case *ast.Ident: + _, ok := set[e.Name] + return ok + case *ast.StarExpr: + if starCount > 1 { + return false + } + return isTypeOrPointerToType(set, e.X, starCount+1) + case *ast.ParenExpr: + return isTypeOrPointerToType(set, e.X, starCount) + default: + return false + } +} + +// isMethodOf determines if the given function declaration is a method of one +// of the types in the provided type set. To do that, it checks if the function +// has a receiver and that its type is either T or *T, where T is a type that +// exists in the set. This is per the spec: +// +// That parameter section must declare a single parameter, the receiver. Its +// type must be of the form T or *T (possibly using parentheses) where T is a +// type name. The type denoted by T is called the receiver base type; it must +// not be a pointer or interface type and it must be declared in the same +// package as the method. +func isMethodOf(set typeSet, f *ast.FuncDecl) bool { + // If the function doesn't have exactly one receiver, then it's + // definitely not a method. + if f.Recv == nil || len(f.Recv.List) != 1 { + return false + } + + return isTypeOrPointerToType(set, f.Recv.List[0].Type, 0) +} + +// removeTypeDefinitions removes the definition of all types contained in the +// provided type set. +func removeTypeDefinitions(set typeSet, d *ast.GenDecl) { + if d.Tok != token.TYPE { + return + } + + i := 0 + for _, gs := range d.Specs { + s := gs.(*ast.TypeSpec) + if _, ok := set[s.Name.Name]; !ok { + d.Specs[i] = gs + i++ + } + } + + d.Specs = d.Specs[:i] +} + +// removeTypes removes from the AST the definition of all types and their +// method sets that are contained in the provided type set. +func removeTypes(set typeSet, f *ast.File) { + // Go through the top-level declarations. + i := 0 + for _, decl := range f.Decls { + keep := true + switch d := decl.(type) { + case *ast.GenDecl: + countBefore := len(d.Specs) + removeTypeDefinitions(set, d) + keep = countBefore == 0 || len(d.Specs) > 0 + case *ast.FuncDecl: + keep = !isMethodOf(set, d) + } + + if keep { + f.Decls[i] = decl + i++ + } + } + + f.Decls = f.Decls[:i] +} diff --git a/tools/go_generics/rules_tests/BUILD b/tools/go_generics/rules_tests/BUILD new file mode 100644 index 000000000..8a329dfc6 --- /dev/null +++ b/tools/go_generics/rules_tests/BUILD @@ -0,0 +1,43 @@ +load("//tools:defs.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "instance", + out = "instance_test.go", + consts = { + "n": "20", + "m": "\"test\"", + "o": "math.MaxUint64", + }, + imports = { + "math": "math", + }, + package = "template_test", + template = ":test_template", + types = { + "t": "int", + }, +) + +go_template( + name = "test_template", + srcs = [ + "template.go", + ], + opt_consts = [ + "n", + "m", + "o", + ], + opt_types = ["t"], +) + +go_test( + name = "template_test", + srcs = [ + "instance_test.go", + "template_test.go", + ], +) diff --git a/tools/go_generics/rules_tests/template.go b/tools/go_generics/rules_tests/template.go new file mode 100644 index 000000000..aace61da1 --- /dev/null +++ b/tools/go_generics/rules_tests/template.go @@ -0,0 +1,42 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package template + +type t float + +const ( + n t = 10.1 + m = "abc" + o = 0 +) + +func max(a, b t) t { + if a > b { + return a + } + return b +} + +func add(a t) t { + return a + n +} + +func getName() string { + return m +} + +func getMax() uint64 { + return o +} diff --git a/tools/go_generics/rules_tests/template_test.go b/tools/go_generics/rules_tests/template_test.go new file mode 100644 index 000000000..b2a3446ef --- /dev/null +++ b/tools/go_generics/rules_tests/template_test.go @@ -0,0 +1,48 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package template_test + +import ( + "math" + "testing" +) + +func TestMax(t *testing.T) { + var a int = max(10, 20) + if a != 20 { + t.Errorf("Bad result of max, got %v, want %v", a, 20) + } +} + +func TestIntConst(t *testing.T) { + var a int = add(10) + if a != 30 { + t.Errorf("Bad result of add, got %v, want %v", a, 30) + } +} + +func TestStrConst(t *testing.T) { + v := getName() + if v != "test" { + t.Errorf("Bad name, got %v, want %v", v, "test") + } +} + +func TestImport(t *testing.T) { + v := getMax() + if v != math.MaxUint64 { + t.Errorf("Bad max value, got %v, want %v", v, uint64(math.MaxUint64)) + } +} diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD new file mode 100644 index 000000000..be49cf9c8 --- /dev/null +++ b/tools/go_marshal/BUILD @@ -0,0 +1,19 @@ +load("//tools:defs.bzl", "go_binary") + +licenses(["notice"]) + +go_binary( + name = "go_marshal", + srcs = ["main.go"], + visibility = [ + "//:sandbox", + ], + deps = [ + "//tools/go_marshal/gomarshal", + ], +) + +config_setting( + name = "marshal_config_verbose", + values = {"define": "gomarshal=verbose"}, +) diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md new file mode 100644 index 000000000..4886efddf --- /dev/null +++ b/tools/go_marshal/README.md @@ -0,0 +1,116 @@ +This package implements the go_marshal utility. + +# Overview + +`go_marshal` is a code generation utility similar to `go_stateify` for +automatically generating code to marshal go data structures to memory. + +`go_marshal` attempts to improve on `binary.Write` and the sentry's +`binary.Marshal` by moving the go runtime reflection necessary to marshal a +struct to compile-time. + +`go_marshal` automatically generates implementations for `abi.Marshallable` and +`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall +implementations) can directly invoke `safemem.Reader.ReadToBlocks` and +`safemem.Writer.WriteFromBlocks`. Data structures that require custom +serialization will have manual implementations for these interfaces. + +Data structures can be flagged for code generation by adding a struct-level +comment `// +marshal`. + +# Usage + +See `defs.bzl`: a new rule is provided, `go_marshal`. + +Under the hood, the `go_marshal` rule is used to generate a file that will +appear in a Go target; the output file should appear explicitly in a srcs list. +For example (note that the above is the preferred method): + +``` +load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_marshal") + +go_marshal( + name = "foo_abi", + srcs = ["foo.go"], + out = "foo_abi.go", + package = "foo", +) + +go_library( + name = "foo", + srcs = [ + "foo.go", + "foo_abi.go", + ], + ... +) +``` + +As part of the interface generation, `go_marshal` also generates some tests for +sanity checking the struct definitions for potential alignment issues, and a +simple round-trip test through Marshal/Unmarshal to verify the implementation. +These tests use reflection to verify properties of the ABI struct, and should be +considered part of the generated interfaces (but are too expensive to execute at +runtime). Ensure these tests run at some point. + +# Restrictions + +Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is +intended for ABI structs, which have these additional restrictions: + +- At the moment, `go_marshal` only supports struct declarations. + +- Structs are marshalled as packed types. This means no implicit padding is + inserted between fields shorter than the platform register size. For + alignment, manually insert padding fields. + +- Structs used with `go_marshal` must have a compile-time static size. This + means no dynamically sizes fields like slices or strings. Use statically + sized array (byte arrays for strings) instead. + +- No pointers, channel, map or function pointer fields, and no fields that are + arrays of these types. These don't make sense in an ABI data structure. + +- We could support opaque pointers as `uintptr`, but this is currently not + implemented. Implementing this would require handling the architecture + dependent native pointer size. + +- Fields must either be a primitive integer type (`byte`, + `[u]int{8,16,32,64}`), or of a type that implements abi.Marshallable. + +- `int` and `uint` fields are not allowed. Use an explicitly-sized numeric + type. + +- `float*` fields are currently not supported, but could be if necessary. + +# Appendix + +## Working with Non-Packed Structs + +ABI structs must generally be packed types, meaning they should have no implicit +padding between short fields. However, if a field is tagged +`marshal:"unaligned"`, `go_marshal` will fall back to a safer but slower +mechanism to deal with potentially unaligned fields. + +Note that the non-packed property is inheritted by any other struct that embeds +this struct, since the `go_marshal` tool currently can't reason about alignments +for embedded structs that are not aligned. + +Because of this, it's generally best to avoid using `marshal:"unaligned"` and +insert explicit padding fields instead. + +## Modifying the `go_marshal` Tool + +The following are some guidelines for modifying the `go_marshal` tool: + +- The `go_marshal` tool currently does a single pass over all types requesting + code generation, in arbitrary order. This means the generated code can't + directly obtain information about embedded marshallable types at + compile-time. One way to work around this restriction is to add a new + Marshallable interface method providing this piece of information, and + calling it from the generated code. Use this sparingly, as we want to rely + on compile-time information as much as possible for performance. + +- No runtime reflection in the code generated for the marshallable interface. + The entire point of the tool is to avoid runtime reflection. The generated + tests may use reflection. diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD new file mode 100644 index 000000000..c2a4d45c4 --- /dev/null +++ b/tools/go_marshal/analysis/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "analysis", + testonly = 1, + srcs = ["analysis_unsafe.go"], + visibility = [ + "//:sandbox", + ], +) diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go new file mode 100644 index 000000000..cd55cf5cb --- /dev/null +++ b/tools/go_marshal/analysis/analysis_unsafe.go @@ -0,0 +1,179 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package analysis implements common functionality used by generated +// go_marshal tests. +package analysis + +// All functions in this package are unsafe and are not intended for general +// consumption. They contain sharp edge cases and the caller is responsible for +// ensuring none of them are hit. Callers must be carefully to pass in only sane +// arguments. Failure to do so may cause panics at best and arbitrary memory +// corruption at worst. +// +// Never use outside of tests. + +import ( + "fmt" + "math/rand" + "reflect" + "testing" + "unsafe" +) + +// RandomizeValue assigns random value(s) to an abitrary type. This is intended +// for used with ABI structs from go_marshal, meaning the typical restrictions +// apply (fixed-size types, no pointers, maps, channels, etc), and should only +// be used on zeroed values to avoid overwriting pointers to active go objects. +// +// Internally, we populate the type with random data by doing an unsafe cast to +// access the underlying memory of the type and filling it as if it were a byte +// slice. This almost gets us what we want, but padding fields named "_" are +// normally not accessible, so we walk the type and recursively zero all "_" +// fields. +// +// Precondition: x must be a pointer. x must not contain any valid +// pointers to active go objects (pointer fields aren't allowed in ABI +// structs anyways), or we'd be violating the go runtime contract and +// the GC may malfunction. +func RandomizeValue(x interface{}) { + v := reflect.Indirect(reflect.ValueOf(x)) + if !v.CanSet() { + panic("RandomizeType() called with an unaddressable value. You probably need to pass a pointer to the argument") + } + + // Cast the underlying memory for the type into a byte slice. + var b []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + // Note: v.UnsafeAddr panics if x is passed by value. x should be a pointer. + hdr.Data = v.UnsafeAddr() + hdr.Len = int(v.Type().Size()) + hdr.Cap = hdr.Len + + // Fill the byte slice with random data, which in effect fills the type with + // random values. + n, err := rand.Read(b) + if err != nil || n != len(b) { + panic("unreachable") + } + + // Normally, padding fields are not accessible, so zero them out. + reflectZeroPaddingFields(v.Type(), b, false) +} + +// reflectZeroPaddingFields assigns zero values to padding fields for the value +// of type r, represented by the memory in data. Padding fields are defined as +// fields with the name "_". If zero is true, the immediate value itself is +// zeroed. In addition, the type is recursively scanned for padding fields in +// inner types. +// +// This is used for zeroing padding fields after calling RandomizeValue. +func reflectZeroPaddingFields(r reflect.Type, data []byte, zero bool) { + if zero { + for i, _ := range data { + data[i] = 0 + } + } + switch r.Kind() { + case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64: + // These types are explicitly allowed in an ABI type, but we don't need + // to recurse further as they're scalar types. + case reflect.Struct: + for i, numFields := 0, r.NumField(); i < numFields; i++ { + f := r.Field(i) + off := f.Offset + len := f.Type.Size() + window := data[off : off+len] + reflectZeroPaddingFields(f.Type, window, f.Name == "_") + } + case reflect.Array: + eLen := int(r.Elem().Size()) + if int(r.Size()) != eLen*r.Len() { + panic("Array has unexpected size?") + } + for i, n := 0, r.Len(); i < n; i++ { + reflectZeroPaddingFields(r.Elem(), data[i*eLen:(i+1)*eLen], false) + } + default: + panic(fmt.Sprintf("Type %v not allowed in ABI struct", r.Kind())) + + } +} + +// AlignmentCheck ensures the definition of the type represented by typ doesn't +// cause the go compiler to emit implicit padding between elements of the type +// (i.e. fields in a struct). +// +// AlignmentCheck doesn't explicitly recurse for embedded structs because any +// struct present in an ABI struct must also be Marshallable, and therefore +// they're aligned by definition (or their alignment check would have failed). +func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) { + switch typ.Kind() { + case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64: + // Primitive types are always considered well aligned. Primitive types + // that are fields in structs are checked independently, this branch + // exists to handle recursive calls to alignmentCheck. + case reflect.Struct: + xOff := 0 + nextXOff := 0 + skipNext := false + for i, numFields := 0, typ.NumField(); i < numFields; i++ { + xOff = nextXOff + f := typ.Field(i) + fmt.Printf("Checking alignment of %s.%s @ %d [+%d]...\n", typ.Name(), f.Name, f.Offset, f.Type.Size()) + nextXOff = int(f.Offset + f.Type.Size()) + + if f.Name == "_" { + // Padding fields need not be aligned. + fmt.Printf("Padding field of type %v\n", f.Type) + continue + } + + if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" { + skipNext = true + continue + } + + if skipNext { + skipNext = false + fmt.Printf("Skipping alignment check for field %s.%s explicitly marked as unaligned.\n", typ.Name(), f.Name) + continue + } + + if xOff != int(f.Offset) { + implicitPad := int(f.Offset) - xOff + t.Fatalf("Suspect offset for field %s.%s, detected an implicit %d byte padding from offset %d to %d; either add %d bytes of explicit padding before this field or tag it as `marshal:\"unaligned\"`.", typ.Name(), f.Name, implicitPad, xOff, f.Offset, implicitPad) + } + } + + // Ensure structs end on a byte explicitly defined by the type. + if typ.NumField() > 0 && nextXOff != int(typ.Size()) { + implicitPad := int(typ.Size()) - nextXOff + f := typ.Field(typ.NumField() - 1) // Final field + if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" { + // Final field explicitly marked unaligned. + break + } + t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.", + typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name) + } + case reflect.Array: + // Independent arrays are also always considered well aligned. We only + // need to worry about their alignment when they're embedded in structs, + // which we handle above. + default: + t.Fatalf("Unsupported type in ABI struct while checking for field alignment for type: %v", typ.Kind()) + } + return true, uint64(typ.Size()) +} diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl new file mode 100644 index 000000000..323e33882 --- /dev/null +++ b/tools/go_marshal/defs.bzl @@ -0,0 +1,65 @@ +"""Marshal is a tool for generating marshalling interfaces for Go types.""" + +def _go_marshal_impl(ctx): + """Execute the go_marshal tool.""" + output = ctx.outputs.lib + output_test = ctx.outputs.test + + # Run the marshal command. + args = ["-output=%s" % output.path] + args += ["-pkg=%s" % ctx.attr.package] + args += ["-output_test=%s" % output_test.path] + + if ctx.attr.debug: + args += ["-debug"] + + args += ["--"] + for src in ctx.attr.srcs: + args += [f.path for f in src.files.to_list()] + ctx.actions.run( + inputs = ctx.files.srcs, + outputs = [output, output_test], + mnemonic = "GoMarshal", + progress_message = "go_marshal: %s" % ctx.label, + arguments = args, + executable = ctx.executable._tool, + ) + +# Generates save and restore logic from a set of Go files. +# +# Args: +# name: the name of the rule. +# srcs: the input source files. These files should include all structs in the +# package that need to be saved. +# imports: an optional list of extra, non-aliased, Go-style absolute import +# paths. +# out: the name of the generated file output. This must not conflict with any +# other files and must be added to the srcs of the relevant go_library. +# package: the package name for the input sources. +go_marshal = rule( + implementation = _go_marshal_impl, + attrs = { + "srcs": attr.label_list(mandatory = True, allow_files = True), + "imports": attr.string_list(mandatory = False), + "package": attr.string(mandatory = True), + "debug": attr.bool(doc = "enable debugging output from the go_marshal tool"), + "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_marshal:go_marshal")), + }, + outputs = { + "lib": "%{name}_unsafe.go", + "test": "%{name}_test.go", + }, +) + +# marshal_deps are the dependencies requied by generated code. +marshal_deps = [ + "//pkg/gohacks", + "//pkg/safecopy", + "//pkg/usermem", + "//tools/go_marshal/marshal", +] + +# marshal_test_deps are required by test targets. +marshal_test_deps = [ + "//tools/go_marshal/analysis", +] diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD new file mode 100644 index 000000000..44cb33ae4 --- /dev/null +++ b/tools/go_marshal/gomarshal/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "gomarshal", + srcs = [ + "generator.go", + "generator_interfaces.go", + "generator_interfaces_array_newtype.go", + "generator_interfaces_primitive_newtype.go", + "generator_interfaces_struct.go", + "generator_tests.go", + "util.go", + ], + stateify = False, + visibility = [ + "//:sandbox", + ], + deps = ["//tools/tags"], +) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go new file mode 100644 index 000000000..177013dbb --- /dev/null +++ b/tools/go_marshal/gomarshal/generator.go @@ -0,0 +1,499 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gomarshal implements the go_marshal code generator. See README.md. +package gomarshal + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "sort" + "strings" + + "gvisor.dev/gvisor/tools/tags" +) + +// List of identifiers we use in generated code that may conflict with a +// similarly-named source identifier. Abort gracefully when we see these to +// avoid potentially confusing compilation failures in generated code. +// +// This only applies to import aliases at the moment. All other identifiers +// are qualified by a receiver argument, since they're struct fields. +// +// All recievers are single letters, so we don't allow import aliases to be a +// single letter. +var badIdents = []string{ + "addr", "blk", "buf", "dst", "dsts", "count", "err", "hdr", "idx", "inner", + "length", "limit", "ptr", "size", "src", "srcs", "task", "val", + // All single-letter identifiers. +} + +// Constructed fromt badIdents in init(). +var badIdentsMap map[string]struct{} + +func init() { + badIdentsMap = make(map[string]struct{}) + for _, ident := range badIdents { + badIdentsMap[ident] = struct{}{} + } +} + +// Generator drives code generation for a single invocation of the go_marshal +// utility. +// +// The Generator holds arguments passed to the tool, and drives parsing, +// processing and code Generator for all types marked with +marshal declared in +// the input files. +// +// See Generator.run() as the entry point. +type Generator struct { + // Paths to input go source files. + inputs []string + // Output file to write generated go source. + output *os.File + // Output file to write generated tests. + outputTest *os.File + // Package name for the generated file. + pkg string + // Set of extra packages to import in the generated file. + imports *importTable +} + +// NewGenerator creates a new code Generator. +func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*Generator, error) { + f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err) + } + fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err) + } + g := Generator{ + inputs: srcs, + output: f, + outputTest: fTest, + pkg: pkg, + imports: newImportTable(), + } + for _, i := range imports { + // All imports on the extra imports list are unconditionally marked as + // used, so that they're always added to the generated code. + g.imports.add(i).markUsed() + } + + // The following imports may or may not be used by the generated code, + // depending on what's required for the target types. Don't mark these as + // used by default. + g.imports.add("io") + g.imports.add("reflect") + g.imports.add("runtime") + g.imports.add("unsafe") + g.imports.add("gvisor.dev/gvisor/pkg/gohacks") + g.imports.add("gvisor.dev/gvisor/pkg/safecopy") + g.imports.add("gvisor.dev/gvisor/pkg/usermem") + g.imports.add("gvisor.dev/gvisor/tools/go_marshal/marshal") + + return &g, nil +} + +// writeHeader writes the header for the generated source file. The header +// includes the package name, package level comments and import statements. +func (g *Generator) writeHeader() error { + var b sourceBuffer + b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") + + // Emit build tags. + if t := tags.Aggregate(g.inputs); len(t) > 0 { + b.emit(strings.Join(t.Lines(), "\n")) + b.emit("\n\n") + } + + // Package header. + b.emit("package %s\n\n", g.pkg) + if err := b.write(g.output); err != nil { + return err + } + + return g.imports.write(g.output) +} + +// writeTypeChecks writes a statement to force the compiler to perform a type +// check for all Marshallable types referenced by the generated code. +func (g *Generator) writeTypeChecks(ms map[string]struct{}) error { + if len(ms) == 0 { + return nil + } + + msl := make([]string, 0, len(ms)) + for m, _ := range ms { + msl = append(msl, m) + } + sort.Strings(msl) + + var buf bytes.Buffer + fmt.Fprint(&buf, "// Marshallable types used by this file.\n") + + for _, m := range msl { + fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m) + } + fmt.Fprint(&buf, "\n") + + _, err := fmt.Fprint(g.output, buf.String()) + return err +} + +// parse processes all input files passed this generator and produces a set of +// parsed go ASTs. +func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { + debugf("go_marshal invoked with %d input files:\n", len(g.inputs)) + for _, path := range g.inputs { + debugf(" %s\n", path) + } + + files := make([]*ast.File, 0, len(g.inputs)) + fsets := make([]*token.FileSet, 0, len(g.inputs)) + + for _, path := range g.inputs { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) + if err != nil { + // Not a valid input file? + return nil, nil, fmt.Errorf("Input %q can't be parsed: %v", path, err) + } + + if debugEnabled() { + debugf("AST for %q:\n", path) + ast.Print(fset, f) + } + + files = append(files, f) + fsets = append(fsets, fset) + } + + return files, fsets, nil +} + +// sliceAPI carries information about the '+marshal slice' directive. +type sliceAPI struct { + // Comment node in the AST containing the +marshal tag. + comment *ast.Comment + // Identifier fragment to use when naming generated functions for the slice + // API. + ident string + // Whether the generated functions should reference the newtype name, or the + // inner type name. Only meaningful on newtype declarations on primitives. + inner bool +} + +// marshallableType carries information about a type marked with the '+marshal' +// directive. +type marshallableType struct { + spec *ast.TypeSpec + slice *sliceAPI +} + +func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) marshallableType { + mt := marshallableType{ + spec: spec, + slice: nil, + } + + var unhandledTags []string + + for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) { + if strings.HasPrefix(tag, "slice:") { + tokens := strings.Split(tag, ":") + if len(tokens) < 2 || len(tokens) > 3 { + abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag)) + } + if len(tokens[1]) == 0 { + abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'") + } + + sa := &sliceAPI{ + comment: tagLine, + ident: tokens[1], + } + mt.slice = sa + + if len(tokens) == 3 { + if tokens[2] != "inner" { + abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'") + } + sa.inner = true + } + + continue + } + + unhandledTags = append(unhandledTags, tag) + } + + if len(unhandledTags) > 0 { + abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " "))) + } + + return mt +} + +// collectMarshallableTypes walks the parsed AST and collects a list of type +// declarations for which we need to generate the Marshallable interface. +func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []marshallableType { + var types []marshallableType + for _, decl := range a.Decls { + gdecl, ok := decl.(*ast.GenDecl) + // Type declaration? + if !ok || gdecl.Tok != token.TYPE { + debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n") + continue + } + // Does it have a comment? + if gdecl.Doc == nil { + debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n") + continue + } + // Does the comment contain a "+marshal" line? + marked := false + var tagLine *ast.Comment + for _, c := range gdecl.Doc.List { + if strings.HasPrefix(c.Text, "// +marshal") { + marked = true + tagLine = c + break + } + } + if !marked { + debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n") + continue + } + for _, spec := range gdecl.Specs { + // We already confirmed we're in a type declaration earlier, so this + // cast will succeed. + t := spec.(*ast.TypeSpec) + switch t.Type.(type) { + case *ast.StructType: + debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) + case *ast.Ident: // Newtype on primitive. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) + case *ast.ArrayType: // Newtype on array. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) + default: + // A user specifically requested marshalling on this type, but we + // don't support it. + abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) + } + types = append(types, newMarshallableType(f, tagLine, t)) + + } + } + return types +} + +// collectImports collects all imports from all input source files. Some of +// these imports are copied to the generated output, if they're referenced by +// the generated code. +// +// collectImports de-duplicates imports while building the list, and ensures +// identifiers in the generated code don't conflict with any imported package +// names. +func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { + is := make(map[string]importStmt) + for _, decl := range a.Decls { + gdecl, ok := decl.(*ast.GenDecl) + // Import statement? + if !ok || gdecl.Tok != token.IMPORT { + continue + } + for _, spec := range gdecl.Specs { + i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f) + debugf("Collected import '%s' as '%s'\n", i.path, i.name) + + // Make sure we have an import that doesn't use any local names that + // would conflict with identifiers in the generated code. + if len(i.name) == 1 && i.name != "_" { + abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) + } + if _, ok := badIdentsMap[i.name]; ok { + abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) + } + } + } + return is + +} + +func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interfaceGenerator { + i := newInterfaceGenerator(t.spec, fset) + switch ty := t.spec.Type.(type) { + case *ast.StructType: + i.validateStruct(t.spec, ty) + i.emitMarshallableForStruct(ty) + if t.slice != nil { + i.emitMarshallableSliceForStruct(ty, t.slice) + } + case *ast.Ident: + i.validatePrimitiveNewtype(ty) + i.emitMarshallableForPrimitiveNewtype(ty) + if t.slice != nil { + i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice) + } + case *ast.ArrayType: + i.validateArrayNewtype(t.spec.Name, ty) + // After validate, we can safely call arrayLen. + i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident)) + if t.slice != nil { + abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?")) + } + default: + // This should've been filtered out by collectMarshallabeTypes. + panic(fmt.Sprintf("Unexpected type %+v", ty)) + } + return i +} + +// generateOneTestSuite generates a test suite for the automatically generated +// implementations type t. +func (g *Generator) generateOneTestSuite(t marshallableType) *testGenerator { + i := newTestGenerator(t.spec) + i.emitTests(t.slice) + return i +} + +// Run is the entry point to code generation using g. +// +// Run parses all input source files specified in g and emits generated code. +func (g *Generator) Run() error { + // Parse our input source files into ASTs and token sets. + asts, fsets, err := g.parse() + if err != nil { + return err + } + + if len(asts) != len(fsets) { + panic("ASTs and FileSets don't match") + } + + // Map of imports in source files; key = local package name, value = import + // path. + is := make(map[string]importStmt) + for i, a := range asts { + // Collect all imports from the source files. We may need to copy some + // of these to the generated code if they're referenced. This has to be + // done before the loop below because we need to process all ASTs before + // we start requesting imports to be copied one by one as we encounter + // them in each generated source. + for name, i := range g.collectImports(a, fsets[i]) { + is[name] = i + } + } + + var impls []*interfaceGenerator + var ts []*testGenerator + // Set of Marshallable types referenced by generated code. + ms := make(map[string]struct{}) + for i, a := range asts { + // Collect type declarations marked for code generation and generate + // Marshallable interfaces. + for _, t := range g.collectMarshallableTypes(a, fsets[i]) { + impl := g.generateOne(t, fsets[i]) + // Collect Marshallable types referenced by the generated code. + for ref, _ := range impl.ms { + ms[ref] = struct{}{} + } + impls = append(impls, impl) + // Collect imports referenced by the generated code and add them to + // the list of imports we need to copy to the generated code. + for name, _ := range impl.is { + if !g.imports.markUsed(name) { + panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name)) + } + } + ts = append(ts, g.generateOneTestSuite(t)) + } + } + + // Write output file header. These include things like package name and + // import statements. + if err := g.writeHeader(); err != nil { + return err + } + + // Write type checks for referenced marshallable types to output file. + if err := g.writeTypeChecks(ms); err != nil { + return err + } + + // Write generated interfaces to output file. + for _, i := range impls { + if err := i.write(g.output); err != nil { + return err + } + } + + // Write generated tests to test file. + return g.writeTests(ts) +} + +// writeTests outputs tests for the generated interface implementations to a go +// source file. +func (g *Generator) writeTests(ts []*testGenerator) error { + var b sourceBuffer + b.emit("package %s\n\n", g.pkg) + if err := b.write(g.outputTest); err != nil { + return err + } + + // Collect and write test import statements. + imports := newImportTable() + for _, t := range ts { + imports.merge(t.imports) + } + + if err := imports.write(g.outputTest); err != nil { + return err + } + + // Write test functions. + + // If we didn't generate any Marshallable implementations, we can't just + // emit an empty test file, since that causes the build to fail with "no + // tests/benchmarks/examples found". Unfortunately we can't signal bazel to + // omit the entire package since the outputs are already defined before + // go-marshal is called. If we'd otherwise emit an empty test suite, emit an + // empty example instead. + if len(ts) == 0 { + b.reset() + b.emit("func Example() {\n") + b.inIndent(func() { + b.emit("// This example is intentionally empty to ensure this file contains at least\n") + b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n") + b.emit("// is marked marshallable, but emitting a test file with no entities results\n") + b.emit("// in a build failure.\n") + }) + b.emit("}\n") + return b.write(g.outputTest) + } + + for _, t := range ts { + if err := t.write(g.outputTest); err != nil { + return err + } + } + return nil +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go new file mode 100644 index 000000000..e3c3dac63 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -0,0 +1,276 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +import ( + "fmt" + "go/ast" + "go/token" + "strings" +) + +// interfaceGenerator generates marshalling interfaces for a single type. +// +// getState is not thread-safe. +type interfaceGenerator struct { + sourceBuffer + + // The type we're serializing. + t *ast.TypeSpec + + // Receiver argument for generated methods. + r string + + // FileSet containing the tokens for the type we're processing. + f *token.FileSet + + // is records external packages referenced by the generated implementation. + is map[string]struct{} + + // ms records Marshallable types referenced by the generated implementation + // of t's interfaces. + ms map[string]struct{} + + // as records embedded fields in t that are potentially not packed. The key + // is the accessor for the field. + as map[string]struct{} +} + +// typeName returns the name of the type this g represents. +func (g *interfaceGenerator) typeName() string { + return g.t.Name.Name +} + +// newinterfaceGenerator creates a new interface generator. +func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { + g := &interfaceGenerator{ + t: t, + r: receiverName(t), + f: fset, + is: make(map[string]struct{}), + ms: make(map[string]struct{}), + as: make(map[string]struct{}), + } + g.recordUsedMarshallable(g.typeName()) + return g +} + +func (g *interfaceGenerator) recordUsedMarshallable(m string) { + g.ms[m] = struct{}{} + +} + +func (g *interfaceGenerator) recordUsedImport(i string) { + g.is[i] = struct{}{} +} + +func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) { + g.as[fieldName] = struct{}{} +} + +// abortAt aborts the go_marshal tool with the given error message, with a +// reference position to the input source. Same as abortAt, but uses g to +// resolve p to position. +func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { + abortAt(g.f.Position(p), msg) +} + +// scalarSize returns the size of type identified by t. If t isn't a primitive +// type, the size isn't known at code generation time, and must be resolved via +// the marshal.Marshallable interface. +func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) { + switch t.Name { + case "int8", "uint8", "byte": + return 1, false + case "int16", "uint16": + return 2, false + case "int32", "uint32": + return 4, false + case "int64", "uint64": + return 8, false + default: + return 0, true + } +} + +func (g *interfaceGenerator) shift(bufVar string, n int) { + g.emit("%s = %s[%d:]\n", bufVar, bufVar, n) +} + +func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { + g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) +} + +// marshalScalar writes a single scalar to a byte slice. +func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(%s)\n", bufVar, accessor) + g.shift(bufVar, 1) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor) + g.shift(bufVar, 2) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor) + g.shift(bufVar, 4) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor) + g.shift(bufVar, 8) + default: + g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) + g.shiftDynamic(bufVar, accessor) + } +} + +// unmarshalScalar reads a single scalar from a byte slice. +func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) { + switch typ { + case "byte": + g.emit("%s = %s[0]\n", accessor, bufVar) + g.shift(bufVar, 1) + case "int8", "uint8": + g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar) + g.shift(bufVar, 1) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar) + g.shift(bufVar, 2) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar) + g.shift(bufVar, 4) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar) + g.shift(bufVar, 8) + default: + g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) + g.shiftDynamic(bufVar, accessor) + g.recordPotentiallyNonPackedField(accessor) + } +} + +// emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a +// byte slice, bypassing escape analysis. The caller is responsible for ensuring +// srcPtr lives until they're done with dstVar, the runtime does not consider +// dstVar dependent on srcPtr due to the escape analysis bypass. +// +// srcPtr must be a pointer. +// +// This function uses internally uses the identifier "hdr", and cannot be used +// in a context where it is already bound. +func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) { + g.recordUsedImport("gohacks") + g.emit("// Construct a slice backed by dst's underlying memory.\n") + g.emit("var %s []byte\n", dstVar) + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar) + g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr) + g.emit("hdr.Len = %s\n", lenExpr) + g.emit("hdr.Cap = %s\n\n", lenExpr) +} + +// emitCastToByteSlice unsafely casts a slice with elements of an abitrary type +// to a byte slice. As part of the cast, the byte slice is made to look +// independent of the src slice by bypassing escape analysis. This means the +// byte slice can be used without causing the source to escape. The caller is +// responsible for ensuring srcPtr lives until they're done with dstVar, as the +// runtime no longer considers dstVar dependent on srcPtr and is free to GC it. +// +// srcPtr must be a pointer. +// +// This function uses internally uses the identifiers "ptr", "val" and "hdr", +// and cannot be used in a context where these identifiers are already bound. +func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) { + g.emitNoEscapeSliceDataPointer(srcPtr, "val") + + g.emit("// Construct a slice backed by dst's underlying memory.\n") + g.emit("var %s []byte\n", dstVar) + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar) + g.emit("hdr.Data = uintptr(val)\n") + g.emit("hdr.Len = %s\n", lenExpr) + g.emit("hdr.Cap = %s\n\n", lenExpr) +} + +// emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an +// unsafe.Pointer, bypassing escape analysis. The caller is responsible for +// ensuring srcPtr lives until they're done with dstVar, as the runtime no +// longer considers dstVar dependent on srcPtr and is free to GC it. +// +// srcPtr must be a pointer. +// +// This function uses internally uses the identifier "ptr" cannot be used in a +// context where this identifier is already bound. +func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) { + g.recordUsedImport("gohacks") + g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr) + g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar) +} + +func (g *interfaceGenerator) emitKeepAlive(ptrVar string) { + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar) + g.emit("// must live until the use above.\n") + g.emit("runtime.KeepAlive(%s)\n", ptrVar) +} + +func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) { + switch x := e.X.(type) { + case *ast.BinaryExpr: + // Recursively expand sub-expression. + g.expandBinaryExpr(b, x) + case *ast.Ident: + fmt.Fprintf(b, "%s", x.Name) + case *ast.BasicLit: + fmt.Fprintf(b, "%s", x.Value) + default: + g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") + } + + fmt.Fprintf(b, "%s", e.Op) + + switch y := e.Y.(type) { + case *ast.BinaryExpr: + // Recursively expand sub-expression. + g.expandBinaryExpr(b, y) + case *ast.Ident: + fmt.Fprintf(b, "%s", y.Name) + case *ast.BasicLit: + fmt.Fprintf(b, "%s", y.Value) + default: + g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") + } +} + +// arrayLenExpr returns a string containing a valid golang expression +// representing the length of array a. The returned expression should be treated +// as a single value, and will be already parenthesized as required. +func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string { + var b strings.Builder + + switch l := a.Len.(type) { + case *ast.Ident: + fmt.Fprintf(&b, "%s", l.Name) + case *ast.BasicLit: + fmt.Fprintf(&b, "%s", l.Value) + case *ast.BinaryExpr: + g.expandBinaryExpr(&b, l) + return fmt.Sprintf("(%s)", b.String()) + default: + g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") + } + return b.String() +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go new file mode 100644 index 000000000..72ef03a22 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -0,0 +1,146 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// newtypes on arrays. + +package gomarshal + +import ( + "fmt" + "go/ast" +) + +func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) { + if a.Len == nil { + g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) + } + + if _, ok := a.Elt.(*ast.Ident); !ok { + g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) + } +} + +func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + lenExpr := g.arrayLenExpr(a) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(elt); !dynamic { + g.emit("return %d * %s\n", size, lenExpr) + } else { + g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst") + }) + g.emit("}\n") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src") + }) + g.emit("}\n") + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Array newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := w.Write(buf)\n") + g.emitKeepAlive(g.r) + g.emit("return int64(length), err\n") + + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go new file mode 100644 index 000000000..39f654ea8 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -0,0 +1,289 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// newtypes on primitives. + +package gomarshal + +import ( + "fmt" + "go/ast" +) + +// marshalPrimitiveScalar writes a single primitive variable to a byte +// slice. +func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) + default: + g.emit("// Explicilty cast to the underlying type before dispatching to\n") + g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor) + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. +func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { + switch typ { + case "byte": + g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) + case "int8", "uint8": + g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) + default: + g.emit("// Explicilty cast to the underlying type before dispatching to\n") + g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor) + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { + switch t.Name { + case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": + // These are the only primitive types we're allow. Below, we provide + // suggestions for some disallowed types and reject them, then attempt + // to marshal any remaining types by invoking the marshal.Marshallable + // interface on them. If these types don't actually implement + // marshal.Marshallable, compilation of the generated code will fail + // with an appropriate error message. + return + case "int": + g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") + case "uint": + g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") + case "string": + g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") + default: + debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) + } +} + +// emitMarshallableForPrimitiveNewtype outputs code to implement the +// marshal.Marshallable interface for a newtype on a primitive. Primitive +// newtypes are always packed, so we can omit the various fallbacks required for +// non-packed structs. +func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return %d\n", size) + } else { + g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.marshalPrimitiveScalar(g.r, nt.Name, "dst") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Scalar newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := w.Write(buf)\n") + g.emitKeepAlive(g.r) + g.emit("return int64(length), err\n") + + }) + g.emit("}\n\n") +} + +func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Ident, slice *sliceAPI) { + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + + eltType := g.typeName() + if slice.inner { + eltType = nt.Name + } + + g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType) + g.emit("//go:nosplit\n") + g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitCastSliceToByteSlice("&dst", "buf", "size * count") + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType) + g.emit("//go:nosplit\n") + g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitCastSliceToByteSlice("&src", "buf", "size * count") + + g.emit("length, err := task.CopyOutBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitNoEscapeSliceDataPointer("&src", "val") + + g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitNoEscapeSliceDataPointer("&dst", "val") + + g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go new file mode 100644 index 000000000..9cd3c9579 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -0,0 +1,618 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// structs. + +package gomarshal + +import ( + "fmt" + "go/ast" + "strings" +) + +func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { + return fmt.Sprintf("%s.%s", g.r, n.Name) +} + +// areFieldsPackedExpression returns a go expression checking whether g.t's fields are +// packed. Returns "", false if g.t has no fields that may be potentially +// packed, otherwise returns <clause>, true, where <clause> is an expression +// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". +func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { + if len(g.as) == 0 { + return "", false + } + + cs := make([]string, 0, len(g.as)) + for accessor, _ := range g.as { + cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) + } + return strings.Join(cs, " && "), true +} + +// validateStruct ensures the type we're working with can be marshalled. These +// checks are done ahead of time and in one place so we can make assumptions +// later. +func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) { + forEachStructField(st, func(f *ast.Field) { + if len(f.Names) == 0 { + g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") + } + }) + + forEachStructField(st, func(f *ast.Field) { + fieldDispatcher{ + primitive: func(_, t *ast.Ident) { + g.validatePrimitiveNewtype(t) + }, + selector: func(_, _, _ *ast.Ident) { + // No validation to perform on selector fields. However this + // callback must still be provided. + }, + array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) { + g.validateArrayNewtype(n, a) + }, + unhandled: func(_ *ast.Ident) { + g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) + }, + }.dispatch(f) + }) +} + +func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool { + packed := true + forEachStructField(st, func(f *ast.Field) { + if f.Tag != nil { + if f.Tag.Value == "`marshal:\"unaligned\"`" { + if packed { + debugfAt(g.f.Position(g.t.Pos()), + fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) + packed = false + } + } + } + }) + return packed +} + +func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { + thisPacked := g.isStructPacked(st) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + primitiveSize := 0 + var dynamicSizeTerms []string + + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) + } + }, + selector: func(n, tX, tSel *ast.Ident) { + tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) + g.recordUsedImport(tX.Name) + g.recordUsedMarshallable(tName) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) + }, + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) + if size, dynamic := g.scalarSize(t); !dynamic { + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr)) + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr)) + } + }, + }.dispatch) + g.emit("return %d", primitiveSize) + if len(dynamicSizeTerms) > 0 { + g.incIndent() + } + { + for _, d := range dynamicSizeTerms { + g.emitNoIndent(" +\n") + g.emit(d) + } + } + if len(dynamicSizeTerms) > 0 { + g.decIndent() + } + }) + g.emit("\n}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) + } + return + } + g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + }, + selector: func(n, tX, tSel *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name) + g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) + return + } + g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + }, + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name) + if size, dynamic := g.scalarSize(t); !dynamic { + g.emit("dst = dst[%d*(%s):]\n", size, lenExpr) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) + } + return + } + + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len) + } else { + // We don't have an instance of the dynamic type we can + // reference here (since the version in this struct is + // anonymous). Use a typed nil pointer to call + // SizeBytes() instead. + g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name)) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) + } + return + } + g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + }, + selector: func(n, tX, tSel *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name) + g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name)) + return + } + g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + }, + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) + if n.Name == "_" { + g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr) + if size, dynamic := g.scalarSize(t); !dynamic { + g.emit("src = src[%d*(%s):]\n", size, lenExpr) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) + } + return + } + + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + expr, fieldsMaybePacked := g.areFieldsPackedExpression() + switch { + case !thisPacked: + g.emit("return false\n") + case fieldsMaybePacked: + g.emit("return %s\n", expr) + default: + g.emit("return true\n") + + } + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.MarshalBytes(dst)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) + g.emit("%s.MarshalBytes(dst)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.UnmarshalBytes(src)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("%s.UnmarshalBytes(src)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r) + g.emit("return task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("//go:nosplit\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("//go:nosplit\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n") + g.emit("// partially unmarshalled struct.\n") + g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r) + g.emit("return length, err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.recordUsedImport("io") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("length, err := w.Write(buf)\n") + g.emit("return int64(length), err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := w.Write(buf)\n") + g.emitKeepAlive(g.r) + g.emit("return int64(length), err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") +} + +func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) { + thisPacked := g.isStructPacked(st) + + if slice.inner { + abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident)) + } + + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + + g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName()) + g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(size * count)\n") + g.emit("length, err := task.CopyInBytes(addr, buf)\n\n") + + g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n") + g.emit("limit := length/size\n") + g.emit("for idx := 0; idx < limit; idx++ {\n") + g.inIndent(func() { + g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n\n") + + g.emit("// Handle any final partial object.\n") + g.emit("if length < size*count && length%size != 0 {\n") + g.inIndent(func() { + g.emit("idx := limit\n") + g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n\n") + + g.emit("return length, err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !dst[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emitCastSliceToByteSlice("&dst", "buf", "size * count") + + g.emit("length, err := task.CopyInBytes(addr, buf)\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName()) + g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(size * count)\n") + g.emit("for idx := 0; idx < count; idx++ {\n") + g.inIndent(func() { + g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n") + g.emit("return task.CopyOutBytes(addr, buf)\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !src[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emitCastSliceToByteSlice("&src", "buf", "size * count") + + g.emit("length, err := task.CopyOutBytes(addr, buf)\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("for idx := 0; idx < count; idx++ {\n") + g.inIndent(func() { + g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n") + }) + g.emit("}\n") + g.emit("return size * count, nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !src[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + g.emitNoEscapeSliceDataPointer("&src", "val") + + g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("for idx := 0; idx < count; idx++ {\n") + g.inIndent(func() { + g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n") + g.emit("return size * count, nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !dst[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + g.emitNoEscapeSliceDataPointer("&dst", "val") + + g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go new file mode 100644 index 000000000..631295373 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -0,0 +1,233 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +import ( + "fmt" + "go/ast" + "io" + "strings" +) + +var standardImports = []string{ + "bytes", + "fmt", + "reflect", + "testing", + + "gvisor.dev/gvisor/tools/go_marshal/analysis", +} + +var sliceAPIImports = []string{ + "encoding/binary", + "gvisor.dev/gvisor/pkg/usermem", +} + +type testGenerator struct { + sourceBuffer + + // The type we're serializing. + t *ast.TypeSpec + + // Receiver argument for generated methods. + r string + + // Imports used by generated code. + imports *importTable + + // Import statement for the package declaring the type we generated code + // for. We need this to construct test instances for the type, since the + // tests aren't written in the same package. + decl *importStmt +} + +func newTestGenerator(t *ast.TypeSpec) *testGenerator { + g := &testGenerator{ + t: t, + r: receiverName(t), + imports: newImportTable(), + } + + for _, i := range standardImports { + g.imports.add(i).markUsed() + } + // These imports are used if a type requests the slice API. Don't + // mark them as used by default. + for _, i := range sliceAPIImports { + g.imports.add(i) + } + + return g +} + +func (g *testGenerator) typeName() string { + return g.t.Name.Name +} + +func (g *testGenerator) testFuncName(base string) string { + return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) +} + +func (g *testGenerator) inTestFunction(name string, body func()) { + g.emit("func %s(t *testing.T) {\n", g.testFuncName(name)) + g.inIndent(body) + g.emit("}\n\n") +} + +func (g *testGenerator) emitTestNonZeroSize() { + g.inTestFunction("TestSizeNonZero", func() { + g.emit("var x %v\n", g.typeName()) + g.emit("if x.SizeBytes() == 0 {\n") + g.inIndent(func() { + g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestSuspectAlignment() { + g.inTestFunction("TestSuspectAlignment", func() { + g.emit("var x %v\n", g.typeName()) + g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") + }) +} + +func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { + g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() { + g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + + g.emit("buf := make([]byte, x.SizeBytes())\n") + g.emit("x.MarshalBytes(buf)\n") + g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n") + g.emit("x.MarshalUnsafe(bufUnsafe)\n\n") + + g.emit("y.UnmarshalBytes(buf)\n") + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n\n") + + g.emit("z.UnmarshalUnsafe(buf)\n") + g.emit("if !reflect.DeepEqual(x, z) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n") + }) + g.emit("}\n") + g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") + g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) { + for _, name := range []string{"binary", "usermem"} { + if !g.imports.markUsed(name) { + panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name)) + } + } + + g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() { + g.emit("var x, y, yUnsafe [8]%s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName()) + g.emit("buf := bytes.NewBuffer(make([]byte, size))\n") + g.emit("buf.Reset()\n") + g.emit("if err := binary.Write(buf, usermem.ByteOrder, x[:]); err != nil {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n") + }) + g.emit("}\n") + g.emit("bufUnsafe := make([]byte, size)\n") + g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident) + + g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident) + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident) + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n\n") + }) +} + +func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { + g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { + g.emit("var x, y, yUnsafe %s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + + g.emit("var buf bytes.Buffer\n\n") + + g.emit("x.WriteTo(&buf)\n") + g.emit("y.UnmarshalBytes(buf.Bytes())\n\n") + g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n") + + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { + g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() { + g.emit("var x %s\n", g.typeName()) + g.emit("sizeFromConcrete := x.SizeBytes()\n") + g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n") + g.inIndent(func() { + g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTests(slice *sliceAPI) { + g.emitTestNonZeroSize() + g.emitTestSuspectAlignment() + g.emitTestMarshalUnmarshalPreservesData() + g.emitTestWriteToUnmarshalPreservesData() + g.emitTestSizeBytesOnTypedNilPtr() + + if slice != nil { + g.emitTestMarshalUnmarshalSlicePreservesData(slice) + } +} + +func (g *testGenerator) write(out io.Writer) error { + return g.sourceBuffer.write(out) +} diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go new file mode 100644 index 000000000..d94314302 --- /dev/null +++ b/tools/go_marshal/gomarshal/util.go @@ -0,0 +1,491 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/token" + "io" + "os" + "path" + "reflect" + "sort" + "strings" +) + +var debug = flag.Bool("debug", false, "enables debugging output") + +// receiverName returns an appropriate receiver name given a type spec. +func receiverName(t *ast.TypeSpec) string { + if len(t.Name.Name) < 1 { + // Zero length type name? + panic("unreachable") + } + return strings.ToLower(t.Name.Name[:1]) +} + +// kindString returns a user-friendly representation of an AST expr type. +func kindString(e ast.Expr) string { + switch e.(type) { + case *ast.Ident: + return "scalar" + case *ast.ArrayType: + return "array" + case *ast.StructType: + return "struct" + case *ast.StarExpr: + return "pointer" + case *ast.FuncType: + return "function" + case *ast.InterfaceType: + return "interface" + case *ast.MapType: + return "map" + case *ast.ChanType: + return "channel" + default: + return reflect.TypeOf(e).String() + } +} + +func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { + for _, field := range st.Fields.List { + fn(field) + } +} + +// fieldDispatcher is a collection of callbacks for handling different types of +// fields in a struct declaration. +type fieldDispatcher struct { + primitive func(n, t *ast.Ident) + selector func(n, tX, tSel *ast.Ident) + array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) + unhandled func(n *ast.Ident) +} + +// Precondition: All dispatch callbacks that will be invoked must be +// provided. Embedded fields are not allowed, len(f.Names) >= 1. +func (fd fieldDispatcher) dispatch(f *ast.Field) { + // Each field declaration may actually be multiple declarations of the same + // type. For example, consider: + // + // type Point struct { + // x, y, z int + // } + // + // We invoke the call-backs once per such instance. Embedded fields are not + // allowed, and results in a panic. + if len(f.Names) < 1 { + panic("Precondition not met: attempted to dispatch on embedded field") + } + + for _, name := range f.Names { + switch v := f.Type.(type) { + case *ast.Ident: + fd.primitive(name, v) + case *ast.SelectorExpr: + fd.selector(name, v.X.(*ast.Ident), v.Sel) + case *ast.ArrayType: + switch t := v.Elt.(type) { + case *ast.Ident: + fd.array(name, v, t) + default: + // Should be handled with a better error message during validate. + panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) + } + default: + fd.unhandled(name) + } + } +} + +// debugEnabled indicates whether debugging is enabled for gomarshal. +func debugEnabled() bool { + return *debug +} + +// abort aborts the go_marshal tool with the given error message. +func abort(msg string) { + if !strings.HasSuffix(msg, "\n") { + msg += "\n" + } + fmt.Print(msg) + os.Exit(1) +} + +// abortAt aborts the go_marshal tool with the given error message, with +// a reference position to the input source. +func abortAt(p token.Position, msg string) { + abort(fmt.Sprintf("%v:\n %s\n", p, msg)) +} + +// debugf conditionally prints a debug message. +func debugf(f string, a ...interface{}) { + if debugEnabled() { + fmt.Printf(f, a...) + } +} + +// debugfAt conditionally prints a debug message with a reference to a position +// in the input source. +func debugfAt(p token.Position, f string, a ...interface{}) { + if debugEnabled() { + fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...)) + } +} + +// emit generates a line of code in the output file. +// +// emit is a wrapper around writing a formatted string to the output +// buffer. emit can be invoked in one of two ways: +// +// (1) emit("some string") +// When emit is called with a single string argument, it is simply copied to +// the output buffer without any further formatting. +// (2) emit(fmtString, args...) +// emit can also be invoked in a similar fashion to *Printf() functions, +// where the first argument is a format string. +// +// Calling emit with a single argument that is not a string will result in a +// panic, as the caller's intent is ambiguous. +func emit(out io.Writer, indent int, a ...interface{}) { + const spacesPerIndentLevel = 4 + + if len(a) < 1 { + panic("emit() called with no arguments") + } + + if indent > 0 { + if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil { + // Writing to the emit output should not fail. Typically the output + // is a byte.Buffer; writes to these never fail. + panic(err) + } + } + + first, ok := a[0].(string) + if !ok { + // First argument must be either the string to emit (case 1 from + // function-level comment), or a format string (case 2). + panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0])) + } + + if len(a) == 1 { + // Single string argument. Assume no formatting requested. + if _, err := fmt.Fprint(out, first); err != nil { + // Writing to out should not fail. + panic(err) + } + return + + } + + // Formatting requested. + if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil { + // Writing to out should not fail. + panic(err) + } +} + +// sourceBuffer represents fragments of generated go source code. +// +// sourceBuffer provides a convenient way to build up go souce fragments in +// memory. May be safely zero-value initialized. Not thread-safe. +type sourceBuffer struct { + // Current indentation level. + indent int + + // Memory buffer containing contents while they're being generated. + b bytes.Buffer +} + +func (b *sourceBuffer) reset() { + b.indent = 0 + b.b.Reset() +} + +func (b *sourceBuffer) incIndent() { + b.indent++ +} + +func (b *sourceBuffer) decIndent() { + if b.indent <= 0 { + panic("decIndent() without matching incIndent()") + } + b.indent-- +} + +func (b *sourceBuffer) emit(a ...interface{}) { + emit(&b.b, b.indent, a...) +} + +func (b *sourceBuffer) emitNoIndent(a ...interface{}) { + emit(&b.b, 0 /*indent*/, a...) +} + +func (b *sourceBuffer) inIndent(body func()) { + b.incIndent() + body() + b.decIndent() +} + +func (b *sourceBuffer) write(out io.Writer) error { + _, err := fmt.Fprint(out, b.b.String()) + return err +} + +// Write implements io.Writer.Write. +func (b *sourceBuffer) Write(buf []byte) (int, error) { + return (b.b.Write(buf)) +} + +// importStmt represents a single import statement. +type importStmt struct { + // Local name of the imported package. + name string + // Import path. + path string + // Indicates whether the local name is an alias, or simply the final + // component of the path. + aliased bool + // Indicates whether this import was referenced by generated code. + used bool + // AST node and file set representing the import statement, if any. These + // are only non-nil if the import statement originates from an input source + // file. + spec *ast.ImportSpec + fset *token.FileSet +} + +func newImport(p string) *importStmt { + name := path.Base(p) + return &importStmt{ + name: name, + path: p, + aliased: false, + } +} + +func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { + p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path. + name := path.Base(p) + if name == "" || name == "/" || name == "." { + panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)", + f.Position(spec.Path.Pos()), name)) + } + if spec.Name != nil { + name = spec.Name.Name + } + return &importStmt{ + name: name, + path: p, + aliased: spec.Name != nil, + spec: spec, + fset: f, + } +} + +// String implements fmt.Stringer.String. This generates a string for the import +// statement appropriate for writing directly to generated code. +func (i *importStmt) String() string { + if i.aliased { + return fmt.Sprintf("%s %q", i.name, i.path) + } + return fmt.Sprintf("%q", i.path) +} + +// debugString returns a debug string representing an import statement. This +// representation is not valid golang code and is used for debugging output. +func (i *importStmt) debugString() string { + if i.spec != nil && i.fset != nil { + return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i) + } + return fmt.Sprintf("(go-marshal import): %s", i) +} + +func (i *importStmt) markUsed() { + i.used = true +} + +func (i *importStmt) equivalent(other *importStmt) bool { + return i.name == other.name && i.path == other.path && i.aliased == other.aliased +} + +// importTable represents a collection of importStmts. +// +// An importTable may contain multiple import statements referencing the same +// local name. All import statements aliasing to the same local name are +// technically ambiguous, as if such an import name is used in the generated +// code, it's not clear which import statement it refers to. We ignore any +// potential collisions until actually writing the import table to the generated +// source file. See importTable.write. +// +// Given the following import statements across all the files comprising a +// package marshalled: +// +// "sync" +// "pkg/sync" +// "pkg/sentry/kernel" +// ktime "pkg/sentry/kernel/time" +// +// An importTable representing them would look like this: +// +// importTable { +// is: map[string][]*importStmt { +// "sync": []*importStmt{ +// importStmt{name:"sync", path:"sync", aliased:false} +// importStmt{name:"sync", path:"pkg/sync", aliased:false} +// }, +// "kernel": []*importStmt{importStmt{ +// name: "kernel", +// path: "pkg/sentry/kernel", +// aliased: false +// }}, +// "ktime": []*importStmt{importStmt{ +// name: "ktime", +// path: "pkg/sentry/kernel/time", +// aliased: true, +// }}, +// } +// } +// +// Note that the local name "sync" is assigned to two different import +// statements. This is possible if the import statements are from different +// source files in the same package. +// +// Since go-marshal generates a single output file per package regardless of the +// number of input files, if "sync" is referenced by any generated code, it's +// unclear which import statement "sync" refers to. While it's theoretically +// possible to resolve this by assigning a unique local alias to each instance +// of the sync package, go-marshal currently aborts when it encounters such an +// ambiguity. +// +// TODO(b/151478251): importTable considers the final component of an import +// path to be the package name, but this is only a convention. The actual +// package name is determined by the package statement in the source files for +// the package. +type importTable struct { + // Map of imports and whether they should be copied to the output. + is map[string][]*importStmt +} + +func newImportTable() *importTable { + return &importTable{ + is: make(map[string][]*importStmt), + } +} + +// Merges import statements from other into i. +func (i *importTable) merge(other *importTable) { + for name, ims := range other.is { + i.is[name] = append(i.is[name], ims...) + } +} + +func (i *importTable) addStmt(s *importStmt) *importStmt { + i.is[s.name] = append(i.is[s.name], s) + return s +} + +func (i *importTable) add(s string) *importStmt { + n := newImport(s) + return i.addStmt(n) +} + +func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { + return i.addStmt(newImportFromSpec(spec, f)) +} + +// Marks the import named n as used. If no such import is in the table, returns +// false. +func (i *importTable) markUsed(n string) bool { + if ns, ok := i.is[n]; ok { + for _, n := range ns { + n.markUsed() + } + return true + } + return false +} + +func (i *importTable) clear() { + for _, is := range i.is { + for _, i := range is { + i.used = false + } + } +} + +func (i *importTable) write(out io.Writer) error { + if len(i.is) == 0 { + // Nothing to import, we're done. + return nil + } + + imports := make([]string, 0, len(i.is)) + for name, is := range i.is { + var lastUsed *importStmt + var ambiguous bool + + for _, i := range is { + if i.used { + if lastUsed != nil { + if !i.equivalent(lastUsed) { + ambiguous = true + } + } + lastUsed = i + } + } + + if ambiguous { + // We have two or more import statements across the different source + // files that share a local name, and at least one of these imports + // are used by the generated code. This ambiguity can't be resolved + // by go-marshal and requires the user intervention. Dump a list of + // the colliding import statements and let the user modify the input + // files as appropriate. + var b strings.Builder + fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name) + fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name) + // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't + // be true. Therefore the slicing below is safe. + for _, i := range is[:len(is)-1] { + fmt.Fprintf(&b, " %v\n", i.debugString()) + } + fmt.Fprintf(&b, " %v", is[len(is)-1].debugString()) + panic(b.String()) + } + + if lastUsed != nil { + imports = append(imports, lastUsed.String()) + } + } + sort.Strings(imports) + + var b sourceBuffer + b.emit("import (\n") + b.incIndent() + for _, i := range imports { + b.emit("%s\n", i) + } + b.decIndent() + b.emit(")\n\n") + + return b.write(out) +} diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go new file mode 100644 index 000000000..f74be5c29 --- /dev/null +++ b/tools/go_marshal/main.go @@ -0,0 +1,72 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// go_marshal is a code generation utility for automatically generating code to +// marshal go data structures to memory. +// +// This binary is typically run as part of the build process, and is invoked by +// the go_marshal bazel rule defined in defs.bzl. +// +// See README.md. +package main + +import ( + "flag" + "fmt" + "os" + "strings" + + "gvisor.dev/gvisor/tools/go_marshal/gomarshal" +) + +var ( + pkg = flag.String("pkg", "", "output package") + output = flag.String("output", "", "output file") + outputTest = flag.String("output_test", "", "output file for tests") + imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") +) + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s <input go src files>\n", os.Args[0]) + flag.PrintDefaults() + } + flag.Parse() + if len(flag.Args()) == 0 { + flag.Usage() + os.Exit(1) + } + + if *pkg == "" { + flag.Usage() + fmt.Fprint(os.Stderr, "Flag -pkg must be provided.\n") + os.Exit(1) + } + + var extraImports []string + if len(*imports) > 0 { + // Note: strings.Split(s, sep) returns s if sep doesn't exist in s. Thus + // we check for an empty imports list to avoid emitting an empty string + // as an import. + extraImports = strings.Split(*imports, ",") + } + g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, extraImports) + if err != nil { + panic(err) + } + + if err := g.Run(); err != nil { + panic(err) + } +} diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD new file mode 100644 index 000000000..bacfaa5a4 --- /dev/null +++ b/tools/go_marshal/marshal/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "marshal", + srcs = [ + "marshal.go", + ], + visibility = [ + "//:sandbox", + ], + deps = [ + "//pkg/usermem", + ], +) diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go new file mode 100644 index 000000000..cb2166252 --- /dev/null +++ b/tools/go_marshal/marshal/marshal.go @@ -0,0 +1,187 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package marshal defines the Marshallable interface for +// serialize/deserializing go data structures to/from memory, according to the +// Linux ABI. +// +// Implementations of this interface are typically automatically generated by +// tools/go_marshal. See the go_marshal README for details. +package marshal + +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" +) + +// Task provides a subset of kernel.Task, used in marshalling. We don't import +// the kernel package directly to avoid circular dependency. +type Task interface { + // CopyScratchBuffer provides a task goroutine-local scratch buffer. See + // kernel.CopyScratchBuffer. + CopyScratchBuffer(size int) []byte + + // CopyOutBytes writes the contents of b to the task's memory. See + // kernel.CopyOutBytes. + CopyOutBytes(addr usermem.Addr, b []byte) (int, error) + + // CopyInBytes reads the contents of the task's memory to b. See + // kernel.CopyInBytes. + CopyInBytes(addr usermem.Addr, b []byte) (int, error) +} + +// Marshallable represents operations on a type that can be marshalled to and +// from memory. +// +// go-marshal automatically generates implementations for this interface for +// types marked as '+marshal'. +type Marshallable interface { + io.WriterTo + + // SizeBytes is the size of the memory representation of a type in + // marshalled form. + // + // SizeBytes must handle a nil receiver. Practically, this means SizeBytes + // cannot deference any fields on the object implementing it (but will + // likely make use of the type of these fields). + SizeBytes() int + + // MarshalBytes serializes a copy of a type to dst. dst may be smaller than + // SizeBytes(), which results in a part of the struct being marshalled. Note + // that this may have unexpected results for non-packed types, as implicit + // padding needs to be taken into account when reasoning about how much of + // the type is serialized. + MarshalBytes(dst []byte) + + // UnmarshalBytes deserializes a type from src. src may be smaller than + // SizeBytes(), which results in a partially deserialized struct. Note that + // this may have unexpected results for non-packed types, as implicit + // padding needs to be taken into account when reasoning about how much of + // the type is deserialized. + UnmarshalBytes(src []byte) + + // Packed returns true if the marshalled size of the type is the same as the + // size it occupies in memory. This happens when the type has no fields + // starting at unaligned addresses (should always be true by default for ABI + // structs, verified by automatically generated tests when using + // go_marshal), and has no fields marked `marshal:"unaligned"`. + // + // Packed must return the same result for all possible values of the type + // implementing it. Violating this constraint implies the type doesn't have + // a static memory layout, and will lead to memory corruption. + // Go-marshal-generated code reuses the result of Packed for multiple values + // of the same type. + Packed() bool + + // MarshalUnsafe serializes a type by bulk copying its in-memory + // representation to the dst buffer. This is only safe to do when the type + // has no implicit padding, see Marshallable.Packed. When Packed would + // return false, MarshalUnsafe should fall back to the safer but slower + // MarshalBytes. dst may be smaller than SizeBytes(), see comment for + // MarshalBytes for implications. + MarshalUnsafe(dst []byte) + + // UnmarshalUnsafe deserializes a type by directly copying to the underlying + // memory allocated for the object by the runtime. + // + // This allows much faster unmarshalling of types which have no implicit + // padding, see Marshallable.Packed. When Packed would return false, + // UnmarshalUnsafe should fall back to the safer but slower unmarshal + // mechanism implemented in UnmarshalBytes. src may be smaller than + // SizeBytes(), see comment for UnmarshalBytes for implications. + UnmarshalUnsafe(src []byte) + + // CopyIn deserializes a Marshallable type from a task's memory. This may + // only be called from a task goroutine. This is more efficient than calling + // UnmarshalUnsafe on Marshallable.Packed types, as the type being + // marshalled does not escape. The implementation should avoid creating + // extra copies in memory by directly deserializing to the object's + // underlying memory. + // + // If the copy-in from the task memory is only partially successful, CopyIn + // should still attempt to deserialize as much data as possible. See comment + // for UnmarshalBytes. + CopyIn(task Task, addr usermem.Addr) (int, error) + + // CopyOut serializes a Marshallable type to a task's memory. This may only + // be called from a task goroutine. This is more efficient than calling + // MarshalUnsafe on Marshallable.Packed types, as the type being serialized + // does not escape. The implementation should avoid creating extra copies in + // memory by directly serializing from the object's underlying memory. + // + // The copy-out to the task memory may be partially successful, in which + // case CopyOut returns how much data was serialized. See comment for + // MarshalBytes for implications. + CopyOut(task Task, addr usermem.Addr) (int, error) + + // CopyOutN is like CopyOut, but explicitly requests a partial + // copy-out. Note that this may yield unexpected results for non-packed + // types and the caller may only want to allow this for packed types. See + // comment on MarshalBytes. + // + // The limit must be less than or equal to SizeBytes(). + CopyOutN(task Task, addr usermem.Addr, limit int) (int, error) +} + +// go-marshal generates additional functions for a type based on additional +// clauses to the +marshal directive. They are documented below. +// +// Slice API +// ========= +// +// Adding a "slice" clause to the +marshal directive for structs or newtypes on +// primitives like this: +// +// // +marshal slice:FooSlice +// type Foo struct { ... } +// +// Generates four additional functions for marshalling slices of Foos like this: +// +// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It's +// // more efficient that repeatedly calling calling Foo.MarshalUnsafe over a +// // []Foo in a loop. +// func MarshalUnsafeFooSlice(src []Foo, dst []byte) (int, error) { ... } +// +// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It's +// // more efficient that repeatedly calling calling Foo.UnmarshalUnsafe over a +// // []Foo in a loop. +// func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... } +// +// // CopyFooSliceIn copies in a slice of Foo objects from the task's memory. +// func CopyFooSliceIn(task marshal.Task, addr usermem.Addr, dst []Foo) (int, error) { ... } +// +// // CopyFooSliceIn copies out a slice of Foo objects to the task's memory. +// func CopyFooSliceOut(task marshal.Task, addr usermem.Addr, src []Foo) (int, error) { ... } +// +// The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where +// %s is the first argument to the slice clause. This directive is not supported +// for newtypes on arrays. +// +// The slice clause also takes an optional second argument, which must be the +// value "inner": +// +// // +marshal slice:Int32Slice:inner +// type Int32 int32 +// +// This is only valid on newtypes on primitives, and causes the generated +// functions to accept slices of the inner type instead: +// +// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []int32) (int, error) { ... } +// +// Without "inner", they would instead be: +// +// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []Int32) (int, error) { ... } +// +// This may help avoid a cast depending on how the generated functions are used. diff --git a/tools/go_marshal/primitive/BUILD b/tools/go_marshal/primitive/BUILD new file mode 100644 index 000000000..cc08ba63a --- /dev/null +++ b/tools/go_marshal/primitive/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "primitive", + srcs = [ + "primitive.go", + ], + marshal = True, + visibility = [ + "//:sandbox", + ], + deps = [ + "//pkg/usermem", + "//tools/go_marshal/marshal", + ], +) diff --git a/tools/go_marshal/primitive/primitive.go b/tools/go_marshal/primitive/primitive.go new file mode 100644 index 000000000..ebcf130ae --- /dev/null +++ b/tools/go_marshal/primitive/primitive.go @@ -0,0 +1,175 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package primitive defines marshal.Marshallable implementations for primitive +// types. +package primitive + +import ( + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" +) + +// Int16 is a marshal.Marshallable implementation for int16. +// +// +marshal slice:Int16Slice:inner +type Int16 int16 + +// Uint16 is a marshal.Marshallable implementation for uint16. +// +// +marshal slice:Uint16Slice:inner +type Uint16 uint16 + +// Int32 is a marshal.Marshallable implementation for int32. +// +// +marshal slice:Int32Slice:inner +type Int32 int32 + +// Uint32 is a marshal.Marshallable implementation for uint32. +// +// +marshal slice:Uint32Slice:inner +type Uint32 uint32 + +// Int64 is a marshal.Marshallable implementation for int64. +// +// +marshal slice:Int64Slice:inner +type Int64 int64 + +// Uint64 is a marshal.Marshallable implementation for uint64. +// +// +marshal slice:Uint64Slice:inner +type Uint64 uint64 + +// Below, we define some convenience functions for marshalling primitive types +// using the newtypes above, without requiring superfluous casts. + +// 16-bit integers + +// CopyInt16In is a convenient wrapper for copying in an int16 from the task's +// memory. +func CopyInt16In(task marshal.Task, addr usermem.Addr, dst *int16) (int, error) { + var buf Int16 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = int16(buf) + return n, nil +} + +// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's +// memory. +func CopyInt16Out(task marshal.Task, addr usermem.Addr, src int16) (int, error) { + srcP := Int16(src) + return srcP.CopyOut(task, addr) +} + +// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's +// memory. +func CopyUint16In(task marshal.Task, addr usermem.Addr, dst *uint16) (int, error) { + var buf Uint16 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = uint16(buf) + return n, nil +} + +// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's +// memory. +func CopyUint16Out(task marshal.Task, addr usermem.Addr, src uint16) (int, error) { + srcP := Uint16(src) + return srcP.CopyOut(task, addr) +} + +// 32-bit integers + +// CopyInt32In is a convenient wrapper for copying in an int32 from the task's +// memory. +func CopyInt32In(task marshal.Task, addr usermem.Addr, dst *int32) (int, error) { + var buf Int32 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = int32(buf) + return n, nil +} + +// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's +// memory. +func CopyInt32Out(task marshal.Task, addr usermem.Addr, src int32) (int, error) { + srcP := Int32(src) + return srcP.CopyOut(task, addr) +} + +// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's +// memory. +func CopyUint32In(task marshal.Task, addr usermem.Addr, dst *uint32) (int, error) { + var buf Uint32 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = uint32(buf) + return n, nil +} + +// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's +// memory. +func CopyUint32Out(task marshal.Task, addr usermem.Addr, src uint32) (int, error) { + srcP := Uint32(src) + return srcP.CopyOut(task, addr) +} + +// 64-bit integers + +// CopyInt64In is a convenient wrapper for copying in an int64 from the task's +// memory. +func CopyInt64In(task marshal.Task, addr usermem.Addr, dst *int64) (int, error) { + var buf Int64 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = int64(buf) + return n, nil +} + +// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's +// memory. +func CopyInt64Out(task marshal.Task, addr usermem.Addr, src int64) (int, error) { + srcP := Int64(src) + return srcP.CopyOut(task, addr) +} + +// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's +// memory. +func CopyUint64In(task marshal.Task, addr usermem.Addr, dst *uint64) (int, error) { + var buf Uint64 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = uint64(buf) + return n, nil +} + +// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's +// memory. +func CopyUint64Out(task marshal.Task, addr usermem.Addr, src uint64) (int, error) { + srcP := Uint64(src) + return srcP.CopyOut(task, addr) +} diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD new file mode 100644 index 000000000..2fbcc8a03 --- /dev/null +++ b/tools/go_marshal/test/BUILD @@ -0,0 +1,44 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +licenses(["notice"]) + +package_group( + name = "gomarshal_test", + packages = [ + "//tools/go_marshal/test/...", + ], +) + +go_test( + name = "benchmark_test", + srcs = ["benchmark_test.go"], + deps = [ + ":test", + "//pkg/binary", + "//pkg/usermem", + "//tools/go_marshal/analysis", + ], +) + +go_library( + name = "test", + testonly = 1, + srcs = ["test.go"], + marshal = True, + visibility = ["//tools/go_marshal/test:__subpackages__"], + deps = ["//tools/go_marshal/test/external"], +) + +go_test( + name = "marshal_test", + size = "small", + srcs = ["marshal_test.go"], + deps = [ + ":test", + "//pkg/syserror", + "//pkg/usermem", + "//tools/go_marshal/analysis", + "//tools/go_marshal/marshal", + "@com_github_google_go-cmp//cmp:go_default_library", + ], +) diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go new file mode 100644 index 000000000..224d308c7 --- /dev/null +++ b/tools/go_marshal/test/benchmark_test.go @@ -0,0 +1,220 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package benchmark_test + +import ( + "bytes" + encbin "encoding/binary" + "fmt" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/analysis" + "gvisor.dev/gvisor/tools/go_marshal/test" +) + +// Marshalling using the standard encoding/binary package. +func BenchmarkEncodingBinary(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + size := encbin.Size(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := bytes.NewBuffer(make([]byte, size)) + buf.Reset() + if err := encbin.Write(buf, usermem.ByteOrder, &s1); err != nil { + b.Error("Write:", err) + } + if err := encbin.Read(buf, usermem.ByteOrder, &s2); err != nil { + b.Error("Read:", err) + } + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +// Marshalling using the sentry's binary.Marshal. +func BenchmarkBinary(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + size := binary.Size(s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, 0, size) + buf = binary.Marshal(buf, usermem.ByteOrder, &s1) + binary.Unmarshal(buf, usermem.ByteOrder, &s2) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +// Marshalling field-by-field with manually-written code. +func BenchmarkMarshalManual(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, 0, s1.SizeBytes()) + + // Marshal + buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Dev) + buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Ino) + buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Nlink) + buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.Mode) + buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.UID) + buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.GID) + buf = binary.AppendUint32(buf, usermem.ByteOrder, 0) + buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Rdev) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Size)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blksize)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blocks)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Sec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Nsec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Sec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Nsec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Sec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Nsec)) + + // Unmarshal + s2.Dev = usermem.ByteOrder.Uint64(buf[0:8]) + s2.Ino = usermem.ByteOrder.Uint64(buf[8:16]) + s2.Nlink = usermem.ByteOrder.Uint64(buf[16:24]) + s2.Mode = usermem.ByteOrder.Uint32(buf[24:28]) + s2.UID = usermem.ByteOrder.Uint32(buf[28:32]) + s2.GID = usermem.ByteOrder.Uint32(buf[32:36]) + // Padding: buf[36:40] + s2.Rdev = usermem.ByteOrder.Uint64(buf[40:48]) + s2.Size = int64(usermem.ByteOrder.Uint64(buf[48:56])) + s2.Blksize = int64(usermem.ByteOrder.Uint64(buf[56:64])) + s2.Blocks = int64(usermem.ByteOrder.Uint64(buf[64:72])) + s2.ATime.Sec = int64(usermem.ByteOrder.Uint64(buf[72:80])) + s2.ATime.Nsec = int64(usermem.ByteOrder.Uint64(buf[80:88])) + s2.MTime.Sec = int64(usermem.ByteOrder.Uint64(buf[88:96])) + s2.MTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[96:104])) + s2.CTime.Sec = int64(usermem.ByteOrder.Uint64(buf[104:112])) + s2.CTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[112:120])) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +// Marshalling with the go_marshal safe API. +func BenchmarkGoMarshalSafe(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, s1.SizeBytes()) + s1.MarshalBytes(buf) + s2.UnmarshalBytes(buf) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +// Marshalling with the go_marshal unsafe API. +func BenchmarkGoMarshalUnsafe(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, s1.SizeBytes()) + s1.MarshalUnsafe(buf) + s2.UnmarshalUnsafe(buf) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +func BenchmarkBinarySlice(b *testing.B) { + var s1, s2 [64]test.Stat + analysis.RandomizeValue(&s1) + + size := binary.Size(s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, 0, size) + buf = binary.Marshal(buf, usermem.ByteOrder, &s1) + binary.Unmarshal(buf, usermem.ByteOrder, &s2) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +func BenchmarkGoMarshalUnsafeSlice(b *testing.B) { + var s1, s2 [64]test.Stat + analysis.RandomizeValue(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, (*test.Stat)(nil).SizeBytes()*len(s1)) + test.MarshalUnsafeStatSlice(s1[:], buf) + test.UnmarshalUnsafeStatSlice(s2[:], buf) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} diff --git a/tools/go_marshal/test/escape/BUILD b/tools/go_marshal/test/escape/BUILD new file mode 100644 index 000000000..f74e6ffae --- /dev/null +++ b/tools/go_marshal/test/escape/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "escape", + testonly = 1, + srcs = ["escape.go"], + deps = [ + "//pkg/usermem", + "//tools/go_marshal/marshal", + "//tools/go_marshal/test", + ], +) diff --git a/tools/go_marshal/test/escape/escape.go b/tools/go_marshal/test/escape/escape.go new file mode 100644 index 000000000..6a46ddbf8 --- /dev/null +++ b/tools/go_marshal/test/escape/escape.go @@ -0,0 +1,95 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package escape + +import ( + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/test" +) + +// dummyTask implements marshal.Task. +type dummyTask struct { +} + +func (*dummyTask) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { + return len(b), nil +} + +func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { + return len(b), nil +} + +func (t *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) { + buf := t.CopyScratchBuffer(marshallable.SizeBytes()) + marshallable.MarshalBytes(buf) + t.CopyOutBytes(addr, buf) +} + +func (t *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) { + buf := t.CopyScratchBuffer(marshallable.SizeBytes()) + marshallable.MarshalUnsafe(buf) + t.CopyOutBytes(addr, buf) +} + +// +checkescape:all +//go:nosplit +func doCopyIn(t *dummyTask) { + var stat test.Stat + stat.CopyIn(t, usermem.Addr(0xf000ba12)) +} + +// +checkescape:all +//go:nosplit +func doCopyOut(t *dummyTask) { + var stat test.Stat + stat.CopyOut(t, usermem.Addr(0xf000ba12)) +} + +// +mustescape:builtin +// +mustescape:stack +func doMarshalBytesDirect(t *dummyTask) { + var stat test.Stat + buf := t.CopyScratchBuffer(stat.SizeBytes()) + stat.MarshalBytes(buf) + t.CopyOutBytes(usermem.Addr(0xf000ba12), buf) +} + +// +mustescape:builtin +// +mustescape:stack +func doMarshalUnsafeDirect(t *dummyTask) { + var stat test.Stat + buf := t.CopyScratchBuffer(stat.SizeBytes()) + stat.MarshalUnsafe(buf) + t.CopyOutBytes(usermem.Addr(0xf000ba12), buf) +} + +// +mustescape:local,heap +// +mustescape:stack +func doMarshalBytesViaMarshallable(t *dummyTask) { + var stat test.Stat + t.MarshalBytes(usermem.Addr(0xf000ba12), &stat) +} + +// +mustescape:local,heap +// +mustescape:stack +func doMarshalUnsafeViaMarshallable(t *dummyTask) { + var stat test.Stat + t.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat) +} diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD new file mode 100644 index 000000000..0cf6da603 --- /dev/null +++ b/tools/go_marshal/test/external/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "external", + testonly = 1, + srcs = ["external.go"], + marshal = True, + visibility = ["//tools/go_marshal/test:gomarshal_test"], +) diff --git a/tools/go_marshal/test/external/external.go b/tools/go_marshal/test/external/external.go new file mode 100644 index 000000000..26fe8e0c8 --- /dev/null +++ b/tools/go_marshal/test/external/external.go @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package external defines types we can import for testing. +package external + +// External is a public Marshallable type for use in testing. +// +// +marshal +type External struct { + j int64 +} + +// NotPacked is an unaligned Marshallable type for use in testing. +// +// +marshal +type NotPacked struct { + a int32 + b byte `marshal:"unaligned"` +} diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go new file mode 100644 index 000000000..16829ee45 --- /dev/null +++ b/tools/go_marshal/test/marshal_test.go @@ -0,0 +1,515 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package marshal_test contains manual tests for the marshal interface. These +// are intended to test behaviour not covered by the automatically generated +// tests. +package marshal_test + +import ( + "bytes" + "encoding/binary" + "fmt" + "reflect" + "runtime" + "testing" + "unsafe" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/analysis" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/test" +) + +var simulatedErr error = syserror.EFAULT + +// mockTask implements marshal.Task. +type mockTask struct { + taskMem usermem.BytesIO +} + +// populate fills the task memory with the contents of val. +func (t *mockTask) populate(val interface{}) { + var buf bytes.Buffer + // Use binary.Write so we aren't testing go-marshal against its own + // potentially buggy implementation. + if err := binary.Write(&buf, usermem.ByteOrder, val); err != nil { + panic(err) + } + t.taskMem.Bytes = buf.Bytes() +} + +func (t *mockTask) setLimit(n int) { + if len(t.taskMem.Bytes) < n { + grown := make([]byte, n) + copy(grown, t.taskMem.Bytes) + t.taskMem.Bytes = grown + return + } + t.taskMem.Bytes = t.taskMem.Bytes[:n] +} + +// CopyScratchBuffer implements marshal.Task.CopyScratchBuffer. +func (t *mockTask) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +// CopyOutBytes implements marshal.Task.CopyOutBytes. The implementation +// completely ignores the target address and stores a copy of b in its +// internally buffer, overriding any previous contents. +func (t *mockTask) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) { + return t.taskMem.CopyOut(nil, 0, b, usermem.IOOpts{}) +} + +// CopyInBytes implements marshal.Task.CopyInBytes. The implementation +// completely ignores the source address and always fills b from the begining of +// its internal buffer. +func (t *mockTask) CopyInBytes(_ usermem.Addr, b []byte) (int, error) { + return t.taskMem.CopyIn(nil, 0, b, usermem.IOOpts{}) +} + +// unsafeMemory returns the underlying memory for m. The returned slice is only +// valid for the lifetime for m. The garbage collector isn't aware that the +// returned slice is related to m, the caller must ensure m lives long enough. +func unsafeMemory(m marshal.Marshallable) []byte { + if !m.Packed() { + // We can't return a slice pointing to the underlying memory + // since the layout isn't packed. Allocate a temporary buffer + // and marshal instead. + var buf bytes.Buffer + if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil { + panic(err) + } + return buf.Bytes() + } + + // reflect.ValueOf(m) + // .Elem() // Unwrap interface to inner concrete object + // .Addr() // Pointer value to object + // .Pointer() // Actual address from the pointer value + ptr := reflect.ValueOf(m).Elem().Addr().Pointer() + + size := m.SizeBytes() + + var mem []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem)) + hdr.Data = ptr + hdr.Len = size + hdr.Cap = size + + return mem +} + +// unsafeMemorySlice returns the underlying memory for m. The returned slice is +// only valid for the lifetime for m. The garbage collector isn't aware that the +// returned slice is related to m, the caller must ensure m lives long enough. +// +// Precondition: m must be a slice. +func unsafeMemorySlice(m interface{}, elt marshal.Marshallable) []byte { + kind := reflect.TypeOf(m).Kind() + if kind != reflect.Slice { + panic("unsafeMemorySlice called on non-slice") + } + + if !elt.Packed() { + // We can't return a slice pointing to the underlying memory + // since the layout isn't packed. Allocate a temporary buffer + // and marshal instead. + var buf bytes.Buffer + if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil { + panic(err) + } + return buf.Bytes() + } + + v := reflect.ValueOf(m) + length := v.Len() * elt.SizeBytes() + + var mem []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem)) + hdr.Data = v.Pointer() // This is a pointer to the first elem for slices. + hdr.Len = length + hdr.Cap = length + + return mem +} + +func isZeroes(buf []byte) bool { + for _, b := range buf { + if b != 0 { + return false + } + } + return true +} + +// compareMemory compares the first n bytes of two chuncks of memory represented +// by expected and actual. +func compareMemory(t *testing.T, expected, actual []byte, n int) { + t.Logf("Expected (%d): %v (%d) + (%d) %v\n", len(expected), expected[:n], n, len(expected)-n, expected[n:]) + t.Logf("Actual (%d): %v (%d) + (%d) %v\n", len(actual), actual[:n], n, len(actual)-n, actual[n:]) + + if diff := cmp.Diff(expected[:n], actual[:n]); diff != "" { + t.Errorf("Memory buffers don't match:\n--- expected only\n+++ actual only\n%v", diff) + } +} + +// limitedCopyIn populates task memory with src, then unmarshals task memory to +// dst. The task signals an error at limit bytes during copy-in, which should +// result in a truncated unmarshalling. +func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) { + var task mockTask + task.populate(src) + task.setLimit(limit) + + n, err := dst.CopyIn(&task, usermem.Addr(0)) + if n != limit { + t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) + } + if err != simulatedErr { + t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err) + } + + expectedMem := unsafeMemory(src) + defer runtime.KeepAlive(src) + actualMem := unsafeMemory(dst) + defer runtime.KeepAlive(dst) + + compareMemory(t, expectedMem, actualMem, n) + + // The last n bytes should be zero for actual, since actual was + // zero-initialized, and CopyIn shouldn't have touched those bytes. However + // we can only guarantee we didn't touch anything in the last n bytes if the + // layout is packed. + if dst.Packed() && !isZeroes(actualMem[n:]) { + t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", dst.SizeBytes()-n, actualMem) + } +} + +// limitedCopyOut marshals src to task memory. The task signals an error at +// limit bytes during copy-out, which should result in a truncated marshalling. +func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) { + var task mockTask + task.setLimit(limit) + + n, err := src.CopyOut(&task, usermem.Addr(0)) + if n != limit { + t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n) + } + if err != simulatedErr { + t.Errorf("CopyOut returned unexpected error, expected %v, got %v", simulatedErr, err) + } + + expectedMem := unsafeMemory(src) + defer runtime.KeepAlive(src) + actualMem := task.taskMem.Bytes + + compareMemory(t, expectedMem, actualMem, n) +} + +// copyOutN marshals src to task memory, requesting the marshalling to be +// limited to limit bytes. +func copyOutN(t *testing.T, src marshal.Marshallable, limit int) { + var task mockTask + task.setLimit(limit) + + n, err := src.CopyOutN(&task, usermem.Addr(0), limit) + if err != nil { + t.Errorf("CopyOut returned unexpected error: %v", err) + } + if n != limit { + t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n) + } + + expectedMem := unsafeMemory(src) + defer runtime.KeepAlive(src) + actualMem := task.taskMem.Bytes + + t.Logf("Expected: %v + %v\n", expectedMem[:n], expectedMem[n:]) + t.Logf("Actual : %v + %v\n", actualMem[:n], actualMem[n:]) + + compareMemory(t, expectedMem, actualMem, n) +} + +// TestLimitedMarshalling verifies marshalling/unmarshalling succeeds when the +// underyling copy in/out operations partially succeed. +func TestLimitedMarshalling(t *testing.T) { + types := []reflect.Type{ + // Packed types. + reflect.TypeOf((*test.Type2)(nil)), + reflect.TypeOf((*test.Type3)(nil)), + reflect.TypeOf((*test.Timespec)(nil)), + reflect.TypeOf((*test.Stat)(nil)), + reflect.TypeOf((*test.InetAddr)(nil)), + reflect.TypeOf((*test.SignalSet)(nil)), + reflect.TypeOf((*test.SignalSetAlias)(nil)), + // Non-packed types. + reflect.TypeOf((*test.Type1)(nil)), + reflect.TypeOf((*test.Type4)(nil)), + reflect.TypeOf((*test.Type5)(nil)), + reflect.TypeOf((*test.Type6)(nil)), + reflect.TypeOf((*test.Type7)(nil)), + reflect.TypeOf((*test.Type8)(nil)), + } + + for _, tyPtr := range types { + // Remove one level of pointer-indirection from the type. We get this + // back when we pass the type to reflect.New. + ty := tyPtr.Elem() + + // Partial copy-in. + t.Run(fmt.Sprintf("PartialCopyIn_%v", ty), func(t *testing.T) { + expected := reflect.New(ty).Interface().(marshal.Marshallable) + actual := reflect.New(ty).Interface().(marshal.Marshallable) + analysis.RandomizeValue(expected) + + limitedCopyIn(t, expected, actual, expected.SizeBytes()/2) + }) + + // Partial copy-out. + t.Run(fmt.Sprintf("PartialCopyOut_%v", ty), func(t *testing.T) { + expected := reflect.New(ty).Interface().(marshal.Marshallable) + analysis.RandomizeValue(expected) + + limitedCopyOut(t, expected, expected.SizeBytes()/2) + }) + + // Explicitly request partial copy-out. + t.Run(fmt.Sprintf("PartialCopyOutN_%v", ty), func(t *testing.T) { + expected := reflect.New(ty).Interface().(marshal.Marshallable) + analysis.RandomizeValue(expected) + + copyOutN(t, expected, expected.SizeBytes()/2) + }) + } +} + +// TestLimitedMarshalling verifies marshalling/unmarshalling of slices of +// marshallable types succeed when the underyling copy in/out operations +// partially succeed. +func TestLimitedSliceMarshalling(t *testing.T) { + types := []struct { + arrayPtrType reflect.Type + copySliceIn func(task marshal.Task, addr usermem.Addr, dstSlice interface{}) (int, error) + copySliceOut func(task marshal.Task, addr usermem.Addr, srcSlice interface{}) (int, error) + unsafeMemory func(arrPtr interface{}) []byte + }{ + // Packed types. + { + reflect.TypeOf((*[20]test.Stat)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[20]test.Stat)[:] + return test.CopyStatSliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[20]test.Stat)[:] + return test.CopyStatSliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[20]test.Stat)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + { + reflect.TypeOf((*[1]test.Stat)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[1]test.Stat)[:] + return test.CopyStatSliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[1]test.Stat)[:] + return test.CopyStatSliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[1]test.Stat)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + { + reflect.TypeOf((*[5]test.SignalSetAlias)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[5]test.SignalSetAlias)[:] + return test.CopySignalSetAliasSliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[5]test.SignalSetAlias)[:] + return test.CopySignalSetAliasSliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[5]test.SignalSetAlias)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + // Non-packed types. + { + reflect.TypeOf((*[20]test.Type1)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[20]test.Type1)[:] + return test.CopyType1SliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[20]test.Type1)[:] + return test.CopyType1SliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[20]test.Type1)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + { + reflect.TypeOf((*[1]test.Type1)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[1]test.Type1)[:] + return test.CopyType1SliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[1]test.Type1)[:] + return test.CopyType1SliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[1]test.Type1)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + { + reflect.TypeOf((*[7]test.Type8)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[7]test.Type8)[:] + return test.CopyType8SliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[7]test.Type8)[:] + return test.CopyType8SliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[7]test.Type8)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + } + + for _, tt := range types { + // The body of this loop is generic over the type tt.arrayPtrType, with + // the help of reflection. To aid in readability, comments below show + // the equivalent go code assuming + // tt.arrayPtrType = typeof(*[20]test.Stat). + + // Equivalent: + // var x *[20]test.Stat + // arrayTy := reflect.TypeOf(*x) + arrayTy := tt.arrayPtrType.Elem() + + // Partial copy-in of slices. + t.Run(fmt.Sprintf("PartialCopySliceIn_%v", arrayTy), func(t *testing.T) { + // Equivalent: + // var x [20]test.Stat + // length := len(x) + length := arrayTy.Len() + if length < 1 { + panic("Test type can't be zero-length array") + } + // Equivalent: + // elem := new(test.Stat).(marshal.Marshallable) + elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable) + + // Equivalent: + // var expected, actual interface{} + // expected = new([20]test.Stat) + // actual = new([20]test.Stat) + expected := reflect.New(arrayTy).Interface() + actual := reflect.New(arrayTy).Interface() + + analysis.RandomizeValue(expected) + + limit := (length * elem.SizeBytes()) / 2 + // Also make sure the limit is partially inside one of the elements. + limit += elem.SizeBytes() / 2 + analysis.RandomizeValue(expected) + + var task mockTask + task.populate(expected) + task.setLimit(limit) + + n, err := tt.copySliceIn(&task, usermem.Addr(0), actual) + if n != limit { + t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) + } + if n < length*elem.SizeBytes() && err != simulatedErr { + t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err) + } + + expectedMem := tt.unsafeMemory(expected) + defer runtime.KeepAlive(expected) + actualMem := tt.unsafeMemory(actual) + defer runtime.KeepAlive(actual) + + compareMemory(t, expectedMem, actualMem, n) + + // The last n bytes should be zero for actual, since actual was + // zero-initialized, and CopyIn shouldn't have touched those bytes. However + // we can only guarantee we didn't touch anything in the last n bytes if the + // layout is packed. + if elem.Packed() && !isZeroes(actualMem[n:]) { + t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", (elem.SizeBytes()*length)-n, actualMem) + } + }) + + // Partial copy-out of slices. + t.Run(fmt.Sprintf("PartialCopySliceOut_%v", arrayTy), func(t *testing.T) { + // Equivalent: + // var x [20]test.Stat + // length := len(x) + length := arrayTy.Len() + if length < 1 { + panic("Test type can't be zero-length array") + } + // Equivalent: + // elem := new(test.Stat).(marshal.Marshallable) + elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable) + + // Equivalent: + // var expected, actual interface{} + // expected = new([20]test.Stat) + // actual = new([20]test.Stat) + expected := reflect.New(arrayTy).Interface() + + analysis.RandomizeValue(expected) + + limit := (length * elem.SizeBytes()) / 2 + // Also make sure the limit is partially inside one of the elements. + limit += elem.SizeBytes() / 2 + analysis.RandomizeValue(expected) + + var task mockTask + task.populate(expected) + task.setLimit(limit) + + n, err := tt.copySliceOut(&task, usermem.Addr(0), expected) + if n != limit { + t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) + } + if n < length*elem.SizeBytes() && err != simulatedErr { + t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err) + } + + expectedMem := tt.unsafeMemory(expected) + defer runtime.KeepAlive(expected) + actualMem := task.taskMem.Bytes + + compareMemory(t, expectedMem, actualMem, n) + }) + } +} diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go new file mode 100644 index 000000000..f75ca1b7f --- /dev/null +++ b/tools/go_marshal/test/test.go @@ -0,0 +1,176 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package test contains data structures for testing the go_marshal tool. +package test + +import ( + // We're intentionally using a package name alias here even though it's not + // necessary to test the code generator's ability to handle package aliases. + ex "gvisor.dev/gvisor/tools/go_marshal/test/external" +) + +// Type1 is a test data type. +// +// +marshal slice:Type1Slice +type Type1 struct { + a Type2 + x, y int64 // Multiple field names. + b byte `marshal:"unaligned"` // Short field. + c uint64 + _ uint32 // Unnamed scalar field. + _ [6]byte // Unnamed vector field, typical padding. + _ [2]byte + xs [8]int32 + as [10]Type2 `marshal:"unaligned"` // Array of Marshallable objects. + ss Type3 +} + +// Type2 is a test data type. +// +// +marshal +type Type2 struct { + n int64 + c byte + _ [7]byte + m int64 + a int64 +} + +// Type3 is a test data type. +// +// +marshal +type Type3 struct { + s int64 + x ex.External // Type defined in another package. +} + +// Type4 is a test data type. +// +// +marshal +type Type4 struct { + c byte + x int64 `marshal:"unaligned"` + d byte + _ [7]byte +} + +// Type5 is a test data type. +// +// +marshal +type Type5 struct { + n int64 + t Type4 + m int64 +} + +// Type6 is a test data type ends mid-word. +// +// +marshal +type Type6 struct { + a int64 + b int64 + // If c isn't marked unaligned, analysis fails (as it should, since + // the unsafe API corrupts Type7). + c byte `marshal:"unaligned"` +} + +// Type7 is a test data type that contains a child struct that ends +// mid-word. +// +marshal +type Type7 struct { + x Type6 + y int64 +} + +// Type8 is a test data type which contains an external non-packed field. +// +// +marshal slice:Type8Slice +type Type8 struct { + a int64 + np ex.NotPacked + b int64 +} + +// Timespec represents struct timespec in <time.h>. +// +// +marshal +type Timespec struct { + Sec int64 + Nsec int64 +} + +// Stat represents struct stat. +// +// +marshal slice:StatSlice +type Stat struct { + Dev uint64 + Ino uint64 + Nlink uint64 + Mode uint32 + UID uint32 + GID uint32 + _ int32 + Rdev uint64 + Size int64 + Blksize int64 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [3]int64 +} + +// InetAddr is an example marshallable newtype on an array. +// +// +marshal +type InetAddr [4]byte + +// SignalSet is an example marshallable newtype on a primitive. +// +// +marshal slice:SignalSetSlice:inner +type SignalSet uint64 + +// SignalSetAlias is an example newtype on another marshallable type. +// +// +marshal slice:SignalSetAliasSlice +type SignalSetAlias SignalSet + +const sizeA = 64 +const sizeB = 8 + +// TestArray is a test data structure on an array with a constant length. +// +// +marshal +type TestArray [sizeA]int32 + +// TestArray2 is a newtype on an array with a simple arithmetic expression of +// constants for the array length. +// +// +marshal +type TestArray2 [sizeA * sizeB]int32 + +// TestArray2 is a newtype on an array with a simple arithmetic expression of +// mixed constants and literals for the array length. +// +// +marshal +type TestArray3 [sizeA*sizeB + 12]int32 + +// Type9 is a test data type containing an array with a non-literal length. +// +// +marshal +type Type9 struct { + x int64 + y [sizeA]int32 +} diff --git a/tools/go_mod.sh b/tools/go_mod.sh new file mode 100755 index 000000000..84b779d6d --- /dev/null +++ b/tools/go_mod.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail + +# Build the :gopath target. +bazel build //:gopath +declare -r gopathdir="bazel-bin/gopath/src/gvisor.dev/gvisor/" + +# Copy go.mod and execute the command. +cp -a go.mod go.sum "${gopathdir}" +(cd "${gopathdir}" && go mod "$@") +cp -a "${gopathdir}/go.mod" "${gopathdir}/go.sum" . + +# Cleanup the WORKSPACE file. +bazel run //:gazelle -- update-repos -from_file=go.mod diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD new file mode 100644 index 000000000..503cdf2e5 --- /dev/null +++ b/tools/go_stateify/BUILD @@ -0,0 +1,10 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "stateify", + srcs = ["main.go"], + visibility = ["//:sandbox"], + deps = ["//tools/tags"], +) diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl new file mode 100644 index 000000000..6a5e666f0 --- /dev/null +++ b/tools/go_stateify/defs.bzl @@ -0,0 +1,60 @@ +"""Stateify is a tool for generating state wrappers for Go types.""" + +def _go_stateify_impl(ctx): + """Implementation for the stateify tool.""" + output = ctx.outputs.out + + # Run the stateify command. + args = ["-output=%s" % output.path] + args.append("-fullpkg=%s" % ctx.attr.package) + if ctx.attr._statepkg: + args.append("-statepkg=%s" % ctx.attr._statepkg) + if ctx.attr.imports: + args.append("-imports=%s" % ",".join(ctx.attr.imports)) + args.append("--") + for src in ctx.attr.srcs: + args += [f.path for f in src.files.to_list()] + ctx.actions.run( + inputs = ctx.files.srcs, + outputs = [output], + mnemonic = "GoStateify", + progress_message = "Generating state library %s" % ctx.label, + arguments = args, + executable = ctx.executable._tool, + ) + +go_stateify = rule( + implementation = _go_stateify_impl, + doc = "Generates save and restore logic from a set of Go files.", + attrs = { + "srcs": attr.label_list( + doc = """ +The input source files. These files should include all structs in the package +that need to be saved. +""", + mandatory = True, + allow_files = True, + ), + "imports": attr.string_list( + doc = """ +An optional list of extra non-aliased, Go-style absolute import paths required +for statified types. +""", + mandatory = False, + ), + "package": attr.string( + doc = "The fully qualified package name for the input sources.", + mandatory = True, + ), + "out": attr.output( + doc = "Name of the generator output file.", + mandatory = True, + ), + "_tool": attr.label( + executable = True, + cfg = "host", + default = Label("//tools/go_stateify:stateify"), + ), + "_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"), + }, +) diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go new file mode 100644 index 000000000..4f6ed208a --- /dev/null +++ b/tools/go_stateify/main.go @@ -0,0 +1,476 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Stateify provides a simple way to generate Load/Save methods based on +// existing types and struct tags. +package main + +import ( + "flag" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "reflect" + "strings" + "sync" + + "gvisor.dev/gvisor/tools/tags" +) + +var ( + fullPkg = flag.String("fullpkg", "", "fully qualified output package") + imports = flag.String("imports", "", "extra imports for the output file") + output = flag.String("output", "", "output file") + statePkg = flag.String("statepkg", "", "state import package; defaults to empty") +) + +// resolveTypeName returns a qualified type name. +func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) { + for done := false; !done; { + // Resolve star expressions. + switch rs := typ.(type) { + case *ast.StarExpr: + qualified += "*" + typ = rs.X + case *ast.ArrayType: + if rs.Len == nil { + // Slice type declaration. + qualified += "[]" + } else { + // Array type declaration. + qualified += "[" + rs.Len.(*ast.BasicLit).Value + "]" + } + typ = rs.Elt + default: + // No more descent. + done = true + } + } + + // Resolve a package selector. + sel, ok := typ.(*ast.SelectorExpr) + if ok { + qualified = qualified + sel.X.(*ast.Ident).Name + "." + typ = sel.Sel + } + + // Figure out actual type name. + ident, ok := typ.(*ast.Ident) + if !ok { + panic(fmt.Sprintf("type not supported: %s (involves anonymous types?)", name)) + } + field = ident.Name + qualified = qualified + field + return +} + +// extractStateTag pulls the relevant state tag. +func extractStateTag(tag *ast.BasicLit) string { + if tag == nil { + return "" + } + if len(tag.Value) < 2 { + return "" + } + return reflect.StructTag(tag.Value[1 : len(tag.Value)-1]).Get("state") +} + +// scanFunctions is a set of functions passed to scanFields. +type scanFunctions struct { + zerovalue func(name string) + normal func(name string) + wait func(name string) + value func(name, typName string) +} + +// scanFields scans the fields of a struct. +// +// Each provided function will be applied to appropriately tagged fields, or +// skipped if nil. +// +// Fields tagged nosave are skipped. +func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) { + if ss.Fields.List == nil { + // No fields. + return + } + + // Scan all fields. + for _, field := range ss.Fields.List { + // Calculate the name. + name := "" + if field.Names != nil { + // It's a named field; override. + name = field.Names[0].Name + } else { + // Anonymous types can't be embedded, so we don't need + // to worry about providing a useful name here. + name, _ = resolveTypeName("", field.Type) + } + + // Skip _ fields. + if name == "_" { + continue + } + + // Is this a anonymous struct? If yes, then continue the + // recursion with the given prefix. We don't pay attention to + // any tags on the top-level struct field. + tag := extractStateTag(field.Tag) + if anon, ok := field.Type.(*ast.StructType); ok && tag == "" { + scanFields(anon, name+".", fn) + continue + } + + switch tag { + case "zerovalue": + if fn.zerovalue != nil { + fn.zerovalue(name) + } + + case "": + if fn.normal != nil { + fn.normal(name) + } + + case "wait": + if fn.wait != nil { + fn.wait(name) + } + + case "manual", "nosave", "ignore": + // Do nothing. + + default: + if strings.HasPrefix(tag, ".(") && strings.HasSuffix(tag, ")") { + if fn.value != nil { + fn.value(name, tag[2:len(tag)-1]) + } + } + } + } +} + +func camelCased(name string) string { + return strings.ToUpper(name[:1]) + name[1:] +} + +func main() { + // Parse flags. + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0]) + flag.PrintDefaults() + } + flag.Parse() + if len(flag.Args()) == 0 { + flag.Usage() + os.Exit(1) + } + if *fullPkg == "" { + fmt.Fprintf(os.Stderr, "Error: package required.") + os.Exit(1) + } + + // Open the output file. + var ( + outputFile *os.File + err error + ) + if *output == "" || *output == "-" { + outputFile = os.Stdout + } else { + outputFile, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + fmt.Fprintf(os.Stderr, "Error opening output %q: %v", *output, err) + } + defer outputFile.Close() + } + + // Set the statePrefix for below, depending on the import. + statePrefix := "" + if *statePkg != "" { + parts := strings.Split(*statePkg, "/") + statePrefix = parts[len(parts)-1] + "." + } + + // initCalls is dumped at the end. + var initCalls []string + + // Common closures. + emitRegister := func(name string) { + initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name)) + } + emitZeroCheck := func(name string) { + fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name) + } + + // Automated warning. + fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") + + // Emit build tags. + if t := tags.Aggregate(flag.Args()); len(t) > 0 { + fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n")) + } + + // Emit the package name. + _, pkg := filepath.Split(*fullPkg) + fmt.Fprintf(outputFile, "package %s\n\n", pkg) + + // Emit the imports lazily. + var once sync.Once + maybeEmitImports := func() { + once.Do(func() { + // Emit the imports. + fmt.Fprint(outputFile, "import (\n") + if *statePkg != "" { + fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg) + } + if *imports != "" { + for _, i := range strings.Split(*imports, ",") { + fmt.Fprintf(outputFile, " \"%s\"\n", i) + } + } + fmt.Fprint(outputFile, ")\n\n") + }) + } + + files := make([]*ast.File, 0, len(flag.Args())) + + // Parse the input files. + for _, filename := range flag.Args() { + // Parse the file. + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) + if err != nil { + // Not a valid input file? + fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err) + os.Exit(1) + } + + files = append(files, f) + } + + type method struct { + receiver string + name string + } + + // Search for and add all methods with a pointer receiver and no other + // arguments to a set. We support auto-detecting the existence of + // several different methods with this signature. + simpleMethods := map[method]struct{}{} + for _, f := range files { + + // Go over all functions. + for _, decl := range f.Decls { + d, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + if d.Name == nil || d.Recv == nil || d.Type == nil { + // Not a named method. + continue + } + if len(d.Recv.List) != 1 { + // Wrong number of receivers? + continue + } + if d.Type.Params != nil && len(d.Type.Params.List) != 0 { + // Has argument(s). + continue + } + if d.Type.Results != nil && len(d.Type.Results.List) != 0 { + // Has return(s). + continue + } + + pt, ok := d.Recv.List[0].Type.(*ast.StarExpr) + if !ok { + // Not a pointer receiver. + continue + } + + t, ok := pt.X.(*ast.Ident) + if !ok { + // This shouldn't happen with valid Go. + continue + } + + simpleMethods[method{t.Name, d.Name.Name}] = struct{}{} + } + } + + for _, f := range files { + // Go over all named types. + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok || d.Tok != token.TYPE { + continue + } + + // Only generate code for types marked "// +stateify + // savable" in one of the proceeding comment lines. If + // the line is marked "// +stateify type" then only + // generate type information and register the type. + if d.Doc == nil { + continue + } + var ( + generateTypeInfo = false + generateSaverLoader = false + ) + for _, l := range d.Doc.List { + if l.Text == "// +stateify savable" { + generateTypeInfo = true + generateSaverLoader = true + break + } + if l.Text == "// +stateify type" { + generateTypeInfo = true + } + } + if !generateTypeInfo && !generateSaverLoader { + continue + } + + for _, gs := range d.Specs { + ts := gs.(*ast.TypeSpec) + switch x := ts.Type.(type) { + case *ast.StructType: + maybeEmitImports() + + // Record the slot for each field. + fieldCount := 0 + fields := make(map[string]int) + emitField := func(name string) { + fmt.Fprintf(outputFile, " \"%s\",\n", name) + fields[name] = fieldCount + fieldCount++ + } + emitFieldValue := func(name string, _ string) { + emitField(name) + } + emitLoadValue := func(name, typName string) { + fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName) + } + emitLoad := func(name string) { + fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name) + } + emitLoadWait := func(name string) { + fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name) + } + emitSaveValue := func(name, typName string) { + fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name)) + fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name) + } + emitSave := func(name string) { + fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name) + } + + // Generate the type name method. + fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) + fmt.Fprintf(outputFile, "}\n\n") + + // Generate the fields method. + fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return []string{\n") + scanFields(x, "", scanFunctions{ + normal: emitField, + wait: emitField, + value: emitFieldValue, + }) + fmt.Fprintf(outputFile, " }\n") + fmt.Fprintf(outputFile, "}\n\n") + + // Define beforeSave if a definition was not found. This prevents + // the code from compiling if a custom beforeSave was defined in a + // file not provided to this binary and prevents inherited methods + // from being called multiple times by overriding them. + if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name) + } + + // Generate the save method. + // + // N.B. For historical reasons, we perform the value saves first, + // and perform the value loads last. There should be no dependency + // on this specific behavior, but the ability to specify slots + // allows a manual implementation to be order-dependent. + if generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix) + fmt.Fprintf(outputFile, " x.beforeSave()\n") + scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck}) + scanFields(x, "", scanFunctions{value: emitSaveValue}) + scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave}) + fmt.Fprintf(outputFile, "}\n\n") + } + + // Define afterLoad if a definition was not found. We do this for + // the same reason that we do it for beforeSave. + _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}] + if !hasAfterLoad && generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name) + } + + // Generate the load method. + // + // N.B. See the comment above for the save method. + if generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix) + scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait}) + scanFields(x, "", scanFunctions{value: emitLoadValue}) + if hasAfterLoad { + // The call to afterLoad is made conditionally, because when + // AfterLoad is called, the object encodes a dependency on + // referred objects (i.e. fields). This means that afterLoad + // will not be called until the other afterLoads are called. + fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n") + } + fmt.Fprintf(outputFile, "}\n\n") + } + + // Add to our registration. + emitRegister(ts.Name.Name) + + case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType: + maybeEmitImports() + + // Generate the info methods. + fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) + fmt.Fprintf(outputFile, "}\n\n") + fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return nil\n") + fmt.Fprintf(outputFile, "}\n\n") + + // See above. + emitRegister(ts.Name.Name) + } + } + } + } + + if len(initCalls) > 0 { + // Emit the init() function. + fmt.Fprintf(outputFile, "func init() {\n") + for _, ic := range initCalls { + fmt.Fprintf(outputFile, " %s\n", ic) + } + fmt.Fprintf(outputFile, "}\n") + } +} diff --git a/tools/installers/BUILD b/tools/installers/BUILD new file mode 100644 index 000000000..caa7b1983 --- /dev/null +++ b/tools/installers/BUILD @@ -0,0 +1,35 @@ +# Installers for use by the tools/vm_test rules. + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +filegroup( + name = "runsc", + srcs = ["//runsc"], +) + +sh_binary( + name = "head", + srcs = ["head.sh"], + data = [":runsc"], +) + +sh_binary( + name = "images", + srcs = ["images.sh"], + data = [ + "//images", + ], +) + +sh_binary( + name = "master", + srcs = ["master.sh"], +) + +sh_binary( + name = "shim", + srcs = ["shim.sh"], +) diff --git a/tools/installers/head.sh b/tools/installers/head.sh new file mode 100755 index 000000000..7fc566ebd --- /dev/null +++ b/tools/installers/head.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Install our runtime. +$(find . -executable -type f -name runsc) install + +# Restart docker. +service docker restart || true diff --git a/tools/installers/images.sh b/tools/installers/images.sh new file mode 100755 index 000000000..52e750f57 --- /dev/null +++ b/tools/installers/images.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeuo pipefail + +# Find the images directory. +for images in $(find . -type d -name images); do + if [[ -f "${images}"/Makefile ]]; then + make -C "${images}" load-all-images + fi +done diff --git a/tools/installers/master.sh b/tools/installers/master.sh new file mode 100755 index 000000000..2c6001c6c --- /dev/null +++ b/tools/installers/master.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Install runsc from the master branch. +set -e + +curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - +add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" + +while true; do + if (apt-get update && apt-get install -y runsc); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done + +runsc install +service docker restart diff --git a/tools/installers/shim.sh b/tools/installers/shim.sh new file mode 100755 index 000000000..f7dd790a1 --- /dev/null +++ b/tools/installers/shim.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Reinstall the latest containerd shim. +declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim" +declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX) +declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX) +wget --no-verbose "${base}"/latest -O ${latest} +wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} +chmod +x ${shim_path} +mv ${shim_path} /usr/local/bin/gvisor-containerd-shim diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD new file mode 100644 index 000000000..4ef1a3124 --- /dev/null +++ b/tools/issue_reviver/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "issue_reviver", + srcs = ["main.go"], + deps = [ + "//tools/issue_reviver/github", + "//tools/issue_reviver/reviver", + ], +) diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD new file mode 100644 index 000000000..da4133472 --- /dev/null +++ b/tools/issue_reviver/github/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "github", + srcs = ["github.go"], + visibility = [ + "//tools/issue_reviver:__subpackages__", + ], + deps = [ + "//tools/issue_reviver/reviver", + "@com_github_google_go-github//github:go_default_library", + "@org_golang_x_oauth2//:go_default_library", + ], +) diff --git a/tools/issue_reviver/github/github.go b/tools/issue_reviver/github/github.go new file mode 100644 index 000000000..e07949c8f --- /dev/null +++ b/tools/issue_reviver/github/github.go @@ -0,0 +1,164 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package github implements reviver.Bugger interface on top of Github issues. +package github + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/google/go-github/github" + "golang.org/x/oauth2" + "gvisor.dev/gvisor/tools/issue_reviver/reviver" +) + +// Bugger implements reviver.Bugger interface for github issues. +type Bugger struct { + owner string + repo string + dryRun bool + + client *github.Client + issues map[int]*github.Issue +} + +// NewBugger creates a new Bugger. +func NewBugger(token, owner, repo string, dryRun bool) (*Bugger, error) { + b := &Bugger{ + owner: owner, + repo: repo, + dryRun: dryRun, + issues: map[int]*github.Issue{}, + } + if err := b.load(token); err != nil { + return nil, err + } + return b, nil +} + +func (b *Bugger) load(token string) error { + ctx := context.Background() + if len(token) == 0 { + fmt.Print("No OAUTH token provided, using unauthenticated account.\n") + b.client = github.NewClient(nil) + } else { + ts := oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: token}, + ) + tc := oauth2.NewClient(ctx, ts) + b.client = github.NewClient(tc) + } + + err := processAllPages(func(listOpts github.ListOptions) (*github.Response, error) { + opts := &github.IssueListByRepoOptions{State: "open", ListOptions: listOpts} + tmps, resp, err := b.client.Issues.ListByRepo(ctx, b.owner, b.repo, opts) + if err != nil { + return resp, err + } + for _, issue := range tmps { + b.issues[issue.GetNumber()] = issue + } + return resp, nil + }) + if err != nil { + return err + } + + fmt.Printf("Loaded %d issues from github.com/%s/%s\n", len(b.issues), b.owner, b.repo) + return nil +} + +// Activate implements reviver.Bugger. +func (b *Bugger) Activate(todo *reviver.Todo) (bool, error) { + const prefix = "gvisor.dev/issue/" + + // First check if I can handle the TODO. + idStr := strings.TrimPrefix(todo.Issue, prefix) + if len(todo.Issue) == len(idStr) { + return false, nil + } + + id, err := strconv.Atoi(idStr) + if err != nil { + return true, err + } + + // Check against active issues cache. + if _, ok := b.issues[id]; ok { + fmt.Printf("%q is active: OK\n", todo.Issue) + return true, nil + } + + fmt.Printf("%q is not active: reopening issue %d\n", todo.Issue, id) + + // Format comment with TODO locations and search link. + comment := strings.Builder{} + fmt.Fprintln(&comment, "There are TODOs still referencing this issue:") + for _, l := range todo.Locations { + fmt.Fprintf(&comment, + "1. [%s:%d](https://github.com/%s/%s/blob/HEAD/%s#%d): %s\n", + l.File, l.Line, b.owner, b.repo, l.File, l.Line, l.Comment) + } + fmt.Fprintf(&comment, + "\n\nSearch [TODO](https://github.com/%s/%s/search?q=%%22%s%d%%22)", b.owner, b.repo, prefix, id) + + if b.dryRun { + fmt.Printf("[dry-run: skipping change to issue %d]\n%s\n=======================\n", id, comment.String()) + return true, nil + } + + ctx := context.Background() + req := &github.IssueRequest{State: github.String("open")} + _, _, err = b.client.Issues.Edit(ctx, b.owner, b.repo, id, req) + if err != nil { + return true, fmt.Errorf("failed to reactivate issue %d: %v", id, err) + } + + cmt := &github.IssueComment{ + Body: github.String(comment.String()), + Reactions: &github.Reactions{Confused: github.Int(1)}, + } + if _, _, err := b.client.Issues.CreateComment(ctx, b.owner, b.repo, id, cmt); err != nil { + return true, fmt.Errorf("failed to add comment to issue %d: %v", id, err) + } + + return true, nil +} + +func processAllPages(fn func(github.ListOptions) (*github.Response, error)) error { + opts := github.ListOptions{PerPage: 1000} + for { + resp, err := fn(opts) + if err != nil { + if rateErr, ok := err.(*github.RateLimitError); ok { + duration := rateErr.Rate.Reset.Sub(time.Now()) + if duration > 5*time.Minute { + return fmt.Errorf("Rate limited for too long: %v", duration) + } + fmt.Printf("Rate limited, sleeping for: %v\n", duration) + time.Sleep(duration) + continue + } + return err + } + if resp.NextPage == 0 { + return nil + } + opts.Page = resp.NextPage + } +} diff --git a/tools/issue_reviver/main.go b/tools/issue_reviver/main.go new file mode 100644 index 000000000..47c796b8a --- /dev/null +++ b/tools/issue_reviver/main.go @@ -0,0 +1,100 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package main is the entry point for issue_reviver. +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "os" + "strings" + + "gvisor.dev/gvisor/tools/issue_reviver/github" + "gvisor.dev/gvisor/tools/issue_reviver/reviver" +) + +var ( + owner string + repo string + tokenFile string + path string + dryRun bool +) + +// Keep the options simple for now. Supports only a single path and repo. +func init() { + flag.StringVar(&owner, "owner", "", "Github project org/owner to look for issues") + flag.StringVar(&repo, "repo", "", "Github repo to look for issues") + flag.StringVar(&tokenFile, "oauth-token-file", "", "Path to file containing the OAUTH token to be used as credential to github") + flag.StringVar(&path, "path", ".", "Path to scan for TODOs") + flag.BoolVar(&dryRun, "dry-run", false, "If set to true, no changes are made to issues") +} + +func main() { + // Set defaults from the environment. + repository := os.Getenv("GITHUB_REPOSITORY") + if parts := strings.SplitN(repository, "/", 2); len(parts) == 2 { + owner = parts[0] + repo = parts[1] + } + + // Parse flags. + flag.Parse() + + // Check for mandatory parameters. + if len(owner) == 0 { + fmt.Println("missing --owner option.") + flag.Usage() + os.Exit(1) + } + if len(repo) == 0 { + fmt.Println("missing --repo option.") + flag.Usage() + os.Exit(1) + } + if len(path) == 0 { + fmt.Println("missing --path option.") + flag.Usage() + os.Exit(1) + } + + // The access token may be passed as a file so it doesn't show up in + // command line arguments. It also may be provided through the + // environment to faciliate use through GitHub's CI system. + token := os.Getenv("GITHUB_TOKEN") + if len(tokenFile) != 0 { + bytes, err := ioutil.ReadFile(tokenFile) + if err != nil { + fmt.Println(err.Error()) + os.Exit(1) + } + token = string(bytes) + } + + bugger, err := github.NewBugger(token, owner, repo, dryRun) + if err != nil { + fmt.Fprintln(os.Stderr, "Error getting github issues:", err) + os.Exit(1) + } + rev := reviver.New([]string{path}, []reviver.Bugger{bugger}) + if errs := rev.Run(); len(errs) > 0 { + fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs)) + for _, err := range errs { + fmt.Fprintf(os.Stderr, "\t%v\n", err) + } + os.Exit(1) + } +} diff --git a/tools/issue_reviver/reviver/BUILD b/tools/issue_reviver/reviver/BUILD new file mode 100644 index 000000000..d262932bd --- /dev/null +++ b/tools/issue_reviver/reviver/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "reviver", + srcs = ["reviver.go"], + visibility = [ + "//tools/issue_reviver:__subpackages__", + ], +) + +go_test( + name = "reviver_test", + size = "small", + srcs = ["reviver_test.go"], + library = ":reviver", +) diff --git a/tools/issue_reviver/reviver/reviver.go b/tools/issue_reviver/reviver/reviver.go new file mode 100644 index 000000000..682db0c01 --- /dev/null +++ b/tools/issue_reviver/reviver/reviver.go @@ -0,0 +1,192 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package reviver scans the code looking for TODOs and pass them to registered +// Buggers to ensure TODOs point to active issues. +package reviver + +import ( + "bufio" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "regexp" + "sync" +) + +// This is how a TODO looks like. +var regexTodo = regexp.MustCompile(`(\/\/|#)\s*(TODO|FIXME)\(([a-zA-Z0-9.\/]+)\):\s*(.+)`) + +// Bugger interface is called for every TODO found in the code. If it can handle +// the TODO, it must return true. If it returns false, the next Bugger is +// called. If no Bugger handles the TODO, it's dropped on the floor. +type Bugger interface { + Activate(todo *Todo) (bool, error) +} + +// Location saves the location where the TODO was found. +type Location struct { + Comment string + File string + Line uint +} + +// Todo represents a unique TODO. There can be several TODOs pointing to the +// same issue in the code. They are all grouped together. +type Todo struct { + Issue string + Locations []Location +} + +// Reviver scans the given paths for TODOs and calls Buggers to handle them. +type Reviver struct { + paths []string + buggers []Bugger + + mu sync.Mutex + todos map[string]*Todo + errs []error +} + +// New create a new Reviver. +func New(paths []string, buggers []Bugger) *Reviver { + return &Reviver{ + paths: paths, + buggers: buggers, + todos: map[string]*Todo{}, + } +} + +// Run runs. It returns all errors found during processing, it doesn't stop +// on errors. +func (r *Reviver) Run() []error { + // Process each directory in parallel. + wg := sync.WaitGroup{} + for _, path := range r.paths { + wg.Add(1) + go func(path string) { + defer wg.Done() + r.processPath(path, &wg) + }(path) + } + + wg.Wait() + + r.mu.Lock() + defer r.mu.Unlock() + + fmt.Printf("Processing %d TODOs (%d errors)...\n", len(r.todos), len(r.errs)) + dropped := 0 + for _, todo := range r.todos { + ok, err := r.processTodo(todo) + if err != nil { + r.errs = append(r.errs, err) + } + if !ok { + dropped++ + } + } + fmt.Printf("Processed %d TODOs, %d were skipped (%d errors)\n", len(r.todos)-dropped, dropped, len(r.errs)) + + return r.errs +} + +func (r *Reviver) processPath(path string, wg *sync.WaitGroup) { + fmt.Printf("Processing dir %q\n", path) + fis, err := ioutil.ReadDir(path) + if err != nil { + r.addErr(fmt.Errorf("error processing dir %q: %v", path, err)) + return + } + + for _, fi := range fis { + childPath := filepath.Join(path, fi.Name()) + switch { + case fi.Mode().IsDir(): + wg.Add(1) + go func() { + defer wg.Done() + r.processPath(childPath, wg) + }() + + case fi.Mode().IsRegular(): + file, err := os.Open(childPath) + if err != nil { + r.addErr(err) + continue + } + + scanner := bufio.NewScanner(file) + lineno := uint(0) + for scanner.Scan() { + lineno++ + line := scanner.Text() + if todo := r.processLine(line, childPath, lineno); todo != nil { + r.addTodo(todo) + } + } + } + } +} + +func (r *Reviver) processLine(line, path string, lineno uint) *Todo { + matches := regexTodo.FindStringSubmatch(line) + if matches == nil { + return nil + } + if len(matches) != 5 { + panic(fmt.Sprintf("regex returned wrong matches for %q: %v", line, matches)) + } + return &Todo{ + Issue: matches[3], + Locations: []Location{ + { + File: path, + Line: lineno, + Comment: matches[4], + }, + }, + } +} + +func (r *Reviver) addTodo(newTodo *Todo) { + r.mu.Lock() + defer r.mu.Unlock() + + if todo := r.todos[newTodo.Issue]; todo == nil { + r.todos[newTodo.Issue] = newTodo + } else { + todo.Locations = append(todo.Locations, newTodo.Locations...) + } +} + +func (r *Reviver) addErr(err error) { + r.mu.Lock() + defer r.mu.Unlock() + r.errs = append(r.errs, err) +} + +func (r *Reviver) processTodo(todo *Todo) (bool, error) { + for _, bugger := range r.buggers { + ok, err := bugger.Activate(todo) + if err != nil { + return false, err + } + if ok { + return true, nil + } + } + return false, nil +} diff --git a/tools/issue_reviver/reviver/reviver_test.go b/tools/issue_reviver/reviver/reviver_test.go new file mode 100644 index 000000000..a9fb1f9f1 --- /dev/null +++ b/tools/issue_reviver/reviver/reviver_test.go @@ -0,0 +1,88 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reviver + +import ( + "testing" +) + +func TestProcessLine(t *testing.T) { + for _, tc := range []struct { + line string + want *Todo + }{ + { + line: "// TODO(foobar.com/issue/123): comment, bla. blabla.", + want: &Todo{ + Issue: "foobar.com/issue/123", + Locations: []Location{ + {Comment: "comment, bla. blabla."}, + }, + }, + }, + { + line: "// FIXME(b/123): internal bug", + want: &Todo{ + Issue: "b/123", + Locations: []Location{ + {Comment: "internal bug"}, + }, + }, + }, + { + line: "TODO(issue): not todo", + }, + { + line: "FIXME(issue): not todo", + }, + { + line: "// TODO (issue): not todo", + }, + { + line: "// TODO(issue) not todo", + }, + { + line: "// todo(issue): not todo", + }, + { + line: "// TODO(issue):", + }, + } { + t.Logf("Testing: %s", tc.line) + r := Reviver{} + got := r.processLine(tc.line, "test", 0) + if got == nil { + if tc.want != nil { + t.Errorf("failed to process line, want: %+v", tc.want) + } + } else { + if tc.want == nil { + t.Errorf("expected error, got: %+v", got) + continue + } + if got.Issue != tc.want.Issue { + t.Errorf("wrong issue, got: %v, want: %v", got.Issue, tc.want.Issue) + } + if len(got.Locations) != len(tc.want.Locations) { + t.Errorf("wrong number of locations, got: %v, want: %v, locations: %+v", len(got.Locations), len(tc.want.Locations), got.Locations) + } + for i, wantLoc := range tc.want.Locations { + if got.Locations[i].Comment != wantLoc.Comment { + t.Errorf("wrong comment, got: %v, want: %v", got.Locations[i].Comment, wantLoc.Comment) + } + } + } + } +} diff --git a/tools/make_apt.sh b/tools/make_apt.sh new file mode 100755 index 000000000..3fb1066e5 --- /dev/null +++ b/tools/make_apt.sh @@ -0,0 +1,139 @@ +#!/bin/bash + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [[ "$#" -le 3 ]]; then + echo "usage: $0 <private-key> <suite> <root> <packages...>" + exit 1 +fi +declare -r private_key=$(readlink -e "$1"); shift +declare -r suite="$1"; shift +declare -r root="$1"; shift + +# Ensure that we have the correct packages installed. +function apt_install() { + while true; do + sudo apt-get update && + sudo apt-get install -y "$@" && + true + result="${?}" + case $result in + 0) + break + ;; + 100) + # 100 is the error code that apt-get returns. + ;; + *) + exit $result + ;; + esac + done +} +dpkg-sig --help >/dev/null 2>&1 || apt_install dpkg-sig +apt-ftparchive --help >/dev/null 2>&1 || apt_install apt-utils +xz --help >/dev/null 2>&1 || apt_install xz-utils + +# Verbose from this point. +set -xeo pipefail + +# Create a directory for the release. +declare -r release="${root}/dists/${suite}" +mkdir -p "${release}" + +# Create a temporary keyring, and ensure it is cleaned up. +declare -r keyring=$(mktemp /tmp/keyringXXXXXX.gpg) +cleanup() { + rm -f "${keyring}" +} +trap cleanup EXIT + +# We attempt the import twice because the first one will fail if the public key +# is not found. This isn't actually a failure for us, because we don't require +# the public (this may be stored separately). The second import will succeed +# because, in reality, the first import succeeded and it's a no-op. +gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" || \ + gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" + +# Copy the packages into the root. +for pkg in "$@"; do + ext=${pkg##*.} + name=$(basename "${pkg}" ".${ext}") + arch=${name##*_} + if [[ "${name}" == "${arch}" ]]; then + continue # Not a regular package. + fi + if [[ "${pkg}" =~ ^.*\.deb$ ]]; then + # Extract from the debian file. + version=$(dpkg --info "${pkg}" | grep -E 'Version:' | cut -d':' -f2) + elif [[ "${pkg}" =~ ^.*\.changes$ ]]; then + # Extract from the changes file. + version=$(grep -E 'Version:' "${pkg}" | cut -d':' -f2) + else + # Unsupported file type. + echo "Unknown file type: ${pkg}" + exit 1 + fi + + # The package may already exist, in which case we leave it alone. + version=${version// /} # Trim whitespace. + destdir="${root}/pool/${version}/binary-${arch}" + target="${destdir}/${name}.${ext}" + if [[ -f "${target}" ]]; then + continue + fi + + # Copy & sign the package. + mkdir -p "${destdir}" + cp -a "${pkg}" "${target}" + chmod 0644 "${target}" + if [[ "${ext}" == "deb" ]]; then + dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${target}" + fi +done + +# Build the package list. +declare arches=() +for dir in "${root}"/pool/*/binary-*; do + name=$(basename "${dir}") + arch=${name##binary-} + arches+=("${arch}") + repo_packages="${release}"/main/"${name}" + mkdir -p "${repo_packages}" + (cd "${root}" && apt-ftparchive --arch "${arch}" packages pool > "${repo_packages}"/Packages) + (cd "${repo_packages}" && cat Packages | gzip > Packages.gz) + (cd "${repo_packages}" && cat Packages | xz > Packages.xz) +done + +# Build the release list. +cat > "${release}"/apt.conf <<EOF +APT { + FTPArchive { + Release { + Architectures "${arches[@]}"; + Suite "${suite}"; + Components "main"; + }; + }; +}; +EOF +(cd "${release}" && apt-ftparchive -c=apt.conf release . > Release) +rm "${release}"/apt.conf + +# Sign the release. +declare -r digest_opts=("--digest-algo" "SHA512" "--cert-digest-algo" "SHA512") +(cd "${release}" && rm -f Release.gpg InRelease) +(cd "${release}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign "${digest_opts[@]}" -o InRelease Release) +(cd "${release}" && gpg --no-default-keyring --keyring "${keyring}" -abs "${digest_opts[@]}" -o Release.gpg Release) diff --git a/tools/make_release.sh b/tools/make_release.sh new file mode 100755 index 000000000..b1cdd47b0 --- /dev/null +++ b/tools/make_release.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [[ "$#" -le 2 ]]; then + echo "usage: $0 <private-key> <root> <binaries & packages...>" + echo "The environment variable NIGHTLY may be set to control" + echo "whether the nightly packages are produced or not." + exit 1 +fi + +set -xeo pipefail +declare -r private_key="$1"; shift +declare -r root="$1"; shift +declare -a binaries +declare -a pkgs + +# Collect binaries & packages. +for arg in "$@"; do + if [[ "${arg}" == *.deb ]] || [[ "${arg}" == *.changes ]]; then + pkgs+=("${arg}") + else + binaries+=("${arg}") + fi +done + +# install_raw installs raw artifacts. +install_raw() { + mkdir -p "${root}/$1" + for binary in "${binaries[@]}"; do + # Copy the raw file & generate a sha512sum. + name=$(basename "${binary}") + cp -f "${binary}" "${root}/$1" + sha512sum "${root}/$1/${name}" | \ + awk "{print $$1 \" ${name}\"}" > "${root}/$1/${name}.sha512" + done +} + +# install_apt installs an apt repository. +install_apt() { + tools/make_apt.sh "${private_key}" "$1" "${root}" "${pkgs[@]}" +} + +# If nightly, install only nightly artifacts. +if [[ "${NIGHTLY:-false}" == "true" ]]; then + # The "latest" directory and current date. + stamp="$(date -Idate)" + install_raw "nightly/latest" + install_raw "nightly/${stamp}" + install_apt "nightly" +else + # Is it a tagged release? Build that. + tags="$(git tag --points-at HEAD 2>/dev/null || true)" + if ! [[ -z "${tags}" ]]; then + # Note that a given commit can match any number of tags. We have to iterate + # through all possible tags and produce associated artifacts. + for tag in ${tags}; do + name=$(echo "${tag}" | cut -d'-' -f2) + base=$(echo "${name}" | cut -d'.' -f1) + install_raw "release/${name}" + install_raw "release/latest" + install_apt "release" + install_apt "${base}" + done + else + # Otherwise, assume it is a raw master commit. + install_raw "master/latest" + install_apt "master" + fi +fi diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD new file mode 100644 index 000000000..c21b09511 --- /dev/null +++ b/tools/nogo/BUILD @@ -0,0 +1,49 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "nogo", + srcs = [ + "build.go", + "config.go", + "matchers.go", + "nogo.go", + "register.go", + ], + nogo = False, + visibility = ["//:sandbox"], + deps = [ + "//tools/checkescape", + "//tools/checkunsafe", + "//tools/nogo/data", + "@org_golang_x_tools//go/analysis:go_tool_library", + "@org_golang_x_tools//go/analysis/internal/facts:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/asmdecl:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/assign:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/atomic:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/bools:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/buildtag:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/cgocall:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/composite:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/copylock:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/errorsas:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/httpresponse:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/loopclosure:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/lostcancel:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/nilfunc:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/nilness:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/printf:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/shadow:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/shift:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/stdmethods:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/stringintconv:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/structtag:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/tests:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/unmarshal:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/unreachable:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/unsafeptr:go_tool_library", + "@org_golang_x_tools//go/analysis/passes/unusedresult:go_tool_library", + "@org_golang_x_tools//go/gcexportdata:go_tool_library", + ], +) diff --git a/tools/nogo/README.md b/tools/nogo/README.md new file mode 100644 index 000000000..6e4db18de --- /dev/null +++ b/tools/nogo/README.md @@ -0,0 +1,31 @@ +# Extended "nogo" analysis + +This package provides a build aspect that perform nogo analysis. This will be +automatically injected to all relevant libraries when using the default +`go_binary` and `go_library` rules. + +It exists for several reasons. + +* The default `nogo` provided by bazel is insufficient with respect to the + possibility of binary analysis. This package allows us to analyze the + generated binary in addition to using the standard analyzers. + +* The configuration provided in this package is much richer than the standard + `nogo` JSON blob. Specifically, it allows us to exclude specific structures + from the composite rules (such as the Ranges that are common with the set + types). + +* The bazel version of `nogo` is run directly against the `go_library` and + `go_binary` targets, meaning that any change to the configuration requires a + rebuild from scratch (for some reason included all C++ source files in the + process). Using an aspect is more efficient in this regard. + +* The checks supported by this package are exported as tests, which makes it + easier to reason about and plumb into the build system. + +* For uninteresting reasons, it is impossible to integrate the default `nogo` + analyzer provided by bazel with internal Google tooling. To provide a + consistent experience, this package allows those systems to be unified. + +To use this package, import `nogo_test` from `defs.bzl` and add a single +dependency which is a `go_binary` or `go_library` rule. diff --git a/tools/nogo/build.go b/tools/nogo/build.go new file mode 100644 index 000000000..1c0d08661 --- /dev/null +++ b/tools/nogo/build.go @@ -0,0 +1,36 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nogo + +import ( + "fmt" + "io" + "os" +) + +var ( + // internalPrefix is the internal path prefix. Note that this is not + // special, as paths should be passed relative to the repository root + // and should not have any special prefix applied. + internalPrefix = fmt.Sprintf("^") + + // externalPrefix is external workspace packages. + externalPrefix = "^external/" +) + +// findStdPkg needs to find the bundled standard library packages. +func findStdPkg(path, GOOS, GOARCH string) (io.ReadCloser, error) { + return os.Open(fmt.Sprintf("external/go_sdk/pkg/%s_%s/%s.a", GOOS, GOARCH, path)) +} diff --git a/tools/nogo/check/BUILD b/tools/nogo/check/BUILD new file mode 100644 index 000000000..e2d76cd5c --- /dev/null +++ b/tools/nogo/check/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +# Note that the check binary must be public, since an aspect may be applied +# across lots of different rules in different repositories. +go_binary( + name = "check", + srcs = ["main.go"], + visibility = ["//visibility:public"], + deps = ["//tools/nogo"], +) diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go new file mode 100644 index 000000000..3828edf3a --- /dev/null +++ b/tools/nogo/check/main.go @@ -0,0 +1,24 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary check is the nogo entrypoint. +package main + +import ( + "gvisor.dev/gvisor/tools/nogo" +) + +func main() { + nogo.Main() +} diff --git a/tools/nogo/config.go b/tools/nogo/config.go new file mode 100644 index 000000000..6958fca69 --- /dev/null +++ b/tools/nogo/config.go @@ -0,0 +1,116 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nogo + +import ( + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/asmdecl" + "golang.org/x/tools/go/analysis/passes/assign" + "golang.org/x/tools/go/analysis/passes/atomic" + "golang.org/x/tools/go/analysis/passes/bools" + "golang.org/x/tools/go/analysis/passes/buildtag" + "golang.org/x/tools/go/analysis/passes/cgocall" + "golang.org/x/tools/go/analysis/passes/composite" + "golang.org/x/tools/go/analysis/passes/copylock" + "golang.org/x/tools/go/analysis/passes/errorsas" + "golang.org/x/tools/go/analysis/passes/httpresponse" + "golang.org/x/tools/go/analysis/passes/loopclosure" + "golang.org/x/tools/go/analysis/passes/lostcancel" + "golang.org/x/tools/go/analysis/passes/nilfunc" + "golang.org/x/tools/go/analysis/passes/nilness" + "golang.org/x/tools/go/analysis/passes/printf" + "golang.org/x/tools/go/analysis/passes/shadow" + "golang.org/x/tools/go/analysis/passes/shift" + "golang.org/x/tools/go/analysis/passes/stdmethods" + "golang.org/x/tools/go/analysis/passes/stringintconv" + "golang.org/x/tools/go/analysis/passes/structtag" + "golang.org/x/tools/go/analysis/passes/tests" + "golang.org/x/tools/go/analysis/passes/unmarshal" + "golang.org/x/tools/go/analysis/passes/unreachable" + "golang.org/x/tools/go/analysis/passes/unsafeptr" + "golang.org/x/tools/go/analysis/passes/unusedresult" + + "gvisor.dev/gvisor/tools/checkescape" + "gvisor.dev/gvisor/tools/checkunsafe" +) + +var analyzerConfig = map[*analysis.Analyzer]matcher{ + // Standard analyzers. + asmdecl.Analyzer: alwaysMatches(), + assign.Analyzer: externalExcluded( + ".*gazelle/walk/walk.go", // False positive. + ), + atomic.Analyzer: alwaysMatches(), + bools.Analyzer: alwaysMatches(), + buildtag.Analyzer: alwaysMatches(), + cgocall.Analyzer: alwaysMatches(), + composite.Analyzer: and( + disableMatches(), // Disabled for now. + resultExcluded{ + "Object_", + "Range{", + }, + ), + copylock.Analyzer: internalMatches(), // Common external issues (e.g. protos). + errorsas.Analyzer: alwaysMatches(), + httpresponse.Analyzer: alwaysMatches(), + loopclosure.Analyzer: alwaysMatches(), + lostcancel.Analyzer: internalMatches(), // Common external issues. + nilfunc.Analyzer: alwaysMatches(), + nilness.Analyzer: and( + internalMatches(), // Common "tautological checks". + internalExcluded( + "pkg/sentry/platform/kvm/kvm_test.go", // Intentional. + "tools/bigquery/bigquery.go", // False positive. + ), + ), + printf.Analyzer: alwaysMatches(), + shift.Analyzer: alwaysMatches(), + stdmethods.Analyzer: internalMatches(), // Common external issues (e.g. methods named "Write"). + stringintconv.Analyzer: and( + internalExcluded(), + externalExcluded( + ".*protobuf/.*.go", // Bad conversions. + ".*flate/huffman_bit_writer.go", // Bad conversion. + ), + ), + shadow.Analyzer: disableMatches(), // Disabled for now. + structtag.Analyzer: internalMatches(), // External not subject to rules. + tests.Analyzer: alwaysMatches(), + unmarshal.Analyzer: alwaysMatches(), + unreachable.Analyzer: internalMatches(), + unsafeptr.Analyzer: and( + internalMatches(), + internalExcluded( + ".*_test.go", // Exclude tests. + "pkg/flipcall/.*_unsafe.go", // Special case. + "pkg/gohacks/gohacks_unsafe.go", // Special case. + "pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go", // Special case. + "pkg/sentry/platform/kvm/bluepill_unsafe.go", // Special case. + "pkg/sentry/platform/kvm/machine_unsafe.go", // Special case. + "pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go", // Special case. + "pkg/sentry/platform/safecopy/safecopy_unsafe.go", // Special case. + "pkg/sentry/vfs/mount_unsafe.go", // Special case. + "pkg/sentry/platform/systrap/stub_unsafe.go", // Special case. + "pkg/sentry/platform/systrap/switchto_google_unsafe.go", // Special case. + "pkg/sentry/platform/systrap/sysmsg_thread_unsafe.go", // Special case. + ), + ), + unusedresult.Analyzer: alwaysMatches(), + + // Internal analyzers: external packages not subject. + checkescape.Analyzer: internalMatches(), + checkunsafe.Analyzer: internalMatches(), +} diff --git a/tools/nogo/data/BUILD b/tools/nogo/data/BUILD new file mode 100644 index 000000000..b7564cc44 --- /dev/null +++ b/tools/nogo/data/BUILD @@ -0,0 +1,10 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "data", + srcs = ["data.go"], + nogo = False, + visibility = ["//tools:__subpackages__"], +) diff --git a/tools/nogo/data/data.go b/tools/nogo/data/data.go new file mode 100644 index 000000000..eb84d0d27 --- /dev/null +++ b/tools/nogo/data/data.go @@ -0,0 +1,21 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package data contains shared data for nogo analysis. +// +// This is used to break a dependency cycle. +package data + +// Objdump is the dumped binary under analysis. +var Objdump string diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl new file mode 100644 index 000000000..6560b57c8 --- /dev/null +++ b/tools/nogo/defs.bzl @@ -0,0 +1,172 @@ +"""Nogo rules.""" + +load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule") + +# NogoInfo is the serialized set of package facts for a nogo analysis. +# +# Each go_library rule will generate a corresponding nogo rule, which will run +# with the source files as input. Note however, that the individual nogo rules +# are simply stubs that enter into the shadow dependency tree (the "aspect"). +NogoInfo = provider( + fields = { + "facts": "serialized package facts", + "importpath": "package import path", + "binaries": "package binary files", + }, +) + +def _nogo_aspect_impl(target, ctx): + # If this is a nogo rule itself (and not the shadow of a go_library or + # go_binary rule created by such a rule), then we simply return nothing. + # All work is done in the shadow properties for go rules. For a proto + # library, we simply skip the analysis portion but still need to return a + # valid NogoInfo to reference the generated binary. + if ctx.rule.kind == "go_library": + srcs = ctx.rule.files.srcs + elif ctx.rule.kind == "go_proto_library" or ctx.rule.kind == "go_wrap_cc": + srcs = [] + else: + return [NogoInfo()] + + # Construct the Go environment from the go_context.env dictionary. + env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_context(ctx).env.items()]) + + # Start with all target files and srcs as input. + inputs = target.files.to_list() + srcs + + # Generate a shell script that dumps the binary. Annoyingly, this seems + # necessary as the context in which a run_shell command runs does not seem + # to cleanly allow us redirect stdout to the actual output file. Perhaps + # I'm missing something here, but the intermediate script does work. + binaries = target.files.to_list() + disasm_file = ctx.actions.declare_file(target.label.name + ".out") + dumper = ctx.actions.declare_file("%s-dumper" % ctx.label.name) + ctx.actions.write(dumper, "\n".join([ + "#!/bin/bash", + "%s %s tool objdump %s > %s\n" % ( + env_prefix, + go_context(ctx).go.path, + [f.path for f in binaries if f.path.endswith(".a")][0], + disasm_file.path, + ), + ]), is_executable = True) + ctx.actions.run( + inputs = binaries, + outputs = [disasm_file], + tools = go_context(ctx).runfiles, + mnemonic = "GoObjdump", + progress_message = "Objdump %s" % target.label, + executable = dumper, + ) + inputs.append(disasm_file) + + # Extract the importpath for this package. + importpath = go_importpath(target) + + # The nogo tool requires a configfile serialized in JSON format to do its + # work. This must line up with the nogo.Config fields. + facts = ctx.actions.declare_file(target.label.name + ".facts") + config = struct( + ImportPath = importpath, + GoFiles = [src.path for src in srcs if src.path.endswith(".go")], + NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")], + GOOS = go_context(ctx).goos, + GOARCH = go_context(ctx).goarch, + Tags = go_context(ctx).tags, + FactMap = {}, # Constructed below. + ImportMap = {}, # Constructed below. + FactOutput = facts.path, + Objdump = disasm_file.path, + ) + + # Collect all info from shadow dependencies. + for dep in ctx.rule.attr.deps: + # There will be no file attribute set for all transitive dependencies + # that are not go_library or go_binary rules, such as a proto rules. + # This is handled by the ctx.rule.kind check above. + info = dep[NogoInfo] + if not hasattr(info, "facts"): + continue + + # Configure where to find the binary & fact files. Note that this will + # use .x and .a regardless of whether this is a go_binary rule, since + # these dependencies must be go_library rules. + x_files = [f.path for f in info.binaries if f.path.endswith(".x")] + if not len(x_files): + x_files = [f.path for f in info.binaries if f.path.endswith(".a")] + config.ImportMap[info.importpath] = x_files[0] + config.FactMap[info.importpath] = info.facts.path + + # Ensure the above are available as inputs. + inputs.append(info.facts) + inputs += info.binaries + + # Write the configuration and run the tool. + config_file = ctx.actions.declare_file(target.label.name + ".cfg") + ctx.actions.write(config_file, config.to_json()) + inputs.append(config_file) + + # Run the nogo tool itself. + ctx.actions.run( + inputs = inputs, + outputs = [facts], + tools = go_context(ctx).runfiles, + executable = ctx.files._nogo[0], + mnemonic = "GoStaticAnalysis", + progress_message = "Analyzing %s" % target.label, + arguments = ["-config=%s" % config_file.path], + ) + + # Return the package facts as output. + return [NogoInfo( + facts = facts, + importpath = importpath, + binaries = binaries, + )] + +nogo_aspect = go_rule( + aspect, + implementation = _nogo_aspect_impl, + attr_aspects = ["deps"], + attrs = { + "_nogo": attr.label( + default = "//tools/nogo/check:check", + allow_single_file = True, + ), + }, +) + +def _nogo_test_impl(ctx): + """Check nogo findings.""" + + # Build a runner that checks for the existence of the facts file. Note that + # the actual build will fail in the case of a broken analysis. We things + # this way so that any test applied is effectively pushed down to all + # upstream dependencies through the aspect. + inputs = [] + runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) + runner_content = ["#!/bin/bash"] + for dep in ctx.attr.deps: + info = dep[NogoInfo] + inputs.append(info.facts) + + # Draw a sweet unicode checkmark with the package name (in green). + runner_content.append("echo -e \"\\033[0;32m\\xE2\\x9C\\x94\\033[0;31m\\033[0m %s\"" % info.importpath) + runner_content.append("exit 0\n") + ctx.actions.write(runner, "\n".join(runner_content), is_executable = True) + return [DefaultInfo( + runfiles = ctx.runfiles(files = inputs), + executable = runner, + )] + +_nogo_test = rule( + implementation = _nogo_test_impl, + attrs = { + "deps": attr.label_list(aspects = [nogo_aspect]), + }, + test = True, +) + +def nogo_test(**kwargs): + tags = kwargs.pop("tags", []) + ["nogo"] + _nogo_test(tags = tags, **kwargs) diff --git a/tools/nogo/io_bazel_rules_go-visibility.patch b/tools/nogo/io_bazel_rules_go-visibility.patch new file mode 100644 index 000000000..6b64b2e85 --- /dev/null +++ b/tools/nogo/io_bazel_rules_go-visibility.patch @@ -0,0 +1,25 @@ +diff --git a/third_party/org_golang_x_tools-extras.patch b/third_party/org_golang_x_tools-extras.patch +index 133fbccc..5f0d9a47 100644 +--- a/third_party/org_golang_x_tools-extras.patch ++++ b/third_party/org_golang_x_tools-extras.patch +@@ -32,7 +32,7 @@ diff -urN c/go/analysis/internal/facts/BUILD.bazel d/go/analysis/internal/facts/ + + go_library( + name = "go_default_library", +-@@ -14,6 +14,23 @@ ++@@ -14,6 +14,20 @@ + ], + ) + +@@ -43,10 +43,7 @@ diff -urN c/go/analysis/internal/facts/BUILD.bazel d/go/analysis/internal/facts/ + + "imports.go", + + ], + + importpath = "golang.org/x/tools/go/analysis/internal/facts", +-+ visibility = [ +-+ "//go/analysis:__subpackages__", +-+ "@io_bazel_rules_go//go/tools/builders:__pkg__", +-+ ], +++ visibility = ["//visibility:public"], + + deps = [ + + "//go/analysis:go_tool_library", + + "//go/types/objectpath:go_tool_library", diff --git a/tools/nogo/matchers.go b/tools/nogo/matchers.go new file mode 100644 index 000000000..57a250501 --- /dev/null +++ b/tools/nogo/matchers.go @@ -0,0 +1,143 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nogo + +import ( + "go/token" + "path/filepath" + "regexp" + "strings" + + "golang.org/x/tools/go/analysis" +) + +type matcher interface { + ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool +} + +// pathRegexps filters explicit paths. +type pathRegexps struct { + expr []*regexp.Regexp + + // include, if true, indicates that paths matching any regexp in expr + // match. + // + // If false, paths matching no regexps in expr match. + include bool +} + +// buildRegexps builds a list of regular expressions. +// +// This will panic on error. +func buildRegexps(prefix string, args ...string) []*regexp.Regexp { + result := make([]*regexp.Regexp, 0, len(args)) + for _, arg := range args { + result = append(result, regexp.MustCompile(filepath.Join(prefix, arg))) + } + return result +} + +// ShouldReport implements matcher.ShouldReport. +func (p *pathRegexps) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool { + fullPos := fs.Position(d.Pos).String() + for _, path := range p.expr { + if path.MatchString(fullPos) { + return p.include + } + } + return !p.include +} + +// internalExcluded excludes specific internal paths. +func internalExcluded(paths ...string) *pathRegexps { + return &pathRegexps{ + expr: buildRegexps(internalPrefix, paths...), + include: false, + } +} + +// excludedExcluded excludes specific external paths. +func externalExcluded(paths ...string) *pathRegexps { + return &pathRegexps{ + expr: buildRegexps(externalPrefix, paths...), + include: false, + } +} + +// internalMatches returns a path matcher for internal packages. +func internalMatches() *pathRegexps { + return &pathRegexps{ + expr: buildRegexps(internalPrefix, ".*"), + include: true, + } +} + +// resultExcluded excludes explicit message contents. +type resultExcluded []string + +// ShouldReport implements matcher.ShouldReport. +func (r resultExcluded) ShouldReport(d analysis.Diagnostic, _ *token.FileSet) bool { + for _, str := range r { + if strings.Contains(d.Message, str) { + return false + } + } + return true // Not excluded. +} + +// andMatcher is a composite matcher. +type andMatcher struct { + first matcher + second matcher +} + +// ShouldReport implements matcher.ShouldReport. +func (a *andMatcher) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool { + return a.first.ShouldReport(d, fs) && a.second.ShouldReport(d, fs) +} + +// and is a syntactic convension for andMatcher. +func and(first matcher, second matcher) *andMatcher { + return &andMatcher{ + first: first, + second: second, + } +} + +// anyMatcher matches everything. +type anyMatcher struct{} + +// ShouldReport implements matcher.ShouldReport. +func (anyMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool { + return true +} + +// alwaysMatches returns an anyMatcher instance. +func alwaysMatches() anyMatcher { + return anyMatcher{} +} + +// neverMatcher will never match. +type neverMatcher struct{} + +// ShouldReport implements matcher.ShouldReport. +func (neverMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool { + return false +} + +// disableMatches returns a neverMatcher instance. +func disableMatches() neverMatcher { + return neverMatcher{} +} diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go new file mode 100644 index 000000000..203cdf688 --- /dev/null +++ b/tools/nogo/nogo.go @@ -0,0 +1,316 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package nogo implements binary analysis similar to bazel's nogo, +// or the unitchecker package. It exists in order to provide additional +// facilities for analysis, namely plumbing through the output from +// dumping the generated binary (to analyze actual produced code). +package nogo + +import ( + "encoding/json" + "flag" + "fmt" + "go/ast" + "go/build" + "go/parser" + "go/token" + "go/types" + "io" + "io/ioutil" + "log" + "os" + "path/filepath" + "reflect" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/internal/facts" + "golang.org/x/tools/go/gcexportdata" + "gvisor.dev/gvisor/tools/nogo/data" +) + +// pkgConfig is serialized as the configuration. +// +// This contains everything required for the analysis. +type pkgConfig struct { + ImportPath string + GoFiles []string + NonGoFiles []string + Tags []string + GOOS string + GOARCH string + ImportMap map[string]string + FactMap map[string]string + FactOutput string + Objdump string +} + +// loadFacts finds and loads facts per FactMap. +func (c *pkgConfig) loadFacts(path string) ([]byte, error) { + realPath, ok := c.FactMap[path] + if !ok { + return nil, nil // No facts available. + } + + // Read the files file. + data, err := ioutil.ReadFile(realPath) + if err != nil { + return nil, err + } + return data, nil +} + +// shouldInclude indicates whether the file should be included. +// +// NOTE: This does only basic parsing of tags. +func (c *pkgConfig) shouldInclude(path string) (bool, error) { + ctx := build.Default + ctx.GOOS = c.GOOS + ctx.GOARCH = c.GOARCH + ctx.BuildTags = c.Tags + return ctx.MatchFile(filepath.Dir(path), filepath.Base(path)) +} + +// importer is an implementation of go/types.Importer. +// +// This wraps a configuration, which provides the map of package names to +// files, and the facts. Note that this importer implementation will always +// pass when a given package is not available. +type importer struct { + pkgConfig + fset *token.FileSet + cache map[string]*types.Package +} + +// Import implements types.Importer.Import. +func (i *importer) Import(path string) (*types.Package, error) { + if path == "unsafe" { + // Special case: go/types has pre-defined type information for + // unsafe. We ensure that this package is correct, in case any + // analyzers are specifically looking for this. + return types.Unsafe, nil + } + realPath, ok := i.ImportMap[path] + var ( + rc io.ReadCloser + err error + ) + if !ok { + // Not found in the import path. Attempt to find the package + // via the standard library. + rc, err = findStdPkg(path, i.GOOS, i.GOARCH) + } else { + // Open the file. + rc, err = os.Open(realPath) + } + if err != nil { + return nil, err + } + defer rc.Close() + + // Load all exported data. + r, err := gcexportdata.NewReader(rc) + if err != nil { + return nil, err + } + + return gcexportdata.Read(r, i.fset, i.cache, path) +} + +// checkPackage runs all analyzers. +// +// The implementation was adapted from [1], which was in turn adpated from [2]. +// This returns a list of matching analysis issues, or an error if the analysis +// could not be completed. +// +// [1] bazelbuid/rules_go/tools/builders/nogo_main.go +// [2] golang.org/x/tools/go/checker/internal/checker +func checkPackage(config pkgConfig) ([]string, error) { + imp := &importer{ + pkgConfig: config, + fset: token.NewFileSet(), + cache: make(map[string]*types.Package), + } + + // Load all source files. + var syntax []*ast.File + for _, file := range config.GoFiles { + include, err := config.shouldInclude(file) + if err != nil { + return nil, fmt.Errorf("error evaluating file %q: %v", file, err) + } + if !include { + continue + } + s, err := parser.ParseFile(imp.fset, file, nil, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("error parsing file %q: %v", file, err) + } + syntax = append(syntax, s) + } + + // Check type information. + typesSizes := types.SizesFor("gc", config.GOARCH) + typeConfig := types.Config{Importer: imp} + typesInfo := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Uses: make(map[*ast.Ident]types.Object), + Defs: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + types, err := typeConfig.Check(config.ImportPath, imp.fset, syntax, typesInfo) + if err != nil { + return nil, fmt.Errorf("error checking types: %v", err) + } + + // Load all package facts. + facts, err := facts.Decode(types, config.loadFacts) + if err != nil { + return nil, fmt.Errorf("error decoding facts: %v", err) + } + + // Set the binary global for use. + data.Objdump = config.Objdump + + // Register fact types and establish dependencies between analyzers. + // The visit closure will execute recursively, and populate results + // will all required analysis results. + diagnostics := make(map[*analysis.Analyzer][]analysis.Diagnostic) + results := make(map[*analysis.Analyzer]interface{}) + var visit func(*analysis.Analyzer) error // For recursion. + visit = func(a *analysis.Analyzer) error { + if _, ok := results[a]; ok { + return nil + } + + // Run recursively for all dependencies. + for _, req := range a.Requires { + if err := visit(req); err != nil { + return err + } + } + + // Prepare the matcher. + m := analyzerConfig[a] + report := func(d analysis.Diagnostic) { + if m.ShouldReport(d, imp.fset) { + diagnostics[a] = append(diagnostics[a], d) + } + } + + // Run the analysis. + factFilter := make(map[reflect.Type]bool) + for _, f := range a.FactTypes { + factFilter[reflect.TypeOf(f)] = true + } + p := &analysis.Pass{ + Analyzer: a, + Fset: imp.fset, + Files: syntax, + Pkg: types, + TypesInfo: typesInfo, + ResultOf: results, // All results. + Report: report, + ImportPackageFact: facts.ImportPackageFact, + ExportPackageFact: facts.ExportPackageFact, + ImportObjectFact: facts.ImportObjectFact, + ExportObjectFact: facts.ExportObjectFact, + AllPackageFacts: func() []analysis.PackageFact { return facts.AllPackageFacts(factFilter) }, + AllObjectFacts: func() []analysis.ObjectFact { return facts.AllObjectFacts(factFilter) }, + TypesSizes: typesSizes, + } + result, err := a.Run(p) + if err != nil { + return fmt.Errorf("error running analysis %s: %v", a, err) + } + + // Sanity check & save the result. + if got, want := reflect.TypeOf(result), a.ResultType; got != want { + return fmt.Errorf("error: analyzer %s returned a result of type %v, but declared ResultType %v", a, got, want) + } + results[a] = result + return nil // Success. + } + + // Visit all analysis recursively. + for a, _ := range analyzerConfig { + if err := visit(a); err != nil { + return nil, err // Already has context. + } + } + + // Write the output file. + if config.FactOutput != "" { + factData := facts.Encode() + if err := ioutil.WriteFile(config.FactOutput, factData, 0644); err != nil { + return nil, fmt.Errorf("error: unable to open facts output %q: %v", config.FactOutput, err) + } + } + + // Convert all diagnostics to strings. + findings := make([]string, 0, len(diagnostics)) + for a, ds := range diagnostics { + for _, d := range ds { + // Include the anlyzer name for debugability and configuration. + findings = append(findings, fmt.Sprintf("%s: %s: %s", a.Name, imp.fset.Position(d.Pos), d.Message)) + } + } + + // Return all findings. + return findings, nil +} + +var ( + configFile = flag.String("config", "", "configuration file (in JSON format)") +) + +// Main is the entrypoint; it should be called directly from main. +// +// N.B. This package registers it's own flags. +func Main() { + // Parse all flags. + flag.Parse() + + // Load the configuration. + f, err := os.Open(*configFile) + if err != nil { + log.Fatalf("unable to open configuration %q: %v", *configFile, err) + } + defer f.Close() + config := new(pkgConfig) + dec := json.NewDecoder(f) + dec.DisallowUnknownFields() + if err := dec.Decode(config); err != nil { + log.Fatalf("unable to decode configuration: %v", err) + } + + // Process the package. + findings, err := checkPackage(*config) + if err != nil { + log.Fatalf("error checking package: %v", err) + } + + // No findings? + if len(findings) == 0 { + os.Exit(0) + } + + // Print findings and exit with non-zero code. + for _, finding := range findings { + fmt.Fprintf(os.Stdout, "%s\n", finding) + } + os.Exit(1) +} diff --git a/tools/nogo/register.go b/tools/nogo/register.go new file mode 100644 index 000000000..62b499661 --- /dev/null +++ b/tools/nogo/register.go @@ -0,0 +1,64 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nogo + +import ( + "encoding/gob" + "log" + + "golang.org/x/tools/go/analysis" +) + +// analyzers returns all configured analyzers. +func analyzers() (all []*analysis.Analyzer) { + for a, _ := range analyzerConfig { + all = append(all, a) + } + return all +} + +func init() { + // Validate basic configuration. + if err := analysis.Validate(analyzers()); err != nil { + log.Fatalf("unable to validate analyzer: %v", err) + } + + // Register all fact types. + // + // N.B. This needs to be done recursively, because there may be + // analyzers in the Requires list that do not appear explicitly above. + registered := make(map[*analysis.Analyzer]struct{}) + var register func(*analysis.Analyzer) + register = func(a *analysis.Analyzer) { + if _, ok := registered[a]; ok { + return + } + + // Regsiter dependencies. + for _, da := range a.Requires { + register(da) + } + + // Register local facts. + for _, f := range a.FactTypes { + gob.Register(f) + } + + registered[a] = struct{}{} // Done. + } + for _, a := range analyzers() { + register(a) + } +} diff --git a/tools/tag_release.sh b/tools/tag_release.sh new file mode 100755 index 000000000..b0bab74b4 --- /dev/null +++ b/tools/tag_release.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script will optionally map a PiperOrigin-RevId to a given commit, +# validate a provided release name, create a tag and push it. It must be +# run manually when a release is created. + +set -xeuo pipefail + +# Check arguments. +if [[ "$#" -ne 3 ]]; then + echo "usage: $0 <commit|revid> <release.rc> <message-file>" + exit 1 +fi + +declare -r target_commit="$1" +declare -r release="$2" +declare -r message_file="$3" + +if [[ -z "${target_commit}" ]]; then + echo "error: <commit|revid> is empty." +fi +if [[ -z "${release}" ]]; then + echo "error: <release.rc> is empty." +fi +if ! [[ -r "${message_file}" ]]; then + echo "error: message file '${message_file}' is not readable." + exit 1 +fi + +closest_commit() { + while read line; do + if [[ "$line" =~ "commit " ]]; then + current_commit="${line#commit }" + continue + elif [[ "$line" =~ "PiperOrigin-RevId: " ]]; then + revid="${line#PiperOrigin-RevId: }" + [[ "${revid}" -le "$1" ]] && break + fi + done + echo "${current_commit}" +} + +# Is the passed identifier a sha commit? +if ! git show "${target_commit}" &> /dev/null; then + # Extract the commit given a piper ID. + declare -r commit="$(git log | closest_commit "${target_commit}")" +else + declare -r commit="${target_commit}" +fi +if ! git show "${commit}" &> /dev/null; then + echo "unknown commit: ${target_commit}" + exit 1 +fi + +# Is the release name sane? Must be a date with patch/rc. +if ! [[ "${release}" =~ ^20[0-9]{6}\.[0-9]+$ ]]; then + declare -r expected="$(date +%Y%m%d.0)" # Use today's date. + echo "unexpected release format: ${release}" + echo " ... expected like ${expected}" + exit 1 +fi + +# Tag the given commit (annotated, to record the committer). Note that the tag +# here is applied as a force, in case the tag already exists and is the same. +# The push will fail in this case (because it is not forced). +declare -r tag="release-${release}" +git tag -f -F "${message_file}" -a "${tag}" "${commit}" && \ + git push origin tag "${tag}" diff --git a/tools/tags/BUILD b/tools/tags/BUILD new file mode 100644 index 000000000..1c02e2c89 --- /dev/null +++ b/tools/tags/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "tags", + srcs = ["tags.go"], + marshal = False, + stateify = False, + visibility = ["//tools:__subpackages__"], +) diff --git a/tools/tags/tags.go b/tools/tags/tags.go new file mode 100644 index 000000000..f35904e0a --- /dev/null +++ b/tools/tags/tags.go @@ -0,0 +1,89 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tags is a utility for parsing build tags. +package tags + +import ( + "fmt" + "io/ioutil" + "strings" +) + +// OrSet is a set of tags on a single line. +// +// Note that tags may include ",", and we don't distinguish this case in the +// logic below. Ideally, this constraints can be split into separate top-level +// build tags in order to resolve any issues. +type OrSet []string + +// Line returns the line for this or. +func (or OrSet) Line() string { + return fmt.Sprintf("// +build %s", strings.Join([]string(or), " ")) +} + +// AndSet is the set of all OrSets. +type AndSet []OrSet + +// Lines returns the lines to be printed. +func (and AndSet) Lines() (ls []string) { + for _, or := range and { + ls = append(ls, or.Line()) + } + return +} + +// Join joins this AndSet with another. +func (and AndSet) Join(other AndSet) AndSet { + return append(and, other...) +} + +// Tags returns the unique set of +build tags. +// +// Derived form the runtime's canBuild. +func Tags(file string) (tags AndSet) { + data, err := ioutil.ReadFile(file) + if err != nil { + return nil + } + // Check file contents for // +build lines. + for _, p := range strings.Split(string(data), "\n") { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if !strings.HasPrefix(p, "//") { + break + } + if !strings.Contains(p, "+build") { + continue + } + fields := strings.Fields(p[2:]) + if len(fields) < 1 || fields[0] != "+build" { + continue + } + tags = append(tags, OrSet(fields[1:])) + } + return tags +} + +// Aggregate aggregates all tags from a set of files. +// +// Note that these may be in conflict, in which case the build will fail. +func Aggregate(files []string) (tags AndSet) { + for _, file := range files { + tags = tags.Join(Tags(file)) + } + return tags +} diff --git a/tools/vm/BUILD b/tools/vm/BUILD new file mode 100644 index 000000000..f7160c627 --- /dev/null +++ b/tools/vm/BUILD @@ -0,0 +1,57 @@ +load("//tools:defs.bzl", "cc_binary", "gtest") +load("//tools/vm:defs.bzl", "vm_image", "vm_test") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +sh_binary( + name = "zone", + srcs = ["zone.sh"], +) + +sh_binary( + name = "builder", + srcs = ["build.sh"], +) + +sh_binary( + name = "executer", + srcs = ["execute.sh"], +) + +cc_binary( + name = "test", + testonly = 1, + srcs = ["test.cc"], + linkstatic = 1, + deps = [ + gtest, + "//test/util:test_main", + ], +) + +vm_image( + name = "ubuntu1604", + family = "ubuntu-1604-lts", + project = "ubuntu-os-cloud", + scripts = [ + "//tools/vm/ubuntu1604", + ], +) + +vm_image( + name = "ubuntu1804", + family = "ubuntu-1804-lts", + project = "ubuntu-os-cloud", + scripts = [ + "//tools/vm/ubuntu1804", + ], +) + +vm_test( + name = "vm_test", + shard_count = 2, + targets = [":test"], +) diff --git a/tools/vm/README.md b/tools/vm/README.md new file mode 100644 index 000000000..898c95fca --- /dev/null +++ b/tools/vm/README.md @@ -0,0 +1,42 @@ +# VM Images & Tests + +All commands in this directory require the `gcloud` project to be set. + +For example: `gcloud config set project gvisor-kokoro-testing`. + +Images can be generated by using the `vm_image` rule. This rule will generate a +binary target that builds an image in an idempotent way, and can be referenced +from other rules. + +For example: + +``` +vm_image( + name = "ubuntu", + project = "ubuntu-1604-lts", + family = "ubuntu-os-cloud", + scripts = [ + "script.sh", + "other.sh", + ], +) +``` + +These images can be built manually by executing the target. The output on +`stdout` will be the image id (in the current project). + +Images are always named per the hash of all the hermetic input scripts. This +allows images to be memoized quickly and easily. + +The `vm_test` rule can be used to execute a command remotely. This is still +under development however, and will likely change over time. + +For example: + +``` +vm_test( + name = "mycommand", + image = ":ubuntu", + targets = [":test"], +) +``` diff --git a/tools/vm/build.sh b/tools/vm/build.sh new file mode 100755 index 000000000..752b2b77b --- /dev/null +++ b/tools/vm/build.sh @@ -0,0 +1,117 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script is responsible for building a new GCP image that: 1) has nested +# virtualization enabled, and 2) has been completely set up with the +# image_setup.sh script. This script should be idempotent, as we memoize the +# setup script with a hash and check for that name. + +set -eou pipefail + +# Parameters. +declare -r USERNAME=${USERNAME:-test} +declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud} +declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts} +declare -r ZONE=${ZONE:-us-central1-f} + +# Random names. +declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z) +declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z) +declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z) + +# Hash inputs in order to memoize the produced image. +declare -r SETUP_HASH=$( (echo ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && cat "$@") | sha256sum - | cut -d' ' -f1 | cut -c 1-16) +declare -r IMAGE_NAME=${IMAGE_FAMILY:-image}-${SETUP_HASH} + +# Does the image already exist? Skip the build. +declare -r existing=$(set -x; gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)") +if ! [[ -z "${existing}" ]]; then + echo "${existing}" + exit 0 +fi + +# Standard arguments (applies only on script execution). +declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--") + +# gcloud has path errors; is this a result of being a genrule? +export PATH=${PATH:-/bin:/usr/bin:/usr/local/bin} + +# Start a unique instance. Note that this instance will have a unique persistent +# disk as it's boot disk with the same name as the instance. +(set -x; gcloud compute instances create \ + --quiet \ + --image-project "${IMAGE_PROJECT}" \ + --image-family "${IMAGE_FAMILY}" \ + --boot-disk-size "200GB" \ + --zone "${ZONE}" \ + "${INSTANCE_NAME}" >/dev/null) +function cleanup { + (set -x; gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}") +} +trap cleanup EXIT + +# Wait for the instance to become available (up to 5 minutes). +echo -n "Waiting for ${INSTANCE_NAME}" >&2 +declare timeout=300 +declare success=0 +declare internal="" +declare -r start=$(date +%s) +declare -r end=$((${start}+${timeout})) +while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do + echo -n "." >&2 + if gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then + success=$((${success}+1)) + elif gcloud compute ssh --internal-ip --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then + success=$((${success}+1)) + internal="--internal-ip" + fi +done + +if [[ "${success}" -eq "0" ]]; then + echo "connect timed out after ${timeout} seconds." >&2 + exit 1 +else + echo "done." >&2 +fi + +# Run the install scripts provided. +for arg; do + (set -x; gcloud compute ssh ${internal} \ + --zone "${ZONE}" \ + "${USERNAME}"@"${INSTANCE_NAME}" -- \ + "${SSH_ARGS[@]}" \ + sudo bash - <"${arg}" >/dev/null) +done + +# Stop the instance; required before creating an image. +(set -x; gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null) + +# Create a snapshot of the instance disk. +(set -x; gcloud compute disks snapshot \ + --quiet \ + --zone "${ZONE}" \ + --snapshot-names="${SNAPSHOT_NAME}" \ + "${INSTANCE_NAME}" >/dev/null) + +# Create the disk image. +(set -x; gcloud compute images create \ + --quiet \ + --source-snapshot="${SNAPSHOT_NAME}" \ + --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \ + "${IMAGE_NAME}" >/dev/null) + +# Finish up. +echo "${IMAGE_NAME}" diff --git a/tools/vm/defs.bzl b/tools/vm/defs.bzl new file mode 100644 index 000000000..0f67cfa92 --- /dev/null +++ b/tools/vm/defs.bzl @@ -0,0 +1,201 @@ +"""Image configuration. See README.md.""" + +load("//tools:defs.bzl", "default_installer") + +# vm_image_builder is a rule that will construct a shell script that actually +# generates a given VM image. Note that this does not _run_ the shell script +# (although it can be run manually). It will be run manually during generation +# of the vm_image target itself. This level of indirection is used so that the +# build system itself only runs the builder once when multiple targets depend +# on it, avoiding a set of races and conflicts. +def _vm_image_builder_impl(ctx): + # Generate a binary that actually builds the image. + builder = ctx.actions.declare_file(ctx.label.name) + script_paths = [] + for script in ctx.files.scripts: + script_paths.append(script.short_path) + builder_content = "\n".join([ + "#!/bin/bash", + "export ZONE=$(%s)" % ctx.files.zone[0].short_path, + "export USERNAME=%s" % ctx.attr.username, + "export IMAGE_PROJECT=%s" % ctx.attr.project, + "export IMAGE_FAMILY=%s" % ctx.attr.family, + "%s %s" % (ctx.files._builder[0].short_path, " ".join(script_paths)), + "", + ]) + ctx.actions.write(builder, builder_content, is_executable = True) + + # Note that the scripts should only be files, and should not include any + # indirect transitive dependencies. The build script wouldn't work. + return [DefaultInfo( + executable = builder, + runfiles = ctx.runfiles( + files = ctx.files.scripts + ctx.files._builder + ctx.files.zone, + ), + )] + +vm_image_builder = rule( + attrs = { + "_builder": attr.label( + executable = True, + default = "//tools/vm:builder", + cfg = "host", + ), + "username": attr.string(default = "$(whoami)"), + "zone": attr.label( + executable = True, + default = "//tools/vm:zone", + cfg = "host", + ), + "family": attr.string(mandatory = True), + "project": attr.string(mandatory = True), + "scripts": attr.label_list(allow_files = True), + }, + executable = True, + implementation = _vm_image_builder_impl, +) + +# See vm_image_builder above. +def _vm_image_impl(ctx): + # Run the builder to generate our output. + echo = ctx.actions.declare_file(ctx.label.name) + resolved_inputs, argv, runfiles_manifests = ctx.resolve_command( + command = "echo -ne \"#!/bin/bash\\nset -e\\nimage=$(%s)\\necho ${image}\\n\" > %s && chmod 0755 %s" % ( + ctx.files.builder[0].path, + echo.path, + echo.path, + ), + tools = [ctx.attr.builder], + ) + ctx.actions.run_shell( + tools = resolved_inputs, + outputs = [echo], + progress_message = "Building image...", + execution_requirements = {"local": "true"}, + command = argv, + input_manifests = runfiles_manifests, + ) + + # Return just the echo command. All of the builder runfiles have been + # resolved and consumed in the generation of the trivial echo script. + return [DefaultInfo(executable = echo)] + +_vm_image_test = rule( + attrs = { + "builder": attr.label( + executable = True, + cfg = "host", + ), + }, + test = True, + implementation = _vm_image_impl, +) + +def vm_image(name, **kwargs): + vm_image_builder( + name = name + "_builder", + **kwargs + ) + _vm_image_test( + name = name, + builder = ":" + name + "_builder", + tags = [ + "local", + "manual", + ], + ) + +def _vm_test_impl(ctx): + runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) + + # Note that the remote execution case must actually generate an + # intermediate target in order to collect all the relevant runfiles so that + # they can be copied over for remote execution. + runner_content = "\n".join([ + "#!/bin/bash", + "export ZONE=$(%s)" % ctx.files.zone[0].short_path, + "export USERNAME=%s" % ctx.attr.username, + "export IMAGE=$(%s)" % ctx.files.image[0].short_path, + "export SUDO=%s" % "true" if ctx.attr.sudo else "false", + "%s %s" % ( + ctx.executable.executer.short_path, + " ".join([ + target.files_to_run.executable.short_path + for target in ctx.attr.targets + ]), + ), + "", + ]) + ctx.actions.write(runner, runner_content, is_executable = True) + + # Return with all transitive files. + runfiles = ctx.runfiles( + transitive_files = depset(transitive = [ + depset(target.data_runfiles.files) + for target in ctx.attr.targets + if hasattr(target, "data_runfiles") + ]), + files = ctx.files.executer + ctx.files.zone + ctx.files.image + + ctx.files.targets, + collect_default = True, + collect_data = True, + ) + return [DefaultInfo(executable = runner, runfiles = runfiles)] + +_vm_test = rule( + attrs = { + "image": attr.label( + executable = True, + default = "//tools/vm:ubuntu1804", + cfg = "host", + ), + "executer": attr.label( + executable = True, + default = "//tools/vm:executer", + cfg = "host", + ), + "username": attr.string(default = "$(whoami)"), + "zone": attr.label( + executable = True, + default = "//tools/vm:zone", + cfg = "host", + ), + "sudo": attr.bool(default = True), + "machine": attr.string(default = "n1-standard-1"), + "targets": attr.label_list( + mandatory = True, + allow_empty = False, + cfg = "target", + ), + }, + test = True, + implementation = _vm_test_impl, +) + +def vm_test( + installers = None, + **kwargs): + """Runs the given targets as a remote test. + + Args: + installer: Script to run before all targets. + **kwargs: All test arguments. Should include targets and image. + """ + targets = kwargs.pop("targets", []) + if installers == None: + installers = [ + "//tools/installers:head", + "//tools/installers:images", + ] + targets = installers + targets + if default_installer(): + targets = [default_installer()] + targets + _vm_test( + tags = [ + "local", + "manual", + ], + targets = targets, + local = 1, + **kwargs + ) diff --git a/tools/vm/execute.sh b/tools/vm/execute.sh new file mode 100755 index 000000000..1f1f3ce01 --- /dev/null +++ b/tools/vm/execute.sh @@ -0,0 +1,160 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeo pipefail + +# Required input. +if ! [[ -v IMAGE ]]; then + echo "no image provided: set IMAGE." + exit 1 +fi + +# Parameters. +declare -r USERNAME=${USERNAME:-test} +declare -r KEYNAME=$(mktemp --tmpdir -u key-XXXXXX) +declare -r SSHKEYS=$(mktemp --tmpdir -u sshkeys-XXXXXX) +declare -r INSTANCE_NAME=$(mktemp -u test-XXXXXX | tr A-Z a-z) +declare -r MACHINE=${MACHINE:-n1-standard-1} +declare -r ZONE=${ZONE:-us-central1-f} +declare -r SUDO=${SUDO:-false} + +# Standard arguments (applies only on script execution). +declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--") + +# This script is executed as a test rule, which will reset the value of HOME. +# Unfortunately, it is needed to load the gconfig credentials. We will reset +# HOME when we actually execute in the remote environment, defined below. +export HOME=$(eval echo ~$(whoami)) + +# Generate unique keys for this test. +[[ -f "${KEYNAME}" ]] || ssh-keygen -t rsa -N "" -f "${KEYNAME}" -C "${USERNAME}" +cat > "${SSHKEYS}" <<EOF +${USERNAME}:$(cat ${KEYNAME}.pub) +EOF + +# Start a unique instance. This means that we first generate a unique set of ssh +# keys to ensure that only we have access to this instance. Note that we must +# constrain ourselves to Haswell or greater in order to have nested +# virtualization available. +gcloud compute instances create \ + --min-cpu-platform "Intel Haswell" \ + --preemptible \ + --no-scopes \ + --metadata block-project-ssh-keys=TRUE \ + --metadata-from-file ssh-keys="${SSHKEYS}" \ + --machine-type "${MACHINE}" \ + --image "${IMAGE}" \ + --zone "${ZONE}" \ + "${INSTANCE_NAME}" +function cleanup { + gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}" +} +trap cleanup EXIT + +# Wait for the instance to become available (up to 5 minutes). +declare timeout=300 +declare success=0 +declare -r start=$(date +%s) +declare -r end=$((${start}+${timeout})) +while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do + if gcloud compute ssh --ssh-key-file="${KEYNAME}" --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then + success=$((${success}+1)) + fi +done +if [[ "${success}" -eq "0" ]]; then + echo "connect timed out after ${timeout} seconds." + exit 1 +fi + +# Copy the local directory over. +tar czf - --dereference --exclude=.git . | + gcloud compute ssh \ + --ssh-key-file="${KEYNAME}" \ + --zone "${ZONE}" \ + "${USERNAME}"@"${INSTANCE_NAME}" -- \ + "${SSH_ARGS[@]}" \ + tar xzf - + +# Execute the command remotely. +for cmd; do + # Setup relevant environment. + # + # N.B. This is not a complete test environment, but is complete enough to + # provide rudimentary sharding and test output support. + declare -a PREFIX=( "env" ) + if [[ -v TEST_SHARD_INDEX ]]; then + PREFIX+=( "TEST_SHARD_INDEX=${TEST_SHARD_INDEX}" ) + fi + if [[ -v TEST_SHARD_STATUS_FILE ]]; then + SHARD_STATUS_FILE=$(mktemp -u test-shard-status-XXXXXX) + PREFIX+=( "TEST_SHARD_STATUS_FILE=/tmp/${SHARD_STATUS_FILE}" ) + fi + if [[ -v TEST_TOTAL_SHARDS ]]; then + PREFIX+=( "TEST_TOTAL_SHARDS=${TEST_TOTAL_SHARDS}" ) + fi + if [[ -v TEST_TMPDIR ]]; then + REMOTE_TMPDIR=$(mktemp -u test-XXXXXX) + PREFIX+=( "TEST_TMPDIR=/tmp/${REMOTE_TMPDIR}" ) + # Create remotely. + gcloud compute ssh \ + --ssh-key-file="${KEYNAME}" \ + --zone "${ZONE}" \ + "${USERNAME}"@"${INSTANCE_NAME}" -- \ + "${SSH_ARGS[@]}" \ + mkdir -p "/tmp/${REMOTE_TMPDIR}" + fi + if [[ -v XML_OUTPUT_FILE ]]; then + TEST_XML_OUTPUT=$(mktemp -u xml-output-XXXXXX) + PREFIX+=( "XML_OUTPUT_FILE=/tmp/${TEST_XML_OUTPUT}" ) + fi + if [[ "${SUDO}" == "true" ]]; then + PREFIX+=( "sudo" "-E" ) + fi + + # Execute the command. + gcloud compute ssh \ + --ssh-key-file="${KEYNAME}" \ + --zone "${ZONE}" \ + "${USERNAME}"@"${INSTANCE_NAME}" -- \ + "${SSH_ARGS[@]}" \ + "${PREFIX[@]}" "${cmd}" + + # Collect relevant results. + if [[ -v TEST_SHARD_STATUS_FILE ]]; then + gcloud compute scp \ + --ssh-key-file="${KEYNAME}" \ + --zone "${ZONE}" \ + "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${SHARD_STATUS_FILE}" \ + "${TEST_SHARD_STATUS_FILE}" 2>/dev/null || true # Allowed to fail. + fi + if [[ -v XML_OUTPUT_FILE ]]; then + gcloud compute scp \ + --ssh-key-file="${KEYNAME}" \ + --zone "${ZONE}" \ + "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${TEST_XML_OUTPUT}" \ + "${XML_OUTPUT_FILE}" 2>/dev/null || true # Allowed to fail. + fi + + # Clean up the temporary directory. + if [[ -v TEST_TMPDIR ]]; then + gcloud compute ssh \ + --ssh-key-file="${KEYNAME}" \ + --zone "${ZONE}" \ + "${USERNAME}"@"${INSTANCE_NAME}" -- \ + "${SSH_ARGS[@]}" \ + rm -rf "/tmp/${REMOTE_TMPDIR}" + fi +done diff --git a/tools/vm/test.cc b/tools/vm/test.cc new file mode 100644 index 000000000..c0ceacda1 --- /dev/null +++ b/tools/vm/test.cc @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" + +namespace { + +TEST(Image, Sanity0) { + // Do nothing (in shard 0). +} + +TEST(Image, Sanity1) { + // Do nothing (in shard 1). +} + +} // namespace diff --git a/tools/vm/ubuntu1604/10_core.sh b/tools/vm/ubuntu1604/10_core.sh new file mode 100755 index 000000000..629f7cf7a --- /dev/null +++ b/tools/vm/ubuntu1604/10_core.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeo pipefail + +# Install all essential build tools. +while true; do + if (apt-get update && apt-get install -y \ + make \ + git-core \ + build-essential \ + linux-headers-$(uname -r) \ + pkg-config); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done + +# Install a recent go toolchain. +if ! [[ -d /usr/local/go ]]; then + wget https://dl.google.com/go/go1.13.5.linux-amd64.tar.gz + tar -xvf go1.13.5.linux-amd64.tar.gz + mv go /usr/local +fi + +# Link the Go binary from /usr/bin; replacing anything there. +(cd /usr/bin && rm -f go && ln -fs /usr/local/go/bin/go go) diff --git a/tools/vm/ubuntu1604/15_gcloud.sh b/tools/vm/ubuntu1604/15_gcloud.sh new file mode 100755 index 000000000..bc2e5eccc --- /dev/null +++ b/tools/vm/ubuntu1604/15_gcloud.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeo pipefail + +# Install all essential build tools. +while true; do + if (apt-get update && apt-get install -y \ + apt-transport-https \ + ca-certificates \ + gnupg); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done + +# Add gcloud repositories. +echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | \ + tee -a /etc/apt/sources.list.d/google-cloud-sdk.list + +# Add the appropriate key. +curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | \ + apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - + +# Install the gcloud SDK. +while true; do + if (apt-get update && apt-get install -y google-cloud-sdk); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done diff --git a/tools/vm/ubuntu1604/20_bazel.sh b/tools/vm/ubuntu1604/20_bazel.sh new file mode 100755 index 000000000..bb7afa676 --- /dev/null +++ b/tools/vm/ubuntu1604/20_bazel.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeo pipefail + +declare -r BAZEL_VERSION=2.0.0 + +# Install bazel dependencies. +while true; do + if (apt-get update && apt-get install -y \ + openjdk-8-jdk-headless \ + unzip); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done + +# Use the release installer. +curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh +chmod a+x bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh +./bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh +rm -f bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh diff --git a/tools/vm/ubuntu1604/25_docker.sh b/tools/vm/ubuntu1604/25_docker.sh new file mode 100755 index 000000000..53d8ca588 --- /dev/null +++ b/tools/vm/ubuntu1604/25_docker.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Add dependencies. +while true; do + if (apt-get update && apt-get install -y \ + apt-transport-https \ + ca-certificates \ + curl \ + gnupg-agent \ + software-properties-common); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done + +# Install the key. +curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add - + +# Add the repository. +add-apt-repository \ + "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) \ + stable" + +# Install docker. +while true; do + if (apt-get update && apt-get install -y \ + docker-ce \ + docker-ce-cli \ + containerd.io); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done + +# Enable Docker IPv6. +cat > /etc/docker/daemon.json <<EOF +{ + "fixed-cidr-v6": "2001:db8:1::/64", + "ipv6": true +} +EOF +# Docker's IPv6 support is lacking and does not work the same way as IPv4. We +# can use NAT so containers can reach the outside world. +ip6tables -t nat -A POSTROUTING -s 2001:db8:1::/64 ! -o docker0 -j MASQUERADE diff --git a/tools/vm/ubuntu1604/30_containerd.sh b/tools/vm/ubuntu1604/30_containerd.sh new file mode 100755 index 000000000..fb3699c12 --- /dev/null +++ b/tools/vm/ubuntu1604/30_containerd.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeo pipefail + +# Helper for Go packages below. +install_helper() { + PACKAGE="${1}" + TAG="${2}" + GOPATH="${3}" + + # Clone the repository. + mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \ + git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}" + + # Checkout and build the repository. + (cd "${GOPATH}"/src/"${PACKAGE}" && \ + git checkout "${TAG}" && \ + GOPATH="${GOPATH}" make && \ + GOPATH="${GOPATH}" make install) +} + +# Install dependencies for the crictl tests. +while true; do + if (apt-get update && apt-get install -y \ + btrfs-tools \ + libseccomp-dev); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done + +# Install containerd & cri-tools. +GOPATH=$(mktemp -d --tmpdir gopathXXXXX) +install_helper github.com/containerd/containerd v1.2.2 "${GOPATH}" +install_helper github.com/kubernetes-sigs/cri-tools v1.11.0 "${GOPATH}" + +# Install gvisor-containerd-shim. +declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim" +declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX) +declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX) +wget --no-verbose "${base}"/latest -O ${latest} +wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} +chmod +x ${shim_path} +mv ${shim_path} /usr/local/bin + +# Configure containerd-shim. +declare -r shim_config_path=/etc/containerd +declare -r shim_config_tmp_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX.toml) +mkdir -p ${shim_config_path} +cat > ${shim_config_tmp_path} <<-EOF + runc_shim = "/usr/local/bin/containerd-shim" + +[runsc_config] + debug = "true" + debug-log = "/tmp/runsc-logs/" + strace = "true" + file-access = "shared" +EOF +mv ${shim_config_tmp_path} ${shim_config_path} + +# Configure CNI. +(cd "${GOPATH}" && GOPATH="${GOPATH}" \ + src/github.com/containerd/containerd/script/setup/install-cni) + +# Cleanup the above. +rm -rf "${GOPATH}" +rm -rf "${latest}" +rm -rf "${shim_path}" +rm -rf "${shim_config_tmp_path}" diff --git a/tools/vm/ubuntu1604/40_kokoro.sh b/tools/vm/ubuntu1604/40_kokoro.sh new file mode 100755 index 000000000..2974f156c --- /dev/null +++ b/tools/vm/ubuntu1604/40_kokoro.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeo pipefail + +# Declare kokoro's required public keys. +declare -r ssh_public_keys=( + "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDg7L/ZaEauETWrPklUTky3kvxqQfe2Ax/2CsSqhNIGNMnK/8d79CHlmY9+dE1FFQ/RzKNCaltgy7XcN/fCYiCZr5jm2ZtnLuGNOTzupMNhaYiPL419qmL+5rZXt4/dWTrsHbFRACxT8j51PcRMO5wgbL0Bg2XXimbx8kDFaurL2gqduQYqlu4lxWCaJqOL71WogcimeL63Nq/yeH5PJPWpqE4P9VUQSwAzBWFK/hLeds/AiP3MgVS65qHBnhq0JsHy8JQsqjZbG7Iidt/Ll0+gqzEbi62gDIcczG4KC0iOVzDDP/1BxDtt1lKeA23ll769Fcm3rJyoBMYxjvdw1TDx sabujp@trigger.mtv.corp.google.com" + "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNgGK/hCdjmulHfRE3hp4rZs38NCR8yAh0eDsztxqGcuXnuSnL7jOlRrbcQpremJ84omD4eKrIpwJUs+YokMdv4= sabujp@trigger.svl.corp.google.com" +) + +# Install dependencies. +while true; do + if (apt-get update && apt-get install -y \ + rsync \ + coreutils \ + python-psutil \ + qemu-kvm \ + python-pip \ + python3-pip \ + zip); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done + +# junitparser is used to merge junit xml files. +pip install junitparser + +# We need a kbuilder user, which may already exist. +useradd -c "kbuilder user" -m -s /bin/bash kbuilder || true + +# We need to provision appropriate keys. +mkdir -p ~kbuilder/.ssh +(IFS=$'\n'; echo "${ssh_public_keys[*]}") > ~kbuilder/.ssh/authorized_keys +chmod 0600 ~kbuilder/.ssh/authorized_keys +chown -R kbuilder ~kbuilder/.ssh + +# Give passwordless sudo access. +cat > /etc/sudoers.d/kokoro <<EOF +kbuilder ALL=(ALL) NOPASSWD:ALL +EOF + +# Ensure we can run Docker without sudo. +usermod -aG docker kbuilder + +# Ensure that we can access kvm. +usermod -aG kvm kbuilder + +# Ensure that /tmpfs exists and is writable by kokoro. +# +# Note that kokoro will typically attach a second disk (sdb) to the instance +# that is used for the /tmpfs volume. In the future we could setup an init +# script that formats and mounts this here; however, we don't expect our build +# artifacts to be that large. +mkdir -p /tmpfs && chmod 0777 /tmpfs && touch /tmpfs/READY diff --git a/tools/vm/ubuntu1604/BUILD b/tools/vm/ubuntu1604/BUILD new file mode 100644 index 000000000..ab1df0c4c --- /dev/null +++ b/tools/vm/ubuntu1604/BUILD @@ -0,0 +1,7 @@ +package(licenses = ["notice"]) + +filegroup( + name = "ubuntu1604", + srcs = glob(["*.sh"]), + visibility = ["//:sandbox"], +) diff --git a/tools/vm/ubuntu1804/BUILD b/tools/vm/ubuntu1804/BUILD new file mode 100644 index 000000000..0c8856dde --- /dev/null +++ b/tools/vm/ubuntu1804/BUILD @@ -0,0 +1,7 @@ +package(licenses = ["notice"]) + +alias( + name = "ubuntu1804", + actual = "//tools/vm/ubuntu1604", + visibility = ["//:sandbox"], +) diff --git a/tools/vm/zone.sh b/tools/vm/zone.sh new file mode 100755 index 000000000..79569fb19 --- /dev/null +++ b/tools/vm/zone.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec gcloud config get-value compute/zone diff --git a/tools/workspace_status.sh b/tools/workspace_status.sh new file mode 100755 index 000000000..a22c8c9f2 --- /dev/null +++ b/tools/workspace_status.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The STABLE_ prefix will trigger a re-link if it changes. +echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty || echo 0.0.0) |