diff options
Diffstat (limited to 'pkg/state/statefile')
-rw-r--r-- | pkg/state/statefile/BUILD | 3 | ||||
-rw-r--r-- | pkg/state/statefile/statefile.go | 11 | ||||
-rw-r--r-- | pkg/state/statefile/statefile_test.go | 70 |
3 files changed, 35 insertions, 49 deletions
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index 16abe1930..6be78dc9b 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -10,7 +10,6 @@ go_library( deps = [ "//pkg/binary", "//pkg/compressio", - "//pkg/hashio", ], ) @@ -19,5 +18,5 @@ go_test( size = "small", srcs = ["statefile_test.go"], embed = [":statefile"], - deps = ["//pkg/hashio"], + deps = ["//pkg/compressio"], ) diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go index 0b4eff8fa..9c86c1934 100644 --- a/pkg/state/statefile/statefile.go +++ b/pkg/state/statefile/statefile.go @@ -57,7 +57,6 @@ import ( "crypto/sha256" "gvisor.googlesource.com/gvisor/pkg/binary" "gvisor.googlesource.com/gvisor/pkg/compressio" - "gvisor.googlesource.com/gvisor/pkg/hashio" ) // keySize is the AES-256 key length. @@ -139,13 +138,11 @@ func NewWriter(w io.Writer, key []byte, metadata map[string]string) (io.WriteClo } } - w = hashio.NewWriter(w, h) - // Wrap in compression. We always use "best speed" mode here. When using // "best compression" mode, there is usually only a little gain in file // size reduction, which translate to even smaller gain in restore // latency reduction, while inccuring much more CPU usage at save time. - return compressio.NewWriter(w, compressionChunkSize, flate.BestSpeed) + return compressio.NewWriter(w, key, compressionChunkSize, flate.BestSpeed) } // MetadataUnsafe reads out the metadata from a state file without verifying any @@ -204,7 +201,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { return nil, err } if !hmac.Equal(cur, buf) { - return nil, hashio.ErrHashMismatch + return nil, compressio.ErrHashMismatch } } @@ -226,10 +223,8 @@ func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) { return nil, nil, err } - r = hashio.NewReader(r, h) - // Wrap in compression. - rc, err := compressio.NewReader(r) + rc, err := compressio.NewReader(r, key) if err != nil { return nil, nil, err } diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go index 66d9581ed..fa3fb9f2c 100644 --- a/pkg/state/statefile/statefile_test.go +++ b/pkg/state/statefile/statefile_test.go @@ -20,9 +20,11 @@ import ( "encoding/base64" "io" "math/rand" + "runtime" "testing" + "time" - "gvisor.googlesource.com/gvisor/pkg/hashio" + "gvisor.googlesource.com/gvisor/pkg/compressio" ) func randomKey() ([]byte, error) { @@ -42,6 +44,8 @@ type testCase struct { } func TestStatefile(t *testing.T) { + rand.Seed(time.Now().Unix()) + cases := []testCase{ // Various data sizes. {"nil", nil, nil}, @@ -59,13 +63,9 @@ func TestStatefile(t *testing.T) { // Make sure we have one longer than the hash length. {"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil}, - // Make sure we have one longer than the segment size. - {"segments", make([]byte, 3*hashio.SegmentSize), nil}, - {"segments minus one", make([]byte, 3*hashio.SegmentSize-1), nil}, - {"segments plus one", make([]byte, 3*hashio.SegmentSize+1), nil}, - {"segments minus hash", make([]byte, 3*hashio.SegmentSize-32), nil}, - {"segments plus hash", make([]byte, 3*hashio.SegmentSize+32), nil}, - {"large", make([]byte, 30*hashio.SegmentSize), nil}, + // Make sure we have one longer than the chunk size. + {"chunks", make([]byte, 3*compressionChunkSize), nil}, + {"large", make([]byte, 30*compressionChunkSize), nil}, // Different metadata. {"one metadata", []byte("data"), map[string]string{"foo": "bar"}}, @@ -130,27 +130,31 @@ func TestStatefile(t *testing.T) { } // Change the data and verify that it fails. - b := append([]byte(nil), bufEncoded.Bytes()...) - b[rand.Intn(len(b))]++ - r, _, err = NewReader(bytes.NewReader(b), key) - if err == nil { - _, err = io.Copy(&bufDecoded, r) - } - if err == nil { - t.Error("got no error: expected error on data corruption") + if key != nil { + b := append([]byte(nil), bufEncoded.Bytes()...) + b[rand.Intn(len(b))]++ + bufDecoded.Reset() + r, _, err = NewReader(bytes.NewReader(b), key) + if err == nil { + _, err = io.Copy(&bufDecoded, r) + } + if err == nil { + t.Error("got no error: expected error on data corruption") + } } // Change the key and verify that it fails. - if key == nil { - key = integrityKey - } else { - key[rand.Intn(len(key))]++ + newKey := integrityKey + if len(key) > 0 { + newKey = append([]byte{}, key...) + newKey[rand.Intn(len(newKey))]++ } - r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), key) + bufDecoded.Reset() + r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), newKey) if err == nil { _, err = io.Copy(&bufDecoded, r) } - if err != hashio.ErrHashMismatch { + if err != compressio.ErrHashMismatch { t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err) } }) @@ -159,7 +163,7 @@ func TestStatefile(t *testing.T) { } } -const benchmarkDataSize = 10 * 1024 * 1024 +const benchmarkDataSize = 100 * 1024 * 1024 func benchmark(b *testing.B, size int, write bool, compressible bool) { b.StopTimer() @@ -249,14 +253,6 @@ func benchmark(b *testing.B, size int, write bool, compressible bool) { } } -func BenchmarkWrite1BCompressible(b *testing.B) { - benchmark(b, 1, true, true) -} - -func BenchmarkWrite1BNoncompressible(b *testing.B) { - benchmark(b, 1, true, false) -} - func BenchmarkWrite4KCompressible(b *testing.B) { benchmark(b, 4096, true, true) } @@ -273,14 +269,6 @@ func BenchmarkWrite1MNoncompressible(b *testing.B) { benchmark(b, 1024*1024, true, false) } -func BenchmarkRead1BCompressible(b *testing.B) { - benchmark(b, 1, false, true) -} - -func BenchmarkRead1BNoncompressible(b *testing.B) { - benchmark(b, 1, false, false) -} - func BenchmarkRead4KCompressible(b *testing.B) { benchmark(b, 4096, false, true) } @@ -296,3 +284,7 @@ func BenchmarkRead1MCompressible(b *testing.B) { func BenchmarkRead1MNoncompressible(b *testing.B) { benchmark(b, 1024*1024, false, false) } + +func init() { + runtime.GOMAXPROCS(runtime.NumCPU()) +} |