// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package safemem

import (
	"bytes"
	"io"
	"testing"
)

func makeBlocks(slices ...[]byte) []Block {
	blocks := make([]Block, 0, len(slices))
	for _, s := range slices {
		blocks = append(blocks, BlockFromSafeSlice(s))
	}
	return blocks
}

func TestFromIOReaderFullRead(t *testing.T) {
	r := FromIOReader{bytes.NewBufferString("foobar")}
	dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
	n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
	if wantN := uint64(6); n != wantN || err != nil {
		t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
	}
	for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
		if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
			t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
		}
	}
}

type eofHidingReader struct {
	Reader io.Reader
}

func (r eofHidingReader) Read(dst []byte) (int, error) {
	n, err := r.Reader.Read(dst)
	if err == io.EOF {
		return n, nil
	}
	return n, err
}

func TestFromIOReaderPartialRead(t *testing.T) {
	r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}}
	dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
	n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
	// FromIOReader should stop after the eofHidingReader returns (1, nil)
	// for a 3-byte read.
	if wantN := uint64(4); n != wantN || err != nil {
		t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
	}
	for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} {
		if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
			t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
		}
	}
}

type singleByteReader struct {
	Reader io.Reader
}

func (r singleByteReader) Read(dst []byte) (int, error) {
	if len(dst) == 0 {
		return r.Reader.Read(dst)
	}
	return r.Reader.Read(dst[:1])
}

func TestSingleByteReader(t *testing.T) {
	r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
	dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
	n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
	// FromIOReader should stop after the singleByteReader returns (1, nil)
	// for a 3-byte read.
	if wantN := uint64(1); n != wantN || err != nil {
		t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
	}
	for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} {
		if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
			t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
		}
	}
}

func TestReadFullToBlocks(t *testing.T) {
	r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
	dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
	n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts))
	// ReadFullToBlocks should call into FromIOReader => singleByteReader
	// repeatedly until dsts is exhausted.
	if wantN := uint64(6); n != wantN || err != nil {
		t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
	}
	for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
		if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
			t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
		}
	}
}

func TestFromIOWriterFullWrite(t *testing.T) {
	srcs := makeBlocks([]byte("foo"), []byte("bar"))
	var dst bytes.Buffer
	w := FromIOWriter{&dst}
	n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
	if wantN := uint64(6); n != wantN || err != nil {
		t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
	}
	if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
		t.Errorf("dst: got %q, wanted %q", got, want)
	}
}

type limitedWriter struct {
	Writer io.Writer
	Done   int
	Limit  int
}

func (w *limitedWriter) Write(src []byte) (int, error) {
	count := len(src)
	if count > (w.Limit - w.Done) {
		count = w.Limit - w.Done
	}
	n, err := w.Writer.Write(src[:count])
	w.Done += n
	return n, err
}

func TestFromIOWriterPartialWrite(t *testing.T) {
	srcs := makeBlocks([]byte("foo"), []byte("bar"))
	var dst bytes.Buffer
	w := FromIOWriter{&limitedWriter{&dst, 0, 4}}
	n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
	// FromIOWriter should stop after the limitedWriter returns (1, nil) for a
	// 3-byte write.
	if wantN := uint64(4); n != wantN || err != nil {
		t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
	}
	if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) {
		t.Errorf("dst: got %q, wanted %q", got, want)
	}
}

type singleByteWriter struct {
	Writer io.Writer
}

func (w singleByteWriter) Write(src []byte) (int, error) {
	if len(src) == 0 {
		return w.Writer.Write(src)
	}
	return w.Writer.Write(src[:1])
}

func TestSingleByteWriter(t *testing.T) {
	srcs := makeBlocks([]byte("foo"), []byte("bar"))
	var dst bytes.Buffer
	w := FromIOWriter{singleByteWriter{&dst}}
	n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
	// FromIOWriter should stop after the singleByteWriter returns (1, nil)
	// for a 3-byte write.
	if wantN := uint64(1); n != wantN || err != nil {
		t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
	}
	if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) {
		t.Errorf("dst: got %q, wanted %q", got, want)
	}
}

func TestWriteFullToBlocks(t *testing.T) {
	srcs := makeBlocks([]byte("foo"), []byte("bar"))
	var dst bytes.Buffer
	w := FromIOWriter{singleByteWriter{&dst}}
	n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs))
	// WriteFullToBlocks should call into FromIOWriter => singleByteWriter
	// repeatedly until srcs is exhausted.
	if wantN := uint64(6); n != wantN || err != nil {
		t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
	}
	if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
		t.Errorf("dst: got %q, wanted %q", got, want)
	}
}