summaryrefslogtreecommitdiffhomepage
path: root/test/benchmarks/ml/tensorflow_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'test/benchmarks/ml/tensorflow_test.go')
-rw-r--r--test/benchmarks/ml/tensorflow_test.go12
1 files changed, 10 insertions, 2 deletions
diff --git a/test/benchmarks/ml/tensorflow_test.go b/test/benchmarks/ml/tensorflow_test.go
index f7746897d..a55329d82 100644
--- a/test/benchmarks/ml/tensorflow_test.go
+++ b/test/benchmarks/ml/tensorflow_test.go
@@ -15,12 +15,15 @@ package ml
import (
"context"
+ "os"
"testing"
"gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/test/benchmarks/harness"
)
+var h harness.Harness
+
// BenchmarkTensorflow runs workloads from a TensorFlow tutorial.
// See: https://github.com/aymericdamien/TensorFlow-Examples
func BenchmarkTensorflow(b *testing.B) {
@@ -44,12 +47,12 @@ func BenchmarkTensorflow(b *testing.B) {
for name, workload := range workloads {
b.Run(name, func(b *testing.B) {
ctx := context.Background()
- container := machine.GetContainer(ctx, b)
- defer container.CleanUp(ctx)
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
+ container := machine.GetContainer(ctx, b)
+ defer container.CleanUp(ctx)
if err := harness.DropCaches(machine); err != nil {
b.Skipf("failed to drop caches: %v. You probably need root.", err)
}
@@ -67,3 +70,8 @@ func BenchmarkTensorflow(b *testing.B) {
}
}
+
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}