diff options
Diffstat (limited to 'pkg/state/statefile')
-rw-r--r-- | pkg/state/statefile/BUILD | 23 | ||||
-rw-r--r-- | pkg/state/statefile/statefile.go | 233 | ||||
-rw-r--r-- | pkg/state/statefile/statefile_test.go | 299 |
3 files changed, 555 insertions, 0 deletions
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD new file mode 100644 index 000000000..df2c6a578 --- /dev/null +++ b/pkg/state/statefile/BUILD @@ -0,0 +1,23 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "statefile", + srcs = ["statefile.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/state/statefile", + visibility = ["//:sandbox"], + deps = [ + "//pkg/binary", + "//pkg/compressio", + "//pkg/hashio", + ], +) + +go_test( + name = "statefile_test", + size = "small", + srcs = ["statefile_test.go"], + embed = [":statefile"], + deps = ["//pkg/hashio"], +) diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go new file mode 100644 index 000000000..b25b743b7 --- /dev/null +++ b/pkg/state/statefile/statefile.go @@ -0,0 +1,233 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package statefile defines the state file data stream. +// +// This package currently does not include any details regarding the state +// encoding itself, only details regarding state metadata and data layout. +// +// The file format is defined as follows. +// +// /------------------------------------------------------\ +// | header (8-bytes) | +// +------------------------------------------------------+ +// | metadata length (8-bytes) | +// +------------------------------------------------------+ +// | metadata | +// +------------------------------------------------------+ +// | data | +// \------------------------------------------------------/ +// +// First, it includes a 8-byte magic header which is the following +// sequence of bytes [0x67, 0x56, 0x69, 0x73, 0x6f, 0x72, 0x53, 0x46] +// +// This header is followed by an 8-byte length N (big endian), and an +// ASCII-encoded JSON map that is exactly N bytes long. +// +// This map includes only strings for keys and strings for values. Keys in the +// map that begin with "_" are for internal use only. They may be read, but may +// not be provided by the user. In the future, this metadata may contain some +// information relating to the state encoding itself. +// +// After the map, the remainder of the file is the state data. +package statefile + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/json" + "fmt" + "hash" + "io" + "strings" + "time" + + "gvisor.googlesource.com/gvisor/pkg/binary" + "gvisor.googlesource.com/gvisor/pkg/compressio" + "gvisor.googlesource.com/gvisor/pkg/hashio" +) + +// keySize is the AES-256 key length. +const keySize = 32 + +// compressionChunkSize is the chunk size for compression. +const compressionChunkSize = 1024 * 1024 + +// maxMetadataSize is the size limit of metadata section. +const maxMetadataSize = 16 * 1024 * 1024 + +// magicHeader is the byte sequence beginning each file. +var magicHeader = []byte("\x67\x56\x69\x73\x6f\x72\x53\x46") + +// ErrBadMagic is returned if the header does not match. +var ErrBadMagic = fmt.Errorf("bad magic header") + +// ErrMetadataMissing is returned if the state file is missing mandatory metadata. +var ErrMetadataMissing = fmt.Errorf("missing metadata") + +// ErrInvalidMetadataLength is returned if the metadata length is too large. +var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size is %d", maxMetadataSize) + +// ErrMetadataInvalid is returned if passed metadata is invalid. +var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _") + +// NewWriter returns a state data writer for a statefile. +// +// Note that the returned WriteCloser must be closed. +func NewWriter(w io.Writer, key []byte, metadata map[string]string, compressionLevel int) (io.WriteCloser, error) { + if metadata == nil { + metadata = make(map[string]string) + } + for k := range metadata { + if strings.HasPrefix(k, "_") { + return nil, ErrMetadataInvalid + } + } + + // Create our HMAC function. + h := hmac.New(sha256.New, key) + mw := io.MultiWriter(w, h) + + // First, write the header. + if _, err := mw.Write(magicHeader); err != nil { + return nil, err + } + + // Generate a timestamp, for convenience only. + metadata["_timestamp"] = time.Now().UTC().String() + defer delete(metadata, "_timestamp") + + // Write the metadata. + b, err := json.Marshal(metadata) + if err != nil { + return nil, err + } + + if len(b) > maxMetadataSize { + return nil, ErrInvalidMetadataLength + } + + // Metadata length. + if err := binary.WriteUint64(mw, binary.BigEndian, uint64(len(b))); err != nil { + return nil, err + } + // Metadata bytes; io.MultiWriter will return a short write error if + // any of the writers returns < n. + if _, err := mw.Write(b); err != nil { + return nil, err + } + // Write the current hash. + cur := h.Sum(nil) + for done := 0; done < len(cur); { + n, err := mw.Write(cur[done:]) + done += n + if err != nil { + return nil, err + } + } + + w = hashio.NewWriter(w, h) + + // Wrap in compression. + return compressio.NewWriter(w, compressionChunkSize, compressionLevel) +} + +// MetadataUnsafe reads out the metadata from a state file without verifying any +// HMAC. This function shouldn't be called for untrusted input files. +func MetadataUnsafe(r io.Reader) (map[string]string, error) { + return metadata(r, nil) +} + +// metadata validates the magic header and reads out the metadata from a state +// data stream. +func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { + if h != nil { + r = io.TeeReader(r, h) + } + + // Read and validate magic header. + b := make([]byte, len(magicHeader)) + if _, err := r.Read(b); err != nil { + return nil, err + } + if !bytes.Equal(b, magicHeader) { + return nil, ErrBadMagic + } + + // Read and validate metadata. + b, err := func() (b []byte, err error) { + defer func() { + if r := recover(); r != nil { + b = nil + err = fmt.Errorf("%v", r) + } + }() + + metadataLen, err := binary.ReadUint64(r, binary.BigEndian) + if err != nil { + return nil, err + } + if metadataLen > maxMetadataSize { + return nil, ErrInvalidMetadataLength + } + b = make([]byte, int(metadataLen)) + if _, err := io.ReadFull(r, b); err != nil { + return nil, err + } + return b, nil + }() + if err != nil { + return nil, err + } + + if h != nil { + // Check the hash prior to decoding. + cur := h.Sum(nil) + buf := make([]byte, len(cur)) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + if !hmac.Equal(cur, buf) { + return nil, hashio.ErrHashMismatch + } + } + + // Decode the metadata. + metadata := make(map[string]string) + if err := json.Unmarshal(b, &metadata); err != nil { + return nil, err + } + + return metadata, nil +} + +// NewReader returns a reader for a statefile. +func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) { + // Read the metadata with the hash. + h := hmac.New(sha256.New, key) + metadata, err := metadata(r, h) + if err != nil { + return nil, nil, err + } + + r = hashio.NewReader(r, h) + + // Wrap in compression. + rc, err := compressio.NewReader(r) + if err != nil { + return nil, nil, err + } + return rc, metadata, nil +} diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go new file mode 100644 index 000000000..6e67b51de --- /dev/null +++ b/pkg/state/statefile/statefile_test.go @@ -0,0 +1,299 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package statefile + +import ( + "bytes" + "compress/flate" + crand "crypto/rand" + "encoding/base64" + "io" + "math/rand" + "testing" + + "gvisor.googlesource.com/gvisor/pkg/hashio" +) + +func randomKey() ([]byte, error) { + r := make([]byte, base64.RawStdEncoding.DecodedLen(keySize)) + if _, err := io.ReadFull(crand.Reader, r); err != nil { + return nil, err + } + key := make([]byte, keySize) + base64.RawStdEncoding.Encode(key, r) + return key, nil +} + +type testCase struct { + name string + data []byte + metadata map[string]string +} + +func TestStatefile(t *testing.T) { + cases := []testCase{ + // Various data sizes. + {"nil", nil, nil}, + {"empty", []byte(""), nil}, + {"some", []byte("_"), nil}, + {"one", []byte("0"), nil}, + {"two", []byte("01"), nil}, + {"three", []byte("012"), nil}, + {"four", []byte("0123"), nil}, + {"five", []byte("01234"), nil}, + {"six", []byte("012356"), nil}, + {"seven", []byte("0123567"), nil}, + {"eight", []byte("01235678"), nil}, + + // Make sure we have one longer than the hash length. + {"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil}, + + // Make sure we have one longer than the segment size. + {"segments", make([]byte, 3*hashio.SegmentSize), nil}, + {"segments minus one", make([]byte, 3*hashio.SegmentSize-1), nil}, + {"segments plus one", make([]byte, 3*hashio.SegmentSize+1), nil}, + {"segments minus hash", make([]byte, 3*hashio.SegmentSize-32), nil}, + {"segments plus hash", make([]byte, 3*hashio.SegmentSize+32), nil}, + {"large", make([]byte, 30*hashio.SegmentSize), nil}, + + // Different metadata. + {"one metadata", []byte("data"), map[string]string{"foo": "bar"}}, + {"two metadata", []byte("data"), map[string]string{"foo": "bar", "one": "two"}}, + } + + for _, c := range cases { + // Generate a key. + integrityKey, err := randomKey() + if err != nil { + t.Errorf("can't generate key: got %v, excepted nil", err) + continue + } + + t.Run(c.name, func(t *testing.T) { + for _, key := range [][]byte{nil, integrityKey} { + t.Run("key="+string(key), func(t *testing.T) { + // Encoding happens via a buffer. + var bufEncoded bytes.Buffer + var bufDecoded bytes.Buffer + + // Do all the writing. + w, err := NewWriter(&bufEncoded, key, c.metadata, flate.BestSpeed) + if err != nil { + t.Fatalf("error creating writer: got %v, expected nil", err) + } + if _, err := io.Copy(w, bytes.NewBuffer(c.data)); err != nil { + t.Fatalf("error during write: got %v, expected nil", err) + } + + // Finish the sum. + if err := w.Close(); err != nil { + t.Fatalf("error during close: got %v, expected nil", err) + } + + t.Logf("original data: %d bytes, encoded: %d bytes.", + len(c.data), len(bufEncoded.Bytes())) + + // Do all the reading. + r, metadata, err := NewReader(bytes.NewReader(bufEncoded.Bytes()), key) + if err != nil { + t.Fatalf("error creating reader: got %v, expected nil", err) + } + if _, err := io.Copy(&bufDecoded, r); err != nil { + t.Fatalf("error during read: got %v, expected nil", err) + } + + // Check that the data matches. + if !bytes.Equal(c.data, bufDecoded.Bytes()) { + t.Fatalf("data didn't match (%d vs %d bytes)", len(bufDecoded.Bytes()), len(c.data)) + } + + // Check that the metadata matches. + for k, v := range c.metadata { + nv, ok := metadata[k] + if !ok { + t.Fatalf("missing metadata: %s", k) + } + if v != nv { + t.Fatalf("mismatched metdata for %s: got %s, expected %s", k, nv, v) + } + } + + // Change the data and verify that it fails. + b := append([]byte(nil), bufEncoded.Bytes()...) + b[rand.Intn(len(b))]++ + r, _, err = NewReader(bytes.NewReader(b), key) + if err == nil { + _, err = io.Copy(&bufDecoded, r) + } + if err == nil { + t.Error("got no error: expected error on data corruption") + } + + // Change the key and verify that it fails. + if key == nil { + key = integrityKey + } else { + key[rand.Intn(len(key))]++ + } + r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), key) + if err == nil { + _, err = io.Copy(&bufDecoded, r) + } + if err != hashio.ErrHashMismatch { + t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err) + } + }) + } + }) + } +} + +const benchmarkDataSize = 10 * 1024 * 1024 + +func benchmark(b *testing.B, size int, write bool, compressible bool) { + b.StopTimer() + b.SetBytes(benchmarkDataSize) + + // Generate source data. + var source []byte + if compressible { + // For compressible data, we use essentially all zeros. + source = make([]byte, benchmarkDataSize) + } else { + // For non-compressible data, we use random base64 data (to + // make it marginally compressible, a ratio of 75%). + var sourceBuf bytes.Buffer + bufW := base64.NewEncoder(base64.RawStdEncoding, &sourceBuf) + bufR := rand.New(rand.NewSource(0)) + if _, err := io.CopyN(bufW, bufR, benchmarkDataSize); err != nil { + b.Fatalf("unable to seed random data: %v", err) + } + source = sourceBuf.Bytes() + } + + // Generate a random key for integrity check. + key, err := randomKey() + if err != nil { + b.Fatalf("error generating key: %v", err) + } + + // Define our benchmark functions. Prior to running the readState + // function here, you must execute the writeState function at least + // once (done below). + var stateBuf bytes.Buffer + writeState := func() { + stateBuf.Reset() + w, err := NewWriter(&stateBuf, key, nil, flate.BestSpeed) + if err != nil { + b.Fatalf("error creating writer: %v", err) + } + for done := 0; done < len(source); { + chunk := size // limit size. + if done+chunk > len(source) { + chunk = len(source) - done + } + n, err := w.Write(source[done : done+chunk]) + done += n + if n == 0 && err != nil { + b.Fatalf("error during write: %v", err) + } + } + if err := w.Close(); err != nil { + b.Fatalf("error closing writer: %v", err) + } + } + readState := func() { + tmpBuf := bytes.NewBuffer(stateBuf.Bytes()) + r, _, err := NewReader(tmpBuf, key) + if err != nil { + b.Fatalf("error creating reader: %v", err) + } + for done := 0; done < len(source); { + chunk := size // limit size. + if done+chunk > len(source) { + chunk = len(source) - done + } + n, err := r.Read(source[done : done+chunk]) + done += n + if n == 0 && err != nil { + b.Fatalf("error during read: %v", err) + } + } + } + // Generate the state once without timing to ensure that buffers have + // been appropriately allocated. + writeState() + if write { + b.StartTimer() + for i := 0; i < b.N; i++ { + writeState() + } + b.StopTimer() + } else { + b.StartTimer() + for i := 0; i < b.N; i++ { + readState() + } + b.StopTimer() + } +} + +func BenchmarkWrite1BCompressible(b *testing.B) { + benchmark(b, 1, true, true) +} + +func BenchmarkWrite1BNoncompressible(b *testing.B) { + benchmark(b, 1, true, false) +} + +func BenchmarkWrite4KCompressible(b *testing.B) { + benchmark(b, 4096, true, true) +} + +func BenchmarkWrite4KNoncompressible(b *testing.B) { + benchmark(b, 4096, true, false) +} + +func BenchmarkWrite1MCompressible(b *testing.B) { + benchmark(b, 1024*1024, true, true) +} + +func BenchmarkWrite1MNoncompressible(b *testing.B) { + benchmark(b, 1024*1024, true, false) +} + +func BenchmarkRead1BCompressible(b *testing.B) { + benchmark(b, 1, false, true) +} + +func BenchmarkRead1BNoncompressible(b *testing.B) { + benchmark(b, 1, false, false) +} + +func BenchmarkRead4KCompressible(b *testing.B) { + benchmark(b, 4096, false, true) +} + +func BenchmarkRead4KNoncompressible(b *testing.B) { + benchmark(b, 4096, false, false) +} + +func BenchmarkRead1MCompressible(b *testing.B) { + benchmark(b, 1024*1024, false, true) +} + +func BenchmarkRead1MNoncompressible(b *testing.B) { + benchmark(b, 1024*1024, false, false) +} |