// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package worker provides an implementation of the bazel worker protocol. // // Tools may be written as a normal command line utility, except the passed // run function may be invoked multiple times. package worker import ( "bufio" "bytes" "flag" "fmt" "io" "io/ioutil" "log" "net" "net/http" "os" "path/filepath" "sort" "strings" "time" _ "net/http/pprof" // For profiling. "golang.org/x/sys/unix" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" wpb "gvisor.dev/bazel/worker_protocol_go_proto" ) var ( persistentWorker = flag.Bool("persistent_worker", false, "enable persistent worker.") workerDebug = flag.Bool("worker_debug", false, "debug persistent workers.") maximumCacheUsage = flag.Int64("maximum_cache_usage", 1024*1024*1024, "maximum cache size.") ) var ( // inputFiles is the last set of input files. // // This is used for cache invalidation. The key is the *absolute* path // name, and the value is the digest in the current run. inputFiles = make(map[string]string) // activeCaches is the set of active caches. activeCaches = make(map[*Cache]struct{}) // totalCacheUsage is the total usage of all caches. totalCacheUsage int64 ) // mustAbs returns the absolute path of a filename or dies. func mustAbs(filename string) string { abs, err := filepath.Abs(filename) if err != nil { log.Fatalf("error getting absolute path: %v", err) } return abs } // updateInputFiles creates an entry in inputFiles. func updateInputFile(filename, digest string) { inputFiles[mustAbs(filename)] = digest } // Sizer returns a size. type Sizer interface { Size() int64 } // CacheBytes is an example of a Sizer. type CacheBytes []byte // Size implements Sizer.Size. func (cb CacheBytes) Size() int64 { return int64(len(cb)) } // Cache is a worker cache. // // They can be created via NewCache. type Cache struct { name string entries map[string]Sizer size int64 hits int64 misses int64 } // NewCache returns a new cache. func NewCache(name string) *Cache { return &Cache{ name: name, } } // Lookup looks up an entry in the cache. // // It is a function of the given files. func (c *Cache) Lookup(filenames []string, generate func() Sizer) Sizer { digests := make([]string, 0, len(filenames)) for _, filename := range filenames { digest, ok := inputFiles[mustAbs(filename)] if !ok { // This is not a valid input. We may not be running as // persistent worker in this cache. If that's the case, // then the file's contents will not change across the // run, and we just use the filename itself. digest = filename } digests = append(digests, digest) } // Attempt the lookup. sort.Slice(digests, func(i, j int) bool { return digests[i] < digests[j] }) cacheKey := strings.Join(digests, "+") if c.entries == nil { c.entries = make(map[string]Sizer) activeCaches[c] = struct{}{} } entry, ok := c.entries[cacheKey] if ok { c.hits++ return entry } // Generate a new entry. entry = generate() c.misses++ c.entries[cacheKey] = entry if entry != nil { sz := entry.Size() c.size += sz totalCacheUsage += sz } // Check the capacity of all caches. If it greater than the maximum, we // flush everything but still return this entry. if totalCacheUsage > *maximumCacheUsage { for entry, _ := range activeCaches { // Drop all entries. entry.size = 0 entry.entries = nil } totalCacheUsage = 0 // Reset. } return entry } // allCacheStats returns stats for all caches. func allCacheStats() string { var sb strings.Builder for entry, _ := range activeCaches { ratio := float64(entry.hits) / float64(entry.hits+entry.misses) fmt.Fprintf(&sb, "% 10s: count: % 5d size: % 10d hits: % 7d misses: % 7d ratio: %2.2f\n", entry.name, len(entry.entries), entry.size, entry.hits, entry.misses, ratio) } if len(activeCaches) > 0 { fmt.Fprintf(&sb, "total: % 10d\n", totalCacheUsage) } return sb.String() } // LookupDigest returns a digest for the given file. func LookupDigest(filename string) (string, bool) { digest, ok := inputFiles[filename] return digest, ok } // Work invokes the main function. func Work(run func([]string) int) { flag.CommandLine.Parse(os.Args[1:]) if !*persistentWorker { // Handle the argument file. args := flag.CommandLine.Args() if len(args) == 1 && len(args[0]) > 1 && args[0][0] == '@' { content, err := ioutil.ReadFile(args[0][1:]) if err != nil { log.Fatalf("unable to parse args file: %v", err) } // Pull arguments from the file. args = strings.Split(string(content), "\n") flag.CommandLine.Parse(args) args = flag.CommandLine.Args() } os.Exit(run(args)) } var listenHeader string // Emitted always. if *workerDebug { // Bind a server for profiling. listener, err := net.Listen("tcp", "localhost:0") if err != nil { log.Fatalf("unable to bind a server: %v", err) } // Construct the header for stats output, below. listenHeader = fmt.Sprintf("Listening @ http://localhost:%d\n", listener.Addr().(*net.TCPAddr).Port) go http.Serve(listener, nil) } // Move stdout. This is done to prevent anything else from accidentally // printing to stdout, which must contain only the valid WorkerResponse // serialized protos. newOutput, err := unix.Dup(1) if err != nil { log.Fatalf("unable to move stdout: %v", err) } // Stderr may be closed or may be a copy of stdout. We make sure that // we have an output that is in a completely separate range. for newOutput <= 2 { newOutput, err = unix.Dup(newOutput) if err != nil { log.Fatalf("unable to move stdout: %v", err) } } // Best-effort: collect logs. rPipe, wPipe, err := os.Pipe() if err != nil { log.Fatalf("unable to create pipe: %v", err) } if err := unix.Dup2(int(wPipe.Fd()), 1); err != nil { log.Fatalf("error duping over stdout: %v", err) } if err := unix.Dup2(int(wPipe.Fd()), 2); err != nil { log.Fatalf("error duping over stderr: %v", err) } wPipe.Close() defer rPipe.Close() // Read requests from stdin. input := bufio.NewReader(os.NewFile(0, "input")) output := bufio.NewWriter(os.NewFile(uintptr(newOutput), "output")) for { szBuf, err := input.Peek(4) if err != nil { log.Fatalf("unabel to read header: %v", err) } // Parse the size, and discard bits. sz, szBytes := protowire.ConsumeVarint(szBuf) if szBytes < 0 { szBytes = 0 } if _, err := input.Discard(szBytes); err != nil { log.Fatalf("error discarding size: %v", err) } // Read a full message. msg := make([]byte, int(sz)) if _, err := io.ReadFull(input, msg); err != nil { log.Fatalf("error reading worker request: %v", err) } var wreq wpb.WorkRequest if err := proto.Unmarshal(msg, &wreq); err != nil { log.Fatalf("error unmarshaling worker request: %v", err) } // Flush relevant caches. inputFiles = make(map[string]string) for _, input := range wreq.GetInputs() { updateInputFile(input.GetPath(), string(input.GetDigest())) } // Prepare logging. outputBuffer := bytes.NewBuffer(nil) outputBuffer.WriteString(listenHeader) log.SetOutput(outputBuffer) // Parse all arguments. flag.CommandLine.Parse(wreq.GetArguments()) var exitCode int exitChan := make(chan int) go func() { exitChan <- run(flag.CommandLine.Args()) }() for running := true; running; { select { case exitCode = <-exitChan: running = false default: } // N.B. rPipe is given a read deadline of 1ms. We expect // this to turn a copy error after 1ms, and we just keep // flushing this buffer while the task is running. rPipe.SetReadDeadline(time.Now().Add(time.Millisecond)) outputBuffer.ReadFrom(rPipe) } if *workerDebug { // Attach all cache stats. outputBuffer.WriteString(allCacheStats()) } // Send the response. var wresp wpb.WorkResponse wresp.ExitCode = int32(exitCode) wresp.Output = string(outputBuffer.Bytes()) rmsg, err := proto.Marshal(&wresp) if err != nil { log.Fatalf("error marshaling response: %v", err) } if _, err := output.Write(append(protowire.AppendVarint(nil, uint64(len(rmsg))), rmsg...)); err != nil { log.Fatalf("error sending worker response: %v", err) } if err := output.Flush(); err != nil { log.Fatalf("error flushing output: %v", err) } } }