// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package state provides functionality related to saving and loading object
// graphs.  For most types, it provides a set of default saving / loading logic
// that will be invoked automatically if custom logic is not defined.
//
//     Kind             Support
//     ----             -------
//     Bool             default
//     Int              default
//     Int8             default
//     Int16            default
//     Int32            default
//     Int64            default
//     Uint             default
//     Uint8            default
//     Uint16           default
//     Uint32           default
//     Uint64           default
//     Float32          default
//     Float64          default
//     Complex64        custom
//     Complex128       custom
//     Array            default
//     Chan             custom
//     Func             custom
//     Interface        custom
//     Map              default (*)
//     Ptr              default
//     Slice            default
//     String           default
//     Struct           custom
//     UnsafePointer    custom
//
// (*) Maps are treated as value types by this package, even if they are
// pointers internally. If you want to save two independent references
// to the same map value, you must explicitly use a pointer to a map.
package state

import (
	"context"
	"fmt"
	"io"
	"reflect"
	"runtime"

	pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
)

// ErrState is returned when an error is encountered during encode/decode.
type ErrState struct {
	// err is the underlying error.
	err error

	// path is the visit path from root to the current object.
	path string

	// trace is the stack trace.
	trace string
}

// Error returns a sensible description of the state error.
func (e *ErrState) Error() string {
	return fmt.Sprintf("%v:\nstate path: %s\n%s", e.err, e.path, e.trace)
}

// UnwrapErrState returns the underlying error in ErrState.
//
// If err is not *ErrState, err is returned directly.
func UnwrapErrState(err error) error {
	if e, ok := err.(*ErrState); ok {
		return e.err
	}
	return err
}

// Save saves the given object state.
func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error {
	// Create the encoding state.
	es := &encodeState{
		ctx:         ctx,
		idsByObject: make(map[uintptr]uint64),
		w:           w,
		stats:       stats,
	}

	// Perform the encoding.
	return es.safely(func() {
		es.Serialize(reflect.ValueOf(rootPtr).Elem())
	})
}

// Load loads a checkpoint.
func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error {
	// Create the decoding state.
	ds := &decodeState{
		ctx:         ctx,
		objectsByID: make(map[uint64]*objectState),
		deferred:    make(map[uint64]*pb.Object),
		r:           r,
		stats:       stats,
	}

	// Attempt our decode.
	return ds.safely(func() {
		ds.Deserialize(reflect.ValueOf(rootPtr).Elem())
	})
}

// Fns are the state dispatch functions.
type Fns struct {
	// Save is a function like Save(concreteType, Map).
	Save interface{}

	// Load is a function like Load(concreteType, Map).
	Load interface{}
}

// Save executes the save function.
func (fns *Fns) invokeSave(obj reflect.Value, m Map) {
	reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)})
}

// Load executes the load function.
func (fns *Fns) invokeLoad(obj reflect.Value, m Map) {
	reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)})
}

// validateStateFn ensures types are correct.
func validateStateFn(fn interface{}, typ reflect.Type) bool {
	fnTyp := reflect.TypeOf(fn)
	if fnTyp.Kind() != reflect.Func {
		return false
	}
	if fnTyp.NumIn() != 2 {
		return false
	}
	if fnTyp.NumOut() != 0 {
		return false
	}
	if fnTyp.In(0) != typ {
		return false
	}
	if fnTyp.In(1) != reflect.TypeOf(Map{}) {
		return false
	}
	return true
}

// Validate validates all state functions.
func (fns *Fns) Validate(typ reflect.Type) bool {
	return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ)
}

type typeDatabase struct {
	// nameToType is a forward lookup table.
	nameToType map[string]reflect.Type

	// typeToName is the reverse lookup table.
	typeToName map[reflect.Type]string

	// typeToFns is the function lookup table.
	typeToFns map[reflect.Type]Fns
}

// registeredTypes is a database used for SaveInterface and LoadInterface.
var registeredTypes = typeDatabase{
	nameToType: make(map[string]reflect.Type),
	typeToName: make(map[reflect.Type]string),
	typeToFns:  make(map[reflect.Type]Fns),
}

// register registers a type under the given name. This will generally be
// called via init() methods, and therefore uses panic to propagate errors.
func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) {
	// We can't allow name collisions.
	if ot, ok := t.nameToType[name]; ok {
		panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name()))
	}

	// Or multiple registrations.
	if on, ok := t.typeToName[typ]; ok {
		panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on))
	}

	t.nameToType[name] = typ
	t.typeToName[typ] = name
	t.typeToFns[typ] = fns
}

// lookupType finds a type given a name.
func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) {
	typ, ok := t.nameToType[name]
	return typ, ok
}

// lookupName finds a name given a type.
func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) {
	name, ok := t.typeToName[typ]
	return name, ok
}

// lookupFns finds functions given a type.
func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) {
	fns, ok := t.typeToFns[typ]
	return fns, ok
}

// Register must be called for any interface implementation types that
// implements Loader.
//
// Register should be called either immediately after startup or via init()
// methods. Double registration of either names or types will result in a panic.
//
// No synchronization is provided; this should only be called in init.
//
// Example usage:
//
// 	state.Register("Foo", (*Foo)(nil), state.Fns{
//		Save: (*Foo).Save,
//		Load: (*Foo).Load,
//	})
//
func Register(name string, instance interface{}, fns Fns) {
	registeredTypes.register(name, reflect.TypeOf(instance), fns)
}

// IsZeroValue checks if the given value is the zero value.
//
// This function is used by the stateify tool.
func IsZeroValue(val interface{}) bool {
	if val == nil {
		return true
	}
	return reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface())
}

// step captures one encoding / decoding step. On each step, there is up to one
// choice made, which is captured by non-nil param. We intentionally do not
// eagerly create the final path string, as that will only be needed upon panic.
type step struct {
	// dereference indicate if the current object is obtained by
	// dereferencing a pointer.
	dereference bool

	// format is the formatting string that takes param below, if
	// non-nil. For example, in array indexing case, we have "[%d]".
	format string

	// param stores the choice made at the current encoding / decoding step.
	// For eaxmple, in array indexing case, param stores the index. When no
	// choice is made, e.g. dereference, param should be nil.
	param interface{}
}

// recoverable is the state encoding / decoding panic recovery facility. It is
// also used to store encoding / decoding steps as well as the reference to the
// original queued object from which the current object is dispatched. The
// complete encoding / decoding path is synthesised from the steps in all queued
// objects leading to the current object.
type recoverable struct {
	from  *recoverable
	steps []step
}

// push enters a new context level.
func (sr *recoverable) push(dereference bool, format string, param interface{}) {
	sr.steps = append(sr.steps, step{dereference, format, param})
}

// pop exits the current context level.
func (sr *recoverable) pop() {
	if len(sr.steps) <= 1 {
		return
	}
	sr.steps = sr.steps[:len(sr.steps)-1]
}

// path returns the complete encoding / decoding path from root. This is only
// called upon panic.
func (sr *recoverable) path() string {
	if sr.from == nil {
		return "root"
	}
	p := sr.from.path()
	for _, s := range sr.steps {
		if s.dereference {
			p = fmt.Sprintf("*(%s)", p)
		}
		if s.param == nil {
			p += s.format
		} else {
			p += fmt.Sprintf(s.format, s.param)
		}
	}
	return p
}

func (sr *recoverable) copy() recoverable {
	return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)}
}

// safely executes the given function, catching a panic and unpacking as an error.
//
// The error flow through the state package uses panic and recover. There are
// two important reasons for this:
//
// 1) Many of the reflection methods will already panic with invalid data or
// violated assumptions. We would want to recover anyways here.
//
// 2) It allows us to eliminate boilerplate within Save() and Load() functions.
// In nearly all cases, when the low-level serialization functions fail, you
// will want the checkpoint to fail anyways. Plumbing errors through every
// method doesn't add a lot of value. If there are specific error conditions
// that you'd like to handle, you should add appropriate functionality to
// objects themselves prior to calling Save() and Load().
func (sr *recoverable) safely(fn func()) (err error) {
	defer func() {
		if r := recover(); r != nil {
			es := new(ErrState)
			if e, ok := r.(error); ok {
				es.err = e
			} else {
				es.err = fmt.Errorf("%v", r)
			}

			es.path = sr.path()

			// Make a stack. We don't know how big it will be ahead
			// of time, but want to make sure we get the whole
			// thing. So we just do a stupid brute force approach.
			var stack []byte
			for sz := 1024; ; sz *= 2 {
				stack = make([]byte, sz)
				n := runtime.Stack(stack, false)
				if n < sz {
					es.trace = string(stack[:n])
					break
				}
			}

			// Set the error.
			err = es
		}
	}()

	// Execute the function.
	fn()
	return nil
}