// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package state

import (
	"container/list"
	"context"
	"encoding/binary"
	"fmt"
	"io"
	"reflect"
	"sort"

	"github.com/golang/protobuf/proto"
	pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
)

// queuedObject is an object queued for encoding.
type queuedObject struct {
	id   uint64
	obj  reflect.Value
	path recoverable
}

// encodeState is state used for encoding.
//
// The encoding process is a breadth-first traversal of the object graph. The
// inherent races and dependencies are much simpler than the decode case.
type encodeState struct {
	// ctx is the encode context.
	ctx context.Context

	// lastID is the last object ID.
	//
	// See idsByObject for context. Because of the special zero encoding
	// used for reference values, the first ID must be 1.
	lastID uint64

	// idsByObject is a set of objects, indexed via:
	//
	//	reflect.ValueOf(x).UnsafeAddr
	//
	// This provides IDs for objects.
	idsByObject map[uintptr]uint64

	// values stores values that span the addresses.
	//
	// addrSet is a a generated type which efficiently stores ranges of
	// addresses. When encoding pointers, these ranges are filled in and
	// used to check for overlapping or conflicting pointers. This would
	// indicate a pointer to an field, or a non-type safe value, neither of
	// which are currently decodable.
	//
	// See the usage of values below for more context.
	values addrSet

	// w is the output stream.
	w io.Writer

	// pending is the list of objects to be serialized.
	//
	// This is a set of queuedObjects.
	pending list.List

	// done is the a list of finished objects.
	//
	// This is kept to prevent garbage collection and address reuse.
	done list.List

	// stats is the passed stats object.
	stats *Stats

	// recoverable is the panic recover facility.
	recoverable
}

// register looks up an ID, registering if necessary.
//
// If the object was not previously registered, it is enqueued to be serialized.
// See the documentation for idsByObject for more information.
func (es *encodeState) register(obj reflect.Value) uint64 {
	// It is not legal to call register for any non-pointer objects (see
	// below), so we panic with a recoverable error if this is a mismatch.
	if obj.Kind() != reflect.Ptr && obj.Kind() != reflect.Map {
		panic(fmt.Errorf("non-pointer %#v registered", obj.Interface()))
	}

	addr := obj.Pointer()
	if obj.Kind() == reflect.Ptr && obj.Elem().Type().Size() == 0 {
		// For zero-sized objects, we always provide a unique ID.
		// That's because the runtime internally multiplexes pointers
		// to the same address. We can't be certain what the intent is
		// with pointers to zero-sized objects, so we just give them
		// all unique identities.
	} else if id, ok := es.idsByObject[addr]; ok {
		// Already registered.
		return id
	}

	// Ensure that the first ID given out is one. See note on lastID. The
	// ID zero is used to indicate nil values.
	es.lastID++
	id := es.lastID
	es.idsByObject[addr] = id
	if obj.Kind() == reflect.Ptr {
		// Dereference and treat as a pointer.
		es.pending.PushBack(queuedObject{id: id, obj: obj.Elem(), path: es.recoverable.copy()})

		// Register this object at all addresses.
		typ := obj.Elem().Type()
		if size := typ.Size(); size > 0 {
			r := addrRange{addr, addr + size}
			if !es.values.IsEmptyRange(r) {
				old := es.values.LowerBoundSegment(addr).Value().Interface().(recoverable)
				panic(fmt.Errorf("overlapping objects: [new object] %#v [existing object path] %s", obj.Interface(), old.path()))
			}
			es.values.Add(r, reflect.ValueOf(es.recoverable.copy()))
		}
	} else {
		// Push back the map itself; when maps are encoded from the
		// top-level, forceMap will be equal to true.
		es.pending.PushBack(queuedObject{id: id, obj: obj, path: es.recoverable.copy()})
	}

	return id
}

// encodeMap encodes a map.
func (es *encodeState) encodeMap(obj reflect.Value) *pb.Map {
	var (
		keys   []*pb.Object
		values []*pb.Object
	)
	for i, k := range obj.MapKeys() {
		v := obj.MapIndex(k)
		kp := es.encodeObject(k, false, ".(key %d)", i)
		vp := es.encodeObject(v, false, "[%#v]", k.Interface())
		keys = append(keys, kp)
		values = append(values, vp)
	}
	return &pb.Map{Keys: keys, Values: values}
}

// encodeStruct encodes a composite object.
func (es *encodeState) encodeStruct(obj reflect.Value) *pb.Struct {
	// Invoke the save.
	m := Map{newInternalMap(es, nil, nil)}
	defer internalMapPool.Put(m.internalMap)
	if !obj.CanAddr() {
		// Force it to a * type of the above; this involves a copy.
		localObj := reflect.New(obj.Type())
		localObj.Elem().Set(obj)
		obj = localObj.Elem()
	}
	fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
	if ok {
		// Invoke the provided saver.
		fns.invokeSave(obj.Addr(), m)
	} else if obj.NumField() == 0 {
		// Allow unregistered anonymous, empty structs.
		return &pb.Struct{}
	} else {
		// Propagate an error.
		panic(fmt.Errorf("unregistered type %T", obj.Interface()))
	}

	// Sort the underlying slice, and check for duplicates. This is done
	// once instead of on each add, because performing this sort once is
	// far more efficient.
	if len(m.data) > 1 {
		sort.Slice(m.data, func(i, j int) bool {
			return m.data[i].name < m.data[j].name
		})
		for i := range m.data {
			if i > 0 && m.data[i-1].name == m.data[i].name {
				panic(fmt.Errorf("duplicate name %s", m.data[i].name))
			}
		}
	}

	// Encode the resulting fields.
	fields := make([]*pb.Field, 0, len(m.data))
	for _, e := range m.data {
		fields = append(fields, &pb.Field{
			Name:  e.name,
			Value: e.object,
		})
	}

	// Return the encoded object.
	return &pb.Struct{Fields: fields}
}

// encodeArray encodes an array.
func (es *encodeState) encodeArray(obj reflect.Value) *pb.Array {
	var (
		contents []*pb.Object
	)
	for i := 0; i < obj.Len(); i++ {
		entry := es.encodeObject(obj.Index(i), false, "[%d]", i)
		contents = append(contents, entry)
	}
	return &pb.Array{Contents: contents}
}

// encodeInterface encodes an interface.
//
// Precondition: the value is not nil.
func (es *encodeState) encodeInterface(obj reflect.Value) *pb.Interface {
	// Check for the nil interface.
	obj = reflect.ValueOf(obj.Interface())
	if !obj.IsValid() {
		return &pb.Interface{
			Type:  "", // left alone in decode.
			Value: &pb.Object{Value: &pb.Object_RefValue{0}},
		}
	}
	// We have an interface value here. How do we save that? We
	// resolve the underlying type and save it as a dispatchable.
	typName, ok := registeredTypes.lookupName(obj.Type())
	if !ok {
		panic(fmt.Errorf("type %s is not registered", obj.Type()))
	}

	// Encode the object again.
	return &pb.Interface{
		Type:  typName,
		Value: es.encodeObject(obj, false, ".(%s)", typName),
	}
}

// encodeObject encodes an object.
//
// If mapAsValue is true, then a map will be encoded directly.
func (es *encodeState) encodeObject(obj reflect.Value, mapAsValue bool, format string, param interface{}) (object *pb.Object) {
	es.push(false, format, param)
	es.stats.Add(obj)
	es.stats.Start(obj)

	switch obj.Kind() {
	case reflect.Bool:
		object = &pb.Object{Value: &pb.Object_BoolValue{obj.Bool()}}
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		object = &pb.Object{Value: &pb.Object_Int64Value{obj.Int()}}
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
		object = &pb.Object{Value: &pb.Object_Uint64Value{obj.Uint()}}
	case reflect.Float32, reflect.Float64:
		object = &pb.Object{Value: &pb.Object_DoubleValue{obj.Float()}}
	case reflect.Array:
		switch obj.Type().Elem().Kind() {
		case reflect.Uint8:
			object = &pb.Object{Value: &pb.Object_ByteArrayValue{pbSlice(obj).Interface().([]byte)}}
		case reflect.Uint16:
			// 16-bit slices are serialized as 32-bit slices.
			// See object.proto for details.
			s := pbSlice(obj).Interface().([]uint16)
			t := make([]uint32, len(s))
			for i := range s {
				t[i] = uint32(s[i])
			}
			object = &pb.Object{Value: &pb.Object_Uint16ArrayValue{&pb.Uint16S{Values: t}}}
		case reflect.Uint32:
			object = &pb.Object{Value: &pb.Object_Uint32ArrayValue{&pb.Uint32S{Values: pbSlice(obj).Interface().([]uint32)}}}
		case reflect.Uint64:
			object = &pb.Object{Value: &pb.Object_Uint64ArrayValue{&pb.Uint64S{Values: pbSlice(obj).Interface().([]uint64)}}}
		case reflect.Uintptr:
			object = &pb.Object{Value: &pb.Object_UintptrArrayValue{&pb.Uintptrs{Values: pbSlice(obj).Interface().([]uint64)}}}
		case reflect.Int8:
			object = &pb.Object{Value: &pb.Object_Int8ArrayValue{&pb.Int8S{Values: pbSlice(obj).Interface().([]byte)}}}
		case reflect.Int16:
			// 16-bit slices are serialized as 32-bit slices.
			// See object.proto for details.
			s := pbSlice(obj).Interface().([]int16)
			t := make([]int32, len(s))
			for i := range s {
				t[i] = int32(s[i])
			}
			object = &pb.Object{Value: &pb.Object_Int16ArrayValue{&pb.Int16S{Values: t}}}
		case reflect.Int32:
			object = &pb.Object{Value: &pb.Object_Int32ArrayValue{&pb.Int32S{Values: pbSlice(obj).Interface().([]int32)}}}
		case reflect.Int64:
			object = &pb.Object{Value: &pb.Object_Int64ArrayValue{&pb.Int64S{Values: pbSlice(obj).Interface().([]int64)}}}
		case reflect.Bool:
			object = &pb.Object{Value: &pb.Object_BoolArrayValue{&pb.Bools{Values: pbSlice(obj).Interface().([]bool)}}}
		case reflect.Float32:
			object = &pb.Object{Value: &pb.Object_Float32ArrayValue{&pb.Float32S{Values: pbSlice(obj).Interface().([]float32)}}}
		case reflect.Float64:
			object = &pb.Object{Value: &pb.Object_Float64ArrayValue{&pb.Float64S{Values: pbSlice(obj).Interface().([]float64)}}}
		default:
			object = &pb.Object{Value: &pb.Object_ArrayValue{es.encodeArray(obj)}}
		}
	case reflect.Slice:
		if obj.IsNil() || obj.Cap() == 0 {
			// Handled specially in decode; store as nil value.
			object = &pb.Object{Value: &pb.Object_RefValue{0}}
		} else {
			// Serialize a slice as the array plus length and capacity.
			object = &pb.Object{Value: &pb.Object_SliceValue{&pb.Slice{
				Capacity: uint32(obj.Cap()),
				Length:   uint32(obj.Len()),
				RefValue: es.register(arrayFromSlice(obj)),
			}}}
		}
	case reflect.String:
		object = &pb.Object{Value: &pb.Object_StringValue{[]byte(obj.String())}}
	case reflect.Ptr:
		if obj.IsNil() {
			// Handled specially in decode; store as a nil value.
			object = &pb.Object{Value: &pb.Object_RefValue{0}}
		} else {
			es.push(true /* dereference */, "", nil)
			object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
			es.pop()
		}
	case reflect.Interface:
		// We don't check for IsNil here, as we want to encode type
		// information. The case of the empty interface (no type, no
		// value) is handled by encodeInteface.
		object = &pb.Object{Value: &pb.Object_InterfaceValue{es.encodeInterface(obj)}}
	case reflect.Struct:
		object = &pb.Object{Value: &pb.Object_StructValue{es.encodeStruct(obj)}}
	case reflect.Map:
		if obj.IsNil() {
			// Handled specially in decode; store as a nil value.
			object = &pb.Object{Value: &pb.Object_RefValue{0}}
		} else if mapAsValue {
			// Encode the map directly.
			object = &pb.Object{Value: &pb.Object_MapValue{es.encodeMap(obj)}}
		} else {
			// Encode a reference to the map.
			//
			// Remove the map object count here to avoid double
			// counting, as this object will be counted again when
			// it gets processed later. We do not add a reference
			// count as the reference is artificial.
			es.stats.Remove(obj)
			object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
		}
	default:
		panic(fmt.Errorf("unknown primitive %#v", obj.Interface()))
	}

	es.stats.Done()
	es.pop()
	return
}

// Serialize serializes the object state.
//
// This function may panic and should be run in safely().
func (es *encodeState) Serialize(obj reflect.Value) {
	es.register(obj.Addr())

	// Pop off the list until we're done.
	for es.pending.Len() > 0 {
		e := es.pending.Front()

		// Extract the queued object.
		qo := e.Value.(queuedObject)
		es.stats.Start(qo.obj)

		es.pending.Remove(e)

		es.from = &qo.path
		o := es.encodeObject(qo.obj, true, "", nil)

		// Emit to our output stream.
		if err := es.writeObject(qo.id, o); err != nil {
			panic(err)
		}

		// Mark as done.
		es.done.PushBack(e)
		es.stats.Done()
	}

	// Write a zero-length terminal at the end; this is a sanity check
	// applied at decode time as well (see decode.go).
	if err := WriteHeader(es.w, 0, false); err != nil {
		panic(err)
	}
}

// WriteHeader writes a header.
//
// Each object written to the statefile should be prefixed with a header. In
// order to generate statefiles that play nicely with debugging tools, raw
// writes should be prefixed with a header with object set to false and the
// appropriate length. This will allow tools to skip these regions.
func WriteHeader(w io.Writer, length uint64, object bool) error {
	// The lowest-order bit encodes whether this is a valid object. This is
	// a purely internal convention, but allows the object flag to be
	// returned from ReadHeader.
	length = length << 1
	if object {
		length |= 0x1
	}

	// Write a header.
	var hdr [32]byte
	encodedLen := binary.PutUvarint(hdr[:], length)
	for done := 0; done < encodedLen; {
		n, err := w.Write(hdr[done:encodedLen])
		done += n
		if n == 0 && err != nil {
			return err
		}
	}

	return nil
}

// writeObject writes an object to the stream.
func (es *encodeState) writeObject(id uint64, obj *pb.Object) error {
	// Marshal the proto.
	buf, err := proto.Marshal(obj)
	if err != nil {
		return err
	}

	// Write the object header.
	if err := WriteHeader(es.w, uint64(len(buf)), true); err != nil {
		return err
	}

	// Write the object.
	for done := 0; done < len(buf); {
		n, err := es.w.Write(buf[done:])
		done += n
		if n == 0 && err != nil {
			return err
		}
	}

	return nil
}

// addrSetFunctions is used by addrSet.
type addrSetFunctions struct{}

func (addrSetFunctions) MinKey() uintptr {
	return 0
}

func (addrSetFunctions) MaxKey() uintptr {
	return ^uintptr(0)
}

func (addrSetFunctions) ClearValue(val *reflect.Value) {
}

func (addrSetFunctions) Merge(_ addrRange, val1 reflect.Value, _ addrRange, val2 reflect.Value) (reflect.Value, bool) {
	return val1, val1 == val2
}

func (addrSetFunctions) Split(_ addrRange, val reflect.Value, _ uintptr) (reflect.Value, reflect.Value) {
	return val, val
}