summaryrefslogtreecommitdiffhomepage
path: root/pkg/hashio/hashio.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/hashio/hashio.go')
-rw-r--r--pkg/hashio/hashio.go295
1 files changed, 295 insertions, 0 deletions
diff --git a/pkg/hashio/hashio.go b/pkg/hashio/hashio.go
new file mode 100644
index 000000000..d97948850
--- /dev/null
+++ b/pkg/hashio/hashio.go
@@ -0,0 +1,295 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+Package hashio provides hash-verified I/O streams.
+
+The I/O stream format is defined as follows.
+
+/-----------------------------------------\
+| payload |
++-----------------------------------------+
+| hash |
++-----------------------------------------+
+| payload |
++-----------------------------------------+
+| hash |
++-----------------------------------------+
+| ...... |
+\-----------------------------------------/
+
+Payload bytes written to / read from the stream are automatically split
+into segments, each followed by a hash. All data read out must have already
+passed hash verification. Hence the client code can safely do any kind of
+(stream) processing of these data.
+*/
+package hashio
+
+import (
+ "crypto/hmac"
+ "errors"
+ "hash"
+ "io"
+ "sync"
+)
+
+// SegmentSize is the unit we split payload data and insert hash at.
+const SegmentSize = 8 * 1024
+
+// ErrHashMismatch is returned if the ErrHashMismatch does not match.
+var ErrHashMismatch = errors.New("hash mismatch")
+
+// writer computes hashs during writes.
+type writer struct {
+ mu sync.Mutex
+ w io.Writer
+ h hash.Hash
+ written int
+ closed bool
+ hashv []byte
+}
+
+// NewWriter creates a hash-verified IO stream writer.
+func NewWriter(w io.Writer, h hash.Hash) io.WriteCloser {
+ return &writer{
+ w: w,
+ h: h,
+ hashv: make([]byte, h.Size()),
+ }
+}
+
+// Write writes the given data.
+func (w *writer) Write(p []byte) (int, error) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ // Did we already close?
+ if w.closed {
+ return 0, io.ErrUnexpectedEOF
+ }
+
+ for done := 0; done < len(p); {
+ // Slice the data at segment boundary.
+ left := SegmentSize - w.written
+ if left > len(p[done:]) {
+ left = len(p[done:])
+ }
+
+ // Write the rest of the segment and write to hash writer the
+ // same number of bytes. Hash.Write may never return an error.
+ n, err := w.w.Write(p[done : done+left])
+ w.h.Write(p[done : done+left])
+ w.written += n
+ done += n
+
+ // And only check the actual write errors here.
+ if n == 0 && err != nil {
+ return done, err
+ }
+
+ // Write hash if starting a new segment.
+ if w.written == SegmentSize {
+ if err := w.closeSegment(); err != nil {
+ return done, err
+ }
+ }
+ }
+
+ return len(p), nil
+}
+
+// closeSegment closes the current segment and writes out its hash.
+func (w *writer) closeSegment() error {
+ // Serialize and write the current segment's hash.
+ hashv := w.h.Sum(w.hashv[:0])
+ for done := 0; done < len(hashv); {
+ n, err := w.w.Write(hashv[done:])
+ done += n
+ if n == 0 && err != nil {
+ return err
+ }
+ }
+ w.written = 0 // reset counter.
+ return nil
+}
+
+// Close writes the final hash to the stream and closes the underlying Writer.
+func (w *writer) Close() error {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ // Did we already close?
+ if w.closed {
+ return io.ErrUnexpectedEOF
+ }
+
+ // Always mark as closed, regardless of errors.
+ w.closed = true
+
+ // Write the final segment.
+ if err := w.closeSegment(); err != nil {
+ return err
+ }
+
+ // Call the underlying closer.
+ if c, ok := w.w.(io.Closer); ok {
+ return c.Close()
+ }
+ return nil
+}
+
+// reader computes and verifies hashs during reads.
+type reader struct {
+ mu sync.Mutex
+ r io.Reader
+ h hash.Hash
+
+ // data is remaining verified but unused payload data. This is
+ // populated on short reads and may be consumed without any
+ // verification.
+ data [SegmentSize]byte
+
+ // index is the index into data above.
+ index int
+
+ // available is the amount of valid data above.
+ available int
+
+ // hashv is the read hash for the current segment.
+ hashv []byte
+
+ // computev is the computed hash for the current segment.
+ computev []byte
+}
+
+// NewReader creates a hash-verified IO stream reader.
+func NewReader(r io.Reader, h hash.Hash) io.Reader {
+ return &reader{
+ r: r,
+ h: h,
+ hashv: make([]byte, h.Size()),
+ computev: make([]byte, h.Size()),
+ }
+}
+
+// readSegment reads a segment and hash vector.
+//
+// Precondition: datav must have length SegmentSize.
+func (r *reader) readSegment(datav []byte) (data []byte, err error) {
+ // Make two reads: the first is the segment, the second is the hash
+ // which needs verification. We may need to adjust the resulting slices
+ // in the case of short reads.
+ for done := 0; done < SegmentSize; {
+ n, err := r.r.Read(datav[done:])
+ done += n
+ if n == 0 && err == io.EOF {
+ if done == 0 {
+ // No data at all.
+ return nil, io.EOF
+ } else if done < len(r.hashv) {
+ // Not enough for a hash.
+ return nil, ErrHashMismatch
+ }
+ // Truncate the data and copy to the hash.
+ copy(r.hashv, datav[done-len(r.hashv):])
+ datav = datav[:done-len(r.hashv)]
+ return datav, nil
+ } else if n == 0 && err != nil {
+ return nil, err
+ }
+ }
+ for done := 0; done < len(r.hashv); {
+ n, err := r.r.Read(r.hashv[done:])
+ done += n
+ if n == 0 && err == io.EOF {
+ // Copy over from the data.
+ missing := len(r.hashv) - done
+ copy(r.hashv[missing:], r.hashv[:done])
+ copy(r.hashv[:missing], datav[len(datav)-missing:])
+ datav = datav[:len(datav)-missing]
+ return datav, nil
+ } else if n == 0 && err != nil {
+ return nil, err
+ }
+ }
+ return datav, nil
+}
+
+// verifyHash verifies the given hash.
+//
+// The passed hash will be returned to the pool.
+func (r *reader) verifyHash(datav []byte) error {
+ for done := 0; done < len(datav); {
+ n, _ := r.h.Write(datav[done:])
+ done += n
+ }
+ computev := r.h.Sum(r.computev[:0])
+ if !hmac.Equal(r.hashv, computev) {
+ return ErrHashMismatch
+ }
+ return nil
+}
+
+// Read reads the data.
+func (r *reader) Read(p []byte) (int, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ for done := 0; done < len(p); {
+ // Check for pending data.
+ if r.index < r.available {
+ n := copy(p[done:], r.data[r.index:r.available])
+ done += n
+ r.index += n
+ continue
+ }
+
+ // Prepare the next read.
+ var (
+ datav []byte
+ inline bool
+ )
+
+ // We need to read a new segment. Can we read directly?
+ if len(p[done:]) >= SegmentSize {
+ datav = p[done : done+SegmentSize]
+ inline = true
+ } else {
+ datav = r.data[:]
+ inline = false
+ }
+
+ // Read the next segments.
+ datav, err := r.readSegment(datav)
+ if err != nil && err != io.EOF {
+ return 0, err
+ } else if err == io.EOF {
+ return done, io.EOF
+ }
+ if err := r.verifyHash(datav); err != nil {
+ return done, err
+ }
+
+ if inline {
+ // Move the cursor.
+ done += len(datav)
+ } else {
+ // Reset index & available.
+ r.index = 0
+ r.available = len(datav)
+ }
+ }
+
+ return len(p), nil
+}