summaryrefslogtreecommitdiffhomepage
path: root/pkg/compressio/compressio.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/compressio/compressio.go')
-rw-r--r--pkg/compressio/compressio.go773
1 files changed, 773 insertions, 0 deletions
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go
new file mode 100644
index 000000000..b094c5662
--- /dev/null
+++ b/pkg/compressio/compressio.go
@@ -0,0 +1,773 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package compressio provides parallel compression and decompression, as well
+// as optional SHA-256 hashing.
+//
+// The stream format is defined as follows.
+//
+// /------------------------------------------------------\
+// | chunk size (4-bytes) |
+// +------------------------------------------------------+
+// | (optional) hash (32-bytes) |
+// +------------------------------------------------------+
+// | compressed data size (4-bytes) |
+// +------------------------------------------------------+
+// | compressed data |
+// +------------------------------------------------------+
+// | (optional) hash (32-bytes) |
+// +------------------------------------------------------+
+// | compressed data size (4-bytes) |
+// +------------------------------------------------------+
+// | ...... |
+// \------------------------------------------------------/
+//
+// where each subsequent hash is calculated from the following items in order
+//
+// compressed data
+// compressed data size
+// previous hash
+//
+// so the stream integrity cannot be compromised by switching and mixing
+// compressed chunks.
+package compressio
+
+import (
+ "bytes"
+ "compress/flate"
+ "crypto/hmac"
+ "crypto/sha256"
+ "errors"
+ "hash"
+ "io"
+ "runtime"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+var bufPool = sync.Pool{
+ New: func() interface{} {
+ return bytes.NewBuffer(nil)
+ },
+}
+
+var chunkPool = sync.Pool{
+ New: func() interface{} {
+ return new(chunk)
+ },
+}
+
+// chunk is a unit of work.
+type chunk struct {
+ // compressed is compressed data.
+ //
+ // This will always be returned to the bufPool directly when work has
+ // finished (in schedule) and therefore must be allocated.
+ compressed *bytes.Buffer
+
+ // uncompressed is the uncompressed data.
+ //
+ // This is not returned to the bufPool automatically, since it may
+ // correspond to a inline slice (provided directly to Read or Write).
+ uncompressed *bytes.Buffer
+
+ // The current hash object. Only used in compress mode.
+ h hash.Hash
+
+ // The hash from previous chunks. Only used in uncompress mode.
+ lastSum []byte
+
+ // The expected hash after current chunk. Only used in uncompress mode.
+ sum []byte
+}
+
+// newChunk allocates a new chunk object (or pulls one from the pool). Buffers
+// will be allocated if nil is provided for compressed or uncompressed.
+func newChunk(lastSum []byte, sum []byte, compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk {
+ c := chunkPool.Get().(*chunk)
+ c.lastSum = lastSum
+ c.sum = sum
+ if compressed != nil {
+ c.compressed = compressed
+ } else {
+ c.compressed = bufPool.Get().(*bytes.Buffer)
+ }
+ if uncompressed != nil {
+ c.uncompressed = uncompressed
+ } else {
+ c.uncompressed = bufPool.Get().(*bytes.Buffer)
+ }
+ return c
+}
+
+// result is the result of some work; it includes the original chunk.
+type result struct {
+ *chunk
+ err error
+}
+
+// worker is a compression/decompression worker.
+//
+// The associated worker goroutine reads in uncompressed buffers from input and
+// writes compressed buffers to its output. Alternatively, the worker reads
+// compressed buffers from input and writes uncompressed buffers to its output.
+//
+// The goroutine will exit when input is closed, and the goroutine will close
+// output.
+type worker struct {
+ hashPool *hashPool
+ input chan *chunk
+ output chan result
+}
+
+// work is the main work routine; see worker.
+func (w *worker) work(compress bool, level int) {
+ defer close(w.output)
+
+ var h hash.Hash
+
+ for c := range w.input {
+ if h == nil && w.hashPool != nil {
+ h = w.hashPool.getHash()
+ }
+ if compress {
+ mw := io.Writer(c.compressed)
+ if h != nil {
+ mw = io.MultiWriter(mw, h)
+ }
+
+ // Encode this slice.
+ fw, err := flate.NewWriter(mw, level)
+ if err != nil {
+ w.output <- result{c, err}
+ continue
+ }
+
+ // Encode the input.
+ if _, err := io.CopyN(fw, c.uncompressed, int64(c.uncompressed.Len())); err != nil {
+ w.output <- result{c, err}
+ continue
+ }
+ if err := fw.Close(); err != nil {
+ w.output <- result{c, err}
+ continue
+ }
+
+ // Write the hash, if enabled.
+ if h != nil {
+ binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len()))
+ c.h = h
+ h = nil
+ }
+ } else {
+ // Check the hash of the compressed contents.
+ if h != nil {
+ h.Write(c.compressed.Bytes())
+ binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len()))
+ io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum)))
+
+ sum := h.Sum(nil)
+ h.Reset()
+ if !hmac.Equal(c.sum, sum) {
+ w.output <- result{c, ErrHashMismatch}
+ continue
+ }
+ }
+
+ // Decode this slice.
+ fr := flate.NewReader(c.compressed)
+
+ // Decode the input.
+ if _, err := io.Copy(c.uncompressed, fr); err != nil {
+ w.output <- result{c, err}
+ continue
+ }
+ }
+
+ // Send the output.
+ w.output <- result{c, nil}
+ }
+}
+
+type hashPool struct {
+ // mu protexts the hash list.
+ mu sync.Mutex
+
+ // key is the key used to create hash objects.
+ key []byte
+
+ // hashes is the hash object free list. Note that this cannot be
+ // globally shared across readers or writers, as it is key-specific.
+ hashes []hash.Hash
+}
+
+// getHash gets a hash object for the pool. It should only be called when the
+// pool key is non-nil.
+func (p *hashPool) getHash() hash.Hash {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if len(p.hashes) == 0 {
+ return hmac.New(sha256.New, p.key)
+ }
+
+ h := p.hashes[len(p.hashes)-1]
+ p.hashes = p.hashes[:len(p.hashes)-1]
+ return h
+}
+
+func (p *hashPool) putHash(h hash.Hash) {
+ h.Reset()
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.hashes = append(p.hashes, h)
+}
+
+// pool is common functionality for reader/writers.
+type pool struct {
+ // workers are the compression/decompression workers.
+ workers []worker
+
+ // chunkSize is the chunk size. This is the first four bytes in the
+ // stream and is shared across both the reader and writer.
+ chunkSize uint32
+
+ // mu protects below; it is generally the responsibility of users to
+ // acquire this mutex before calling any methods on the pool.
+ mu sync.Mutex
+
+ // nextInput is the next worker for input (scheduling).
+ nextInput int
+
+ // nextOutput is the next worker for output (result).
+ nextOutput int
+
+ // buf is the current active buffer; the exact semantics of this buffer
+ // depending on whether this is a reader or a writer.
+ buf *bytes.Buffer
+
+ // lasSum records the hash of the last chunk processed.
+ lastSum []byte
+
+ // hashPool is the hash object pool. It cannot be embedded into pool
+ // itself as worker refers to it and that would stop pool from being
+ // GCed.
+ hashPool *hashPool
+}
+
+// init initializes the worker pool.
+//
+// This should only be called once.
+func (p *pool) init(key []byte, workers int, compress bool, level int) {
+ if key != nil {
+ p.hashPool = &hashPool{key: key}
+ }
+ p.workers = make([]worker, workers)
+ for i := 0; i < len(p.workers); i++ {
+ p.workers[i] = worker{
+ hashPool: p.hashPool,
+ input: make(chan *chunk, 1),
+ output: make(chan result, 1),
+ }
+ go p.workers[i].work(compress, level) // S/R-SAFE: In save path only.
+ }
+ runtime.SetFinalizer(p, (*pool).stop)
+}
+
+// stop stops all workers.
+func (p *pool) stop() {
+ for i := 0; i < len(p.workers); i++ {
+ close(p.workers[i].input)
+ }
+ p.workers = nil
+ p.hashPool = nil
+}
+
+// handleResult calls the callback.
+func handleResult(r result, callback func(*chunk) error) error {
+ defer func() {
+ r.chunk.compressed.Reset()
+ bufPool.Put(r.chunk.compressed)
+ chunkPool.Put(r.chunk)
+ }()
+ if r.err != nil {
+ return r.err
+ }
+ return callback(r.chunk)
+}
+
+// schedule schedules the given buffers.
+//
+// If c is non-nil, then it will return as soon as the chunk is scheduled. If c
+// is nil, then it will return only when no more work is left to do.
+//
+// If no callback function is provided, then the output channel will be
+// ignored. You must be sure that the input is schedulable in this case.
+func (p *pool) schedule(c *chunk, callback func(*chunk) error) error {
+ for {
+ var (
+ inputChan chan *chunk
+ outputChan chan result
+ )
+ if c != nil && len(p.workers) != 0 {
+ inputChan = p.workers[(p.nextInput+1)%len(p.workers)].input
+ }
+ if callback != nil && p.nextOutput != p.nextInput && len(p.workers) != 0 {
+ outputChan = p.workers[(p.nextOutput+1)%len(p.workers)].output
+ }
+ if inputChan == nil && outputChan == nil {
+ return nil
+ }
+
+ select {
+ case inputChan <- c:
+ p.nextInput++
+ return nil
+ case r := <-outputChan:
+ p.nextOutput++
+ if err := handleResult(r, callback); err != nil {
+ return err
+ }
+ }
+ }
+}
+
+// Reader is a compressed reader.
+type Reader struct {
+ pool
+
+ // in is the source.
+ in io.Reader
+}
+
+var _ io.Reader = (*Reader)(nil)
+
+// NewReader returns a new compressed reader. If key is non-nil, the data stream
+// is assumed to contain expected hash values, which will be compared against
+// hash values computed from the compressed bytes. See package comments for
+// details.
+func NewReader(in io.Reader, key []byte) (*Reader, error) {
+ r := &Reader{
+ in: in,
+ }
+
+ // Use double buffering for read.
+ r.init(key, 2*runtime.GOMAXPROCS(0), false, 0)
+
+ var err error
+ if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil {
+ return nil, err
+ }
+
+ if r.hashPool != nil {
+ h := r.hashPool.getHash()
+ binary.WriteUint32(h, binary.BigEndian, r.chunkSize)
+ r.lastSum = h.Sum(nil)
+ r.hashPool.putHash(h)
+ sum := make([]byte, len(r.lastSum))
+ if _, err := io.ReadFull(r.in, sum); err != nil {
+ return nil, err
+ }
+ if !hmac.Equal(r.lastSum, sum) {
+ return nil, ErrHashMismatch
+ }
+ }
+
+ return r, nil
+}
+
+// errNewBuffer is returned when a new buffer is completed.
+var errNewBuffer = errors.New("buffer ready")
+
+// ErrHashMismatch is returned if the hash does not match.
+var ErrHashMismatch = errors.New("hash mismatch")
+
+// ReadByte implements wire.Reader.ReadByte.
+func (r *Reader) ReadByte() (byte, error) {
+ var p [1]byte
+ n, err := r.Read(p[:])
+ if n != 1 {
+ return p[0], err
+ }
+ // Suppress EOF.
+ return p[0], nil
+}
+
+// Read implements io.Reader.Read.
+func (r *Reader) Read(p []byte) (int, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ // Total bytes completed; this is declared up front because it must be
+ // adjustable by the callback below.
+ done := 0
+
+ // Total bytes pending in the asynchronous workers for buffers. This is
+ // used to process the proper regions of the input as inline buffers.
+ var (
+ pendingPre = r.nextInput - r.nextOutput
+ pendingInline = 0
+ )
+
+ // Define our callback for completed work.
+ callback := func(c *chunk) error {
+ // Check for an inline buffer.
+ if pendingPre == 0 && pendingInline > 0 {
+ pendingInline--
+ done += c.uncompressed.Len()
+ return nil
+ }
+
+ // Copy the resulting buffer to our intermediate one, and
+ // return errNewBuffer to ensure that we aren't called a second
+ // time. This error code is handled specially below.
+ //
+ // c.buf will be freed and return to the pool when it is done.
+ if pendingPre > 0 {
+ pendingPre--
+ }
+ r.buf = c.uncompressed
+ return errNewBuffer
+ }
+
+ for done < len(p) {
+ // Do we have buffered data available?
+ if r.buf != nil {
+ n, err := r.buf.Read(p[done:])
+ done += n
+ if err == io.EOF {
+ // This is the uncompressed buffer, it can be
+ // returned to the pool at this point.
+ r.buf.Reset()
+ bufPool.Put(r.buf)
+ r.buf = nil
+ } else if err != nil {
+ // Should never happen.
+ defer r.stop()
+ return done, err
+ }
+ continue
+ }
+
+ // Read the length of the next chunk and reset the
+ // reader. The length is used to limit the reader.
+ //
+ // See writer.flush.
+ l, err := binary.ReadUint32(r.in, binary.BigEndian)
+ if err != nil {
+ // This is generally okay as long as there
+ // are still buffers outstanding. We actually
+ // just wait for completion of those buffers here
+ // and continue our loop.
+ if err := r.schedule(nil, callback); err == nil {
+ // We've actually finished all buffers; this is
+ // the normal EOF exit path.
+ defer r.stop()
+ return done, io.EOF
+ } else if err == errNewBuffer {
+ // A new buffer is now available.
+ continue
+ } else {
+ // Some other error occurred; we cannot
+ // process any further.
+ defer r.stop()
+ return done, err
+ }
+ }
+
+ // Read this chunk and schedule decompression.
+ compressed := bufPool.Get().(*bytes.Buffer)
+ if _, err := io.CopyN(compressed, r.in, int64(l)); err != nil {
+ // Some other error occurred; see above.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return done, err
+ }
+
+ var sum []byte
+ if r.hashPool != nil {
+ sum = make([]byte, len(r.lastSum))
+ if _, err := io.ReadFull(r.in, sum); err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return done, err
+ }
+ }
+
+ // Are we doing inline decoding?
+ //
+ // Note that we need to check the length here against
+ // bytes.MinRead, since the bytes library will choose to grow
+ // the slice if the available capacity is not at least
+ // bytes.MinRead. This limits inline decoding to chunkSizes
+ // that are at least bytes.MinRead (which is not unreasonable).
+ var c *chunk
+ start := done + ((pendingPre + pendingInline) * int(r.chunkSize))
+ if len(p) >= start+int(r.chunkSize) && len(p) >= start+bytes.MinRead {
+ c = newChunk(r.lastSum, sum, compressed, bytes.NewBuffer(p[start:start]))
+ pendingInline++
+ } else {
+ c = newChunk(r.lastSum, sum, compressed, nil)
+ }
+ r.lastSum = sum
+ if err := r.schedule(c, callback); err == errNewBuffer {
+ // A new buffer was completed while we were reading.
+ // That's great, but we need to force schedule the
+ // current buffer so that it does not get lost.
+ //
+ // It is safe to pass nil as an output function here,
+ // because we know that we just freed up a slot above.
+ r.schedule(c, nil)
+ } else if err != nil {
+ // Some other error occurred; see above.
+ defer r.stop()
+ return done, err
+ }
+ }
+
+ // Make sure that everything has been decoded successfully, otherwise
+ // parts of p may not actually have completed.
+ for pendingInline > 0 {
+ if err := r.schedule(nil, func(c *chunk) error {
+ if err := callback(c); err != nil {
+ return err
+ }
+ // The nil case means that an inline buffer has
+ // completed. The callback will have already removed
+ // the inline buffer from the map, so we just return an
+ // error to check the top of the loop again.
+ return errNewBuffer
+ }); err != errNewBuffer {
+ // Some other error occurred; see above.
+ return done, err
+ }
+ }
+
+ // Need to return done here, since it may have been adjusted by the
+ // callback to compensation for partial reads on some inline buffer.
+ return done, nil
+}
+
+// Writer is a compressed writer.
+type Writer struct {
+ pool
+
+ // out is the underlying writer.
+ out io.Writer
+
+ // closed indicates whether the file has been closed.
+ closed bool
+}
+
+var _ io.Writer = (*Writer)(nil)
+
+// NewWriter returns a new compressed writer. If key is non-nil, hash values are
+// generated and written out for compressed bytes. See package comments for
+// details.
+//
+// The recommended chunkSize is on the order of 1M. Extra memory may be
+// buffered (in the form of read-ahead, or buffered writes), and is limited to
+// O(chunkSize * [1+GOMAXPROCS]).
+func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, error) {
+ w := &Writer{
+ pool: pool{
+ chunkSize: chunkSize,
+ buf: bufPool.Get().(*bytes.Buffer),
+ },
+ out: out,
+ }
+ w.init(key, 1+runtime.GOMAXPROCS(0), true, level)
+
+ if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil {
+ return nil, err
+ }
+
+ if w.hashPool != nil {
+ h := w.hashPool.getHash()
+ binary.WriteUint32(h, binary.BigEndian, chunkSize)
+ w.lastSum = h.Sum(nil)
+ w.hashPool.putHash(h)
+ if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil {
+ return nil, err
+ }
+ }
+
+ return w, nil
+}
+
+// flush writes a single buffer.
+func (w *Writer) flush(c *chunk) error {
+ // Prefix each chunk with a length; this allows the reader to safely
+ // limit reads while buffering.
+ l := uint32(c.compressed.Len())
+ if err := binary.WriteUint32(w.out, binary.BigEndian, l); err != nil {
+ return err
+ }
+
+ // Write out to the stream.
+ if _, err := io.CopyN(w.out, c.compressed, int64(c.compressed.Len())); err != nil {
+ return err
+ }
+
+ if w.hashPool != nil {
+ io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum)))
+ sum := c.h.Sum(nil)
+ w.hashPool.putHash(c.h)
+ c.h = nil
+ if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil {
+ return err
+ }
+ w.lastSum = sum
+ }
+
+ return nil
+}
+
+// WriteByte implements wire.Writer.WriteByte.
+//
+// Note that this implementation is necessary on the object itself, as an
+// interface-based dispatch cannot tell whether the array backing the slice
+// escapes, therefore the all bytes written will generate an escape.
+func (w *Writer) WriteByte(b byte) error {
+ var p [1]byte
+ p[0] = b
+ n, err := w.Write(p[:])
+ if n != 1 {
+ return err
+ }
+ return nil
+}
+
+// Write implements io.Writer.Write.
+func (w *Writer) Write(p []byte) (int, error) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ // Did we close already?
+ if w.closed {
+ return 0, io.ErrUnexpectedEOF
+ }
+
+ // See above; we need to track in the same way.
+ var (
+ pendingPre = w.nextInput - w.nextOutput
+ pendingInline = 0
+ )
+ callback := func(c *chunk) error {
+ if pendingPre == 0 && pendingInline > 0 {
+ pendingInline--
+ return w.flush(c)
+ }
+ if pendingPre > 0 {
+ pendingPre--
+ }
+ err := w.flush(c)
+ c.uncompressed.Reset()
+ bufPool.Put(c.uncompressed)
+ return err
+ }
+
+ for done := 0; done < len(p); {
+ // Construct an inline buffer if we're doing an inline
+ // encoding; see above regarding the bytes.MinRead constraint.
+ if w.buf.Len() == 0 && len(p) >= done+int(w.chunkSize) && len(p) >= done+bytes.MinRead {
+ bufPool.Put(w.buf) // Return to the pool; never scheduled.
+ w.buf = bytes.NewBuffer(p[done : done+int(w.chunkSize)])
+ done += int(w.chunkSize)
+ pendingInline++
+ }
+
+ // Do we need to flush w.buf? Note that this case should be hit
+ // immediately following the inline case above.
+ left := int(w.chunkSize) - w.buf.Len()
+ if left == 0 {
+ if err := w.schedule(newChunk(nil, nil, nil, w.buf), callback); err != nil {
+ return done, err
+ }
+ // Reset the buffer, since this has now been scheduled
+ // for compression. Note that this may be trampled
+ // immediately by the bufPool.Put(w.buf) above if the
+ // next buffer happens to be inline, but that's okay.
+ w.buf = bufPool.Get().(*bytes.Buffer)
+ continue
+ }
+
+ // Read from p into w.buf.
+ toWrite := len(p) - done
+ if toWrite > left {
+ toWrite = left
+ }
+ n, err := w.buf.Write(p[done : done+toWrite])
+ done += n
+ if err != nil {
+ return done, err
+ }
+ }
+
+ // Make sure that everything has been flushed, we can't return until
+ // all the contents from p have been used.
+ for pendingInline > 0 {
+ if err := w.schedule(nil, func(c *chunk) error {
+ if err := callback(c); err != nil {
+ return err
+ }
+ // The flush was successful, return errNewBuffer here
+ // to break from the loop and check the condition
+ // again.
+ return errNewBuffer
+ }); err != errNewBuffer {
+ return len(p), err
+ }
+ }
+
+ return len(p), nil
+}
+
+// Close implements io.Closer.Close.
+func (w *Writer) Close() error {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ // Did we already close? After the call to Close, we always mark as
+ // closed, regardless of whether the flush is successful.
+ if w.closed {
+ return io.ErrUnexpectedEOF
+ }
+ w.closed = true
+ defer w.stop()
+
+ // Schedule any remaining partial buffer; we pass w.flush directly here
+ // because the final buffer is guaranteed to not be an inline buffer.
+ if w.buf.Len() > 0 {
+ if err := w.schedule(newChunk(nil, nil, nil, w.buf), w.flush); err != nil {
+ return err
+ }
+ }
+
+ // Flush all scheduled buffers; see above.
+ if err := w.schedule(nil, w.flush); err != nil {
+ return err
+ }
+
+ // Close the underlying writer (if necessary).
+ if closer, ok := w.out.(io.Closer); ok {
+ return closer.Close()
+ }
+ return nil
+}