summaryrefslogtreecommitdiffhomepage
path: root/pkg/state/statefile
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/state/statefile')
-rw-r--r--pkg/state/statefile/BUILD23
-rw-r--r--pkg/state/statefile/statefile.go233
-rw-r--r--pkg/state/statefile/statefile_test.go299
3 files changed, 555 insertions, 0 deletions
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
new file mode 100644
index 000000000..df2c6a578
--- /dev/null
+++ b/pkg/state/statefile/BUILD
@@ -0,0 +1,23 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "statefile",
+ srcs = ["statefile.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/state/statefile",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/binary",
+ "//pkg/compressio",
+ "//pkg/hashio",
+ ],
+)
+
+go_test(
+ name = "statefile_test",
+ size = "small",
+ srcs = ["statefile_test.go"],
+ embed = [":statefile"],
+ deps = ["//pkg/hashio"],
+)
diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go
new file mode 100644
index 000000000..b25b743b7
--- /dev/null
+++ b/pkg/state/statefile/statefile.go
@@ -0,0 +1,233 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package statefile defines the state file data stream.
+//
+// This package currently does not include any details regarding the state
+// encoding itself, only details regarding state metadata and data layout.
+//
+// The file format is defined as follows.
+//
+// /------------------------------------------------------\
+// | header (8-bytes) |
+// +------------------------------------------------------+
+// | metadata length (8-bytes) |
+// +------------------------------------------------------+
+// | metadata |
+// +------------------------------------------------------+
+// | data |
+// \------------------------------------------------------/
+//
+// First, it includes a 8-byte magic header which is the following
+// sequence of bytes [0x67, 0x56, 0x69, 0x73, 0x6f, 0x72, 0x53, 0x46]
+//
+// This header is followed by an 8-byte length N (big endian), and an
+// ASCII-encoded JSON map that is exactly N bytes long.
+//
+// This map includes only strings for keys and strings for values. Keys in the
+// map that begin with "_" are for internal use only. They may be read, but may
+// not be provided by the user. In the future, this metadata may contain some
+// information relating to the state encoding itself.
+//
+// After the map, the remainder of the file is the state data.
+package statefile
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/json"
+ "fmt"
+ "hash"
+ "io"
+ "strings"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/compressio"
+ "gvisor.googlesource.com/gvisor/pkg/hashio"
+)
+
+// keySize is the AES-256 key length.
+const keySize = 32
+
+// compressionChunkSize is the chunk size for compression.
+const compressionChunkSize = 1024 * 1024
+
+// maxMetadataSize is the size limit of metadata section.
+const maxMetadataSize = 16 * 1024 * 1024
+
+// magicHeader is the byte sequence beginning each file.
+var magicHeader = []byte("\x67\x56\x69\x73\x6f\x72\x53\x46")
+
+// ErrBadMagic is returned if the header does not match.
+var ErrBadMagic = fmt.Errorf("bad magic header")
+
+// ErrMetadataMissing is returned if the state file is missing mandatory metadata.
+var ErrMetadataMissing = fmt.Errorf("missing metadata")
+
+// ErrInvalidMetadataLength is returned if the metadata length is too large.
+var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size is %d", maxMetadataSize)
+
+// ErrMetadataInvalid is returned if passed metadata is invalid.
+var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _")
+
+// NewWriter returns a state data writer for a statefile.
+//
+// Note that the returned WriteCloser must be closed.
+func NewWriter(w io.Writer, key []byte, metadata map[string]string, compressionLevel int) (io.WriteCloser, error) {
+ if metadata == nil {
+ metadata = make(map[string]string)
+ }
+ for k := range metadata {
+ if strings.HasPrefix(k, "_") {
+ return nil, ErrMetadataInvalid
+ }
+ }
+
+ // Create our HMAC function.
+ h := hmac.New(sha256.New, key)
+ mw := io.MultiWriter(w, h)
+
+ // First, write the header.
+ if _, err := mw.Write(magicHeader); err != nil {
+ return nil, err
+ }
+
+ // Generate a timestamp, for convenience only.
+ metadata["_timestamp"] = time.Now().UTC().String()
+ defer delete(metadata, "_timestamp")
+
+ // Write the metadata.
+ b, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(b) > maxMetadataSize {
+ return nil, ErrInvalidMetadataLength
+ }
+
+ // Metadata length.
+ if err := binary.WriteUint64(mw, binary.BigEndian, uint64(len(b))); err != nil {
+ return nil, err
+ }
+ // Metadata bytes; io.MultiWriter will return a short write error if
+ // any of the writers returns < n.
+ if _, err := mw.Write(b); err != nil {
+ return nil, err
+ }
+ // Write the current hash.
+ cur := h.Sum(nil)
+ for done := 0; done < len(cur); {
+ n, err := mw.Write(cur[done:])
+ done += n
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ w = hashio.NewWriter(w, h)
+
+ // Wrap in compression.
+ return compressio.NewWriter(w, compressionChunkSize, compressionLevel)
+}
+
+// MetadataUnsafe reads out the metadata from a state file without verifying any
+// HMAC. This function shouldn't be called for untrusted input files.
+func MetadataUnsafe(r io.Reader) (map[string]string, error) {
+ return metadata(r, nil)
+}
+
+// metadata validates the magic header and reads out the metadata from a state
+// data stream.
+func metadata(r io.Reader, h hash.Hash) (map[string]string, error) {
+ if h != nil {
+ r = io.TeeReader(r, h)
+ }
+
+ // Read and validate magic header.
+ b := make([]byte, len(magicHeader))
+ if _, err := r.Read(b); err != nil {
+ return nil, err
+ }
+ if !bytes.Equal(b, magicHeader) {
+ return nil, ErrBadMagic
+ }
+
+ // Read and validate metadata.
+ b, err := func() (b []byte, err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ b = nil
+ err = fmt.Errorf("%v", r)
+ }
+ }()
+
+ metadataLen, err := binary.ReadUint64(r, binary.BigEndian)
+ if err != nil {
+ return nil, err
+ }
+ if metadataLen > maxMetadataSize {
+ return nil, ErrInvalidMetadataLength
+ }
+ b = make([]byte, int(metadataLen))
+ if _, err := io.ReadFull(r, b); err != nil {
+ return nil, err
+ }
+ return b, nil
+ }()
+ if err != nil {
+ return nil, err
+ }
+
+ if h != nil {
+ // Check the hash prior to decoding.
+ cur := h.Sum(nil)
+ buf := make([]byte, len(cur))
+ if _, err := io.ReadFull(r, buf); err != nil {
+ return nil, err
+ }
+ if !hmac.Equal(cur, buf) {
+ return nil, hashio.ErrHashMismatch
+ }
+ }
+
+ // Decode the metadata.
+ metadata := make(map[string]string)
+ if err := json.Unmarshal(b, &metadata); err != nil {
+ return nil, err
+ }
+
+ return metadata, nil
+}
+
+// NewReader returns a reader for a statefile.
+func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) {
+ // Read the metadata with the hash.
+ h := hmac.New(sha256.New, key)
+ metadata, err := metadata(r, h)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ r = hashio.NewReader(r, h)
+
+ // Wrap in compression.
+ rc, err := compressio.NewReader(r)
+ if err != nil {
+ return nil, nil, err
+ }
+ return rc, metadata, nil
+}
diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go
new file mode 100644
index 000000000..6e67b51de
--- /dev/null
+++ b/pkg/state/statefile/statefile_test.go
@@ -0,0 +1,299 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package statefile
+
+import (
+ "bytes"
+ "compress/flate"
+ crand "crypto/rand"
+ "encoding/base64"
+ "io"
+ "math/rand"
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/hashio"
+)
+
+func randomKey() ([]byte, error) {
+ r := make([]byte, base64.RawStdEncoding.DecodedLen(keySize))
+ if _, err := io.ReadFull(crand.Reader, r); err != nil {
+ return nil, err
+ }
+ key := make([]byte, keySize)
+ base64.RawStdEncoding.Encode(key, r)
+ return key, nil
+}
+
+type testCase struct {
+ name string
+ data []byte
+ metadata map[string]string
+}
+
+func TestStatefile(t *testing.T) {
+ cases := []testCase{
+ // Various data sizes.
+ {"nil", nil, nil},
+ {"empty", []byte(""), nil},
+ {"some", []byte("_"), nil},
+ {"one", []byte("0"), nil},
+ {"two", []byte("01"), nil},
+ {"three", []byte("012"), nil},
+ {"four", []byte("0123"), nil},
+ {"five", []byte("01234"), nil},
+ {"six", []byte("012356"), nil},
+ {"seven", []byte("0123567"), nil},
+ {"eight", []byte("01235678"), nil},
+
+ // Make sure we have one longer than the hash length.
+ {"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil},
+
+ // Make sure we have one longer than the segment size.
+ {"segments", make([]byte, 3*hashio.SegmentSize), nil},
+ {"segments minus one", make([]byte, 3*hashio.SegmentSize-1), nil},
+ {"segments plus one", make([]byte, 3*hashio.SegmentSize+1), nil},
+ {"segments minus hash", make([]byte, 3*hashio.SegmentSize-32), nil},
+ {"segments plus hash", make([]byte, 3*hashio.SegmentSize+32), nil},
+ {"large", make([]byte, 30*hashio.SegmentSize), nil},
+
+ // Different metadata.
+ {"one metadata", []byte("data"), map[string]string{"foo": "bar"}},
+ {"two metadata", []byte("data"), map[string]string{"foo": "bar", "one": "two"}},
+ }
+
+ for _, c := range cases {
+ // Generate a key.
+ integrityKey, err := randomKey()
+ if err != nil {
+ t.Errorf("can't generate key: got %v, excepted nil", err)
+ continue
+ }
+
+ t.Run(c.name, func(t *testing.T) {
+ for _, key := range [][]byte{nil, integrityKey} {
+ t.Run("key="+string(key), func(t *testing.T) {
+ // Encoding happens via a buffer.
+ var bufEncoded bytes.Buffer
+ var bufDecoded bytes.Buffer
+
+ // Do all the writing.
+ w, err := NewWriter(&bufEncoded, key, c.metadata, flate.BestSpeed)
+ if err != nil {
+ t.Fatalf("error creating writer: got %v, expected nil", err)
+ }
+ if _, err := io.Copy(w, bytes.NewBuffer(c.data)); err != nil {
+ t.Fatalf("error during write: got %v, expected nil", err)
+ }
+
+ // Finish the sum.
+ if err := w.Close(); err != nil {
+ t.Fatalf("error during close: got %v, expected nil", err)
+ }
+
+ t.Logf("original data: %d bytes, encoded: %d bytes.",
+ len(c.data), len(bufEncoded.Bytes()))
+
+ // Do all the reading.
+ r, metadata, err := NewReader(bytes.NewReader(bufEncoded.Bytes()), key)
+ if err != nil {
+ t.Fatalf("error creating reader: got %v, expected nil", err)
+ }
+ if _, err := io.Copy(&bufDecoded, r); err != nil {
+ t.Fatalf("error during read: got %v, expected nil", err)
+ }
+
+ // Check that the data matches.
+ if !bytes.Equal(c.data, bufDecoded.Bytes()) {
+ t.Fatalf("data didn't match (%d vs %d bytes)", len(bufDecoded.Bytes()), len(c.data))
+ }
+
+ // Check that the metadata matches.
+ for k, v := range c.metadata {
+ nv, ok := metadata[k]
+ if !ok {
+ t.Fatalf("missing metadata: %s", k)
+ }
+ if v != nv {
+ t.Fatalf("mismatched metdata for %s: got %s, expected %s", k, nv, v)
+ }
+ }
+
+ // Change the data and verify that it fails.
+ b := append([]byte(nil), bufEncoded.Bytes()...)
+ b[rand.Intn(len(b))]++
+ r, _, err = NewReader(bytes.NewReader(b), key)
+ if err == nil {
+ _, err = io.Copy(&bufDecoded, r)
+ }
+ if err == nil {
+ t.Error("got no error: expected error on data corruption")
+ }
+
+ // Change the key and verify that it fails.
+ if key == nil {
+ key = integrityKey
+ } else {
+ key[rand.Intn(len(key))]++
+ }
+ r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), key)
+ if err == nil {
+ _, err = io.Copy(&bufDecoded, r)
+ }
+ if err != hashio.ErrHashMismatch {
+ t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err)
+ }
+ })
+ }
+ })
+ }
+}
+
+const benchmarkDataSize = 10 * 1024 * 1024
+
+func benchmark(b *testing.B, size int, write bool, compressible bool) {
+ b.StopTimer()
+ b.SetBytes(benchmarkDataSize)
+
+ // Generate source data.
+ var source []byte
+ if compressible {
+ // For compressible data, we use essentially all zeros.
+ source = make([]byte, benchmarkDataSize)
+ } else {
+ // For non-compressible data, we use random base64 data (to
+ // make it marginally compressible, a ratio of 75%).
+ var sourceBuf bytes.Buffer
+ bufW := base64.NewEncoder(base64.RawStdEncoding, &sourceBuf)
+ bufR := rand.New(rand.NewSource(0))
+ if _, err := io.CopyN(bufW, bufR, benchmarkDataSize); err != nil {
+ b.Fatalf("unable to seed random data: %v", err)
+ }
+ source = sourceBuf.Bytes()
+ }
+
+ // Generate a random key for integrity check.
+ key, err := randomKey()
+ if err != nil {
+ b.Fatalf("error generating key: %v", err)
+ }
+
+ // Define our benchmark functions. Prior to running the readState
+ // function here, you must execute the writeState function at least
+ // once (done below).
+ var stateBuf bytes.Buffer
+ writeState := func() {
+ stateBuf.Reset()
+ w, err := NewWriter(&stateBuf, key, nil, flate.BestSpeed)
+ if err != nil {
+ b.Fatalf("error creating writer: %v", err)
+ }
+ for done := 0; done < len(source); {
+ chunk := size // limit size.
+ if done+chunk > len(source) {
+ chunk = len(source) - done
+ }
+ n, err := w.Write(source[done : done+chunk])
+ done += n
+ if n == 0 && err != nil {
+ b.Fatalf("error during write: %v", err)
+ }
+ }
+ if err := w.Close(); err != nil {
+ b.Fatalf("error closing writer: %v", err)
+ }
+ }
+ readState := func() {
+ tmpBuf := bytes.NewBuffer(stateBuf.Bytes())
+ r, _, err := NewReader(tmpBuf, key)
+ if err != nil {
+ b.Fatalf("error creating reader: %v", err)
+ }
+ for done := 0; done < len(source); {
+ chunk := size // limit size.
+ if done+chunk > len(source) {
+ chunk = len(source) - done
+ }
+ n, err := r.Read(source[done : done+chunk])
+ done += n
+ if n == 0 && err != nil {
+ b.Fatalf("error during read: %v", err)
+ }
+ }
+ }
+ // Generate the state once without timing to ensure that buffers have
+ // been appropriately allocated.
+ writeState()
+ if write {
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ writeState()
+ }
+ b.StopTimer()
+ } else {
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ readState()
+ }
+ b.StopTimer()
+ }
+}
+
+func BenchmarkWrite1BCompressible(b *testing.B) {
+ benchmark(b, 1, true, true)
+}
+
+func BenchmarkWrite1BNoncompressible(b *testing.B) {
+ benchmark(b, 1, true, false)
+}
+
+func BenchmarkWrite4KCompressible(b *testing.B) {
+ benchmark(b, 4096, true, true)
+}
+
+func BenchmarkWrite4KNoncompressible(b *testing.B) {
+ benchmark(b, 4096, true, false)
+}
+
+func BenchmarkWrite1MCompressible(b *testing.B) {
+ benchmark(b, 1024*1024, true, true)
+}
+
+func BenchmarkWrite1MNoncompressible(b *testing.B) {
+ benchmark(b, 1024*1024, true, false)
+}
+
+func BenchmarkRead1BCompressible(b *testing.B) {
+ benchmark(b, 1, false, true)
+}
+
+func BenchmarkRead1BNoncompressible(b *testing.B) {
+ benchmark(b, 1, false, false)
+}
+
+func BenchmarkRead4KCompressible(b *testing.B) {
+ benchmark(b, 4096, false, true)
+}
+
+func BenchmarkRead4KNoncompressible(b *testing.B) {
+ benchmark(b, 4096, false, false)
+}
+
+func BenchmarkRead1MCompressible(b *testing.B) {
+ benchmark(b, 1024*1024, false, true)
+}
+
+func BenchmarkRead1MNoncompressible(b *testing.B) {
+ benchmark(b, 1024*1024, false, false)
+}