// Copyright 2018 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
Package hashio provides hash-verified I/O streams.

The I/O stream format is defined as follows.

/-----------------------------------------\
|                 payload                 |
+-----------------------------------------+
|                  hash                   |
+-----------------------------------------+
|                 payload                 |
+-----------------------------------------+
|                  hash                   |
+-----------------------------------------+
|                 ......                  |
\-----------------------------------------/

Payload bytes written to / read from the stream are automatically split
into segments, each followed by a hash. All data read out must have already
passed hash verification. Hence the client code can safely do any kind of
(stream) processing of these data.
*/
package hashio

import (
	"crypto/hmac"
	"errors"
	"hash"
	"io"
	"sync"
)

// SegmentSize is the unit we split payload data and insert hash at.
const SegmentSize = 8 * 1024

// ErrHashMismatch is returned if the ErrHashMismatch does not match.
var ErrHashMismatch = errors.New("hash mismatch")

// writer computes hashs during writes.
type writer struct {
	mu      sync.Mutex
	w       io.Writer
	h       hash.Hash
	written int
	closed  bool
	hashv   []byte
}

// NewWriter creates a hash-verified IO stream writer.
func NewWriter(w io.Writer, h hash.Hash) io.WriteCloser {
	return &writer{
		w:     w,
		h:     h,
		hashv: make([]byte, h.Size()),
	}
}

// Write writes the given data.
func (w *writer) Write(p []byte) (int, error) {
	w.mu.Lock()
	defer w.mu.Unlock()

	// Did we already close?
	if w.closed {
		return 0, io.ErrUnexpectedEOF
	}

	for done := 0; done < len(p); {
		// Slice the data at segment boundary.
		left := SegmentSize - w.written
		if left > len(p[done:]) {
			left = len(p[done:])
		}

		// Write the rest of the segment and write to hash writer the
		// same number of bytes. Hash.Write may never return an error.
		n, err := w.w.Write(p[done : done+left])
		w.h.Write(p[done : done+left])
		w.written += n
		done += n

		// And only check the actual write errors here.
		if n == 0 && err != nil {
			return done, err
		}

		// Write hash if starting a new segment.
		if w.written == SegmentSize {
			if err := w.closeSegment(); err != nil {
				return done, err
			}
		}
	}

	return len(p), nil
}

// closeSegment closes the current segment and writes out its hash.
func (w *writer) closeSegment() error {
	// Serialize and write the current segment's hash.
	hashv := w.h.Sum(w.hashv[:0])
	for done := 0; done < len(hashv); {
		n, err := w.w.Write(hashv[done:])
		done += n
		if n == 0 && err != nil {
			return err
		}
	}
	w.written = 0 // reset counter.
	return nil
}

// Close writes the final hash to the stream and closes the underlying Writer.
func (w *writer) Close() error {
	w.mu.Lock()
	defer w.mu.Unlock()

	// Did we already close?
	if w.closed {
		return io.ErrUnexpectedEOF
	}

	// Always mark as closed, regardless of errors.
	w.closed = true

	// Write the final segment.
	if err := w.closeSegment(); err != nil {
		return err
	}

	// Call the underlying closer.
	if c, ok := w.w.(io.Closer); ok {
		return c.Close()
	}
	return nil
}

// reader computes and verifies hashs during reads.
type reader struct {
	mu sync.Mutex
	r  io.Reader
	h  hash.Hash

	// data is remaining verified but unused payload data. This is
	// populated on short reads and may be consumed without any
	// verification.
	data [SegmentSize]byte

	// index is the index into data above.
	index int

	// available is the amount of valid data above.
	available int

	// hashv is the read hash for the current segment.
	hashv []byte

	// computev is the computed hash for the current segment.
	computev []byte
}

// NewReader creates a hash-verified IO stream reader.
func NewReader(r io.Reader, h hash.Hash) io.Reader {
	return &reader{
		r:        r,
		h:        h,
		hashv:    make([]byte, h.Size()),
		computev: make([]byte, h.Size()),
	}
}

// readSegment reads a segment and hash vector.
//
// Precondition: datav must have length SegmentSize.
func (r *reader) readSegment(datav []byte) (data []byte, err error) {
	// Make two reads: the first is the segment, the second is the hash
	// which needs verification. We may need to adjust the resulting slices
	// in the case of short reads.
	for done := 0; done < SegmentSize; {
		n, err := r.r.Read(datav[done:])
		done += n
		if n == 0 && err == io.EOF {
			if done == 0 {
				// No data at all.
				return nil, io.EOF
			} else if done < len(r.hashv) {
				// Not enough for a hash.
				return nil, ErrHashMismatch
			}
			// Truncate the data and copy to the hash.
			copy(r.hashv, datav[done-len(r.hashv):])
			datav = datav[:done-len(r.hashv)]
			return datav, nil
		} else if n == 0 && err != nil {
			return nil, err
		}
	}
	for done := 0; done < len(r.hashv); {
		n, err := r.r.Read(r.hashv[done:])
		done += n
		if n == 0 && err == io.EOF {
			// Copy over from the data.
			missing := len(r.hashv) - done
			copy(r.hashv[missing:], r.hashv[:done])
			copy(r.hashv[:missing], datav[len(datav)-missing:])
			datav = datav[:len(datav)-missing]
			return datav, nil
		} else if n == 0 && err != nil {
			return nil, err
		}
	}
	return datav, nil
}

// verifyHash verifies the given hash.
//
// The passed hash will be returned to the pool.
func (r *reader) verifyHash(datav []byte) error {
	for done := 0; done < len(datav); {
		n, _ := r.h.Write(datav[done:])
		done += n
	}
	computev := r.h.Sum(r.computev[:0])
	if !hmac.Equal(r.hashv, computev) {
		return ErrHashMismatch
	}
	return nil
}

// Read reads the data.
func (r *reader) Read(p []byte) (int, error) {
	r.mu.Lock()
	defer r.mu.Unlock()

	for done := 0; done < len(p); {
		// Check for pending data.
		if r.index < r.available {
			n := copy(p[done:], r.data[r.index:r.available])
			done += n
			r.index += n
			continue
		}

		// Prepare the next read.
		var (
			datav  []byte
			inline bool
		)

		// We need to read a new segment. Can we read directly?
		if len(p[done:]) >= SegmentSize {
			datav = p[done : done+SegmentSize]
			inline = true
		} else {
			datav = r.data[:]
			inline = false
		}

		// Read the next segments.
		datav, err := r.readSegment(datav)
		if err != nil && err != io.EOF {
			return 0, err
		} else if err == io.EOF {
			return done, io.EOF
		}
		if err := r.verifyHash(datav); err != nil {
			return done, err
		}

		if inline {
			// Move the cursor.
			done += len(datav)
		} else {
			// Reset index & available.
			r.index = 0
			r.available = len(datav)
		}
	}

	return len(p), nil
}