summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorZhaozhong Ni <nzz@google.com>2018-08-24 14:52:23 -0700
committerShentubot <shentubot@google.com>2018-08-24 14:53:31 -0700
commita6b00502b04ced2f12cfcf35c6f276cff349737b (patch)
treed443ea0679091b193bcc5568f0aa5aff3ba1a0f3 /pkg
parent02dfceab6d4c4a2a3342ef69be0265b7ab03e5d7 (diff)
compressio: support optional hashing and eliminate hashio.
Compared to previous compressio / hashio nesting, there is up to 100% speedup. PiperOrigin-RevId: 210161269 Change-Id: I481aa9fe980bb817fe465fe34d32ea33fc8abf1c
Diffstat (limited to 'pkg')
-rw-r--r--pkg/compressio/compressio.go223
-rw-r--r--pkg/compressio/compressio_test.go145
-rw-r--r--pkg/hashio/BUILD19
-rw-r--r--pkg/hashio/hashio.go296
-rw-r--r--pkg/hashio/hashio_test.go142
-rw-r--r--pkg/state/statefile/BUILD3
-rw-r--r--pkg/state/statefile/statefile.go11
-rw-r--r--pkg/state/statefile/statefile_test.go70
8 files changed, 339 insertions, 570 deletions
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go
index ef8cbd2a5..591b37130 100644
--- a/pkg/compressio/compressio.go
+++ b/pkg/compressio/compressio.go
@@ -12,17 +12,48 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package compressio provides parallel compression and decompression.
+// Package compressio provides parallel compression and decompression, as well
+// as optional SHA-256 hashing.
+//
+// The stream format is defined as follows.
+//
+// /------------------------------------------------------\
+// | chunk size (4-bytes) |
+// +------------------------------------------------------+
+// | (optional) hash (32-bytes) |
+// +------------------------------------------------------+
+// | compressed data size (4-bytes) |
+// +------------------------------------------------------+
+// | compressed data |
+// +------------------------------------------------------+
+// | (optional) hash (32-bytes) |
+// +------------------------------------------------------+
+// | compressed data size (4-bytes) |
+// +------------------------------------------------------+
+// | ...... |
+// \------------------------------------------------------/
+//
+// where each subsequent hash is calculated from the following items in order
+//
+// compressed data
+// compressed data size
+// previous hash
+//
+// so the stream integrity cannot be compromised by switching and mixing
+// compressed chunks.
package compressio
import (
"bytes"
"compress/flate"
"errors"
+ "hash"
"io"
"runtime"
"sync"
+ "crypto/hmac"
+ "crypto/sha256"
"gvisor.googlesource.com/gvisor/pkg/binary"
)
@@ -51,12 +82,23 @@ type chunk struct {
// This is not returned to the bufPool automatically, since it may
// correspond to a inline slice (provided directly to Read or Write).
uncompressed *bytes.Buffer
+
+ // The current hash object. Only used in compress mode.
+ h hash.Hash
+
+ // The hash from previous chunks. Only used in uncompress mode.
+ lastSum []byte
+
+ // The expected hash after current chunk. Only used in uncompress mode.
+ sum []byte
}
// newChunk allocates a new chunk object (or pulls one from the pool). Buffers
// will be allocated if nil is provided for compressed or uncompressed.
-func newChunk(compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk {
+func newChunk(lastSum []byte, sum []byte, compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk {
c := chunkPool.Get().(*chunk)
+ c.lastSum = lastSum
+ c.sum = sum
if compressed != nil {
c.compressed = compressed
} else {
@@ -85,6 +127,7 @@ type result struct {
// The goroutine will exit when input is closed, and the goroutine will close
// output.
type worker struct {
+ pool *pool
input chan *chunk
output chan result
}
@@ -93,17 +136,27 @@ type worker struct {
func (w *worker) work(compress bool, level int) {
defer close(w.output)
+ var h hash.Hash
+
for c := range w.input {
+ if h == nil && w.pool.key != nil {
+ h = w.pool.getHash()
+ }
if compress {
+ mw := io.Writer(c.compressed)
+ if h != nil {
+ mw = io.MultiWriter(mw, h)
+ }
+
// Encode this slice.
- fw, err := flate.NewWriter(c.compressed, level)
+ fw, err := flate.NewWriter(mw, level)
if err != nil {
w.output <- result{c, err}
continue
}
// Encode the input.
- if _, err := io.Copy(fw, c.uncompressed); err != nil {
+ if _, err := io.CopyN(fw, c.uncompressed, int64(c.uncompressed.Len())); err != nil {
w.output <- result{c, err}
continue
}
@@ -111,7 +164,28 @@ func (w *worker) work(compress bool, level int) {
w.output <- result{c, err}
continue
}
+
+ // Write the hash, if enabled.
+ if h != nil {
+ binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len()))
+ c.h = h
+ h = nil
+ }
} else {
+ // Check the hash of the compressed contents.
+ if h != nil {
+ h.Write(c.compressed.Bytes())
+ binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len()))
+ io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum)))
+
+ sum := h.Sum(nil)
+ h.Reset()
+ if !hmac.Equal(c.sum, sum) {
+ w.output <- result{c, ErrHashMismatch}
+ continue
+ }
+ }
+
// Decode this slice.
fr := flate.NewReader(c.compressed)
@@ -136,6 +210,16 @@ type pool struct {
// stream and is shared across both the reader and writer.
chunkSize uint32
+ // key is the key used to create hash objects.
+ key []byte
+
+ // hashMu protexts the hash list.
+ hashMu sync.Mutex
+
+ // hashes is the hash object free list. Note that this cannot be
+ // globally shared across readers or writers, as it is key-specific.
+ hashes []hash.Hash
+
// mu protects below; it is generally the responsibility of users to
// acquire this mutex before calling any methods on the pool.
mu sync.Mutex
@@ -149,15 +233,20 @@ type pool struct {
// buf is the current active buffer; the exact semantics of this buffer
// depending on whether this is a reader or a writer.
buf *bytes.Buffer
+
+ // lasSum records the hash of the last chunk processed.
+ lastSum []byte
}
// init initializes the worker pool.
//
// This should only be called once.
-func (p *pool) init(compress bool, level int) {
- p.workers = make([]worker, 1+runtime.GOMAXPROCS(0))
+func (p *pool) init(key []byte, workers int, compress bool, level int) {
+ p.key = key
+ p.workers = make([]worker, workers)
for i := 0; i < len(p.workers); i++ {
p.workers[i] = worker{
+ pool: p,
input: make(chan *chunk, 1),
output: make(chan result, 1),
}
@@ -174,6 +263,30 @@ func (p *pool) stop() {
p.workers = nil
}
+// getHash gets a hash object for the pool. It should only be called when the
+// pool key is non-nil.
+func (p *pool) getHash() hash.Hash {
+ p.hashMu.Lock()
+ defer p.hashMu.Unlock()
+
+ if len(p.hashes) == 0 {
+ return hmac.New(sha256.New, p.key)
+ }
+
+ h := p.hashes[len(p.hashes)-1]
+ p.hashes = p.hashes[:len(p.hashes)-1]
+ return h
+}
+
+func (p *pool) putHash(h hash.Hash) {
+ h.Reset()
+
+ p.hashMu.Lock()
+ defer p.hashMu.Unlock()
+
+ p.hashes = append(p.hashes, h)
+}
+
// handleResult calls the callback.
func handleResult(r result, callback func(*chunk) error) error {
defer func() {
@@ -231,22 +344,46 @@ type reader struct {
in io.Reader
}
-// NewReader returns a new compressed reader.
-func NewReader(in io.Reader) (io.Reader, error) {
+// NewReader returns a new compressed reader. If key is non-nil, the data stream
+// is assumed to contain expected hash values, which will be compared against
+// hash values computed from the compressed bytes. See package comments for
+// details.
+func NewReader(in io.Reader, key []byte) (io.Reader, error) {
r := &reader{
in: in,
}
- r.init(false, 0)
+
+ // Use double buffering for read.
+ r.init(key, 2*runtime.GOMAXPROCS(0), false, 0)
+
var err error
- if r.chunkSize, err = binary.ReadUint32(r.in, binary.BigEndian); err != nil {
+ if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil {
return nil, err
}
+
+ if r.key != nil {
+ h := r.getHash()
+ binary.WriteUint32(h, binary.BigEndian, r.chunkSize)
+ r.lastSum = h.Sum(nil)
+ r.putHash(h)
+ sum := make([]byte, len(r.lastSum))
+ if _, err := io.ReadFull(r.in, sum); err != nil {
+ return nil, err
+ }
+ if !hmac.Equal(r.lastSum, sum) {
+ return nil, ErrHashMismatch
+ }
+ }
+
return r, nil
}
// errNewBuffer is returned when a new buffer is completed.
var errNewBuffer = errors.New("buffer ready")
+// ErrHashMismatch is returned if the hash does not match.
+var ErrHashMismatch = errors.New("hash mismatch")
+
// Read implements io.Reader.Read.
func (r *reader) Read(p []byte) (int, error) {
r.mu.Lock()
@@ -331,14 +468,25 @@ func (r *reader) Read(p []byte) (int, error) {
// Read this chunk and schedule decompression.
compressed := bufPool.Get().(*bytes.Buffer)
- if _, err := io.Copy(compressed, &io.LimitedReader{
- R: r.in,
- N: int64(l),
- }); err != nil {
+ if _, err := io.CopyN(compressed, r.in, int64(l)); err != nil {
// Some other error occurred; see above.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
return done, err
}
+ var sum []byte
+ if r.key != nil {
+ sum = make([]byte, len(r.lastSum))
+ if _, err := io.ReadFull(r.in, sum); err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return done, err
+ }
+ }
+
// Are we doing inline decoding?
//
// Note that we need to check the length here against
@@ -349,11 +497,12 @@ func (r *reader) Read(p []byte) (int, error) {
var c *chunk
start := done + ((pendingPre + pendingInline) * int(r.chunkSize))
if len(p) >= start+int(r.chunkSize) && len(p) >= start+bytes.MinRead {
- c = newChunk(compressed, bytes.NewBuffer(p[start:start]))
+ c = newChunk(r.lastSum, sum, compressed, bytes.NewBuffer(p[start:start]))
pendingInline++
} else {
- c = newChunk(compressed, nil)
+ c = newChunk(r.lastSum, sum, compressed, nil)
}
+ r.lastSum = sum
if err := r.schedule(c, callback); err == errNewBuffer {
// A new buffer was completed while we were reading.
// That's great, but we need to force schedule the
@@ -403,12 +552,14 @@ type writer struct {
closed bool
}
-// NewWriter returns a new compressed writer.
+// NewWriter returns a new compressed writer. If key is non-nil, hash values are
+// generated and written out for compressed bytes. See package comments for
+// details.
//
// The recommended chunkSize is on the order of 1M. Extra memory may be
// buffered (in the form of read-ahead, or buffered writes), and is limited to
// O(chunkSize * [1+GOMAXPROCS]).
-func NewWriter(out io.Writer, chunkSize uint32, level int) (io.WriteCloser, error) {
+func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.WriteCloser, error) {
w := &writer{
pool: pool{
chunkSize: chunkSize,
@@ -416,10 +567,22 @@ func NewWriter(out io.Writer, chunkSize uint32, level int) (io.WriteCloser, erro
},
out: out,
}
- w.init(true, level)
+ w.init(key, 1+runtime.GOMAXPROCS(0), true, level)
+
if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil {
return nil, err
}
+
+ if w.key != nil {
+ h := w.getHash()
+ binary.WriteUint32(h, binary.BigEndian, chunkSize)
+ w.lastSum = h.Sum(nil)
+ w.putHash(h)
+ if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil {
+ return nil, err
+ }
+ }
+
return w, nil
}
@@ -433,8 +596,22 @@ func (w *writer) flush(c *chunk) error {
}
// Write out to the stream.
- _, err := io.Copy(w.out, c.compressed)
- return err
+ if _, err := io.CopyN(w.out, c.compressed, int64(c.compressed.Len())); err != nil {
+ return err
+ }
+
+ if w.key != nil {
+ io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum)))
+ sum := c.h.Sum(nil)
+ w.putHash(c.h)
+ c.h = nil
+ if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil {
+ return err
+ }
+ w.lastSum = sum
+ }
+
+ return nil
}
// Write implements io.Writer.Write.
@@ -480,7 +657,7 @@ func (w *writer) Write(p []byte) (int, error) {
// immediately following the inline case above.
left := int(w.chunkSize) - w.buf.Len()
if left == 0 {
- if err := w.schedule(newChunk(nil, w.buf), callback); err != nil {
+ if err := w.schedule(newChunk(nil, nil, nil, w.buf), callback); err != nil {
return done, err
}
// Reset the buffer, since this has now been scheduled
@@ -538,7 +715,7 @@ func (w *writer) Close() error {
// Schedule any remaining partial buffer; we pass w.flush directly here
// because the final buffer is guaranteed to not be an inline buffer.
if w.buf.Len() > 0 {
- if err := w.schedule(newChunk(nil, w.buf), w.flush); err != nil {
+ if err := w.schedule(newChunk(nil, nil, nil, w.buf), w.flush); err != nil {
return err
}
}
diff --git a/pkg/compressio/compressio_test.go b/pkg/compressio/compressio_test.go
index d7911419d..7cb5f8dc4 100644
--- a/pkg/compressio/compressio_test.go
+++ b/pkg/compressio/compressio_test.go
@@ -59,6 +59,7 @@ type testOpts struct {
PostDecompress func()
CompressIters int
DecompressIters int
+ CorruptData bool
}
func doTest(t harness, opts testOpts) {
@@ -104,15 +105,22 @@ func doTest(t harness, opts testOpts) {
if opts.DecompressIters <= 0 {
opts.DecompressIters = 1
}
+ if opts.CorruptData {
+ b := compressed.Bytes()
+ b[rand.Intn(len(b))]++
+ }
for i := 0; i < opts.DecompressIters; i++ {
decompressed.Reset()
r, err := opts.NewReader(bytes.NewBuffer(compressed.Bytes()))
if err != nil {
+ if opts.CorruptData {
+ continue
+ }
t.Errorf("%s: NewReader got err %v, expected nil", opts.Name, err)
return
}
- if _, err := io.Copy(&decompressed, r); err != nil {
- t.Errorf("%s: decompress got err %v, expected nil", opts.Name, err)
+ if _, err := io.Copy(&decompressed, r); (err != nil) != opts.CorruptData {
+ t.Errorf("%s: decompress got err %v unexpectly", opts.Name, err)
return
}
}
@@ -121,6 +129,10 @@ func doTest(t harness, opts testOpts) {
}
decompressionTime := time.Since(decompressionStartTime)
+ if opts.CorruptData {
+ return
+ }
+
// Verify.
if decompressed.Len() != len(opts.Data) {
t.Errorf("%s: got %d bytes, expected %d", opts.Name, decompressed.Len(), len(opts.Data))
@@ -136,7 +148,11 @@ func doTest(t harness, opts testOpts) {
opts.Name, compressionTime, compressionRatio, decompressionTime)
}
+var hashKey = []byte("01234567890123456789012345678901")
+
func TestCompress(t *testing.T) {
+ rand.Seed(time.Now().Unix())
+
var (
data = initTest(t, 10*1024*1024)
data0 = data[:0]
@@ -153,17 +169,27 @@ func TestCompress(t *testing.T) {
continue
}
- // Do the compress test.
- doTest(t, testOpts{
- Name: fmt.Sprintf("len(data)=%d, blockSize=%d", len(data), blockSize),
- Data: data,
- NewWriter: func(b *bytes.Buffer) (io.Writer, error) {
- return NewWriter(b, blockSize, flate.BestCompression)
- },
- NewReader: func(b *bytes.Buffer) (io.Reader, error) {
- return NewReader(b)
- },
- })
+ for _, key := range [][]byte{nil, hashKey} {
+ for _, corruptData := range []bool{false, true} {
+ if key == nil && corruptData {
+ // No need to test corrupt data
+ // case when not doing hashing.
+ continue
+ }
+ // Do the compress test.
+ doTest(t, testOpts{
+ Name: fmt.Sprintf("len(data)=%d, blockSize=%d, key=%s, corruptData=%v", len(data), blockSize, string(key), corruptData),
+ Data: data,
+ NewWriter: func(b *bytes.Buffer) (io.Writer, error) {
+ return NewWriter(b, key, blockSize, flate.BestSpeed)
+ },
+ NewReader: func(b *bytes.Buffer) (io.Reader, error) {
+ return NewReader(b, key)
+ },
+ CorruptData: corruptData,
+ })
+ }
+ }
}
// Do the vanilla test.
@@ -171,7 +197,7 @@ func TestCompress(t *testing.T) {
Name: fmt.Sprintf("len(data)=%d, vanilla flate", len(data)),
Data: data,
NewWriter: func(b *bytes.Buffer) (io.Writer, error) {
- return flate.NewWriter(b, flate.BestCompression)
+ return flate.NewWriter(b, flate.BestSpeed)
},
NewReader: func(b *bytes.Buffer) (io.Reader, error) {
return flate.NewReader(b), nil
@@ -181,47 +207,84 @@ func TestCompress(t *testing.T) {
}
const (
- // benchBlockSize is the blockSize for benchmarks.
- benchBlockSize = 32 * 1024
-
- // benchDataSize is the amount of data for benchmarks.
- benchDataSize = 10 * 1024 * 1024
+ benchDataSize = 600 * 1024 * 1024
)
-func BenchmarkCompress(b *testing.B) {
+func benchmark(b *testing.B, compress bool, hash bool, blockSize uint32) {
b.StopTimer()
b.SetBytes(benchDataSize)
data := initTest(b, benchDataSize)
+ compIters := b.N
+ decompIters := b.N
+ if compress {
+ decompIters = 0
+ } else {
+ compIters = 0
+ }
+ key := hashKey
+ if !hash {
+ key = nil
+ }
doTest(b, testOpts{
- Name: fmt.Sprintf("len(data)=%d, blockSize=%d", len(data), benchBlockSize),
+ Name: fmt.Sprintf("compress=%t, hash=%t, len(data)=%d, blockSize=%d", compress, hash, len(data), blockSize),
Data: data,
PreCompress: b.StartTimer,
PostCompress: b.StopTimer,
NewWriter: func(b *bytes.Buffer) (io.Writer, error) {
- return NewWriter(b, benchBlockSize, flate.BestCompression)
+ return NewWriter(b, key, blockSize, flate.BestSpeed)
},
NewReader: func(b *bytes.Buffer) (io.Reader, error) {
- return NewReader(b)
+ return NewReader(b, key)
},
- CompressIters: b.N,
+ CompressIters: compIters,
+ DecompressIters: decompIters,
})
}
-func BenchmarkDecompress(b *testing.B) {
- b.StopTimer()
- b.SetBytes(benchDataSize)
- data := initTest(b, benchDataSize)
- doTest(b, testOpts{
- Name: fmt.Sprintf("len(data)=%d, blockSize=%d", len(data), benchBlockSize),
- Data: data,
- PreDecompress: b.StartTimer,
- PostDecompress: b.StopTimer,
- NewWriter: func(b *bytes.Buffer) (io.Writer, error) {
- return NewWriter(b, benchBlockSize, flate.BestCompression)
- },
- NewReader: func(b *bytes.Buffer) (io.Reader, error) {
- return NewReader(b)
- },
- DecompressIters: b.N,
- })
+func BenchmarkCompressNoHash64K(b *testing.B) {
+ benchmark(b, true, false, 64*1024)
+}
+
+func BenchmarkCompressHash64K(b *testing.B) {
+ benchmark(b, true, true, 64*1024)
+}
+
+func BenchmarkDecompressNoHash64K(b *testing.B) {
+ benchmark(b, false, false, 64*1024)
+}
+
+func BenchmarkDecompressHash64K(b *testing.B) {
+ benchmark(b, false, true, 64*1024)
+}
+
+func BenchmarkCompressNoHash1M(b *testing.B) {
+ benchmark(b, true, false, 1024*1024)
+}
+
+func BenchmarkCompressHash1M(b *testing.B) {
+ benchmark(b, true, true, 1024*1024)
+}
+
+func BenchmarkDecompressNoHash1M(b *testing.B) {
+ benchmark(b, false, false, 1024*1024)
+}
+
+func BenchmarkDecompressHash1M(b *testing.B) {
+ benchmark(b, false, true, 1024*1024)
+}
+
+func BenchmarkCompressNoHash16M(b *testing.B) {
+ benchmark(b, true, false, 16*1024*1024)
+}
+
+func BenchmarkCompressHash16M(b *testing.B) {
+ benchmark(b, true, true, 16*1024*1024)
+}
+
+func BenchmarkDecompressNoHash16M(b *testing.B) {
+ benchmark(b, false, false, 16*1024*1024)
+}
+
+func BenchmarkDecompressHash16M(b *testing.B) {
+ benchmark(b, false, true, 16*1024*1024)
}
diff --git a/pkg/hashio/BUILD b/pkg/hashio/BUILD
deleted file mode 100644
index 5736e2e73..000000000
--- a/pkg/hashio/BUILD
+++ /dev/null
@@ -1,19 +0,0 @@
-package(licenses = ["notice"]) # Apache 2.0
-
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-
-go_library(
- name = "hashio",
- srcs = [
- "hashio.go",
- ],
- importpath = "gvisor.googlesource.com/gvisor/pkg/hashio",
- visibility = ["//:sandbox"],
-)
-
-go_test(
- name = "hashio_test",
- size = "small",
- srcs = ["hashio_test.go"],
- embed = [":hashio"],
-)
diff --git a/pkg/hashio/hashio.go b/pkg/hashio/hashio.go
deleted file mode 100644
index e0e8ef413..000000000
--- a/pkg/hashio/hashio.go
+++ /dev/null
@@ -1,296 +0,0 @@
-// Copyright 2018 Google Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-/*
-Package hashio provides hash-verified I/O streams.
-
-The I/O stream format is defined as follows.
-
-/-----------------------------------------\
-| payload |
-+-----------------------------------------+
-| hash |
-+-----------------------------------------+
-| payload |
-+-----------------------------------------+
-| hash |
-+-----------------------------------------+
-| ...... |
-\-----------------------------------------/
-
-Payload bytes written to / read from the stream are automatically split
-into segments, each followed by a hash. All data read out must have already
-passed hash verification. Hence the client code can safely do any kind of
-(stream) processing of these data.
-*/
-package hashio
-
-import (
- "errors"
- "hash"
- "io"
- "sync"
-
- "crypto/hmac"
-)
-
-// SegmentSize is the unit we split payload data and insert hash at.
-const SegmentSize = 8 * 1024
-
-// ErrHashMismatch is returned if the ErrHashMismatch does not match.
-var ErrHashMismatch = errors.New("hash mismatch")
-
-// writer computes hashs during writes.
-type writer struct {
- mu sync.Mutex
- w io.Writer
- h hash.Hash
- written int
- closed bool
- hashv []byte
-}
-
-// NewWriter creates a hash-verified IO stream writer.
-func NewWriter(w io.Writer, h hash.Hash) io.WriteCloser {
- return &writer{
- w: w,
- h: h,
- hashv: make([]byte, h.Size()),
- }
-}
-
-// Write writes the given data.
-func (w *writer) Write(p []byte) (int, error) {
- w.mu.Lock()
- defer w.mu.Unlock()
-
- // Did we already close?
- if w.closed {
- return 0, io.ErrUnexpectedEOF
- }
-
- for done := 0; done < len(p); {
- // Slice the data at segment boundary.
- left := SegmentSize - w.written
- if left > len(p[done:]) {
- left = len(p[done:])
- }
-
- // Write the rest of the segment and write to hash writer the
- // same number of bytes. Hash.Write may never return an error.
- n, err := w.w.Write(p[done : done+left])
- w.h.Write(p[done : done+left])
- w.written += n
- done += n
-
- // And only check the actual write errors here.
- if n == 0 && err != nil {
- return done, err
- }
-
- // Write hash if starting a new segment.
- if w.written == SegmentSize {
- if err := w.closeSegment(); err != nil {
- return done, err
- }
- }
- }
-
- return len(p), nil
-}
-
-// closeSegment closes the current segment and writes out its hash.
-func (w *writer) closeSegment() error {
- // Serialize and write the current segment's hash.
- hashv := w.h.Sum(w.hashv[:0])
- for done := 0; done < len(hashv); {
- n, err := w.w.Write(hashv[done:])
- done += n
- if n == 0 && err != nil {
- return err
- }
- }
- w.written = 0 // reset counter.
- return nil
-}
-
-// Close writes the final hash to the stream and closes the underlying Writer.
-func (w *writer) Close() error {
- w.mu.Lock()
- defer w.mu.Unlock()
-
- // Did we already close?
- if w.closed {
- return io.ErrUnexpectedEOF
- }
-
- // Always mark as closed, regardless of errors.
- w.closed = true
-
- // Write the final segment.
- if err := w.closeSegment(); err != nil {
- return err
- }
-
- // Call the underlying closer.
- if c, ok := w.w.(io.Closer); ok {
- return c.Close()
- }
- return nil
-}
-
-// reader computes and verifies hashs during reads.
-type reader struct {
- mu sync.Mutex
- r io.Reader
- h hash.Hash
-
- // data is remaining verified but unused payload data. This is
- // populated on short reads and may be consumed without any
- // verification.
- data [SegmentSize]byte
-
- // index is the index into data above.
- index int
-
- // available is the amount of valid data above.
- available int
-
- // hashv is the read hash for the current segment.
- hashv []byte
-
- // computev is the computed hash for the current segment.
- computev []byte
-}
-
-// NewReader creates a hash-verified IO stream reader.
-func NewReader(r io.Reader, h hash.Hash) io.Reader {
- return &reader{
- r: r,
- h: h,
- hashv: make([]byte, h.Size()),
- computev: make([]byte, h.Size()),
- }
-}
-
-// readSegment reads a segment and hash vector.
-//
-// Precondition: datav must have length SegmentSize.
-func (r *reader) readSegment(datav []byte) (data []byte, err error) {
- // Make two reads: the first is the segment, the second is the hash
- // which needs verification. We may need to adjust the resulting slices
- // in the case of short reads.
- for done := 0; done < SegmentSize; {
- n, err := r.r.Read(datav[done:])
- done += n
- if n == 0 && err == io.EOF {
- if done == 0 {
- // No data at all.
- return nil, io.EOF
- } else if done < len(r.hashv) {
- // Not enough for a hash.
- return nil, ErrHashMismatch
- }
- // Truncate the data and copy to the hash.
- copy(r.hashv, datav[done-len(r.hashv):])
- datav = datav[:done-len(r.hashv)]
- return datav, nil
- } else if n == 0 && err != nil {
- return nil, err
- }
- }
- for done := 0; done < len(r.hashv); {
- n, err := r.r.Read(r.hashv[done:])
- done += n
- if n == 0 && err == io.EOF {
- // Copy over from the data.
- missing := len(r.hashv) - done
- copy(r.hashv[missing:], r.hashv[:done])
- copy(r.hashv[:missing], datav[len(datav)-missing:])
- datav = datav[:len(datav)-missing]
- return datav, nil
- } else if n == 0 && err != nil {
- return nil, err
- }
- }
- return datav, nil
-}
-
-// verifyHash verifies the given hash.
-//
-// The passed hash will be returned to the pool.
-func (r *reader) verifyHash(datav []byte) error {
- for done := 0; done < len(datav); {
- n, _ := r.h.Write(datav[done:])
- done += n
- }
- computev := r.h.Sum(r.computev[:0])
- if !hmac.Equal(r.hashv, computev) {
- return ErrHashMismatch
- }
- return nil
-}
-
-// Read reads the data.
-func (r *reader) Read(p []byte) (int, error) {
- r.mu.Lock()
- defer r.mu.Unlock()
-
- for done := 0; done < len(p); {
- // Check for pending data.
- if r.index < r.available {
- n := copy(p[done:], r.data[r.index:r.available])
- done += n
- r.index += n
- continue
- }
-
- // Prepare the next read.
- var (
- datav []byte
- inline bool
- )
-
- // We need to read a new segment. Can we read directly?
- if len(p[done:]) >= SegmentSize {
- datav = p[done : done+SegmentSize]
- inline = true
- } else {
- datav = r.data[:]
- inline = false
- }
-
- // Read the next segments.
- datav, err := r.readSegment(datav)
- if err != nil && err != io.EOF {
- return 0, err
- } else if err == io.EOF {
- return done, io.EOF
- }
- if err := r.verifyHash(datav); err != nil {
- return done, err
- }
-
- if inline {
- // Move the cursor.
- done += len(datav)
- } else {
- // Reset index & available.
- r.index = 0
- r.available = len(datav)
- }
- }
-
- return len(p), nil
-}
diff --git a/pkg/hashio/hashio_test.go b/pkg/hashio/hashio_test.go
deleted file mode 100644
index 41dbdf860..000000000
--- a/pkg/hashio/hashio_test.go
+++ /dev/null
@@ -1,142 +0,0 @@
-// Copyright 2018 Google Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package hashio
-
-import (
- "bytes"
- "crypto/hmac"
- "crypto/sha256"
- "fmt"
- "io"
- "math/rand"
- "testing"
-)
-
-var testKey = []byte("01234567890123456789012345678901")
-
-func runTest(c []byte, fn func(enc *bytes.Buffer), iters int) error {
- // Encoding happens via a buffer.
- var (
- enc bytes.Buffer
- dec bytes.Buffer
- )
-
- for i := 0; i < iters; i++ {
- enc.Reset()
- w := NewWriter(&enc, hmac.New(sha256.New, testKey))
- if _, err := io.Copy(w, bytes.NewBuffer(c)); err != nil {
- return err
- }
- if err := w.Close(); err != nil {
- return err
- }
- }
-
- fn(&enc)
-
- for i := 0; i < iters; i++ {
- dec.Reset()
- r := NewReader(bytes.NewReader(enc.Bytes()), hmac.New(sha256.New, testKey))
- if _, err := io.Copy(&dec, r); err != nil {
- return err
- }
- }
-
- // Check that the data matches; this should never fail.
- if !bytes.Equal(c, dec.Bytes()) {
- panic(fmt.Sprintf("data didn't match: got %v, expected %v", dec.Bytes(), c))
- }
-
- return nil
-}
-
-func TestTable(t *testing.T) {
- cases := [][]byte{
- // Various data sizes.
- nil,
- []byte(""),
- []byte("_"),
- []byte("0"),
- []byte("01"),
- []byte("012"),
- []byte("0123"),
- []byte("01234"),
- []byte("012356"),
- []byte("0123567"),
- []byte("01235678"),
-
- // Make sure we have one longer than the hash length.
- []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"),
-
- // Make sure we have one longer than the segment size.
- make([]byte, 3*SegmentSize),
- make([]byte, 3*SegmentSize-1),
- make([]byte, 3*SegmentSize+1),
- make([]byte, 3*SegmentSize-32),
- make([]byte, 3*SegmentSize+32),
- make([]byte, 30*SegmentSize),
- }
-
- for _, c := range cases {
- for _, flip := range []bool{false, true} {
- if len(c) == 0 && flip == true {
- continue
- }
-
- // Log the case.
- t.Logf("case: len=%d flip=%v", len(c), flip)
-
- if err := runTest(c, func(enc *bytes.Buffer) {
- if flip {
- corrupted := rand.Intn(enc.Len())
- enc.Bytes()[corrupted]++
- }
- }, 1); err != nil {
- if !flip || err != ErrHashMismatch {
- t.Errorf("error during read: got %v, expected nil", err)
- }
- continue
- } else if flip {
- t.Errorf("failed to detect ErrHashMismatch on corrupted data!")
- continue
- }
- }
- }
-}
-
-const benchBytes = 10 * 1024 * 1024 // 10 MB.
-
-func BenchmarkWrite(b *testing.B) {
- b.StopTimer()
- x := make([]byte, benchBytes)
- b.SetBytes(benchBytes)
- b.StartTimer()
- if err := runTest(x, func(enc *bytes.Buffer) {
- b.StopTimer()
- }, b.N); err != nil {
- b.Errorf("benchmark failed: %v", err)
- }
-}
-
-func BenchmarkRead(b *testing.B) {
- b.StopTimer()
- x := make([]byte, benchBytes)
- b.SetBytes(benchBytes)
- if err := runTest(x, func(enc *bytes.Buffer) {
- b.StartTimer()
- }, b.N); err != nil {
- b.Errorf("benchmark failed: %v", err)
- }
-}
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
index 16abe1930..6be78dc9b 100644
--- a/pkg/state/statefile/BUILD
+++ b/pkg/state/statefile/BUILD
@@ -10,7 +10,6 @@ go_library(
deps = [
"//pkg/binary",
"//pkg/compressio",
- "//pkg/hashio",
],
)
@@ -19,5 +18,5 @@ go_test(
size = "small",
srcs = ["statefile_test.go"],
embed = [":statefile"],
- deps = ["//pkg/hashio"],
+ deps = ["//pkg/compressio"],
)
diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go
index 0b4eff8fa..9c86c1934 100644
--- a/pkg/state/statefile/statefile.go
+++ b/pkg/state/statefile/statefile.go
@@ -57,7 +57,6 @@ import (
"crypto/sha256"
"gvisor.googlesource.com/gvisor/pkg/binary"
"gvisor.googlesource.com/gvisor/pkg/compressio"
- "gvisor.googlesource.com/gvisor/pkg/hashio"
)
// keySize is the AES-256 key length.
@@ -139,13 +138,11 @@ func NewWriter(w io.Writer, key []byte, metadata map[string]string) (io.WriteClo
}
}
- w = hashio.NewWriter(w, h)
-
// Wrap in compression. We always use "best speed" mode here. When using
// "best compression" mode, there is usually only a little gain in file
// size reduction, which translate to even smaller gain in restore
// latency reduction, while inccuring much more CPU usage at save time.
- return compressio.NewWriter(w, compressionChunkSize, flate.BestSpeed)
+ return compressio.NewWriter(w, key, compressionChunkSize, flate.BestSpeed)
}
// MetadataUnsafe reads out the metadata from a state file without verifying any
@@ -204,7 +201,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) {
return nil, err
}
if !hmac.Equal(cur, buf) {
- return nil, hashio.ErrHashMismatch
+ return nil, compressio.ErrHashMismatch
}
}
@@ -226,10 +223,8 @@ func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) {
return nil, nil, err
}
- r = hashio.NewReader(r, h)
-
// Wrap in compression.
- rc, err := compressio.NewReader(r)
+ rc, err := compressio.NewReader(r, key)
if err != nil {
return nil, nil, err
}
diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go
index 66d9581ed..fa3fb9f2c 100644
--- a/pkg/state/statefile/statefile_test.go
+++ b/pkg/state/statefile/statefile_test.go
@@ -20,9 +20,11 @@ import (
"encoding/base64"
"io"
"math/rand"
+ "runtime"
"testing"
+ "time"
- "gvisor.googlesource.com/gvisor/pkg/hashio"
+ "gvisor.googlesource.com/gvisor/pkg/compressio"
)
func randomKey() ([]byte, error) {
@@ -42,6 +44,8 @@ type testCase struct {
}
func TestStatefile(t *testing.T) {
+ rand.Seed(time.Now().Unix())
+
cases := []testCase{
// Various data sizes.
{"nil", nil, nil},
@@ -59,13 +63,9 @@ func TestStatefile(t *testing.T) {
// Make sure we have one longer than the hash length.
{"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil},
- // Make sure we have one longer than the segment size.
- {"segments", make([]byte, 3*hashio.SegmentSize), nil},
- {"segments minus one", make([]byte, 3*hashio.SegmentSize-1), nil},
- {"segments plus one", make([]byte, 3*hashio.SegmentSize+1), nil},
- {"segments minus hash", make([]byte, 3*hashio.SegmentSize-32), nil},
- {"segments plus hash", make([]byte, 3*hashio.SegmentSize+32), nil},
- {"large", make([]byte, 30*hashio.SegmentSize), nil},
+ // Make sure we have one longer than the chunk size.
+ {"chunks", make([]byte, 3*compressionChunkSize), nil},
+ {"large", make([]byte, 30*compressionChunkSize), nil},
// Different metadata.
{"one metadata", []byte("data"), map[string]string{"foo": "bar"}},
@@ -130,27 +130,31 @@ func TestStatefile(t *testing.T) {
}
// Change the data and verify that it fails.
- b := append([]byte(nil), bufEncoded.Bytes()...)
- b[rand.Intn(len(b))]++
- r, _, err = NewReader(bytes.NewReader(b), key)
- if err == nil {
- _, err = io.Copy(&bufDecoded, r)
- }
- if err == nil {
- t.Error("got no error: expected error on data corruption")
+ if key != nil {
+ b := append([]byte(nil), bufEncoded.Bytes()...)
+ b[rand.Intn(len(b))]++
+ bufDecoded.Reset()
+ r, _, err = NewReader(bytes.NewReader(b), key)
+ if err == nil {
+ _, err = io.Copy(&bufDecoded, r)
+ }
+ if err == nil {
+ t.Error("got no error: expected error on data corruption")
+ }
}
// Change the key and verify that it fails.
- if key == nil {
- key = integrityKey
- } else {
- key[rand.Intn(len(key))]++
+ newKey := integrityKey
+ if len(key) > 0 {
+ newKey = append([]byte{}, key...)
+ newKey[rand.Intn(len(newKey))]++
}
- r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), key)
+ bufDecoded.Reset()
+ r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), newKey)
if err == nil {
_, err = io.Copy(&bufDecoded, r)
}
- if err != hashio.ErrHashMismatch {
+ if err != compressio.ErrHashMismatch {
t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err)
}
})
@@ -159,7 +163,7 @@ func TestStatefile(t *testing.T) {
}
}
-const benchmarkDataSize = 10 * 1024 * 1024
+const benchmarkDataSize = 100 * 1024 * 1024
func benchmark(b *testing.B, size int, write bool, compressible bool) {
b.StopTimer()
@@ -249,14 +253,6 @@ func benchmark(b *testing.B, size int, write bool, compressible bool) {
}
}
-func BenchmarkWrite1BCompressible(b *testing.B) {
- benchmark(b, 1, true, true)
-}
-
-func BenchmarkWrite1BNoncompressible(b *testing.B) {
- benchmark(b, 1, true, false)
-}
-
func BenchmarkWrite4KCompressible(b *testing.B) {
benchmark(b, 4096, true, true)
}
@@ -273,14 +269,6 @@ func BenchmarkWrite1MNoncompressible(b *testing.B) {
benchmark(b, 1024*1024, true, false)
}
-func BenchmarkRead1BCompressible(b *testing.B) {
- benchmark(b, 1, false, true)
-}
-
-func BenchmarkRead1BNoncompressible(b *testing.B) {
- benchmark(b, 1, false, false)
-}
-
func BenchmarkRead4KCompressible(b *testing.B) {
benchmark(b, 4096, false, true)
}
@@ -296,3 +284,7 @@ func BenchmarkRead1MCompressible(b *testing.B) {
func BenchmarkRead1MNoncompressible(b *testing.B) {
benchmark(b, 1024*1024, false, false)
}
+
+func init() {
+ runtime.GOMAXPROCS(runtime.NumCPU())
+}