diff options
Diffstat (limited to 'pkg/compressio')
-rw-r--r-- | pkg/compressio/compressio.go | 223 | ||||
-rw-r--r-- | pkg/compressio/compressio_test.go | 145 |
2 files changed, 304 insertions, 64 deletions
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index ef8cbd2a5..591b37130 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -12,17 +12,48 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package compressio provides parallel compression and decompression. +// Package compressio provides parallel compression and decompression, as well +// as optional SHA-256 hashing. +// +// The stream format is defined as follows. +// +// /------------------------------------------------------\ +// | chunk size (4-bytes) | +// +------------------------------------------------------+ +// | (optional) hash (32-bytes) | +// +------------------------------------------------------+ +// | compressed data size (4-bytes) | +// +------------------------------------------------------+ +// | compressed data | +// +------------------------------------------------------+ +// | (optional) hash (32-bytes) | +// +------------------------------------------------------+ +// | compressed data size (4-bytes) | +// +------------------------------------------------------+ +// | ...... | +// \------------------------------------------------------/ +// +// where each subsequent hash is calculated from the following items in order +// +// compressed data +// compressed data size +// previous hash +// +// so the stream integrity cannot be compromised by switching and mixing +// compressed chunks. package compressio import ( "bytes" "compress/flate" "errors" + "hash" "io" "runtime" "sync" + "crypto/hmac" + "crypto/sha256" "gvisor.googlesource.com/gvisor/pkg/binary" ) @@ -51,12 +82,23 @@ type chunk struct { // This is not returned to the bufPool automatically, since it may // correspond to a inline slice (provided directly to Read or Write). uncompressed *bytes.Buffer + + // The current hash object. Only used in compress mode. + h hash.Hash + + // The hash from previous chunks. Only used in uncompress mode. + lastSum []byte + + // The expected hash after current chunk. Only used in uncompress mode. + sum []byte } // newChunk allocates a new chunk object (or pulls one from the pool). Buffers // will be allocated if nil is provided for compressed or uncompressed. -func newChunk(compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk { +func newChunk(lastSum []byte, sum []byte, compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk { c := chunkPool.Get().(*chunk) + c.lastSum = lastSum + c.sum = sum if compressed != nil { c.compressed = compressed } else { @@ -85,6 +127,7 @@ type result struct { // The goroutine will exit when input is closed, and the goroutine will close // output. type worker struct { + pool *pool input chan *chunk output chan result } @@ -93,17 +136,27 @@ type worker struct { func (w *worker) work(compress bool, level int) { defer close(w.output) + var h hash.Hash + for c := range w.input { + if h == nil && w.pool.key != nil { + h = w.pool.getHash() + } if compress { + mw := io.Writer(c.compressed) + if h != nil { + mw = io.MultiWriter(mw, h) + } + // Encode this slice. - fw, err := flate.NewWriter(c.compressed, level) + fw, err := flate.NewWriter(mw, level) if err != nil { w.output <- result{c, err} continue } // Encode the input. - if _, err := io.Copy(fw, c.uncompressed); err != nil { + if _, err := io.CopyN(fw, c.uncompressed, int64(c.uncompressed.Len())); err != nil { w.output <- result{c, err} continue } @@ -111,7 +164,28 @@ func (w *worker) work(compress bool, level int) { w.output <- result{c, err} continue } + + // Write the hash, if enabled. + if h != nil { + binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + c.h = h + h = nil + } } else { + // Check the hash of the compressed contents. + if h != nil { + h.Write(c.compressed.Bytes()) + binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum))) + + sum := h.Sum(nil) + h.Reset() + if !hmac.Equal(c.sum, sum) { + w.output <- result{c, ErrHashMismatch} + continue + } + } + // Decode this slice. fr := flate.NewReader(c.compressed) @@ -136,6 +210,16 @@ type pool struct { // stream and is shared across both the reader and writer. chunkSize uint32 + // key is the key used to create hash objects. + key []byte + + // hashMu protexts the hash list. + hashMu sync.Mutex + + // hashes is the hash object free list. Note that this cannot be + // globally shared across readers or writers, as it is key-specific. + hashes []hash.Hash + // mu protects below; it is generally the responsibility of users to // acquire this mutex before calling any methods on the pool. mu sync.Mutex @@ -149,15 +233,20 @@ type pool struct { // buf is the current active buffer; the exact semantics of this buffer // depending on whether this is a reader or a writer. buf *bytes.Buffer + + // lasSum records the hash of the last chunk processed. + lastSum []byte } // init initializes the worker pool. // // This should only be called once. -func (p *pool) init(compress bool, level int) { - p.workers = make([]worker, 1+runtime.GOMAXPROCS(0)) +func (p *pool) init(key []byte, workers int, compress bool, level int) { + p.key = key + p.workers = make([]worker, workers) for i := 0; i < len(p.workers); i++ { p.workers[i] = worker{ + pool: p, input: make(chan *chunk, 1), output: make(chan result, 1), } @@ -174,6 +263,30 @@ func (p *pool) stop() { p.workers = nil } +// getHash gets a hash object for the pool. It should only be called when the +// pool key is non-nil. +func (p *pool) getHash() hash.Hash { + p.hashMu.Lock() + defer p.hashMu.Unlock() + + if len(p.hashes) == 0 { + return hmac.New(sha256.New, p.key) + } + + h := p.hashes[len(p.hashes)-1] + p.hashes = p.hashes[:len(p.hashes)-1] + return h +} + +func (p *pool) putHash(h hash.Hash) { + h.Reset() + + p.hashMu.Lock() + defer p.hashMu.Unlock() + + p.hashes = append(p.hashes, h) +} + // handleResult calls the callback. func handleResult(r result, callback func(*chunk) error) error { defer func() { @@ -231,22 +344,46 @@ type reader struct { in io.Reader } -// NewReader returns a new compressed reader. -func NewReader(in io.Reader) (io.Reader, error) { +// NewReader returns a new compressed reader. If key is non-nil, the data stream +// is assumed to contain expected hash values, which will be compared against +// hash values computed from the compressed bytes. See package comments for +// details. +func NewReader(in io.Reader, key []byte) (io.Reader, error) { r := &reader{ in: in, } - r.init(false, 0) + + // Use double buffering for read. + r.init(key, 2*runtime.GOMAXPROCS(0), false, 0) + var err error - if r.chunkSize, err = binary.ReadUint32(r.in, binary.BigEndian); err != nil { + if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil { return nil, err } + + if r.key != nil { + h := r.getHash() + binary.WriteUint32(h, binary.BigEndian, r.chunkSize) + r.lastSum = h.Sum(nil) + r.putHash(h) + sum := make([]byte, len(r.lastSum)) + if _, err := io.ReadFull(r.in, sum); err != nil { + return nil, err + } + if !hmac.Equal(r.lastSum, sum) { + return nil, ErrHashMismatch + } + } + return r, nil } // errNewBuffer is returned when a new buffer is completed. var errNewBuffer = errors.New("buffer ready") +// ErrHashMismatch is returned if the hash does not match. +var ErrHashMismatch = errors.New("hash mismatch") + // Read implements io.Reader.Read. func (r *reader) Read(p []byte) (int, error) { r.mu.Lock() @@ -331,14 +468,25 @@ func (r *reader) Read(p []byte) (int, error) { // Read this chunk and schedule decompression. compressed := bufPool.Get().(*bytes.Buffer) - if _, err := io.Copy(compressed, &io.LimitedReader{ - R: r.in, - N: int64(l), - }); err != nil { + if _, err := io.CopyN(compressed, r.in, int64(l)); err != nil { // Some other error occurred; see above. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } return done, err } + var sum []byte + if r.key != nil { + sum = make([]byte, len(r.lastSum)) + if _, err := io.ReadFull(r.in, sum); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return done, err + } + } + // Are we doing inline decoding? // // Note that we need to check the length here against @@ -349,11 +497,12 @@ func (r *reader) Read(p []byte) (int, error) { var c *chunk start := done + ((pendingPre + pendingInline) * int(r.chunkSize)) if len(p) >= start+int(r.chunkSize) && len(p) >= start+bytes.MinRead { - c = newChunk(compressed, bytes.NewBuffer(p[start:start])) + c = newChunk(r.lastSum, sum, compressed, bytes.NewBuffer(p[start:start])) pendingInline++ } else { - c = newChunk(compressed, nil) + c = newChunk(r.lastSum, sum, compressed, nil) } + r.lastSum = sum if err := r.schedule(c, callback); err == errNewBuffer { // A new buffer was completed while we were reading. // That's great, but we need to force schedule the @@ -403,12 +552,14 @@ type writer struct { closed bool } -// NewWriter returns a new compressed writer. +// NewWriter returns a new compressed writer. If key is non-nil, hash values are +// generated and written out for compressed bytes. See package comments for +// details. // // The recommended chunkSize is on the order of 1M. Extra memory may be // buffered (in the form of read-ahead, or buffered writes), and is limited to // O(chunkSize * [1+GOMAXPROCS]). -func NewWriter(out io.Writer, chunkSize uint32, level int) (io.WriteCloser, error) { +func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.WriteCloser, error) { w := &writer{ pool: pool{ chunkSize: chunkSize, @@ -416,10 +567,22 @@ func NewWriter(out io.Writer, chunkSize uint32, level int) (io.WriteCloser, erro }, out: out, } - w.init(true, level) + w.init(key, 1+runtime.GOMAXPROCS(0), true, level) + if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil { return nil, err } + + if w.key != nil { + h := w.getHash() + binary.WriteUint32(h, binary.BigEndian, chunkSize) + w.lastSum = h.Sum(nil) + w.putHash(h) + if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil { + return nil, err + } + } + return w, nil } @@ -433,8 +596,22 @@ func (w *writer) flush(c *chunk) error { } // Write out to the stream. - _, err := io.Copy(w.out, c.compressed) - return err + if _, err := io.CopyN(w.out, c.compressed, int64(c.compressed.Len())); err != nil { + return err + } + + if w.key != nil { + io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum))) + sum := c.h.Sum(nil) + w.putHash(c.h) + c.h = nil + if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil { + return err + } + w.lastSum = sum + } + + return nil } // Write implements io.Writer.Write. @@ -480,7 +657,7 @@ func (w *writer) Write(p []byte) (int, error) { // immediately following the inline case above. left := int(w.chunkSize) - w.buf.Len() if left == 0 { - if err := w.schedule(newChunk(nil, w.buf), callback); err != nil { + if err := w.schedule(newChunk(nil, nil, nil, w.buf), callback); err != nil { return done, err } // Reset the buffer, since this has now been scheduled @@ -538,7 +715,7 @@ func (w *writer) Close() error { // Schedule any remaining partial buffer; we pass w.flush directly here // because the final buffer is guaranteed to not be an inline buffer. if w.buf.Len() > 0 { - if err := w.schedule(newChunk(nil, w.buf), w.flush); err != nil { + if err := w.schedule(newChunk(nil, nil, nil, w.buf), w.flush); err != nil { return err } } diff --git a/pkg/compressio/compressio_test.go b/pkg/compressio/compressio_test.go index d7911419d..7cb5f8dc4 100644 --- a/pkg/compressio/compressio_test.go +++ b/pkg/compressio/compressio_test.go @@ -59,6 +59,7 @@ type testOpts struct { PostDecompress func() CompressIters int DecompressIters int + CorruptData bool } func doTest(t harness, opts testOpts) { @@ -104,15 +105,22 @@ func doTest(t harness, opts testOpts) { if opts.DecompressIters <= 0 { opts.DecompressIters = 1 } + if opts.CorruptData { + b := compressed.Bytes() + b[rand.Intn(len(b))]++ + } for i := 0; i < opts.DecompressIters; i++ { decompressed.Reset() r, err := opts.NewReader(bytes.NewBuffer(compressed.Bytes())) if err != nil { + if opts.CorruptData { + continue + } t.Errorf("%s: NewReader got err %v, expected nil", opts.Name, err) return } - if _, err := io.Copy(&decompressed, r); err != nil { - t.Errorf("%s: decompress got err %v, expected nil", opts.Name, err) + if _, err := io.Copy(&decompressed, r); (err != nil) != opts.CorruptData { + t.Errorf("%s: decompress got err %v unexpectly", opts.Name, err) return } } @@ -121,6 +129,10 @@ func doTest(t harness, opts testOpts) { } decompressionTime := time.Since(decompressionStartTime) + if opts.CorruptData { + return + } + // Verify. if decompressed.Len() != len(opts.Data) { t.Errorf("%s: got %d bytes, expected %d", opts.Name, decompressed.Len(), len(opts.Data)) @@ -136,7 +148,11 @@ func doTest(t harness, opts testOpts) { opts.Name, compressionTime, compressionRatio, decompressionTime) } +var hashKey = []byte("01234567890123456789012345678901") + func TestCompress(t *testing.T) { + rand.Seed(time.Now().Unix()) + var ( data = initTest(t, 10*1024*1024) data0 = data[:0] @@ -153,17 +169,27 @@ func TestCompress(t *testing.T) { continue } - // Do the compress test. - doTest(t, testOpts{ - Name: fmt.Sprintf("len(data)=%d, blockSize=%d", len(data), blockSize), - Data: data, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return NewWriter(b, blockSize, flate.BestCompression) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return NewReader(b) - }, - }) + for _, key := range [][]byte{nil, hashKey} { + for _, corruptData := range []bool{false, true} { + if key == nil && corruptData { + // No need to test corrupt data + // case when not doing hashing. + continue + } + // Do the compress test. + doTest(t, testOpts{ + Name: fmt.Sprintf("len(data)=%d, blockSize=%d, key=%s, corruptData=%v", len(data), blockSize, string(key), corruptData), + Data: data, + NewWriter: func(b *bytes.Buffer) (io.Writer, error) { + return NewWriter(b, key, blockSize, flate.BestSpeed) + }, + NewReader: func(b *bytes.Buffer) (io.Reader, error) { + return NewReader(b, key) + }, + CorruptData: corruptData, + }) + } + } } // Do the vanilla test. @@ -171,7 +197,7 @@ func TestCompress(t *testing.T) { Name: fmt.Sprintf("len(data)=%d, vanilla flate", len(data)), Data: data, NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return flate.NewWriter(b, flate.BestCompression) + return flate.NewWriter(b, flate.BestSpeed) }, NewReader: func(b *bytes.Buffer) (io.Reader, error) { return flate.NewReader(b), nil @@ -181,47 +207,84 @@ func TestCompress(t *testing.T) { } const ( - // benchBlockSize is the blockSize for benchmarks. - benchBlockSize = 32 * 1024 - - // benchDataSize is the amount of data for benchmarks. - benchDataSize = 10 * 1024 * 1024 + benchDataSize = 600 * 1024 * 1024 ) -func BenchmarkCompress(b *testing.B) { +func benchmark(b *testing.B, compress bool, hash bool, blockSize uint32) { b.StopTimer() b.SetBytes(benchDataSize) data := initTest(b, benchDataSize) + compIters := b.N + decompIters := b.N + if compress { + decompIters = 0 + } else { + compIters = 0 + } + key := hashKey + if !hash { + key = nil + } doTest(b, testOpts{ - Name: fmt.Sprintf("len(data)=%d, blockSize=%d", len(data), benchBlockSize), + Name: fmt.Sprintf("compress=%t, hash=%t, len(data)=%d, blockSize=%d", compress, hash, len(data), blockSize), Data: data, PreCompress: b.StartTimer, PostCompress: b.StopTimer, NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return NewWriter(b, benchBlockSize, flate.BestCompression) + return NewWriter(b, key, blockSize, flate.BestSpeed) }, NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return NewReader(b) + return NewReader(b, key) }, - CompressIters: b.N, + CompressIters: compIters, + DecompressIters: decompIters, }) } -func BenchmarkDecompress(b *testing.B) { - b.StopTimer() - b.SetBytes(benchDataSize) - data := initTest(b, benchDataSize) - doTest(b, testOpts{ - Name: fmt.Sprintf("len(data)=%d, blockSize=%d", len(data), benchBlockSize), - Data: data, - PreDecompress: b.StartTimer, - PostDecompress: b.StopTimer, - NewWriter: func(b *bytes.Buffer) (io.Writer, error) { - return NewWriter(b, benchBlockSize, flate.BestCompression) - }, - NewReader: func(b *bytes.Buffer) (io.Reader, error) { - return NewReader(b) - }, - DecompressIters: b.N, - }) +func BenchmarkCompressNoHash64K(b *testing.B) { + benchmark(b, true, false, 64*1024) +} + +func BenchmarkCompressHash64K(b *testing.B) { + benchmark(b, true, true, 64*1024) +} + +func BenchmarkDecompressNoHash64K(b *testing.B) { + benchmark(b, false, false, 64*1024) +} + +func BenchmarkDecompressHash64K(b *testing.B) { + benchmark(b, false, true, 64*1024) +} + +func BenchmarkCompressNoHash1M(b *testing.B) { + benchmark(b, true, false, 1024*1024) +} + +func BenchmarkCompressHash1M(b *testing.B) { + benchmark(b, true, true, 1024*1024) +} + +func BenchmarkDecompressNoHash1M(b *testing.B) { + benchmark(b, false, false, 1024*1024) +} + +func BenchmarkDecompressHash1M(b *testing.B) { + benchmark(b, false, true, 1024*1024) +} + +func BenchmarkCompressNoHash16M(b *testing.B) { + benchmark(b, true, false, 16*1024*1024) +} + +func BenchmarkCompressHash16M(b *testing.B) { + benchmark(b, true, true, 16*1024*1024) +} + +func BenchmarkDecompressNoHash16M(b *testing.B) { + benchmark(b, false, false, 16*1024*1024) +} + +func BenchmarkDecompressHash16M(b *testing.B) { + benchmark(b, false, true, 16*1024*1024) } |