// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package state

import (
	"bytes"
	"context"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"reflect"
	"sort"

	"github.com/golang/protobuf/proto"
	pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
)

// objectState represents an object that may be in the process of being
// decoded. Specifically, it represents either a decoded object, or an an
// interest in a future object that will be decoded. When that interest is
// registered (via register), the storage for the object will be created, but
// it will not be decoded until the object is encountered in the stream.
type objectState struct {
	// id is the id for this object.
	//
	// If this field is zero, then this is an anonymous (unregistered,
	// non-reference primitive) object. This is immutable.
	id uint64

	// obj is the object. This may or may not be valid yet, depending on
	// whether complete returns true. However, regardless of whether the
	// object is valid, obj contains a final storage location for the
	// object. This is immutable.
	//
	// Note that this must be addressable (obj.Addr() must not panic).
	//
	// The obj passed to the decode methods below will equal this obj only
	// in the case of decoding the top-level object. However, the passed
	// obj may represent individual fields, elements of a slice, etc. that
	// are effectively embedded within the reflect.Value below but with
	// distinct types.
	obj reflect.Value

	// blockedBy is the number of dependencies this object has.
	blockedBy int

	// blocking is a list of the objects blocked by this one.
	blocking []*objectState

	// callbacks is a set of callbacks to execute on load.
	callbacks []func()

	// path is the decoding path to the object.
	path recoverable
}

// complete indicates the object is complete.
func (os *objectState) complete() bool {
	return os.blockedBy == 0 && len(os.callbacks) == 0
}

// checkComplete checks for completion. If the object is complete, pending
// callbacks will be executed and checkComplete will be called on downstream
// objects (those depending on this one).
func (os *objectState) checkComplete(stats *Stats) {
	if os.blockedBy > 0 {
		return
	}
	stats.Start(os.obj)

	// Fire all callbacks.
	for _, fn := range os.callbacks {
		fn()
	}
	os.callbacks = nil

	// Clear all blocked objects.
	for _, other := range os.blocking {
		other.blockedBy--
		other.checkComplete(stats)
	}
	os.blocking = nil
	stats.Done()
}

// waitFor queues a dependency on the given object.
func (os *objectState) waitFor(other *objectState, callback func()) {
	os.blockedBy++
	other.blocking = append(other.blocking, os)
	if callback != nil {
		other.callbacks = append(other.callbacks, callback)
	}
}

// findCycleFor returns when the given object is found in the blocking set.
func (os *objectState) findCycleFor(target *objectState) []*objectState {
	for _, other := range os.blocking {
		if other == target {
			return []*objectState{target}
		} else if childList := other.findCycleFor(target); childList != nil {
			return append(childList, other)
		}
	}
	return nil
}

// findCycle finds a dependency cycle.
func (os *objectState) findCycle() []*objectState {
	return append(os.findCycleFor(os), os)
}

// decodeState is a graph of objects in the process of being decoded.
//
// The decode process involves loading the breadth-first graph generated by
// encode. This graph is read in it's entirety, ensuring that all object
// storage is complete.
//
// As the graph is being serialized, a set of completion callbacks are
// executed. These completion callbacks should form a set of acyclic subgraphs
// over the original one. After decoding is complete, the objects are scanned
// to ensure that all callbacks are executed, otherwise the callback graph was
// not acyclic.
type decodeState struct {
	// ctx is the decode context.
	ctx context.Context

	// objectByID is the set of objects in progress.
	objectsByID map[uint64]*objectState

	// deferred are objects that have been read, by no interest has been
	// registered yet. These will be decoded once interest in registered.
	deferred map[uint64]*pb.Object

	// outstanding is the number of outstanding objects.
	outstanding uint32

	// r is the input stream.
	r io.Reader

	// stats is the passed stats object.
	stats *Stats

	// recoverable is the panic recover facility.
	recoverable
}

// lookup looks up an object in decodeState or returns nil if no such object
// has been previously registered.
func (ds *decodeState) lookup(id uint64) *objectState {
	return ds.objectsByID[id]
}

// wait registers a dependency on an object.
//
// As a special case, we always allow _useable_ references back to the first
// decoding object because it may have fields that are already decoded. We also
// allow trivial self reference, since they can be handled internally.
func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
	switch id {
	case 0:
		// Nil pointer; nothing to wait for.
		fallthrough
	case waiter.id:
		// Trivial self reference.
		fallthrough
	case 1:
		// Root object; see above.
		if callback != nil {
			callback()
		}
		return
	}

	// No nil can be returned here.
	waiter.waitFor(ds.lookup(id), callback)
}

// waitObject notes a blocking relationship.
func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) {
	if rv, ok := p.Value.(*pb.Object_RefValue); ok {
		// Refs can encode pointers and maps.
		ds.wait(os, rv.RefValue, callback)
	} else if sv, ok := p.Value.(*pb.Object_SliceValue); ok {
		// See decodeObject; we need to wait for the array (if non-nil).
		ds.wait(os, sv.SliceValue.RefValue, callback)
	} else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok {
		// It's an interface (wait recurisvely).
		ds.waitObject(os, iv.InterfaceValue.Value, callback)
	} else if callback != nil {
		// Nothing to wait for: execute the callback immediately.
		callback()
	}
}

// register registers a decode with a type.
//
// This type is only used to instantiate a new object if it has not been
// registered previously.
func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState {
	os, ok := ds.objectsByID[id]
	if ok {
		return os
	}

	// Record in the object index.
	if typ.Kind() == reflect.Map {
		os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()}
	} else {
		os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()}
	}
	ds.objectsByID[id] = os

	if o, ok := ds.deferred[id]; ok {
		// There is a deferred object.
		delete(ds.deferred, id) // Free memory.
		ds.decodeObject(os, os.obj, o, "", nil)
	} else {
		// There is no deferred object.
		ds.outstanding++
	}

	return os
}

// decodeStruct decodes a struct value.
func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) {
	// Set the fields.
	m := Map{newInternalMap(nil, ds, os)}
	defer internalMapPool.Put(m.internalMap)
	for _, field := range s.Fields {
		m.data = append(m.data, entry{
			name:   field.Name,
			object: field.Value,
		})
	}

	// Sort the fields for efficient searching.
	//
	// Technically, these should already appear in sorted order in the
	// state ordering, so this cost is effectively a single scan to ensure
	// that the order is correct.
	if len(m.data) > 1 {
		sort.Slice(m.data, func(i, j int) bool {
			return m.data[i].name < m.data[j].name
		})
	}

	// Invoke the load; this will recursively decode other objects.
	fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
	if ok {
		// Invoke the loader.
		fns.invokeLoad(obj.Addr(), m)
	} else if obj.NumField() == 0 {
		// Allow anonymous empty structs.
		return
	} else {
		// Propagate an error.
		panic(fmt.Errorf("unregistered type %s", obj.Type()))
	}
}

// decodeMap decodes a map value.
func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) {
	if obj.IsNil() {
		obj.Set(reflect.MakeMap(obj.Type()))
	}
	for i := 0; i < len(m.Keys); i++ {
		// Decode the objects.
		kv := reflect.New(obj.Type().Key()).Elem()
		vv := reflect.New(obj.Type().Elem()).Elem()
		ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i)
		ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface())
		ds.waitObject(os, m.Keys[i], nil)
		ds.waitObject(os, m.Values[i], nil)

		// Set in the map.
		obj.SetMapIndex(kv, vv)
	}
}

// decodeArray decodes an array value.
func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) {
	if len(a.Contents) != obj.Len() {
		panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents)))
	}
	// Decode the contents into the array.
	for i := 0; i < len(a.Contents); i++ {
		ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i)
		ds.waitObject(os, a.Contents[i], nil)
	}
}

// decodeInterface decodes an interface value.
func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) {
	// Is this a nil value?
	if i.Type == "" {
		return // Just leave obj alone.
	}

	// Get the dispatchable type. This may not be used if the given
	// reference has already been resolved, but if not we need to know the
	// type to create.
	t, ok := registeredTypes.lookupType(i.Type)
	if !ok {
		panic(fmt.Errorf("no valid type for %q", i.Type))
	}

	if obj.Kind() != reflect.Map {
		// Set the obj to be the given typed value; this actually sets
		// obj to be a non-zero value -- namely, it inserts type
		// information. There's no need to do this for maps.
		obj.Set(reflect.Zero(t))
	}

	// Decode the dereferenced element; there is no need to wait here, as
	// the interface object shares the current object state.
	ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type)
}

// decodeObject decodes a object value.
func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) {
	ds.push(false, format, param)
	ds.stats.Add(obj)
	ds.stats.Start(obj)

	switch x := object.GetValue().(type) {
	case *pb.Object_BoolValue:
		obj.SetBool(x.BoolValue)
	case *pb.Object_StringValue:
		obj.SetString(string(x.StringValue))
	case *pb.Object_Int64Value:
		obj.SetInt(x.Int64Value)
		if obj.Int() != x.Int64Value {
			panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type()))
		}
	case *pb.Object_Uint64Value:
		obj.SetUint(x.Uint64Value)
		if obj.Uint() != x.Uint64Value {
			panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type()))
		}
	case *pb.Object_DoubleValue:
		obj.SetFloat(x.DoubleValue)
		if obj.Float() != x.DoubleValue {
			panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type()))
		}
	case *pb.Object_RefValue:
		// Resolve the pointer itself, even though the object may not
		// be decoded yet. You need to use wait() in order to ensure
		// that is the case. See wait above, and Map.Barrier.
		if id := x.RefValue; id != 0 {
			// Decoding the interface should have imparted type
			// information, so from this point it's safe to resolve
			// and use this dynamic information for actually
			// creating the object in register.
			//
			// (For non-interfaces this is a no-op).
			dyntyp := reflect.TypeOf(obj.Interface())
			if dyntyp.Kind() == reflect.Map {
				// Remove the map object count here to avoid
				// double counting, as this object will be
				// counted again when it gets processed later.
				// We do not add a reference count as the
				// reference is artificial.
				ds.stats.Remove(obj)
				obj.Set(ds.register(id, dyntyp).obj)
			} else if dyntyp.Kind() == reflect.Ptr {
				ds.push(true /* dereference */, "", nil)
				obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
				ds.pop()
			} else {
				obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
			}
		} else {
			// We leave obj alone here. That's because if obj
			// represents an interface, it may have been embued
			// with type information in decodeInterface, and we
			// don't want to destroy that information.
		}
	case *pb.Object_SliceValue:
		// It's okay to slice the array here, since the contents will
		// still be provided later on. These semantics are a bit
		// strange but they are handled in the Map.Barrier properly.
		//
		// The special semantics of zero ref apply here too.
		if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 {
			v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem())
			obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity)))
		}
	case *pb.Object_ArrayValue:
		ds.decodeArray(os, obj, x.ArrayValue)
	case *pb.Object_StructValue:
		ds.decodeStruct(os, obj, x.StructValue)
	case *pb.Object_MapValue:
		ds.decodeMap(os, obj, x.MapValue)
	case *pb.Object_InterfaceValue:
		ds.decodeInterface(os, obj, x.InterfaceValue)
	case *pb.Object_ByteArrayValue:
		copyArray(obj, reflect.ValueOf(x.ByteArrayValue))
	case *pb.Object_Uint16ArrayValue:
		// 16-bit slices are serialized as 32-bit slices.
		// See object.proto for details.
		s := x.Uint16ArrayValue.Values
		t := obj.Slice(0, obj.Len()).Interface().([]uint16)
		if len(t) != len(s) {
			panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
		}
		for i := range s {
			t[i] = uint16(s[i])
		}
	case *pb.Object_Uint32ArrayValue:
		copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values))
	case *pb.Object_Uint64ArrayValue:
		copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values))
	case *pb.Object_UintptrArrayValue:
		copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
	case *pb.Object_Int8ArrayValue:
		copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
	case *pb.Object_Int16ArrayValue:
		// 16-bit slices are serialized as 32-bit slices.
		// See object.proto for details.
		s := x.Int16ArrayValue.Values
		t := obj.Slice(0, obj.Len()).Interface().([]int16)
		if len(t) != len(s) {
			panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
		}
		for i := range s {
			t[i] = int16(s[i])
		}
	case *pb.Object_Int32ArrayValue:
		copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values))
	case *pb.Object_Int64ArrayValue:
		copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values))
	case *pb.Object_BoolArrayValue:
		copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values))
	case *pb.Object_Float64ArrayValue:
		copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values))
	case *pb.Object_Float32ArrayValue:
		copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values))
	default:
		// Shoud not happen, not propagated as an error.
		panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type()))
	}

	ds.stats.Done()
	ds.pop()
}

func copyArray(dest reflect.Value, src reflect.Value) {
	if dest.Len() != src.Len() {
		panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len()))
	}
	reflect.Copy(dest, castSlice(src, dest.Type().Elem()))
}

// Deserialize deserializes the object state.
//
// This function may panic and should be run in safely().
func (ds *decodeState) Deserialize(obj reflect.Value) {
	ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()}
	ds.outstanding = 1 // The root object.

	// Decode all objects in the stream.
	//
	// See above, we never process objects while we have no outstanding
	// interests (other than the very first object).
	for id := uint64(1); ds.outstanding > 0; id++ {
		os := ds.lookup(id)
		ds.stats.Start(os.obj)

		o, err := ds.readObject()
		if err != nil {
			panic(err)
		}

		if os != nil {
			// Decode the object.
			ds.from = &os.path
			ds.decodeObject(os, os.obj, o, "", nil)
			ds.outstanding--
		} else {
			// If an object hasn't had interest registered
			// previously, we deferred decoding until interest is
			// registered.
			ds.deferred[id] = o
		}

		ds.stats.Done()
	}

	// Check the zero-length header at the end.
	length, object, err := ReadHeader(ds.r)
	if err != nil {
		panic(err)
	}
	if length != 0 {
		panic(fmt.Sprintf("expected zero-length terminal, got %d", length))
	}
	if object {
		panic("expected non-object terminal")
	}

	// Check if we have any deferred objects.
	if count := len(ds.deferred); count > 0 {
		// Shoud not happen, not propagated as an error.
		panic(fmt.Sprintf("still have %d deferred objects", count))
	}

	// Scan and fire all callbacks.
	for _, os := range ds.objectsByID {
		os.checkComplete(ds.stats)
	}

	// Check if we have any remaining dependency cycles.
	for _, os := range ds.objectsByID {
		if !os.complete() {
			// This must be the result of a dependency cycle.
			cycle := os.findCycle()
			var buf bytes.Buffer
			buf.WriteString("dependency cycle: {")
			for i, cycleOS := range cycle {
				if i > 0 {
					buf.WriteString(" => ")
				}
				buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type()))
			}
			buf.WriteString("}")
			// Panic as an error; propagate to the caller.
			panic(errors.New(string(buf.Bytes())))
		}
	}
}

type byteReader struct {
	io.Reader
}

// ReadByte implements io.ByteReader.
func (br byteReader) ReadByte() (byte, error) {
	var b [1]byte
	n, err := br.Reader.Read(b[:])
	if n > 0 {
		return b[0], nil
	} else if err != nil {
		return 0, err
	} else {
		return 0, io.ErrUnexpectedEOF
	}
}

// ReadHeader reads an object header.
//
// Each object written to the statefile is prefixed with a header. See
// WriteHeader for more information; these functions are exported to allow
// non-state writes to the file to play nice with debugging tools.
func ReadHeader(r io.Reader) (length uint64, object bool, err error) {
	// Read the header.
	length, err = binary.ReadUvarint(byteReader{r})
	if err != nil {
		return
	}

	// Decode whether the object is valid.
	object = length&0x1 != 0
	length = length >> 1
	return
}

// readObject reads an object from the stream.
func (ds *decodeState) readObject() (*pb.Object, error) {
	// Read the header.
	length, object, err := ReadHeader(ds.r)
	if err != nil {
		return nil, err
	}
	if !object {
		return nil, fmt.Errorf("invalid object header")
	}

	// Read the object.
	buf := make([]byte, length)
	for done := 0; done < len(buf); {
		n, err := ds.r.Read(buf[done:])
		done += n
		if n == 0 && err != nil {
			return nil, err
		}
	}

	// Unmarshal.
	obj := new(pb.Object)
	if err := proto.Unmarshal(buf, obj); err != nil {
		return nil, err
	}

	return obj, nil
}