diff options
author | Zhaozhong Ni <nzz@google.com> | 2018-08-24 14:52:23 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-08-24 14:53:31 -0700 |
commit | a6b00502b04ced2f12cfcf35c6f276cff349737b (patch) | |
tree | d443ea0679091b193bcc5568f0aa5aff3ba1a0f3 /pkg/compressio/compressio.go | |
parent | 02dfceab6d4c4a2a3342ef69be0265b7ab03e5d7 (diff) |
compressio: support optional hashing and eliminate hashio.
Compared to previous compressio / hashio nesting, there is up to 100% speedup.
PiperOrigin-RevId: 210161269
Change-Id: I481aa9fe980bb817fe465fe34d32ea33fc8abf1c
Diffstat (limited to 'pkg/compressio/compressio.go')
-rw-r--r-- | pkg/compressio/compressio.go | 223 |
1 files changed, 200 insertions, 23 deletions
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index ef8cbd2a5..591b37130 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -12,17 +12,48 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package compressio provides parallel compression and decompression. +// Package compressio provides parallel compression and decompression, as well +// as optional SHA-256 hashing. +// +// The stream format is defined as follows. +// +// /------------------------------------------------------\ +// | chunk size (4-bytes) | +// +------------------------------------------------------+ +// | (optional) hash (32-bytes) | +// +------------------------------------------------------+ +// | compressed data size (4-bytes) | +// +------------------------------------------------------+ +// | compressed data | +// +------------------------------------------------------+ +// | (optional) hash (32-bytes) | +// +------------------------------------------------------+ +// | compressed data size (4-bytes) | +// +------------------------------------------------------+ +// | ...... | +// \------------------------------------------------------/ +// +// where each subsequent hash is calculated from the following items in order +// +// compressed data +// compressed data size +// previous hash +// +// so the stream integrity cannot be compromised by switching and mixing +// compressed chunks. package compressio import ( "bytes" "compress/flate" "errors" + "hash" "io" "runtime" "sync" + "crypto/hmac" + "crypto/sha256" "gvisor.googlesource.com/gvisor/pkg/binary" ) @@ -51,12 +82,23 @@ type chunk struct { // This is not returned to the bufPool automatically, since it may // correspond to a inline slice (provided directly to Read or Write). uncompressed *bytes.Buffer + + // The current hash object. Only used in compress mode. + h hash.Hash + + // The hash from previous chunks. Only used in uncompress mode. + lastSum []byte + + // The expected hash after current chunk. Only used in uncompress mode. + sum []byte } // newChunk allocates a new chunk object (or pulls one from the pool). Buffers // will be allocated if nil is provided for compressed or uncompressed. -func newChunk(compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk { +func newChunk(lastSum []byte, sum []byte, compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk { c := chunkPool.Get().(*chunk) + c.lastSum = lastSum + c.sum = sum if compressed != nil { c.compressed = compressed } else { @@ -85,6 +127,7 @@ type result struct { // The goroutine will exit when input is closed, and the goroutine will close // output. type worker struct { + pool *pool input chan *chunk output chan result } @@ -93,17 +136,27 @@ type worker struct { func (w *worker) work(compress bool, level int) { defer close(w.output) + var h hash.Hash + for c := range w.input { + if h == nil && w.pool.key != nil { + h = w.pool.getHash() + } if compress { + mw := io.Writer(c.compressed) + if h != nil { + mw = io.MultiWriter(mw, h) + } + // Encode this slice. - fw, err := flate.NewWriter(c.compressed, level) + fw, err := flate.NewWriter(mw, level) if err != nil { w.output <- result{c, err} continue } // Encode the input. - if _, err := io.Copy(fw, c.uncompressed); err != nil { + if _, err := io.CopyN(fw, c.uncompressed, int64(c.uncompressed.Len())); err != nil { w.output <- result{c, err} continue } @@ -111,7 +164,28 @@ func (w *worker) work(compress bool, level int) { w.output <- result{c, err} continue } + + // Write the hash, if enabled. + if h != nil { + binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + c.h = h + h = nil + } } else { + // Check the hash of the compressed contents. + if h != nil { + h.Write(c.compressed.Bytes()) + binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum))) + + sum := h.Sum(nil) + h.Reset() + if !hmac.Equal(c.sum, sum) { + w.output <- result{c, ErrHashMismatch} + continue + } + } + // Decode this slice. fr := flate.NewReader(c.compressed) @@ -136,6 +210,16 @@ type pool struct { // stream and is shared across both the reader and writer. chunkSize uint32 + // key is the key used to create hash objects. + key []byte + + // hashMu protexts the hash list. + hashMu sync.Mutex + + // hashes is the hash object free list. Note that this cannot be + // globally shared across readers or writers, as it is key-specific. + hashes []hash.Hash + // mu protects below; it is generally the responsibility of users to // acquire this mutex before calling any methods on the pool. mu sync.Mutex @@ -149,15 +233,20 @@ type pool struct { // buf is the current active buffer; the exact semantics of this buffer // depending on whether this is a reader or a writer. buf *bytes.Buffer + + // lasSum records the hash of the last chunk processed. + lastSum []byte } // init initializes the worker pool. // // This should only be called once. -func (p *pool) init(compress bool, level int) { - p.workers = make([]worker, 1+runtime.GOMAXPROCS(0)) +func (p *pool) init(key []byte, workers int, compress bool, level int) { + p.key = key + p.workers = make([]worker, workers) for i := 0; i < len(p.workers); i++ { p.workers[i] = worker{ + pool: p, input: make(chan *chunk, 1), output: make(chan result, 1), } @@ -174,6 +263,30 @@ func (p *pool) stop() { p.workers = nil } +// getHash gets a hash object for the pool. It should only be called when the +// pool key is non-nil. +func (p *pool) getHash() hash.Hash { + p.hashMu.Lock() + defer p.hashMu.Unlock() + + if len(p.hashes) == 0 { + return hmac.New(sha256.New, p.key) + } + + h := p.hashes[len(p.hashes)-1] + p.hashes = p.hashes[:len(p.hashes)-1] + return h +} + +func (p *pool) putHash(h hash.Hash) { + h.Reset() + + p.hashMu.Lock() + defer p.hashMu.Unlock() + + p.hashes = append(p.hashes, h) +} + // handleResult calls the callback. func handleResult(r result, callback func(*chunk) error) error { defer func() { @@ -231,22 +344,46 @@ type reader struct { in io.Reader } -// NewReader returns a new compressed reader. -func NewReader(in io.Reader) (io.Reader, error) { +// NewReader returns a new compressed reader. If key is non-nil, the data stream +// is assumed to contain expected hash values, which will be compared against +// hash values computed from the compressed bytes. See package comments for +// details. +func NewReader(in io.Reader, key []byte) (io.Reader, error) { r := &reader{ in: in, } - r.init(false, 0) + + // Use double buffering for read. + r.init(key, 2*runtime.GOMAXPROCS(0), false, 0) + var err error - if r.chunkSize, err = binary.ReadUint32(r.in, binary.BigEndian); err != nil { + if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil { return nil, err } + + if r.key != nil { + h := r.getHash() + binary.WriteUint32(h, binary.BigEndian, r.chunkSize) + r.lastSum = h.Sum(nil) + r.putHash(h) + sum := make([]byte, len(r.lastSum)) + if _, err := io.ReadFull(r.in, sum); err != nil { + return nil, err + } + if !hmac.Equal(r.lastSum, sum) { + return nil, ErrHashMismatch + } + } + return r, nil } // errNewBuffer is returned when a new buffer is completed. var errNewBuffer = errors.New("buffer ready") +// ErrHashMismatch is returned if the hash does not match. +var ErrHashMismatch = errors.New("hash mismatch") + // Read implements io.Reader.Read. func (r *reader) Read(p []byte) (int, error) { r.mu.Lock() @@ -331,14 +468,25 @@ func (r *reader) Read(p []byte) (int, error) { // Read this chunk and schedule decompression. compressed := bufPool.Get().(*bytes.Buffer) - if _, err := io.Copy(compressed, &io.LimitedReader{ - R: r.in, - N: int64(l), - }); err != nil { + if _, err := io.CopyN(compressed, r.in, int64(l)); err != nil { // Some other error occurred; see above. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } return done, err } + var sum []byte + if r.key != nil { + sum = make([]byte, len(r.lastSum)) + if _, err := io.ReadFull(r.in, sum); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return done, err + } + } + // Are we doing inline decoding? // // Note that we need to check the length here against @@ -349,11 +497,12 @@ func (r *reader) Read(p []byte) (int, error) { var c *chunk start := done + ((pendingPre + pendingInline) * int(r.chunkSize)) if len(p) >= start+int(r.chunkSize) && len(p) >= start+bytes.MinRead { - c = newChunk(compressed, bytes.NewBuffer(p[start:start])) + c = newChunk(r.lastSum, sum, compressed, bytes.NewBuffer(p[start:start])) pendingInline++ } else { - c = newChunk(compressed, nil) + c = newChunk(r.lastSum, sum, compressed, nil) } + r.lastSum = sum if err := r.schedule(c, callback); err == errNewBuffer { // A new buffer was completed while we were reading. // That's great, but we need to force schedule the @@ -403,12 +552,14 @@ type writer struct { closed bool } -// NewWriter returns a new compressed writer. +// NewWriter returns a new compressed writer. If key is non-nil, hash values are +// generated and written out for compressed bytes. See package comments for +// details. // // The recommended chunkSize is on the order of 1M. Extra memory may be // buffered (in the form of read-ahead, or buffered writes), and is limited to // O(chunkSize * [1+GOMAXPROCS]). -func NewWriter(out io.Writer, chunkSize uint32, level int) (io.WriteCloser, error) { +func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.WriteCloser, error) { w := &writer{ pool: pool{ chunkSize: chunkSize, @@ -416,10 +567,22 @@ func NewWriter(out io.Writer, chunkSize uint32, level int) (io.WriteCloser, erro }, out: out, } - w.init(true, level) + w.init(key, 1+runtime.GOMAXPROCS(0), true, level) + if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil { return nil, err } + + if w.key != nil { + h := w.getHash() + binary.WriteUint32(h, binary.BigEndian, chunkSize) + w.lastSum = h.Sum(nil) + w.putHash(h) + if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil { + return nil, err + } + } + return w, nil } @@ -433,8 +596,22 @@ func (w *writer) flush(c *chunk) error { } // Write out to the stream. - _, err := io.Copy(w.out, c.compressed) - return err + if _, err := io.CopyN(w.out, c.compressed, int64(c.compressed.Len())); err != nil { + return err + } + + if w.key != nil { + io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum))) + sum := c.h.Sum(nil) + w.putHash(c.h) + c.h = nil + if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil { + return err + } + w.lastSum = sum + } + + return nil } // Write implements io.Writer.Write. @@ -480,7 +657,7 @@ func (w *writer) Write(p []byte) (int, error) { // immediately following the inline case above. left := int(w.chunkSize) - w.buf.Len() if left == 0 { - if err := w.schedule(newChunk(nil, w.buf), callback); err != nil { + if err := w.schedule(newChunk(nil, nil, nil, w.buf), callback); err != nil { return done, err } // Reset the buffer, since this has now been scheduled @@ -538,7 +715,7 @@ func (w *writer) Close() error { // Schedule any remaining partial buffer; we pass w.flush directly here // because the final buffer is guaranteed to not be an inline buffer. if w.buf.Len() > 0 { - if err := w.schedule(newChunk(nil, w.buf), w.flush); err != nil { + if err := w.schedule(newChunk(nil, nil, nil, w.buf), w.flush); err != nil { return err } } |