diff options
Diffstat (limited to 'tools/worker/worker.go')
-rw-r--r-- | tools/worker/worker.go | 325 |
1 files changed, 325 insertions, 0 deletions
diff --git a/tools/worker/worker.go b/tools/worker/worker.go new file mode 100644 index 000000000..669a5f203 --- /dev/null +++ b/tools/worker/worker.go @@ -0,0 +1,325 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package worker provides an implementation of the bazel worker protocol. +// +// Tools may be written as a normal command line utility, except the passed +// run function may be invoked multiple times. +package worker + +import ( + "bufio" + "bytes" + "flag" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "os" + "path/filepath" + "sort" + "strings" + "time" + + _ "net/http/pprof" // For profiling. + + "golang.org/x/sys/unix" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + wpb "gvisor.dev/bazel/worker_protocol_go_proto" +) + +var ( + persistentWorker = flag.Bool("persistent_worker", false, "enable persistent worker.") + workerDebug = flag.Bool("worker_debug", false, "debug persistent workers.") + maximumCacheUsage = flag.Int64("maximum_cache_usage", 1024*1024*1024, "maximum cache size.") +) + +var ( + // inputFiles is the last set of input files. + // + // This is used for cache invalidation. The key is the *absolute* path + // name, and the value is the digest in the current run. + inputFiles = make(map[string]string) + + // activeCaches is the set of active caches. + activeCaches = make(map[*Cache]struct{}) + + // totalCacheUsage is the total usage of all caches. + totalCacheUsage int64 +) + +// mustAbs returns the absolute path of a filename or dies. +func mustAbs(filename string) string { + abs, err := filepath.Abs(filename) + if err != nil { + log.Fatalf("error getting absolute path: %v", err) + } + return abs +} + +// updateInputFiles creates an entry in inputFiles. +func updateInputFile(filename, digest string) { + inputFiles[mustAbs(filename)] = digest +} + +// Sizer returns a size. +type Sizer interface { + Size() int64 +} + +// CacheBytes is an example of a Sizer. +type CacheBytes []byte + +// Size implements Sizer.Size. +func (cb CacheBytes) Size() int64 { + return int64(len(cb)) +} + +// Cache is a worker cache. +// +// They can be created via NewCache. +type Cache struct { + name string + entries map[string]Sizer + size int64 + hits int64 + misses int64 +} + +// NewCache returns a new cache. +func NewCache(name string) *Cache { + return &Cache{ + name: name, + } +} + +// Lookup looks up an entry in the cache. +// +// It is a function of the given files. +func (c *Cache) Lookup(filenames []string, generate func() Sizer) Sizer { + digests := make([]string, 0, len(filenames)) + for _, filename := range filenames { + digest, ok := inputFiles[mustAbs(filename)] + if !ok { + // This is not a valid input. We may not be running as + // persistent worker in this cache. If that's the case, + // then the file's contents will not change across the + // run, and we just use the filename itself. + digest = filename + } + digests = append(digests, digest) + } + + // Attempt the lookup. + sort.Slice(digests, func(i, j int) bool { + return digests[i] < digests[j] + }) + cacheKey := strings.Join(digests, "+") + if c.entries == nil { + c.entries = make(map[string]Sizer) + activeCaches[c] = struct{}{} + } + entry, ok := c.entries[cacheKey] + if ok { + c.hits++ + return entry + } + + // Generate a new entry. + entry = generate() + c.misses++ + c.entries[cacheKey] = entry + if entry != nil { + sz := entry.Size() + c.size += sz + totalCacheUsage += sz + } + + // Check the capacity of all caches. If it greater than the maximum, we + // flush everything but still return this entry. + if totalCacheUsage > *maximumCacheUsage { + for entry, _ := range activeCaches { + // Drop all entries. + entry.size = 0 + entry.entries = nil + } + totalCacheUsage = 0 // Reset. + } + + return entry +} + +// allCacheStats returns stats for all caches. +func allCacheStats() string { + var sb strings.Builder + for entry, _ := range activeCaches { + ratio := float64(entry.hits) / float64(entry.hits+entry.misses) + fmt.Fprintf(&sb, + "% 10s: count: % 5d size: % 10d hits: % 7d misses: % 7d ratio: %2.2f\n", + entry.name, len(entry.entries), entry.size, entry.hits, entry.misses, ratio) + } + if len(activeCaches) > 0 { + fmt.Fprintf(&sb, "total: % 10d\n", totalCacheUsage) + } + return sb.String() +} + +// LookupDigest returns a digest for the given file. +func LookupDigest(filename string) (string, bool) { + digest, ok := inputFiles[filename] + return digest, ok +} + +// Work invokes the main function. +func Work(run func([]string) int) { + flag.CommandLine.Parse(os.Args[1:]) + if !*persistentWorker { + // Handle the argument file. + args := flag.CommandLine.Args() + if len(args) == 1 && len(args[0]) > 1 && args[0][0] == '@' { + content, err := ioutil.ReadFile(args[0][1:]) + if err != nil { + log.Fatalf("unable to parse args file: %v", err) + } + // Pull arguments from the file. + args = strings.Split(string(content), "\n") + flag.CommandLine.Parse(args) + args = flag.CommandLine.Args() + } + os.Exit(run(args)) + } + + var listenHeader string // Emitted always. + if *workerDebug { + // Bind a server for profiling. + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + log.Fatalf("unable to bind a server: %v", err) + } + // Construct the header for stats output, below. + listenHeader = fmt.Sprintf("Listening @ http://localhost:%d\n", listener.Addr().(*net.TCPAddr).Port) + go http.Serve(listener, nil) + } + + // Move stdout. This is done to prevent anything else from accidentally + // printing to stdout, which must contain only the valid WorkerResponse + // serialized protos. + newOutput, err := unix.Dup(1) + if err != nil { + log.Fatalf("unable to move stdout: %v", err) + } + // Stderr may be closed or may be a copy of stdout. We make sure that + // we have an output that is in a completely separate range. + for newOutput <= 2 { + newOutput, err = unix.Dup(newOutput) + if err != nil { + log.Fatalf("unable to move stdout: %v", err) + } + } + + // Best-effort: collect logs. + rPipe, wPipe, err := os.Pipe() + if err != nil { + log.Fatalf("unable to create pipe: %v", err) + } + if err := unix.Dup2(int(wPipe.Fd()), 1); err != nil { + log.Fatalf("error duping over stdout: %v", err) + } + if err := unix.Dup2(int(wPipe.Fd()), 2); err != nil { + log.Fatalf("error duping over stderr: %v", err) + } + wPipe.Close() + defer rPipe.Close() + + // Read requests from stdin. + input := bufio.NewReader(os.NewFile(0, "input")) + output := bufio.NewWriter(os.NewFile(uintptr(newOutput), "output")) + for { + szBuf, err := input.Peek(4) + if err != nil { + log.Fatalf("unabel to read header: %v", err) + } + + // Parse the size, and discard bits. + sz, szBytes := protowire.ConsumeVarint(szBuf) + if szBytes < 0 { + szBytes = 0 + } + if _, err := input.Discard(szBytes); err != nil { + log.Fatalf("error discarding size: %v", err) + } + + // Read a full message. + msg := make([]byte, int(sz)) + if _, err := io.ReadFull(input, msg); err != nil { + log.Fatalf("error reading worker request: %v", err) + } + var wreq wpb.WorkRequest + if err := proto.Unmarshal(msg, &wreq); err != nil { + log.Fatalf("error unmarshaling worker request: %v", err) + } + + // Flush relevant caches. + inputFiles = make(map[string]string) + for _, input := range wreq.GetInputs() { + updateInputFile(input.GetPath(), string(input.GetDigest())) + } + + // Prepare logging. + outputBuffer := bytes.NewBuffer(nil) + outputBuffer.WriteString(listenHeader) + log.SetOutput(outputBuffer) + + // Parse all arguments. + flag.CommandLine.Parse(wreq.GetArguments()) + var exitCode int + exitChan := make(chan int) + go func() { exitChan <- run(flag.CommandLine.Args()) }() + for running := true; running; { + select { + case exitCode = <-exitChan: + running = false + default: + } + // N.B. rPipe is given a read deadline of 1ms. We expect + // this to turn a copy error after 1ms, and we just keep + // flushing this buffer while the task is running. + rPipe.SetReadDeadline(time.Now().Add(time.Millisecond)) + outputBuffer.ReadFrom(rPipe) + } + + if *workerDebug { + // Attach all cache stats. + outputBuffer.WriteString(allCacheStats()) + } + + // Send the response. + var wresp wpb.WorkResponse + wresp.ExitCode = int32(exitCode) + wresp.Output = string(outputBuffer.Bytes()) + rmsg, err := proto.Marshal(&wresp) + if err != nil { + log.Fatalf("error marshaling response: %v", err) + } + if _, err := output.Write(append(protowire.AppendVarint(nil, uint64(len(rmsg))), rmsg...)); err != nil { + log.Fatalf("error sending worker response: %v", err) + } + if err := output.Flush(); err != nil { + log.Fatalf("error flushing output: %v", err) + } + } +} |