diff options
Diffstat (limited to 'pkg/compressio/compressio.go')
-rw-r--r-- | pkg/compressio/compressio.go | 773 |
1 files changed, 773 insertions, 0 deletions
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go new file mode 100644 index 000000000..b094c5662 --- /dev/null +++ b/pkg/compressio/compressio.go @@ -0,0 +1,773 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package compressio provides parallel compression and decompression, as well +// as optional SHA-256 hashing. +// +// The stream format is defined as follows. +// +// /------------------------------------------------------\ +// | chunk size (4-bytes) | +// +------------------------------------------------------+ +// | (optional) hash (32-bytes) | +// +------------------------------------------------------+ +// | compressed data size (4-bytes) | +// +------------------------------------------------------+ +// | compressed data | +// +------------------------------------------------------+ +// | (optional) hash (32-bytes) | +// +------------------------------------------------------+ +// | compressed data size (4-bytes) | +// +------------------------------------------------------+ +// | ...... | +// \------------------------------------------------------/ +// +// where each subsequent hash is calculated from the following items in order +// +// compressed data +// compressed data size +// previous hash +// +// so the stream integrity cannot be compromised by switching and mixing +// compressed chunks. +package compressio + +import ( + "bytes" + "compress/flate" + "crypto/hmac" + "crypto/sha256" + "errors" + "hash" + "io" + "runtime" + + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sync" +) + +var bufPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(nil) + }, +} + +var chunkPool = sync.Pool{ + New: func() interface{} { + return new(chunk) + }, +} + +// chunk is a unit of work. +type chunk struct { + // compressed is compressed data. + // + // This will always be returned to the bufPool directly when work has + // finished (in schedule) and therefore must be allocated. + compressed *bytes.Buffer + + // uncompressed is the uncompressed data. + // + // This is not returned to the bufPool automatically, since it may + // correspond to a inline slice (provided directly to Read or Write). + uncompressed *bytes.Buffer + + // The current hash object. Only used in compress mode. + h hash.Hash + + // The hash from previous chunks. Only used in uncompress mode. + lastSum []byte + + // The expected hash after current chunk. Only used in uncompress mode. + sum []byte +} + +// newChunk allocates a new chunk object (or pulls one from the pool). Buffers +// will be allocated if nil is provided for compressed or uncompressed. +func newChunk(lastSum []byte, sum []byte, compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk { + c := chunkPool.Get().(*chunk) + c.lastSum = lastSum + c.sum = sum + if compressed != nil { + c.compressed = compressed + } else { + c.compressed = bufPool.Get().(*bytes.Buffer) + } + if uncompressed != nil { + c.uncompressed = uncompressed + } else { + c.uncompressed = bufPool.Get().(*bytes.Buffer) + } + return c +} + +// result is the result of some work; it includes the original chunk. +type result struct { + *chunk + err error +} + +// worker is a compression/decompression worker. +// +// The associated worker goroutine reads in uncompressed buffers from input and +// writes compressed buffers to its output. Alternatively, the worker reads +// compressed buffers from input and writes uncompressed buffers to its output. +// +// The goroutine will exit when input is closed, and the goroutine will close +// output. +type worker struct { + hashPool *hashPool + input chan *chunk + output chan result +} + +// work is the main work routine; see worker. +func (w *worker) work(compress bool, level int) { + defer close(w.output) + + var h hash.Hash + + for c := range w.input { + if h == nil && w.hashPool != nil { + h = w.hashPool.getHash() + } + if compress { + mw := io.Writer(c.compressed) + if h != nil { + mw = io.MultiWriter(mw, h) + } + + // Encode this slice. + fw, err := flate.NewWriter(mw, level) + if err != nil { + w.output <- result{c, err} + continue + } + + // Encode the input. + if _, err := io.CopyN(fw, c.uncompressed, int64(c.uncompressed.Len())); err != nil { + w.output <- result{c, err} + continue + } + if err := fw.Close(); err != nil { + w.output <- result{c, err} + continue + } + + // Write the hash, if enabled. + if h != nil { + binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + c.h = h + h = nil + } + } else { + // Check the hash of the compressed contents. + if h != nil { + h.Write(c.compressed.Bytes()) + binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum))) + + sum := h.Sum(nil) + h.Reset() + if !hmac.Equal(c.sum, sum) { + w.output <- result{c, ErrHashMismatch} + continue + } + } + + // Decode this slice. + fr := flate.NewReader(c.compressed) + + // Decode the input. + if _, err := io.Copy(c.uncompressed, fr); err != nil { + w.output <- result{c, err} + continue + } + } + + // Send the output. + w.output <- result{c, nil} + } +} + +type hashPool struct { + // mu protexts the hash list. + mu sync.Mutex + + // key is the key used to create hash objects. + key []byte + + // hashes is the hash object free list. Note that this cannot be + // globally shared across readers or writers, as it is key-specific. + hashes []hash.Hash +} + +// getHash gets a hash object for the pool. It should only be called when the +// pool key is non-nil. +func (p *hashPool) getHash() hash.Hash { + p.mu.Lock() + defer p.mu.Unlock() + + if len(p.hashes) == 0 { + return hmac.New(sha256.New, p.key) + } + + h := p.hashes[len(p.hashes)-1] + p.hashes = p.hashes[:len(p.hashes)-1] + return h +} + +func (p *hashPool) putHash(h hash.Hash) { + h.Reset() + + p.mu.Lock() + defer p.mu.Unlock() + + p.hashes = append(p.hashes, h) +} + +// pool is common functionality for reader/writers. +type pool struct { + // workers are the compression/decompression workers. + workers []worker + + // chunkSize is the chunk size. This is the first four bytes in the + // stream and is shared across both the reader and writer. + chunkSize uint32 + + // mu protects below; it is generally the responsibility of users to + // acquire this mutex before calling any methods on the pool. + mu sync.Mutex + + // nextInput is the next worker for input (scheduling). + nextInput int + + // nextOutput is the next worker for output (result). + nextOutput int + + // buf is the current active buffer; the exact semantics of this buffer + // depending on whether this is a reader or a writer. + buf *bytes.Buffer + + // lasSum records the hash of the last chunk processed. + lastSum []byte + + // hashPool is the hash object pool. It cannot be embedded into pool + // itself as worker refers to it and that would stop pool from being + // GCed. + hashPool *hashPool +} + +// init initializes the worker pool. +// +// This should only be called once. +func (p *pool) init(key []byte, workers int, compress bool, level int) { + if key != nil { + p.hashPool = &hashPool{key: key} + } + p.workers = make([]worker, workers) + for i := 0; i < len(p.workers); i++ { + p.workers[i] = worker{ + hashPool: p.hashPool, + input: make(chan *chunk, 1), + output: make(chan result, 1), + } + go p.workers[i].work(compress, level) // S/R-SAFE: In save path only. + } + runtime.SetFinalizer(p, (*pool).stop) +} + +// stop stops all workers. +func (p *pool) stop() { + for i := 0; i < len(p.workers); i++ { + close(p.workers[i].input) + } + p.workers = nil + p.hashPool = nil +} + +// handleResult calls the callback. +func handleResult(r result, callback func(*chunk) error) error { + defer func() { + r.chunk.compressed.Reset() + bufPool.Put(r.chunk.compressed) + chunkPool.Put(r.chunk) + }() + if r.err != nil { + return r.err + } + return callback(r.chunk) +} + +// schedule schedules the given buffers. +// +// If c is non-nil, then it will return as soon as the chunk is scheduled. If c +// is nil, then it will return only when no more work is left to do. +// +// If no callback function is provided, then the output channel will be +// ignored. You must be sure that the input is schedulable in this case. +func (p *pool) schedule(c *chunk, callback func(*chunk) error) error { + for { + var ( + inputChan chan *chunk + outputChan chan result + ) + if c != nil && len(p.workers) != 0 { + inputChan = p.workers[(p.nextInput+1)%len(p.workers)].input + } + if callback != nil && p.nextOutput != p.nextInput && len(p.workers) != 0 { + outputChan = p.workers[(p.nextOutput+1)%len(p.workers)].output + } + if inputChan == nil && outputChan == nil { + return nil + } + + select { + case inputChan <- c: + p.nextInput++ + return nil + case r := <-outputChan: + p.nextOutput++ + if err := handleResult(r, callback); err != nil { + return err + } + } + } +} + +// Reader is a compressed reader. +type Reader struct { + pool + + // in is the source. + in io.Reader +} + +var _ io.Reader = (*Reader)(nil) + +// NewReader returns a new compressed reader. If key is non-nil, the data stream +// is assumed to contain expected hash values, which will be compared against +// hash values computed from the compressed bytes. See package comments for +// details. +func NewReader(in io.Reader, key []byte) (*Reader, error) { + r := &Reader{ + in: in, + } + + // Use double buffering for read. + r.init(key, 2*runtime.GOMAXPROCS(0), false, 0) + + var err error + if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil { + return nil, err + } + + if r.hashPool != nil { + h := r.hashPool.getHash() + binary.WriteUint32(h, binary.BigEndian, r.chunkSize) + r.lastSum = h.Sum(nil) + r.hashPool.putHash(h) + sum := make([]byte, len(r.lastSum)) + if _, err := io.ReadFull(r.in, sum); err != nil { + return nil, err + } + if !hmac.Equal(r.lastSum, sum) { + return nil, ErrHashMismatch + } + } + + return r, nil +} + +// errNewBuffer is returned when a new buffer is completed. +var errNewBuffer = errors.New("buffer ready") + +// ErrHashMismatch is returned if the hash does not match. +var ErrHashMismatch = errors.New("hash mismatch") + +// ReadByte implements wire.Reader.ReadByte. +func (r *Reader) ReadByte() (byte, error) { + var p [1]byte + n, err := r.Read(p[:]) + if n != 1 { + return p[0], err + } + // Suppress EOF. + return p[0], nil +} + +// Read implements io.Reader.Read. +func (r *Reader) Read(p []byte) (int, error) { + r.mu.Lock() + defer r.mu.Unlock() + + // Total bytes completed; this is declared up front because it must be + // adjustable by the callback below. + done := 0 + + // Total bytes pending in the asynchronous workers for buffers. This is + // used to process the proper regions of the input as inline buffers. + var ( + pendingPre = r.nextInput - r.nextOutput + pendingInline = 0 + ) + + // Define our callback for completed work. + callback := func(c *chunk) error { + // Check for an inline buffer. + if pendingPre == 0 && pendingInline > 0 { + pendingInline-- + done += c.uncompressed.Len() + return nil + } + + // Copy the resulting buffer to our intermediate one, and + // return errNewBuffer to ensure that we aren't called a second + // time. This error code is handled specially below. + // + // c.buf will be freed and return to the pool when it is done. + if pendingPre > 0 { + pendingPre-- + } + r.buf = c.uncompressed + return errNewBuffer + } + + for done < len(p) { + // Do we have buffered data available? + if r.buf != nil { + n, err := r.buf.Read(p[done:]) + done += n + if err == io.EOF { + // This is the uncompressed buffer, it can be + // returned to the pool at this point. + r.buf.Reset() + bufPool.Put(r.buf) + r.buf = nil + } else if err != nil { + // Should never happen. + defer r.stop() + return done, err + } + continue + } + + // Read the length of the next chunk and reset the + // reader. The length is used to limit the reader. + // + // See writer.flush. + l, err := binary.ReadUint32(r.in, binary.BigEndian) + if err != nil { + // This is generally okay as long as there + // are still buffers outstanding. We actually + // just wait for completion of those buffers here + // and continue our loop. + if err := r.schedule(nil, callback); err == nil { + // We've actually finished all buffers; this is + // the normal EOF exit path. + defer r.stop() + return done, io.EOF + } else if err == errNewBuffer { + // A new buffer is now available. + continue + } else { + // Some other error occurred; we cannot + // process any further. + defer r.stop() + return done, err + } + } + + // Read this chunk and schedule decompression. + compressed := bufPool.Get().(*bytes.Buffer) + if _, err := io.CopyN(compressed, r.in, int64(l)); err != nil { + // Some other error occurred; see above. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return done, err + } + + var sum []byte + if r.hashPool != nil { + sum = make([]byte, len(r.lastSum)) + if _, err := io.ReadFull(r.in, sum); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return done, err + } + } + + // Are we doing inline decoding? + // + // Note that we need to check the length here against + // bytes.MinRead, since the bytes library will choose to grow + // the slice if the available capacity is not at least + // bytes.MinRead. This limits inline decoding to chunkSizes + // that are at least bytes.MinRead (which is not unreasonable). + var c *chunk + start := done + ((pendingPre + pendingInline) * int(r.chunkSize)) + if len(p) >= start+int(r.chunkSize) && len(p) >= start+bytes.MinRead { + c = newChunk(r.lastSum, sum, compressed, bytes.NewBuffer(p[start:start])) + pendingInline++ + } else { + c = newChunk(r.lastSum, sum, compressed, nil) + } + r.lastSum = sum + if err := r.schedule(c, callback); err == errNewBuffer { + // A new buffer was completed while we were reading. + // That's great, but we need to force schedule the + // current buffer so that it does not get lost. + // + // It is safe to pass nil as an output function here, + // because we know that we just freed up a slot above. + r.schedule(c, nil) + } else if err != nil { + // Some other error occurred; see above. + defer r.stop() + return done, err + } + } + + // Make sure that everything has been decoded successfully, otherwise + // parts of p may not actually have completed. + for pendingInline > 0 { + if err := r.schedule(nil, func(c *chunk) error { + if err := callback(c); err != nil { + return err + } + // The nil case means that an inline buffer has + // completed. The callback will have already removed + // the inline buffer from the map, so we just return an + // error to check the top of the loop again. + return errNewBuffer + }); err != errNewBuffer { + // Some other error occurred; see above. + return done, err + } + } + + // Need to return done here, since it may have been adjusted by the + // callback to compensation for partial reads on some inline buffer. + return done, nil +} + +// Writer is a compressed writer. +type Writer struct { + pool + + // out is the underlying writer. + out io.Writer + + // closed indicates whether the file has been closed. + closed bool +} + +var _ io.Writer = (*Writer)(nil) + +// NewWriter returns a new compressed writer. If key is non-nil, hash values are +// generated and written out for compressed bytes. See package comments for +// details. +// +// The recommended chunkSize is on the order of 1M. Extra memory may be +// buffered (in the form of read-ahead, or buffered writes), and is limited to +// O(chunkSize * [1+GOMAXPROCS]). +func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, error) { + w := &Writer{ + pool: pool{ + chunkSize: chunkSize, + buf: bufPool.Get().(*bytes.Buffer), + }, + out: out, + } + w.init(key, 1+runtime.GOMAXPROCS(0), true, level) + + if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil { + return nil, err + } + + if w.hashPool != nil { + h := w.hashPool.getHash() + binary.WriteUint32(h, binary.BigEndian, chunkSize) + w.lastSum = h.Sum(nil) + w.hashPool.putHash(h) + if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil { + return nil, err + } + } + + return w, nil +} + +// flush writes a single buffer. +func (w *Writer) flush(c *chunk) error { + // Prefix each chunk with a length; this allows the reader to safely + // limit reads while buffering. + l := uint32(c.compressed.Len()) + if err := binary.WriteUint32(w.out, binary.BigEndian, l); err != nil { + return err + } + + // Write out to the stream. + if _, err := io.CopyN(w.out, c.compressed, int64(c.compressed.Len())); err != nil { + return err + } + + if w.hashPool != nil { + io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum))) + sum := c.h.Sum(nil) + w.hashPool.putHash(c.h) + c.h = nil + if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil { + return err + } + w.lastSum = sum + } + + return nil +} + +// WriteByte implements wire.Writer.WriteByte. +// +// Note that this implementation is necessary on the object itself, as an +// interface-based dispatch cannot tell whether the array backing the slice +// escapes, therefore the all bytes written will generate an escape. +func (w *Writer) WriteByte(b byte) error { + var p [1]byte + p[0] = b + n, err := w.Write(p[:]) + if n != 1 { + return err + } + return nil +} + +// Write implements io.Writer.Write. +func (w *Writer) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + + // Did we close already? + if w.closed { + return 0, io.ErrUnexpectedEOF + } + + // See above; we need to track in the same way. + var ( + pendingPre = w.nextInput - w.nextOutput + pendingInline = 0 + ) + callback := func(c *chunk) error { + if pendingPre == 0 && pendingInline > 0 { + pendingInline-- + return w.flush(c) + } + if pendingPre > 0 { + pendingPre-- + } + err := w.flush(c) + c.uncompressed.Reset() + bufPool.Put(c.uncompressed) + return err + } + + for done := 0; done < len(p); { + // Construct an inline buffer if we're doing an inline + // encoding; see above regarding the bytes.MinRead constraint. + if w.buf.Len() == 0 && len(p) >= done+int(w.chunkSize) && len(p) >= done+bytes.MinRead { + bufPool.Put(w.buf) // Return to the pool; never scheduled. + w.buf = bytes.NewBuffer(p[done : done+int(w.chunkSize)]) + done += int(w.chunkSize) + pendingInline++ + } + + // Do we need to flush w.buf? Note that this case should be hit + // immediately following the inline case above. + left := int(w.chunkSize) - w.buf.Len() + if left == 0 { + if err := w.schedule(newChunk(nil, nil, nil, w.buf), callback); err != nil { + return done, err + } + // Reset the buffer, since this has now been scheduled + // for compression. Note that this may be trampled + // immediately by the bufPool.Put(w.buf) above if the + // next buffer happens to be inline, but that's okay. + w.buf = bufPool.Get().(*bytes.Buffer) + continue + } + + // Read from p into w.buf. + toWrite := len(p) - done + if toWrite > left { + toWrite = left + } + n, err := w.buf.Write(p[done : done+toWrite]) + done += n + if err != nil { + return done, err + } + } + + // Make sure that everything has been flushed, we can't return until + // all the contents from p have been used. + for pendingInline > 0 { + if err := w.schedule(nil, func(c *chunk) error { + if err := callback(c); err != nil { + return err + } + // The flush was successful, return errNewBuffer here + // to break from the loop and check the condition + // again. + return errNewBuffer + }); err != errNewBuffer { + return len(p), err + } + } + + return len(p), nil +} + +// Close implements io.Closer.Close. +func (w *Writer) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + + // Did we already close? After the call to Close, we always mark as + // closed, regardless of whether the flush is successful. + if w.closed { + return io.ErrUnexpectedEOF + } + w.closed = true + defer w.stop() + + // Schedule any remaining partial buffer; we pass w.flush directly here + // because the final buffer is guaranteed to not be an inline buffer. + if w.buf.Len() > 0 { + if err := w.schedule(newChunk(nil, nil, nil, w.buf), w.flush); err != nil { + return err + } + } + + // Flush all scheduled buffers; see above. + if err := w.schedule(nil, w.flush); err != nil { + return err + } + + // Close the underlying writer (if necessary). + if closer, ok := w.out.(io.Closer); ok { + return closer.Close() + } + return nil +} |