// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package tests tests the state packages.
package tests

import (
	"bytes"
	"context"
	"fmt"
	"math"
	"reflect"
	"testing"

	"gvisor.dev/gvisor/pkg/state"
	"gvisor.dev/gvisor/pkg/state/pretty"
)

// discard is an implementation of wire.Writer.
type discard struct{}

// Write implements wire.Writer.Write.
func (discard) Write(p []byte) (int, error) { return len(p), nil }

// WriteByte implements wire.Writer.WriteByte.
func (discard) WriteByte(byte) error { return nil }

// checkEqual checks if two objects are equal.
//
// N.B. This only handles one level of dereferences for NaN. Otherwise we
// would need to fork the entire implementation of reflect.DeepEqual.
func checkEqual(root, loadedValue interface{}) bool {
	if reflect.DeepEqual(root, loadedValue) {
		return true
	}

	// NaN is not equal to itself. We handle the case of raw floating point
	// primitives here, but don't handle this case nested.
	rf32, ok1 := root.(float32)
	lf32, ok2 := loadedValue.(float32)
	if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) {
		return true
	}
	rf64, ok1 := root.(float64)
	lf64, ok2 := loadedValue.(float64)
	if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) {
		return true
	}

	// Same real for complex numbers.
	rc64, ok1 := root.(complex64)
	lc64, ok2 := root.(complex64)
	if ok1 && ok2 {
		return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64))
	}
	rc128, ok1 := root.(complex128)
	lc128, ok2 := root.(complex128)
	if ok1 && ok2 {
		return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128))
	}

	return false
}

// runTestCases runs a test for each object in objects.
func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []interface{}) {
	t.Helper()
	for i, root := range objects {
		t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) {
			t.Logf("Original object:\n%#v", root)

			// Save the passed object.
			saveBuffer := &bytes.Buffer{}
			saveObjectPtr := reflect.New(reflect.TypeOf(root))
			saveObjectPtr.Elem().Set(reflect.ValueOf(root))
			saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface())
			if err != nil {
				if shouldFail {
					return
				}
				t.Fatalf("Save failed unexpectedly: %v", err)
			}

			// Dump the serialized proto to aid with debugging.
			var ppBuf bytes.Buffer
			t.Logf("Raw state:\n%v", saveBuffer.Bytes())
			if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil {
				// We don't count this as a test failure if we
				// have shouldFail set, but we will count as a
				// failure if we were not expecting to fail.
				if !shouldFail {
					t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err)
				}
			}
			if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil {
				// See above.
				if !shouldFail {
					t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err)
				}
			}
			t.Logf("Encoded state:\n%s", ppBuf.String())
			t.Logf("Save stats:\n%s", saveStats.String())

			// Load a new copy of the object.
			loadObjectPtr := reflect.New(reflect.TypeOf(root))
			loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface())
			if err != nil {
				if shouldFail {
					return
				}
				t.Fatalf("Load failed unexpectedly: %v", err)
			}

			// Compare the values.
			loadedValue := loadObjectPtr.Elem().Interface()
			if !checkEqual(root, loadedValue) {
				if shouldFail {
					return
				}
				t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded:   %#v\n", root, loadedValue)
			}

			// Everything went okay. Is that good?
			if shouldFail {
				t.Fatalf("This test was expected to fail, but didn't.")
			}
			t.Logf("Load stats:\n%s", loadStats.String())

			// Truncate half the bytes in the byte stream,
			// and ensure that we can't restore. Then
			// truncate only the final byte and ensure that
			// we can't restore.
			l := saveBuffer.Len()
			halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2])
			if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil {
				t.Errorf("Load with half bytes succeeded unexpectedly.")
			}
			missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1])
			if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil {
				t.Errorf("Load with missing byte succeeded unexpectedly.")
			}
		})
	}
}

// convert converts the slice to an []interface{}.
func convert(v interface{}) (r []interface{}) {
	s := reflect.ValueOf(v) // Must be slice.
	for i := 0; i < s.Len(); i++ {
		r = append(r, s.Index(i).Interface())
	}
	return r
}

// flatten flattens multiple slices.
func flatten(vs ...interface{}) (r []interface{}) {
	for _, v := range vs {
		r = append(r, convert(v)...)
	}
	return r
}

// filter maps from one slice to another.
func filter(vs interface{}, fn func(interface{}) (interface{}, bool)) (r []interface{}) {
	s := reflect.ValueOf(vs)
	for i := 0; i < s.Len(); i++ {
		v, ok := fn(s.Index(i).Interface())
		if ok {
			r = append(r, v)
		}
	}
	return r
}

// combine combines objects in two slices as specified.
func combine(v1, v2 interface{}, fn func(_, _ interface{}) interface{}) (r []interface{}) {
	s1 := reflect.ValueOf(v1)
	s2 := reflect.ValueOf(v2)
	for i := 0; i < s1.Len(); i++ {
		for j := 0; j < s2.Len(); j++ {
			// Combine using the given function.
			r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface()))
		}
	}
	return r
}

// pointersTo is a filter function that returns pointers.
func pointersTo(vs interface{}) []interface{} {
	return filter(vs, func(o interface{}) (interface{}, bool) {
		v := reflect.New(reflect.TypeOf(o))
		v.Elem().Set(reflect.ValueOf(o))
		return v.Interface(), true
	})
}

// interfacesTo is a filter function that returns interface objects.
func interfacesTo(vs interface{}) []interface{} {
	return filter(vs, func(o interface{}) (interface{}, bool) {
		var v [1]interface{}
		v[0] = o
		return v, true
	})
}