diff options
author | Zhaozhong Ni <nzz@google.com> | 2018-08-24 14:52:23 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-08-24 14:53:31 -0700 |
commit | a6b00502b04ced2f12cfcf35c6f276cff349737b (patch) | |
tree | d443ea0679091b193bcc5568f0aa5aff3ba1a0f3 /pkg | |
parent | 02dfceab6d4c4a2a3342ef69be0265b7ab03e5d7 (diff) |
compressio: support optional hashing and eliminate hashio.
Compared to previous compressio / hashio nesting, there is up to 100% speedup.
PiperOrigin-RevId: 210161269
Change-Id: I481aa9fe980bb817fe465fe34d32ea33fc8abf1c
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/compressio/compressio.go | 223 | ||||
-rw-r--r-- | pkg/compressio/compressio_test.go | 145 | ||||
-rw-r--r-- | pkg/hashio/BUILD | 19 | ||||
-rw-r--r-- | pkg/hashio/hashio.go | 296 | ||||
-rw-r--r-- | pkg/hashio/hashio_test.go | 142 | ||||
-rw-r--r-- | pkg/state/statefile/BUILD | 3 | ||||
-rw-r--r-- | pkg/state/statefile/statefile.go | 11 | ||||
-rw-r--r-- | pkg/state/statefile/statefile_test.go | 70 |
8 files changed, 339 insertions, 570 deletions
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index ef8cbd2a5..591b37130 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -12,17 +12,48 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package compressio provides parallel compression and decompression. +// Package compressio provides parallel compression and decompression, as well +// as optional SHA-256 hashing. +// +// The stream format is defined as follows. +// +// /------------------------------------------------------\ +// | chunk size (4-bytes) | +// +------------------------------------------------------+ +// | (optional) hash (32-bytes) | +// +------------------------------------------------------+ +// | compressed data size (4-bytes) | +// +------------------------------------------------------+ +// | compressed data | +// +------------------------------------------------------+ +// | (optional) hash (32-bytes) | +// +------------------------------------------------------+ +// | compressed data size (4-bytes) | +// +------------------------------------------------------+ +// | ...... | +// \------------------------------------------------------/ +// +// where each subsequent hash is calculated from the following items in order +// +// compressed data +// compressed data size +// previous hash +// +// so the stream integrity cannot be compromised by switching and mixing +// compressed chunks. package compressio import ( "bytes" "compress/flate" "errors" + "hash" "io" "runtime" "sync" + "crypto/hmac" + "crypto/sha256" "gvisor.googlesource.com/gvisor/pkg/binary" ) @@ -51,12 +82,23 @@ type chunk struct { // This is not returned to the bufPool automatically, since it may // correspond to a inline slice (provided directly to Read or Write). uncompressed *bytes.Buffer + + // The current hash object. Only used in compress mode. + h hash.Hash + + // The hash from previous chunks. Only used in uncompress mode. + lastSum []byte + + // The expected hash after current chunk. Only used in uncompress mode. + sum []byte } // newChunk allocates a new chunk object (or pulls one from the pool). Buffers // will be allocated if nil is provided for compressed or uncompressed. -func newChunk(compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk { +func newChunk(lastSum []byte, sum []byte, compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk { c := chunkPool.Get().(*chunk) + c.lastSum = lastSum + c.sum = sum if compressed != nil { c.compressed = compressed } else { @@ -85,6 +127,7 @@ type result struct { // The goroutine will exit when input is closed, and the goroutine will close // output. type worker struct { + pool *pool input chan *chunk output chan result } @@ -93,17 +136,27 @@ type worker struct { func (w *worker) work(compress bool, level int) { defer close(w.output) + var h hash.Hash + for c := range w.input { + if h == nil && w.pool.key != nil { + h = w.pool.getHash() + } if compress { + mw := io.Writer(c.compressed) + if h != nil { + mw = io.MultiWriter(mw, h) + } + // Encode this slice. - fw, err := flate.NewWriter(c.compressed, level) + fw, err := flate.NewWriter(mw, level) if err != nil { w.output <- result{c, err} continue } // Encode the input. - if _, err := io.Copy(fw, c.uncompressed); err != nil { + if _, err := io.CopyN(fw, c.uncompressed, int64(c.uncompressed.Len())); err != nil { w.output <- result{c, err} continue } @@ -111,7 +164,28 @@ func (w *worker) work(compress bool, level int) { w.output <- result{c, err} continue } + + // Write the hash, if enabled. + if h != nil { + binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + c.h = h + h = nil + } } else { + // Check the hash of the compressed contents. + if h != nil { + h.Write(c.compressed.Bytes()) + binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum))) + + sum := h.Sum(nil) + h.Reset() + if !hmac.Equal(c.sum, sum) { + w.output <- result{c, ErrHashMismatch} + continue + } + } + // Decode this slice. fr := flate.NewReader(c.compressed) @@ -136,6 +210,16 @@ type pool struct { // stream and is shared across both the reader and writer. chunkSize uint32 + // key is the key used to create hash objects. + key []byte + + // hashMu protexts the hash list. + hashMu sync.Mutex + + // hashes is the hash object free list. Note that this cannot be + // globally shared across readers or writers, as it is key-specific. + hashes []hash.Hash + // mu protects below; it is generally the responsibility of users to // acquire this mutex before calling any methods on the pool. mu sync.Mutex @@ -149,15 +233,20 @@ type pool struct { // buf is the current active buffer; the exact semantics of this buffer // depending on whether this is a reader or a writer. buf *bytes.Buffer + + // lasSum records the hash of the last chunk processed. + lastSum []byte } // init initializes the worker pool. // // This should only be called once. -func (p *pool) init(compress bool, level int) { - p.workers = make([]worker, 1+runtime.GOMAXPROCS(0)) +func (p *pool) init(key []byte, workers int, compress bool, level int) { + p.key = key + p.workers = make([]worker, workers) for i := 0; i < len(p.workers); i++ { p.workers[i] = worker{ + pool: p, input: make(chan *chunk, 1), output: make(chan result, 1), } @@ -174,6 +263,30 @@ func (p *pool) stop() { p.workers = nil } +// getHash gets a hash object for the pool. It should only be called when the +// pool key is non-nil. +func (p *pool) getHash() hash.Hash { + p.hashMu.Lock() + defer p.hashMu.Unlock() + + if len(p.hashes) == 0 { + return hmac.New(sha256.New, p.key) + } + + h := p.hashes[len(p.hashes)-1] + p.hashes = p.hashes[:len(p.hashes)-1] + return h +} + +func (p *pool) putHash(h hash.Hash) { + h.Reset() + + p.hashMu.Lock() + defer p.hashMu.Unlock() + + p.hashes = append(p.hashes, h) +} + // handleResult calls the callback. func handleResult(r result, callback func(*chunk) error) error { defer func() { @@ -231,22 +344,46 @@ type reader struct { in io.Reader } -// NewReader returns a new compressed reader. -func NewReader(in io.Reader) (io.Reader, error) { +// NewReader returns a new compressed reader. If key is non-nil, the data stream +// is assumed to contain expected hash values, which will be compared against +// hash values computed from the compressed bytes. See package comments for +// details. +func NewReader(in io.Reader, key []byte) (io.Reader, error) { r := &reader{ in: in, } - r.init(false, 0) + + // Use double buffering for read. + r.init(key, 2*runtime.GOMAXPROCS(0), false, 0) + var err error - if r.chunkSize, err = binary.ReadUint32(r.in, binary.BigEndian); err != nil { + if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil { return nil, err } + + if r.key != nil { + h := r.getHash() + binary.WriteUint32(h, binary.BigEndian, r.chunkSize) + r.lastSum = h.Sum(nil) + r.putHash(h) + sum := make([]byte, len(r.lastSum)) + if _, err := io.ReadFull(r.in, sum); err != nil { + return nil, err + } + if !hmac.Equal(r.lastSum, sum) { + return nil, ErrHashMismatch + } + } + return r, nil } // errNewBuffer is returned when a new buffer is completed. var errNewBuffer = errors.New("buffer ready") +// ErrHashMismatch is returned if the hash does not match. +var ErrHashMismatch = errors.New("hash mismatch") + // Read implements io.Reader.Read. func (r *reader) Read(p []byte) (int, error) { r.mu.Lock() @@ -331,14 +468,25 @@ func (r *reader) Read(p []byte) (int, error) { // Read this chunk and schedule decompression. compressed := bufPool.Get().(*bytes.Buffer) - if _, err := io.Copy(compressed, &io.LimitedReader{ - R: r.in, - N: int64(l), - }); err != nil { + if _, err := io.CopyN(compressed, r.in, int64(l)); err != nil { // Some other error occurred; see above. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } return done, err } + var sum []byte + if r.key != nil { + sum = make([]byte, len(r.lastSum)) + if _, err := io.ReadFull(r.in, sum); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return done, err + } + } + // Are we doing inline decoding? // // Note that we need to check the length here against @@ -349,11 +497,12 @@ func (r *reader) Read(p []byte) (int, error) { var c *chunk start := done + ((pendingPre + pendingInline) * int(r.chunkSize)) if len(p) >= start+int(r.chunkSize) && len(p) >= start+bytes.MinRead { - c = newChunk(compressed, bytes.NewBuffer(p[start:start])) + c = newChunk(r.lastSum, sum, compressed, bytes.NewBuffer(p[start:start])) pendingInline++ } else { - c = newChunk(compressed, nil) + c = newChunk(r.lastSum, sum, compressed, nil) } + r.lastSum = sum if err := r.schedule(c, callback); err == errNewBuffer { // A new buffer was completed while we were reading. // That's great, but we need to force schedule the @@ -403,12 +552,14 @@ type writer struct { closed bool } -// NewWriter returns a new compressed writer. +// NewWriter returns a new compressed writer. If key is non-nil, hash values are +// generated and written out for compressed bytes. See package comments for +// details. // // The recommended chunkSize is on the order of 1M. Extra memory may be // buffered (in the form of read-ahead, or buffered writes), and is limited to // O(chunkSize * [1+GOMAXPROCS]). -func NewWriter(out io.Writer, chunkSize uint32, level int) (io.WriteCloser, error) { +func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.WriteCloser, error) { w := &writer{ pool: pool{ chunkSize: chunkSize, @@ -416,10 +567,22 @@ func NewWriter(out io.Writer, chunkSize uint32, level int) (io.WriteCloser, erro }, out: out, } - w.init(true, level) + w.init(key, 1+runtime.GOMAXPROCS(0), true, level) + if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil { return nil, err } + + if w.key != nil { + h := w.getHash() + binary.WriteUint32(h, binary.BigEndian, chunkSize) + w.lastSum = h.Sum(nil) + w.putHash(h) + if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil { + return nil, err + } + } + return w, nil } @@ -433,8 +596,22 @@ func (w *writer) flush(c *chunk) error { } // Write out to the stream. - _, err := io.Copy(w.out, c.compressed) - return err + if _, err := io.CopyN(w.out, c.compressed, int64(c.compressed.Len())); err != nil { + return err + } + + if w.key != nil { + io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum))) + sum := c.h.Sum(nil) + w.putHash(c.h) + c.h = nil + if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil { + return err + } + w.lastSum = sum + } + + return nil } // Write implements io.Writer.Write. @@ -480,7 +657,7 @@ func (w *writer) Write(p []byte) (int, error) { // immediately following the inline case above. left := int(w.chunkSize) - w.buf.Len() if left == 0 { - if err := w.schedule(newChunk(nil, w.buf), callback); err != nil { + if err := w.schedule(newChunk(nil, nil, nil, w.buf), callback); err != nil { return done, err } // Reset the buffer, since this has now been scheduled @@ -538,7 +715,7 @@ func (w *writer) Close() error { // Schedule any remaining partial buffer; we pass w.flush directly here // because the final buffer is guaranteed to not be an inline buffer. if w.buf.Len() > 0 { - if err := w.schedule(newChunk(nil, w.buf), w.flush); err != nil { + if err := w.schedule(newChunk(nil, nil, nil, w.buf), w.flush); err != nil { return err } } diff --git a/pkg/compressio/compressio_test.go b/pkg/compressio/compressio_test.go index d7911419d..7cb5f8dc4 100644 --- a/pkg/compressio/compressio_test.go +++ b/pkg/compressio/compressio_test.go @@ -59,6 +59,7 @@ type testOpts struct { PostDecompress func() CompressIters int DecompressIters int + CorruptData bool } func doTest(t harness, opts testOpts) { @@ -104,15 +105,22 @@ func doTest(t harness, opts testOpts) { if opts.DecompressIters <= 0 { opts.DecompressIters = 1 } + if opts.CorruptData { + b := compressed.Bytes() + b[rand.Intn(len(b))]++ + } for i := 0; i < opts.DecompressIters; i++ { decompressed.Reset() r, err := opts.NewReader(bytes.NewBuffer(compressed.Bytes())) if err != nil { + if opts.CorruptData { + continue + } t.Errorf("%s: NewReader got err %v, expected nil", opts.Name, err) return } - if _, err := io.Copy(&decompressed, r); err != nil { - t.Errorf("%s: decompress got err %v, expected nil", opts.Name, err) + if _, err := io.Copy(&decompressed, r); (err != nil) != opts.CorruptData { + t.Errorf("%s: decompress got err %v unexpectly", opts.Name, err) return } } @@ -121,6 +129,10 @@ func doTest(t harness, opts testOpts) { } decompressionTime := time.Since(decompressionStartTime) + if opts.CorruptData { + return + } + // Verify. if decompressed.Len() != len(opts.Data) { t.Errorf("%s: got %d bytes, expected %d", opts.Name, decompressed.Len(), len(opts.Data)) @@ -136,7 +148,11 @@ func doTest(t harness, opts testOpts) { opts.Name, compressionTime, compressionRatio, decompressionTime) } +var hashKey = []byte("01234567890123456789012345678901") + func TestCompress(t *testing.T) { + rand.Seed(time.Now().Unix()) + var ( data = initTest(t, 10*1024*1024) data0 = data[:0] @@ -153,17 +169,27 @@ func TestCompress(t *testing.T) { continue } - // Do the compress test. - doTest(t, testOpts{ - Name: fmt.Sprintf("len(data)=%d, blockSize=%d", len(data), blockSize), - Data: data, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return NewWriter(b, blockSize, flate.BestCompression) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return NewReader(b) - }, - }) + for _, key := range [][]byte{nil, hashKey} { + for _, corruptData := range []bool{false, true} { + if key == nil && corruptData { + // No need to test corrupt data + // case when not doing hashing. + continue + } + // Do the compress test. + doTest(t, testOpts{ + Name: fmt.Sprintf("len(data)=%d, blockSize=%d, key=%s, corruptData=%v", len(data), blockSize, string(key), corruptData), + Data: data, + NewWriter: func(b *bytes.Buffer) (io.Writer, error) { + return NewWriter(b, key, blockSize, flate.BestSpeed) + }, + NewReader: func(b *bytes.Buffer) (io.Reader, error) { + return NewReader(b, key) + }, + CorruptData: corruptData, + }) + } + } } // Do the vanilla test. @@ -171,7 +197,7 @@ func TestCompress(t *testing.T) { Name: fmt.Sprintf("len(data)=%d, vanilla flate", len(data)), Data: data, NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return flate.NewWriter(b, flate.BestCompression) + return flate.NewWriter(b, flate.BestSpeed) }, NewReader: func(b *bytes.Buffer) (io.Reader, error) { return flate.NewReader(b), nil @@ -181,47 +207,84 @@ func TestCompress(t *testing.T) { } const ( - // benchBlockSize is the blockSize for benchmarks. - benchBlockSize = 32 * 1024 - - // benchDataSize is the amount of data for benchmarks. - benchDataSize = 10 * 1024 * 1024 + benchDataSize = 600 * 1024 * 1024 ) -func BenchmarkCompress(b *testing.B) { +func benchmark(b *testing.B, compress bool, hash bool, blockSize uint32) { b.StopTimer() b.SetBytes(benchDataSize) data := initTest(b, benchDataSize) + compIters := b.N + decompIters := b.N + if compress { + decompIters = 0 + } else { + compIters = 0 + } + key := hashKey + if !hash { + key = nil + } doTest(b, testOpts{ - Name: fmt.Sprintf("len(data)=%d, blockSize=%d", len(data), benchBlockSize), + Name: fmt.Sprintf("compress=%t, hash=%t, len(data)=%d, blockSize=%d", compress, hash, len(data), blockSize), Data: data, PreCompress: b.StartTimer, PostCompress: b.StopTimer, NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return NewWriter(b, benchBlockSize, flate.BestCompression) + return NewWriter(b, key, blockSize, flate.BestSpeed) }, NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return NewReader(b) + return NewReader(b, key) }, - CompressIters: b.N, + CompressIters: compIters, + DecompressIters: decompIters, }) } -func BenchmarkDecompress(b *testing.B) { - b.StopTimer() - b.SetBytes(benchDataSize) - data := initTest(b, benchDataSize) - doTest(b, testOpts{ - Name: fmt.Sprintf("len(data)=%d, blockSize=%d", len(data), benchBlockSize), - Data: data, - PreDecompress: b.StartTimer, - PostDecompress: b.StopTimer, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return NewWriter(b, benchBlockSize, flate.BestCompression) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return NewReader(b) - }, - DecompressIters: b.N, - }) +func BenchmarkCompressNoHash64K(b *testing.B) { + benchmark(b, true, false, 64*1024) +} + +func BenchmarkCompressHash64K(b *testing.B) { + benchmark(b, true, true, 64*1024) +} + +func BenchmarkDecompressNoHash64K(b *testing.B) { + benchmark(b, false, false, 64*1024) +} + +func BenchmarkDecompressHash64K(b *testing.B) { + benchmark(b, false, true, 64*1024) +} + +func BenchmarkCompressNoHash1M(b *testing.B) { + benchmark(b, true, false, 1024*1024) +} + +func BenchmarkCompressHash1M(b *testing.B) { + benchmark(b, true, true, 1024*1024) +} + +func BenchmarkDecompressNoHash1M(b *testing.B) { + benchmark(b, false, false, 1024*1024) +} + +func BenchmarkDecompressHash1M(b *testing.B) { + benchmark(b, false, true, 1024*1024) +} + +func BenchmarkCompressNoHash16M(b *testing.B) { + benchmark(b, true, false, 16*1024*1024) +} + +func BenchmarkCompressHash16M(b *testing.B) { + benchmark(b, true, true, 16*1024*1024) +} + +func BenchmarkDecompressNoHash16M(b *testing.B) { + benchmark(b, false, false, 16*1024*1024) +} + +func BenchmarkDecompressHash16M(b *testing.B) { + benchmark(b, false, true, 16*1024*1024) } diff --git a/pkg/hashio/BUILD b/pkg/hashio/BUILD deleted file mode 100644 index 5736e2e73..000000000 --- a/pkg/hashio/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -package(licenses = ["notice"]) # Apache 2.0 - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - -go_library( - name = "hashio", - srcs = [ - "hashio.go", - ], - importpath = "gvisor.googlesource.com/gvisor/pkg/hashio", - visibility = ["//:sandbox"], -) - -go_test( - name = "hashio_test", - size = "small", - srcs = ["hashio_test.go"], - embed = [":hashio"], -) diff --git a/pkg/hashio/hashio.go b/pkg/hashio/hashio.go deleted file mode 100644 index e0e8ef413..000000000 --- a/pkg/hashio/hashio.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2018 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* -Package hashio provides hash-verified I/O streams. - -The I/O stream format is defined as follows. - -/-----------------------------------------\ -| payload | -+-----------------------------------------+ -| hash | -+-----------------------------------------+ -| payload | -+-----------------------------------------+ -| hash | -+-----------------------------------------+ -| ...... | -\-----------------------------------------/ - -Payload bytes written to / read from the stream are automatically split -into segments, each followed by a hash. All data read out must have already -passed hash verification. Hence the client code can safely do any kind of -(stream) processing of these data. -*/ -package hashio - -import ( - "errors" - "hash" - "io" - "sync" - - "crypto/hmac" -) - -// SegmentSize is the unit we split payload data and insert hash at. -const SegmentSize = 8 * 1024 - -// ErrHashMismatch is returned if the ErrHashMismatch does not match. -var ErrHashMismatch = errors.New("hash mismatch") - -// writer computes hashs during writes. -type writer struct { - mu sync.Mutex - w io.Writer - h hash.Hash - written int - closed bool - hashv []byte -} - -// NewWriter creates a hash-verified IO stream writer. -func NewWriter(w io.Writer, h hash.Hash) io.WriteCloser { - return &writer{ - w: w, - h: h, - hashv: make([]byte, h.Size()), - } -} - -// Write writes the given data. -func (w *writer) Write(p []byte) (int, error) { - w.mu.Lock() - defer w.mu.Unlock() - - // Did we already close? - if w.closed { - return 0, io.ErrUnexpectedEOF - } - - for done := 0; done < len(p); { - // Slice the data at segment boundary. - left := SegmentSize - w.written - if left > len(p[done:]) { - left = len(p[done:]) - } - - // Write the rest of the segment and write to hash writer the - // same number of bytes. Hash.Write may never return an error. - n, err := w.w.Write(p[done : done+left]) - w.h.Write(p[done : done+left]) - w.written += n - done += n - - // And only check the actual write errors here. - if n == 0 && err != nil { - return done, err - } - - // Write hash if starting a new segment. - if w.written == SegmentSize { - if err := w.closeSegment(); err != nil { - return done, err - } - } - } - - return len(p), nil -} - -// closeSegment closes the current segment and writes out its hash. -func (w *writer) closeSegment() error { - // Serialize and write the current segment's hash. - hashv := w.h.Sum(w.hashv[:0]) - for done := 0; done < len(hashv); { - n, err := w.w.Write(hashv[done:]) - done += n - if n == 0 && err != nil { - return err - } - } - w.written = 0 // reset counter. - return nil -} - -// Close writes the final hash to the stream and closes the underlying Writer. -func (w *writer) Close() error { - w.mu.Lock() - defer w.mu.Unlock() - - // Did we already close? - if w.closed { - return io.ErrUnexpectedEOF - } - - // Always mark as closed, regardless of errors. - w.closed = true - - // Write the final segment. - if err := w.closeSegment(); err != nil { - return err - } - - // Call the underlying closer. - if c, ok := w.w.(io.Closer); ok { - return c.Close() - } - return nil -} - -// reader computes and verifies hashs during reads. -type reader struct { - mu sync.Mutex - r io.Reader - h hash.Hash - - // data is remaining verified but unused payload data. This is - // populated on short reads and may be consumed without any - // verification. - data [SegmentSize]byte - - // index is the index into data above. - index int - - // available is the amount of valid data above. - available int - - // hashv is the read hash for the current segment. - hashv []byte - - // computev is the computed hash for the current segment. - computev []byte -} - -// NewReader creates a hash-verified IO stream reader. -func NewReader(r io.Reader, h hash.Hash) io.Reader { - return &reader{ - r: r, - h: h, - hashv: make([]byte, h.Size()), - computev: make([]byte, h.Size()), - } -} - -// readSegment reads a segment and hash vector. -// -// Precondition: datav must have length SegmentSize. -func (r *reader) readSegment(datav []byte) (data []byte, err error) { - // Make two reads: the first is the segment, the second is the hash - // which needs verification. We may need to adjust the resulting slices - // in the case of short reads. - for done := 0; done < SegmentSize; { - n, err := r.r.Read(datav[done:]) - done += n - if n == 0 && err == io.EOF { - if done == 0 { - // No data at all. - return nil, io.EOF - } else if done < len(r.hashv) { - // Not enough for a hash. - return nil, ErrHashMismatch - } - // Truncate the data and copy to the hash. - copy(r.hashv, datav[done-len(r.hashv):]) - datav = datav[:done-len(r.hashv)] - return datav, nil - } else if n == 0 && err != nil { - return nil, err - } - } - for done := 0; done < len(r.hashv); { - n, err := r.r.Read(r.hashv[done:]) - done += n - if n == 0 && err == io.EOF { - // Copy over from the data. - missing := len(r.hashv) - done - copy(r.hashv[missing:], r.hashv[:done]) - copy(r.hashv[:missing], datav[len(datav)-missing:]) - datav = datav[:len(datav)-missing] - return datav, nil - } else if n == 0 && err != nil { - return nil, err - } - } - return datav, nil -} - -// verifyHash verifies the given hash. -// -// The passed hash will be returned to the pool. -func (r *reader) verifyHash(datav []byte) error { - for done := 0; done < len(datav); { - n, _ := r.h.Write(datav[done:]) - done += n - } - computev := r.h.Sum(r.computev[:0]) - if !hmac.Equal(r.hashv, computev) { - return ErrHashMismatch - } - return nil -} - -// Read reads the data. -func (r *reader) Read(p []byte) (int, error) { - r.mu.Lock() - defer r.mu.Unlock() - - for done := 0; done < len(p); { - // Check for pending data. - if r.index < r.available { - n := copy(p[done:], r.data[r.index:r.available]) - done += n - r.index += n - continue - } - - // Prepare the next read. - var ( - datav []byte - inline bool - ) - - // We need to read a new segment. Can we read directly? - if len(p[done:]) >= SegmentSize { - datav = p[done : done+SegmentSize] - inline = true - } else { - datav = r.data[:] - inline = false - } - - // Read the next segments. - datav, err := r.readSegment(datav) - if err != nil && err != io.EOF { - return 0, err - } else if err == io.EOF { - return done, io.EOF - } - if err := r.verifyHash(datav); err != nil { - return done, err - } - - if inline { - // Move the cursor. - done += len(datav) - } else { - // Reset index & available. - r.index = 0 - r.available = len(datav) - } - } - - return len(p), nil -} diff --git a/pkg/hashio/hashio_test.go b/pkg/hashio/hashio_test.go deleted file mode 100644 index 41dbdf860..000000000 --- a/pkg/hashio/hashio_test.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2018 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package hashio - -import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "fmt" - "io" - "math/rand" - "testing" -) - -var testKey = []byte("01234567890123456789012345678901") - -func runTest(c []byte, fn func(enc *bytes.Buffer), iters int) error { - // Encoding happens via a buffer. - var ( - enc bytes.Buffer - dec bytes.Buffer - ) - - for i := 0; i < iters; i++ { - enc.Reset() - w := NewWriter(&enc, hmac.New(sha256.New, testKey)) - if _, err := io.Copy(w, bytes.NewBuffer(c)); err != nil { - return err - } - if err := w.Close(); err != nil { - return err - } - } - - fn(&enc) - - for i := 0; i < iters; i++ { - dec.Reset() - r := NewReader(bytes.NewReader(enc.Bytes()), hmac.New(sha256.New, testKey)) - if _, err := io.Copy(&dec, r); err != nil { - return err - } - } - - // Check that the data matches; this should never fail. - if !bytes.Equal(c, dec.Bytes()) { - panic(fmt.Sprintf("data didn't match: got %v, expected %v", dec.Bytes(), c)) - } - - return nil -} - -func TestTable(t *testing.T) { - cases := [][]byte{ - // Various data sizes. - nil, - []byte(""), - []byte("_"), - []byte("0"), - []byte("01"), - []byte("012"), - []byte("0123"), - []byte("01234"), - []byte("012356"), - []byte("0123567"), - []byte("01235678"), - - // Make sure we have one longer than the hash length. - []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), - - // Make sure we have one longer than the segment size. - make([]byte, 3*SegmentSize), - make([]byte, 3*SegmentSize-1), - make([]byte, 3*SegmentSize+1), - make([]byte, 3*SegmentSize-32), - make([]byte, 3*SegmentSize+32), - make([]byte, 30*SegmentSize), - } - - for _, c := range cases { - for _, flip := range []bool{false, true} { - if len(c) == 0 && flip == true { - continue - } - - // Log the case. - t.Logf("case: len=%d flip=%v", len(c), flip) - - if err := runTest(c, func(enc *bytes.Buffer) { - if flip { - corrupted := rand.Intn(enc.Len()) - enc.Bytes()[corrupted]++ - } - }, 1); err != nil { - if !flip || err != ErrHashMismatch { - t.Errorf("error during read: got %v, expected nil", err) - } - continue - } else if flip { - t.Errorf("failed to detect ErrHashMismatch on corrupted data!") - continue - } - } - } -} - -const benchBytes = 10 * 1024 * 1024 // 10 MB. - -func BenchmarkWrite(b *testing.B) { - b.StopTimer() - x := make([]byte, benchBytes) - b.SetBytes(benchBytes) - b.StartTimer() - if err := runTest(x, func(enc *bytes.Buffer) { - b.StopTimer() - }, b.N); err != nil { - b.Errorf("benchmark failed: %v", err) - } -} - -func BenchmarkRead(b *testing.B) { - b.StopTimer() - x := make([]byte, benchBytes) - b.SetBytes(benchBytes) - if err := runTest(x, func(enc *bytes.Buffer) { - b.StartTimer() - }, b.N); err != nil { - b.Errorf("benchmark failed: %v", err) - } -} diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index 16abe1930..6be78dc9b 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -10,7 +10,6 @@ go_library( deps = [ "//pkg/binary", "//pkg/compressio", - "//pkg/hashio", ], ) @@ -19,5 +18,5 @@ go_test( size = "small", srcs = ["statefile_test.go"], embed = [":statefile"], - deps = ["//pkg/hashio"], + deps = ["//pkg/compressio"], ) diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go index 0b4eff8fa..9c86c1934 100644 --- a/pkg/state/statefile/statefile.go +++ b/pkg/state/statefile/statefile.go @@ -57,7 +57,6 @@ import ( "crypto/sha256" "gvisor.googlesource.com/gvisor/pkg/binary" "gvisor.googlesource.com/gvisor/pkg/compressio" - "gvisor.googlesource.com/gvisor/pkg/hashio" ) // keySize is the AES-256 key length. @@ -139,13 +138,11 @@ func NewWriter(w io.Writer, key []byte, metadata map[string]string) (io.WriteClo } } - w = hashio.NewWriter(w, h) - // Wrap in compression. We always use "best speed" mode here. When using // "best compression" mode, there is usually only a little gain in file // size reduction, which translate to even smaller gain in restore // latency reduction, while inccuring much more CPU usage at save time. - return compressio.NewWriter(w, compressionChunkSize, flate.BestSpeed) + return compressio.NewWriter(w, key, compressionChunkSize, flate.BestSpeed) } // MetadataUnsafe reads out the metadata from a state file without verifying any @@ -204,7 +201,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { return nil, err } if !hmac.Equal(cur, buf) { - return nil, hashio.ErrHashMismatch + return nil, compressio.ErrHashMismatch } } @@ -226,10 +223,8 @@ func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) { return nil, nil, err } - r = hashio.NewReader(r, h) - // Wrap in compression. - rc, err := compressio.NewReader(r) + rc, err := compressio.NewReader(r, key) if err != nil { return nil, nil, err } diff --git a/pkg/state/statefile/statefile_test.go b/pkg/state/statefile/statefile_test.go index 66d9581ed..fa3fb9f2c 100644 --- a/pkg/state/statefile/statefile_test.go +++ b/pkg/state/statefile/statefile_test.go @@ -20,9 +20,11 @@ import ( "encoding/base64" "io" "math/rand" + "runtime" "testing" + "time" - "gvisor.googlesource.com/gvisor/pkg/hashio" + "gvisor.googlesource.com/gvisor/pkg/compressio" ) func randomKey() ([]byte, error) { @@ -42,6 +44,8 @@ type testCase struct { } func TestStatefile(t *testing.T) { + rand.Seed(time.Now().Unix()) + cases := []testCase{ // Various data sizes. {"nil", nil, nil}, @@ -59,13 +63,9 @@ func TestStatefile(t *testing.T) { // Make sure we have one longer than the hash length. {"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil}, - // Make sure we have one longer than the segment size. - {"segments", make([]byte, 3*hashio.SegmentSize), nil}, - {"segments minus one", make([]byte, 3*hashio.SegmentSize-1), nil}, - {"segments plus one", make([]byte, 3*hashio.SegmentSize+1), nil}, - {"segments minus hash", make([]byte, 3*hashio.SegmentSize-32), nil}, - {"segments plus hash", make([]byte, 3*hashio.SegmentSize+32), nil}, - {"large", make([]byte, 30*hashio.SegmentSize), nil}, + // Make sure we have one longer than the chunk size. + {"chunks", make([]byte, 3*compressionChunkSize), nil}, + {"large", make([]byte, 30*compressionChunkSize), nil}, // Different metadata. {"one metadata", []byte("data"), map[string]string{"foo": "bar"}}, @@ -130,27 +130,31 @@ func TestStatefile(t *testing.T) { } // Change the data and verify that it fails. - b := append([]byte(nil), bufEncoded.Bytes()...) - b[rand.Intn(len(b))]++ - r, _, err = NewReader(bytes.NewReader(b), key) - if err == nil { - _, err = io.Copy(&bufDecoded, r) - } - if err == nil { - t.Error("got no error: expected error on data corruption") + if key != nil { + b := append([]byte(nil), bufEncoded.Bytes()...) + b[rand.Intn(len(b))]++ + bufDecoded.Reset() + r, _, err = NewReader(bytes.NewReader(b), key) + if err == nil { + _, err = io.Copy(&bufDecoded, r) + } + if err == nil { + t.Error("got no error: expected error on data corruption") + } } // Change the key and verify that it fails. - if key == nil { - key = integrityKey - } else { - key[rand.Intn(len(key))]++ + newKey := integrityKey + if len(key) > 0 { + newKey = append([]byte{}, key...) + newKey[rand.Intn(len(newKey))]++ } - r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), key) + bufDecoded.Reset() + r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), newKey) if err == nil { _, err = io.Copy(&bufDecoded, r) } - if err != hashio.ErrHashMismatch { + if err != compressio.ErrHashMismatch { t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err) } }) @@ -159,7 +163,7 @@ func TestStatefile(t *testing.T) { } } -const benchmarkDataSize = 10 * 1024 * 1024 +const benchmarkDataSize = 100 * 1024 * 1024 func benchmark(b *testing.B, size int, write bool, compressible bool) { b.StopTimer() @@ -249,14 +253,6 @@ func benchmark(b *testing.B, size int, write bool, compressible bool) { } } -func BenchmarkWrite1BCompressible(b *testing.B) { - benchmark(b, 1, true, true) -} - -func BenchmarkWrite1BNoncompressible(b *testing.B) { - benchmark(b, 1, true, false) -} - func BenchmarkWrite4KCompressible(b *testing.B) { benchmark(b, 4096, true, true) } @@ -273,14 +269,6 @@ func BenchmarkWrite1MNoncompressible(b *testing.B) { benchmark(b, 1024*1024, true, false) } -func BenchmarkRead1BCompressible(b *testing.B) { - benchmark(b, 1, false, true) -} - -func BenchmarkRead1BNoncompressible(b *testing.B) { - benchmark(b, 1, false, false) -} - func BenchmarkRead4KCompressible(b *testing.B) { benchmark(b, 4096, false, true) } @@ -296,3 +284,7 @@ func BenchmarkRead1MCompressible(b *testing.B) { func BenchmarkRead1MNoncompressible(b *testing.B) { benchmark(b, 1024*1024, false, false) } + +func init() { + runtime.GOMAXPROCS(runtime.NumCPU()) +} |