summaryrefslogtreecommitdiffhomepage
path: root/tools/worker
diff options
context:
space:
mode:
Diffstat (limited to 'tools/worker')
-rw-r--r--tools/worker/BUILD21
-rw-r--r--tools/worker/worker.go325
2 files changed, 346 insertions, 0 deletions
diff --git a/tools/worker/BUILD b/tools/worker/BUILD
new file mode 100644
index 000000000..dc03ce11e
--- /dev/null
+++ b/tools/worker/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "bazel_worker_proto", "go_library")
+
+package(licenses = ["notice"])
+
+# For Google-tooling.
+# @unused
+glaze_ignore = [
+ "worker.go",
+]
+
+go_library(
+ name = "worker",
+ srcs = ["worker.go"],
+ visibility = ["//tools:__subpackages__"],
+ deps = [
+ bazel_worker_proto,
+ "@org_golang_google_protobuf//encoding/protowire:go_default_library",
+ "@org_golang_google_protobuf//proto:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/tools/worker/worker.go b/tools/worker/worker.go
new file mode 100644
index 000000000..669a5f203
--- /dev/null
+++ b/tools/worker/worker.go
@@ -0,0 +1,325 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package worker provides an implementation of the bazel worker protocol.
+//
+// Tools may be written as a normal command line utility, except the passed
+// run function may be invoked multiple times.
+package worker
+
+import (
+ "bufio"
+ "bytes"
+ "flag"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+ "time"
+
+ _ "net/http/pprof" // For profiling.
+
+ "golang.org/x/sys/unix"
+ "google.golang.org/protobuf/encoding/protowire"
+ "google.golang.org/protobuf/proto"
+ wpb "gvisor.dev/bazel/worker_protocol_go_proto"
+)
+
+var (
+ persistentWorker = flag.Bool("persistent_worker", false, "enable persistent worker.")
+ workerDebug = flag.Bool("worker_debug", false, "debug persistent workers.")
+ maximumCacheUsage = flag.Int64("maximum_cache_usage", 1024*1024*1024, "maximum cache size.")
+)
+
+var (
+ // inputFiles is the last set of input files.
+ //
+ // This is used for cache invalidation. The key is the *absolute* path
+ // name, and the value is the digest in the current run.
+ inputFiles = make(map[string]string)
+
+ // activeCaches is the set of active caches.
+ activeCaches = make(map[*Cache]struct{})
+
+ // totalCacheUsage is the total usage of all caches.
+ totalCacheUsage int64
+)
+
+// mustAbs returns the absolute path of a filename or dies.
+func mustAbs(filename string) string {
+ abs, err := filepath.Abs(filename)
+ if err != nil {
+ log.Fatalf("error getting absolute path: %v", err)
+ }
+ return abs
+}
+
+// updateInputFiles creates an entry in inputFiles.
+func updateInputFile(filename, digest string) {
+ inputFiles[mustAbs(filename)] = digest
+}
+
+// Sizer returns a size.
+type Sizer interface {
+ Size() int64
+}
+
+// CacheBytes is an example of a Sizer.
+type CacheBytes []byte
+
+// Size implements Sizer.Size.
+func (cb CacheBytes) Size() int64 {
+ return int64(len(cb))
+}
+
+// Cache is a worker cache.
+//
+// They can be created via NewCache.
+type Cache struct {
+ name string
+ entries map[string]Sizer
+ size int64
+ hits int64
+ misses int64
+}
+
+// NewCache returns a new cache.
+func NewCache(name string) *Cache {
+ return &Cache{
+ name: name,
+ }
+}
+
+// Lookup looks up an entry in the cache.
+//
+// It is a function of the given files.
+func (c *Cache) Lookup(filenames []string, generate func() Sizer) Sizer {
+ digests := make([]string, 0, len(filenames))
+ for _, filename := range filenames {
+ digest, ok := inputFiles[mustAbs(filename)]
+ if !ok {
+ // This is not a valid input. We may not be running as
+ // persistent worker in this cache. If that's the case,
+ // then the file's contents will not change across the
+ // run, and we just use the filename itself.
+ digest = filename
+ }
+ digests = append(digests, digest)
+ }
+
+ // Attempt the lookup.
+ sort.Slice(digests, func(i, j int) bool {
+ return digests[i] < digests[j]
+ })
+ cacheKey := strings.Join(digests, "+")
+ if c.entries == nil {
+ c.entries = make(map[string]Sizer)
+ activeCaches[c] = struct{}{}
+ }
+ entry, ok := c.entries[cacheKey]
+ if ok {
+ c.hits++
+ return entry
+ }
+
+ // Generate a new entry.
+ entry = generate()
+ c.misses++
+ c.entries[cacheKey] = entry
+ if entry != nil {
+ sz := entry.Size()
+ c.size += sz
+ totalCacheUsage += sz
+ }
+
+ // Check the capacity of all caches. If it greater than the maximum, we
+ // flush everything but still return this entry.
+ if totalCacheUsage > *maximumCacheUsage {
+ for entry, _ := range activeCaches {
+ // Drop all entries.
+ entry.size = 0
+ entry.entries = nil
+ }
+ totalCacheUsage = 0 // Reset.
+ }
+
+ return entry
+}
+
+// allCacheStats returns stats for all caches.
+func allCacheStats() string {
+ var sb strings.Builder
+ for entry, _ := range activeCaches {
+ ratio := float64(entry.hits) / float64(entry.hits+entry.misses)
+ fmt.Fprintf(&sb,
+ "% 10s: count: % 5d size: % 10d hits: % 7d misses: % 7d ratio: %2.2f\n",
+ entry.name, len(entry.entries), entry.size, entry.hits, entry.misses, ratio)
+ }
+ if len(activeCaches) > 0 {
+ fmt.Fprintf(&sb, "total: % 10d\n", totalCacheUsage)
+ }
+ return sb.String()
+}
+
+// LookupDigest returns a digest for the given file.
+func LookupDigest(filename string) (string, bool) {
+ digest, ok := inputFiles[filename]
+ return digest, ok
+}
+
+// Work invokes the main function.
+func Work(run func([]string) int) {
+ flag.CommandLine.Parse(os.Args[1:])
+ if !*persistentWorker {
+ // Handle the argument file.
+ args := flag.CommandLine.Args()
+ if len(args) == 1 && len(args[0]) > 1 && args[0][0] == '@' {
+ content, err := ioutil.ReadFile(args[0][1:])
+ if err != nil {
+ log.Fatalf("unable to parse args file: %v", err)
+ }
+ // Pull arguments from the file.
+ args = strings.Split(string(content), "\n")
+ flag.CommandLine.Parse(args)
+ args = flag.CommandLine.Args()
+ }
+ os.Exit(run(args))
+ }
+
+ var listenHeader string // Emitted always.
+ if *workerDebug {
+ // Bind a server for profiling.
+ listener, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ log.Fatalf("unable to bind a server: %v", err)
+ }
+ // Construct the header for stats output, below.
+ listenHeader = fmt.Sprintf("Listening @ http://localhost:%d\n", listener.Addr().(*net.TCPAddr).Port)
+ go http.Serve(listener, nil)
+ }
+
+ // Move stdout. This is done to prevent anything else from accidentally
+ // printing to stdout, which must contain only the valid WorkerResponse
+ // serialized protos.
+ newOutput, err := unix.Dup(1)
+ if err != nil {
+ log.Fatalf("unable to move stdout: %v", err)
+ }
+ // Stderr may be closed or may be a copy of stdout. We make sure that
+ // we have an output that is in a completely separate range.
+ for newOutput <= 2 {
+ newOutput, err = unix.Dup(newOutput)
+ if err != nil {
+ log.Fatalf("unable to move stdout: %v", err)
+ }
+ }
+
+ // Best-effort: collect logs.
+ rPipe, wPipe, err := os.Pipe()
+ if err != nil {
+ log.Fatalf("unable to create pipe: %v", err)
+ }
+ if err := unix.Dup2(int(wPipe.Fd()), 1); err != nil {
+ log.Fatalf("error duping over stdout: %v", err)
+ }
+ if err := unix.Dup2(int(wPipe.Fd()), 2); err != nil {
+ log.Fatalf("error duping over stderr: %v", err)
+ }
+ wPipe.Close()
+ defer rPipe.Close()
+
+ // Read requests from stdin.
+ input := bufio.NewReader(os.NewFile(0, "input"))
+ output := bufio.NewWriter(os.NewFile(uintptr(newOutput), "output"))
+ for {
+ szBuf, err := input.Peek(4)
+ if err != nil {
+ log.Fatalf("unabel to read header: %v", err)
+ }
+
+ // Parse the size, and discard bits.
+ sz, szBytes := protowire.ConsumeVarint(szBuf)
+ if szBytes < 0 {
+ szBytes = 0
+ }
+ if _, err := input.Discard(szBytes); err != nil {
+ log.Fatalf("error discarding size: %v", err)
+ }
+
+ // Read a full message.
+ msg := make([]byte, int(sz))
+ if _, err := io.ReadFull(input, msg); err != nil {
+ log.Fatalf("error reading worker request: %v", err)
+ }
+ var wreq wpb.WorkRequest
+ if err := proto.Unmarshal(msg, &wreq); err != nil {
+ log.Fatalf("error unmarshaling worker request: %v", err)
+ }
+
+ // Flush relevant caches.
+ inputFiles = make(map[string]string)
+ for _, input := range wreq.GetInputs() {
+ updateInputFile(input.GetPath(), string(input.GetDigest()))
+ }
+
+ // Prepare logging.
+ outputBuffer := bytes.NewBuffer(nil)
+ outputBuffer.WriteString(listenHeader)
+ log.SetOutput(outputBuffer)
+
+ // Parse all arguments.
+ flag.CommandLine.Parse(wreq.GetArguments())
+ var exitCode int
+ exitChan := make(chan int)
+ go func() { exitChan <- run(flag.CommandLine.Args()) }()
+ for running := true; running; {
+ select {
+ case exitCode = <-exitChan:
+ running = false
+ default:
+ }
+ // N.B. rPipe is given a read deadline of 1ms. We expect
+ // this to turn a copy error after 1ms, and we just keep
+ // flushing this buffer while the task is running.
+ rPipe.SetReadDeadline(time.Now().Add(time.Millisecond))
+ outputBuffer.ReadFrom(rPipe)
+ }
+
+ if *workerDebug {
+ // Attach all cache stats.
+ outputBuffer.WriteString(allCacheStats())
+ }
+
+ // Send the response.
+ var wresp wpb.WorkResponse
+ wresp.ExitCode = int32(exitCode)
+ wresp.Output = string(outputBuffer.Bytes())
+ rmsg, err := proto.Marshal(&wresp)
+ if err != nil {
+ log.Fatalf("error marshaling response: %v", err)
+ }
+ if _, err := output.Write(append(protowire.AppendVarint(nil, uint64(len(rmsg))), rmsg...)); err != nil {
+ log.Fatalf("error sending worker response: %v", err)
+ }
+ if err := output.Flush(); err != nil {
+ log.Fatalf("error flushing output: %v", err)
+ }
+ }
+}