// Copyright 2018 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package hashio

import (
	"bytes"
	"crypto/hmac"
	"crypto/sha256"
	"fmt"
	"io"
	"math/rand"
	"testing"
)

var testKey = []byte("01234567890123456789012345678901")

func runTest(c []byte, fn func(enc *bytes.Buffer), iters int) error {
	// Encoding happens via a buffer.
	var (
		enc bytes.Buffer
		dec bytes.Buffer
	)

	for i := 0; i < iters; i++ {
		enc.Reset()
		w := NewWriter(&enc, hmac.New(sha256.New, testKey))
		if _, err := io.Copy(w, bytes.NewBuffer(c)); err != nil {
			return err
		}
		if err := w.Close(); err != nil {
			return err
		}
	}

	fn(&enc)

	for i := 0; i < iters; i++ {
		dec.Reset()
		r := NewReader(bytes.NewReader(enc.Bytes()), hmac.New(sha256.New, testKey))
		if _, err := io.Copy(&dec, r); err != nil {
			return err
		}
	}

	// Check that the data matches; this should never fail.
	if !bytes.Equal(c, dec.Bytes()) {
		panic(fmt.Sprintf("data didn't match: got %v, expected %v", dec.Bytes(), c))
	}

	return nil
}

func TestTable(t *testing.T) {
	cases := [][]byte{
		// Various data sizes.
		nil,
		[]byte(""),
		[]byte("_"),
		[]byte("0"),
		[]byte("01"),
		[]byte("012"),
		[]byte("0123"),
		[]byte("01234"),
		[]byte("012356"),
		[]byte("0123567"),
		[]byte("01235678"),

		// Make sure we have one longer than the hash length.
		[]byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"),

		// Make sure we have one longer than the segment size.
		make([]byte, 3*SegmentSize),
		make([]byte, 3*SegmentSize-1),
		make([]byte, 3*SegmentSize+1),
		make([]byte, 3*SegmentSize-32),
		make([]byte, 3*SegmentSize+32),
		make([]byte, 30*SegmentSize),
	}

	for _, c := range cases {
		for _, flip := range []bool{false, true} {
			if len(c) == 0 && flip == true {
				continue
			}

			// Log the case.
			t.Logf("case: len=%d flip=%v", len(c), flip)

			if err := runTest(c, func(enc *bytes.Buffer) {
				if flip {
					corrupted := rand.Intn(enc.Len())
					enc.Bytes()[corrupted]++
				}
			}, 1); err != nil {
				if !flip || err != ErrHashMismatch {
					t.Errorf("error during read: got %v, expected nil", err)
				}
				continue
			} else if flip {
				t.Errorf("failed to detect ErrHashMismatch on corrupted data!")
				continue
			}
		}
	}
}

const benchBytes = 10 * 1024 * 1024 // 10 MB.

func BenchmarkWrite(b *testing.B) {
	b.StopTimer()
	x := make([]byte, benchBytes)
	b.SetBytes(benchBytes)
	b.StartTimer()
	if err := runTest(x, func(enc *bytes.Buffer) {
		b.StopTimer()
	}, b.N); err != nil {
		b.Errorf("benchmark failed: %v", err)
	}
}

func BenchmarkRead(b *testing.B) {
	b.StopTimer()
	x := make([]byte, benchBytes)
	b.SetBytes(benchBytes)
	if err := runTest(x, func(enc *bytes.Buffer) {
		b.StartTimer()
	}, b.N); err != nil {
		b.Errorf("benchmark failed: %v", err)
	}
}