summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorZhaozhong Ni <nzz@google.com>2018-09-12 17:23:56 -0700
committerShentubot <shentubot@google.com>2018-09-12 17:24:53 -0700
commit9dec7a3db99d8c7045324bc6d8f0c27e88407f6c (patch)
tree6f6098da7129a4e75b5b8e66df34a1fb05b7283d
parent2eff1fdd061be9cfabc36532dda8cbefeb02e534 (diff)
compressio: stop worker-pool reference / dependency loop.
PiperOrigin-RevId: 212732300 Change-Id: I9a0b9b7c28e7b7439d34656dd4f2f6114d173e22
-rw-r--r--pkg/compressio/compressio.go114
1 files changed, 62 insertions, 52 deletions
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go
index 591b37130..b4c1c70d9 100644
--- a/pkg/compressio/compressio.go
+++ b/pkg/compressio/compressio.go
@@ -127,9 +127,9 @@ type result struct {
// The goroutine will exit when input is closed, and the goroutine will close
// output.
type worker struct {
- pool *pool
- input chan *chunk
- output chan result
+ hashPool *hashPool
+ input chan *chunk
+ output chan result
}
// work is the main work routine; see worker.
@@ -139,8 +139,8 @@ func (w *worker) work(compress bool, level int) {
var h hash.Hash
for c := range w.input {
- if h == nil && w.pool.key != nil {
- h = w.pool.getHash()
+ if h == nil && w.hashPool != nil {
+ h = w.hashPool.getHash()
}
if compress {
mw := io.Writer(c.compressed)
@@ -201,6 +201,42 @@ func (w *worker) work(compress bool, level int) {
}
}
+type hashPool struct {
+ // mu protexts the hash list.
+ mu sync.Mutex
+
+ // key is the key used to create hash objects.
+ key []byte
+
+ // hashes is the hash object free list. Note that this cannot be
+ // globally shared across readers or writers, as it is key-specific.
+ hashes []hash.Hash
+}
+
+// getHash gets a hash object for the pool. It should only be called when the
+// pool key is non-nil.
+func (p *hashPool) getHash() hash.Hash {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if len(p.hashes) == 0 {
+ return hmac.New(sha256.New, p.key)
+ }
+
+ h := p.hashes[len(p.hashes)-1]
+ p.hashes = p.hashes[:len(p.hashes)-1]
+ return h
+}
+
+func (p *hashPool) putHash(h hash.Hash) {
+ h.Reset()
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.hashes = append(p.hashes, h)
+}
+
// pool is common functionality for reader/writers.
type pool struct {
// workers are the compression/decompression workers.
@@ -210,16 +246,6 @@ type pool struct {
// stream and is shared across both the reader and writer.
chunkSize uint32
- // key is the key used to create hash objects.
- key []byte
-
- // hashMu protexts the hash list.
- hashMu sync.Mutex
-
- // hashes is the hash object free list. Note that this cannot be
- // globally shared across readers or writers, as it is key-specific.
- hashes []hash.Hash
-
// mu protects below; it is generally the responsibility of users to
// acquire this mutex before calling any methods on the pool.
mu sync.Mutex
@@ -236,19 +262,26 @@ type pool struct {
// lasSum records the hash of the last chunk processed.
lastSum []byte
+
+ // hashPool is the hash object pool. It cannot be embedded into pool
+ // itself as worker refers to it and that would stop pool from being
+ // GCed.
+ hashPool *hashPool
}
// init initializes the worker pool.
//
// This should only be called once.
func (p *pool) init(key []byte, workers int, compress bool, level int) {
- p.key = key
+ if key != nil {
+ p.hashPool = &hashPool{key: key}
+ }
p.workers = make([]worker, workers)
for i := 0; i < len(p.workers); i++ {
p.workers[i] = worker{
- pool: p,
- input: make(chan *chunk, 1),
- output: make(chan result, 1),
+ hashPool: p.hashPool,
+ input: make(chan *chunk, 1),
+ output: make(chan result, 1),
}
go p.workers[i].work(compress, level) // S/R-SAFE: In save path only.
}
@@ -261,30 +294,7 @@ func (p *pool) stop() {
close(p.workers[i].input)
}
p.workers = nil
-}
-
-// getHash gets a hash object for the pool. It should only be called when the
-// pool key is non-nil.
-func (p *pool) getHash() hash.Hash {
- p.hashMu.Lock()
- defer p.hashMu.Unlock()
-
- if len(p.hashes) == 0 {
- return hmac.New(sha256.New, p.key)
- }
-
- h := p.hashes[len(p.hashes)-1]
- p.hashes = p.hashes[:len(p.hashes)-1]
- return h
-}
-
-func (p *pool) putHash(h hash.Hash) {
- h.Reset()
-
- p.hashMu.Lock()
- defer p.hashMu.Unlock()
-
- p.hashes = append(p.hashes, h)
+ p.hashPool = nil
}
// handleResult calls the callback.
@@ -361,11 +371,11 @@ func NewReader(in io.Reader, key []byte) (io.Reader, error) {
return nil, err
}
- if r.key != nil {
- h := r.getHash()
+ if r.hashPool != nil {
+ h := r.hashPool.getHash()
binary.WriteUint32(h, binary.BigEndian, r.chunkSize)
r.lastSum = h.Sum(nil)
- r.putHash(h)
+ r.hashPool.putHash(h)
sum := make([]byte, len(r.lastSum))
if _, err := io.ReadFull(r.in, sum); err != nil {
return nil, err
@@ -477,7 +487,7 @@ func (r *reader) Read(p []byte) (int, error) {
}
var sum []byte
- if r.key != nil {
+ if r.hashPool != nil {
sum = make([]byte, len(r.lastSum))
if _, err := io.ReadFull(r.in, sum); err != nil {
if err == io.EOF {
@@ -573,11 +583,11 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.Write
return nil, err
}
- if w.key != nil {
- h := w.getHash()
+ if w.hashPool != nil {
+ h := w.hashPool.getHash()
binary.WriteUint32(h, binary.BigEndian, chunkSize)
w.lastSum = h.Sum(nil)
- w.putHash(h)
+ w.hashPool.putHash(h)
if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil {
return nil, err
}
@@ -600,10 +610,10 @@ func (w *writer) flush(c *chunk) error {
return err
}
- if w.key != nil {
+ if w.hashPool != nil {
io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum)))
sum := c.h.Sum(nil)
- w.putHash(c.h)
+ w.hashPool.putHash(c.h)
c.h = nil
if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil {
return err