summaryrefslogtreecommitdiffhomepage
path: root/pkg/compressio/compressio.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/compressio/compressio.go')
-rw-r--r--pkg/compressio/compressio.go223
1 files changed, 200 insertions, 23 deletions
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go
index ef8cbd2a5..591b37130 100644
--- a/pkg/compressio/compressio.go
+++ b/pkg/compressio/compressio.go
@@ -12,17 +12,48 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package compressio provides parallel compression and decompression.
+// Package compressio provides parallel compression and decompression, as well
+// as optional SHA-256 hashing.
+//
+// The stream format is defined as follows.
+//
+// /------------------------------------------------------\
+// | chunk size (4-bytes) |
+// +------------------------------------------------------+
+// | (optional) hash (32-bytes) |
+// +------------------------------------------------------+
+// | compressed data size (4-bytes) |
+// +------------------------------------------------------+
+// | compressed data |
+// +------------------------------------------------------+
+// | (optional) hash (32-bytes) |
+// +------------------------------------------------------+
+// | compressed data size (4-bytes) |
+// +------------------------------------------------------+
+// | ...... |
+// \------------------------------------------------------/
+//
+// where each subsequent hash is calculated from the following items in order
+//
+// compressed data
+// compressed data size
+// previous hash
+//
+// so the stream integrity cannot be compromised by switching and mixing
+// compressed chunks.
package compressio
import (
"bytes"
"compress/flate"
"errors"
+ "hash"
"io"
"runtime"
"sync"
+ "crypto/hmac"
+ "crypto/sha256"
"gvisor.googlesource.com/gvisor/pkg/binary"
)
@@ -51,12 +82,23 @@ type chunk struct {
// This is not returned to the bufPool automatically, since it may
// correspond to a inline slice (provided directly to Read or Write).
uncompressed *bytes.Buffer
+
+ // The current hash object. Only used in compress mode.
+ h hash.Hash
+
+ // The hash from previous chunks. Only used in uncompress mode.
+ lastSum []byte
+
+ // The expected hash after current chunk. Only used in uncompress mode.
+ sum []byte
}
// newChunk allocates a new chunk object (or pulls one from the pool). Buffers
// will be allocated if nil is provided for compressed or uncompressed.
-func newChunk(compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk {
+func newChunk(lastSum []byte, sum []byte, compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk {
c := chunkPool.Get().(*chunk)
+ c.lastSum = lastSum
+ c.sum = sum
if compressed != nil {
c.compressed = compressed
} else {
@@ -85,6 +127,7 @@ type result struct {
// The goroutine will exit when input is closed, and the goroutine will close
// output.
type worker struct {
+ pool *pool
input chan *chunk
output chan result
}
@@ -93,17 +136,27 @@ type worker struct {
func (w *worker) work(compress bool, level int) {
defer close(w.output)
+ var h hash.Hash
+
for c := range w.input {
+ if h == nil && w.pool.key != nil {
+ h = w.pool.getHash()
+ }
if compress {
+ mw := io.Writer(c.compressed)
+ if h != nil {
+ mw = io.MultiWriter(mw, h)
+ }
+
// Encode this slice.
- fw, err := flate.NewWriter(c.compressed, level)
+ fw, err := flate.NewWriter(mw, level)
if err != nil {
w.output <- result{c, err}
continue
}
// Encode the input.
- if _, err := io.Copy(fw, c.uncompressed); err != nil {
+ if _, err := io.CopyN(fw, c.uncompressed, int64(c.uncompressed.Len())); err != nil {
w.output <- result{c, err}
continue
}
@@ -111,7 +164,28 @@ func (w *worker) work(compress bool, level int) {
w.output <- result{c, err}
continue
}
+
+ // Write the hash, if enabled.
+ if h != nil {
+ binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len()))
+ c.h = h
+ h = nil
+ }
} else {
+ // Check the hash of the compressed contents.
+ if h != nil {
+ h.Write(c.compressed.Bytes())
+ binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len()))
+ io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum)))
+
+ sum := h.Sum(nil)
+ h.Reset()
+ if !hmac.Equal(c.sum, sum) {
+ w.output <- result{c, ErrHashMismatch}
+ continue
+ }
+ }
+
// Decode this slice.
fr := flate.NewReader(c.compressed)
@@ -136,6 +210,16 @@ type pool struct {
// stream and is shared across both the reader and writer.
chunkSize uint32
+ // key is the key used to create hash objects.
+ key []byte
+
+ // hashMu protexts the hash list.
+ hashMu sync.Mutex
+
+ // hashes is the hash object free list. Note that this cannot be
+ // globally shared across readers or writers, as it is key-specific.
+ hashes []hash.Hash
+
// mu protects below; it is generally the responsibility of users to
// acquire this mutex before calling any methods on the pool.
mu sync.Mutex
@@ -149,15 +233,20 @@ type pool struct {
// buf is the current active buffer; the exact semantics of this buffer
// depending on whether this is a reader or a writer.
buf *bytes.Buffer
+
+ // lasSum records the hash of the last chunk processed.
+ lastSum []byte
}
// init initializes the worker pool.
//
// This should only be called once.
-func (p *pool) init(compress bool, level int) {
- p.workers = make([]worker, 1+runtime.GOMAXPROCS(0))
+func (p *pool) init(key []byte, workers int, compress bool, level int) {
+ p.key = key
+ p.workers = make([]worker, workers)
for i := 0; i < len(p.workers); i++ {
p.workers[i] = worker{
+ pool: p,
input: make(chan *chunk, 1),
output: make(chan result, 1),
}
@@ -174,6 +263,30 @@ func (p *pool) stop() {
p.workers = nil
}
+// getHash gets a hash object for the pool. It should only be called when the
+// pool key is non-nil.
+func (p *pool) getHash() hash.Hash {
+ p.hashMu.Lock()
+ defer p.hashMu.Unlock()
+
+ if len(p.hashes) == 0 {
+ return hmac.New(sha256.New, p.key)
+ }
+
+ h := p.hashes[len(p.hashes)-1]
+ p.hashes = p.hashes[:len(p.hashes)-1]
+ return h
+}
+
+func (p *pool) putHash(h hash.Hash) {
+ h.Reset()
+
+ p.hashMu.Lock()
+ defer p.hashMu.Unlock()
+
+ p.hashes = append(p.hashes, h)
+}
+
// handleResult calls the callback.
func handleResult(r result, callback func(*chunk) error) error {
defer func() {
@@ -231,22 +344,46 @@ type reader struct {
in io.Reader
}
-// NewReader returns a new compressed reader.
-func NewReader(in io.Reader) (io.Reader, error) {
+// NewReader returns a new compressed reader. If key is non-nil, the data stream
+// is assumed to contain expected hash values, which will be compared against
+// hash values computed from the compressed bytes. See package comments for
+// details.
+func NewReader(in io.Reader, key []byte) (io.Reader, error) {
r := &reader{
in: in,
}
- r.init(false, 0)
+
+ // Use double buffering for read.
+ r.init(key, 2*runtime.GOMAXPROCS(0), false, 0)
+
var err error
- if r.chunkSize, err = binary.ReadUint32(r.in, binary.BigEndian); err != nil {
+ if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil {
return nil, err
}
+
+ if r.key != nil {
+ h := r.getHash()
+ binary.WriteUint32(h, binary.BigEndian, r.chunkSize)
+ r.lastSum = h.Sum(nil)
+ r.putHash(h)
+ sum := make([]byte, len(r.lastSum))
+ if _, err := io.ReadFull(r.in, sum); err != nil {
+ return nil, err
+ }
+ if !hmac.Equal(r.lastSum, sum) {
+ return nil, ErrHashMismatch
+ }
+ }
+
return r, nil
}
// errNewBuffer is returned when a new buffer is completed.
var errNewBuffer = errors.New("buffer ready")
+// ErrHashMismatch is returned if the hash does not match.
+var ErrHashMismatch = errors.New("hash mismatch")
+
// Read implements io.Reader.Read.
func (r *reader) Read(p []byte) (int, error) {
r.mu.Lock()
@@ -331,14 +468,25 @@ func (r *reader) Read(p []byte) (int, error) {
// Read this chunk and schedule decompression.
compressed := bufPool.Get().(*bytes.Buffer)
- if _, err := io.Copy(compressed, &io.LimitedReader{
- R: r.in,
- N: int64(l),
- }); err != nil {
+ if _, err := io.CopyN(compressed, r.in, int64(l)); err != nil {
// Some other error occurred; see above.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
return done, err
}
+ var sum []byte
+ if r.key != nil {
+ sum = make([]byte, len(r.lastSum))
+ if _, err := io.ReadFull(r.in, sum); err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return done, err
+ }
+ }
+
// Are we doing inline decoding?
//
// Note that we need to check the length here against
@@ -349,11 +497,12 @@ func (r *reader) Read(p []byte) (int, error) {
var c *chunk
start := done + ((pendingPre + pendingInline) * int(r.chunkSize))
if len(p) >= start+int(r.chunkSize) && len(p) >= start+bytes.MinRead {
- c = newChunk(compressed, bytes.NewBuffer(p[start:start]))
+ c = newChunk(r.lastSum, sum, compressed, bytes.NewBuffer(p[start:start]))
pendingInline++
} else {
- c = newChunk(compressed, nil)
+ c = newChunk(r.lastSum, sum, compressed, nil)
}
+ r.lastSum = sum
if err := r.schedule(c, callback); err == errNewBuffer {
// A new buffer was completed while we were reading.
// That's great, but we need to force schedule the
@@ -403,12 +552,14 @@ type writer struct {
closed bool
}
-// NewWriter returns a new compressed writer.
+// NewWriter returns a new compressed writer. If key is non-nil, hash values are
+// generated and written out for compressed bytes. See package comments for
+// details.
//
// The recommended chunkSize is on the order of 1M. Extra memory may be
// buffered (in the form of read-ahead, or buffered writes), and is limited to
// O(chunkSize * [1+GOMAXPROCS]).
-func NewWriter(out io.Writer, chunkSize uint32, level int) (io.WriteCloser, error) {
+func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.WriteCloser, error) {
w := &writer{
pool: pool{
chunkSize: chunkSize,
@@ -416,10 +567,22 @@ func NewWriter(out io.Writer, chunkSize uint32, level int) (io.WriteCloser, erro
},
out: out,
}
- w.init(true, level)
+ w.init(key, 1+runtime.GOMAXPROCS(0), true, level)
+
if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil {
return nil, err
}
+
+ if w.key != nil {
+ h := w.getHash()
+ binary.WriteUint32(h, binary.BigEndian, chunkSize)
+ w.lastSum = h.Sum(nil)
+ w.putHash(h)
+ if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil {
+ return nil, err
+ }
+ }
+
return w, nil
}
@@ -433,8 +596,22 @@ func (w *writer) flush(c *chunk) error {
}
// Write out to the stream.
- _, err := io.Copy(w.out, c.compressed)
- return err
+ if _, err := io.CopyN(w.out, c.compressed, int64(c.compressed.Len())); err != nil {
+ return err
+ }
+
+ if w.key != nil {
+ io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum)))
+ sum := c.h.Sum(nil)
+ w.putHash(c.h)
+ c.h = nil
+ if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil {
+ return err
+ }
+ w.lastSum = sum
+ }
+
+ return nil
}
// Write implements io.Writer.Write.
@@ -480,7 +657,7 @@ func (w *writer) Write(p []byte) (int, error) {
// immediately following the inline case above.
left := int(w.chunkSize) - w.buf.Len()
if left == 0 {
- if err := w.schedule(newChunk(nil, w.buf), callback); err != nil {
+ if err := w.schedule(newChunk(nil, nil, nil, w.buf), callback); err != nil {
return done, err
}
// Reset the buffer, since this has now been scheduled
@@ -538,7 +715,7 @@ func (w *writer) Close() error {
// Schedule any remaining partial buffer; we pass w.flush directly here
// because the final buffer is guaranteed to not be an inline buffer.
if w.buf.Len() > 0 {
- if err := w.schedule(newChunk(nil, w.buf), w.flush); err != nil {
+ if err := w.schedule(newChunk(nil, nil, nil, w.buf), w.flush); err != nil {
return err
}
}