diff options
Diffstat (limited to 'test/benchmarks/ml')
-rw-r--r-- | test/benchmarks/ml/BUILD | 10 | ||||
-rw-r--r-- | test/benchmarks/ml/ml.go | 15 | ||||
-rw-r--r-- | test/benchmarks/ml/tensorflow_test.go | 19 |
3 files changed, 19 insertions, 25 deletions
diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD index 970f52706..285ec35d9 100644 --- a/test/benchmarks/ml/BUILD +++ b/test/benchmarks/ml/BUILD @@ -1,4 +1,5 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library") +load("//test/benchmarks:defs.bzl", "benchmark_test") package(licenses = ["notice"]) @@ -6,12 +7,11 @@ go_library( name = "ml", testonly = 1, srcs = ["ml.go"], - deps = ["//test/benchmarks/harness"], ) -go_test( - name = "ml_test", - size = "large", +benchmark_test( + name = "tensorflow_test", + size = "enormous", srcs = ["tensorflow_test.go"], library = ":ml", visibility = ["//:sandbox"], diff --git a/test/benchmarks/ml/ml.go b/test/benchmarks/ml/ml.go index 13282d7bb..d5fc5b7da 100644 --- a/test/benchmarks/ml/ml.go +++ b/test/benchmarks/ml/ml.go @@ -14,18 +14,3 @@ // Package ml holds benchmarks around machine learning performance. package ml - -import ( - "os" - "testing" - - "gvisor.dev/gvisor/test/benchmarks/harness" -) - -var h harness.Harness - -// TestMain is the main method for package ml. -func TestMain(m *testing.M) { - h.Init() - os.Exit(m.Run()) -} diff --git a/test/benchmarks/ml/tensorflow_test.go b/test/benchmarks/ml/tensorflow_test.go index f7746897d..b0e0c4720 100644 --- a/test/benchmarks/ml/tensorflow_test.go +++ b/test/benchmarks/ml/tensorflow_test.go @@ -15,6 +15,7 @@ package ml import ( "context" + "os" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" @@ -35,7 +36,7 @@ func BenchmarkTensorflow(b *testing.B) { "NeuralNetwork": "3_NeuralNetworks/neural_network.py", } - machine, err := h.GetMachine() + machine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine: %v", err) } @@ -44,17 +45,19 @@ func BenchmarkTensorflow(b *testing.B) { for name, workload := range workloads { b.Run(name, func(b *testing.B) { ctx := context.Background() - container := machine.GetContainer(ctx, b) - defer container.CleanUp(ctx) b.ResetTimer() + b.StopTimer() + for i := 0; i < b.N; i++ { - b.StopTimer() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) if err := harness.DropCaches(machine); err != nil { b.Skipf("failed to drop caches: %v. You probably need root.", err) } - b.StartTimer() + // Run tensorflow. + b.StartTimer() if out, err := container.Run(ctx, dockerutil.RunOpts{ Image: "benchmarks/tensorflow", Env: []string{"PYTHONPATH=$PYTHONPATH:/TensorFlow-Examples/examples"}, @@ -62,8 +65,14 @@ func BenchmarkTensorflow(b *testing.B) { }, "python", workload); err != nil { b.Fatalf("failed to run container: %v logs: %s", err, out) } + b.StopTimer() } }) } +} +func TestMain(m *testing.M) { + harness.Init() + harness.SetFixedBenchmarks() + os.Exit(m.Run()) } |