// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package tests

import (
	"bytes"
	"context"
	"encoding/gob"
	"fmt"
	"testing"

	"gvisor.dev/gvisor/pkg/state"
	"gvisor.dev/gvisor/pkg/state/wire"
)

// buildPtrObject builds a benchmark object.
func buildPtrObject(n int) interface{} {
	b := new(benchStruct)
	for i := 0; i < n; i++ {
		b = &benchStruct{B: b}
	}
	return b
}

// buildMapObject builds a benchmark object.
func buildMapObject(n int) interface{} {
	b := new(benchStruct)
	m := make(map[int]*benchStruct)
	for i := 0; i < n; i++ {
		m[i] = b
	}
	return &m
}

// buildSliceObject builds a benchmark object.
func buildSliceObject(n int) interface{} {
	b := new(benchStruct)
	s := make([]*benchStruct, 0, n)
	for i := 0; i < n; i++ {
		s = append(s, b)
	}
	return &s
}

var allObjects = map[string]struct {
	New func(int) interface{}
}{
	"ptr": {
		New: buildPtrObject,
	},
	"map": {
		New: buildMapObject,
	},
	"slice": {
		New: buildSliceObject,
	},
}

func buildObjects(n int, fn func(int) interface{}) (iters int, v interface{}) {
	// maxSize is the maximum size of an individual object below. For an N
	// larger than this, we start to return multiple objects.
	const maxSize = 1024
	if n <= maxSize {
		return 1, fn(n)
	}
	iters = (n + maxSize - 1) / maxSize
	return iters, fn(maxSize)
}

// gobSave is a version of save using gob (no stats available).
func gobSave(_ context.Context, w wire.Writer, v interface{}) (_ state.Stats, err error) {
	enc := gob.NewEncoder(w)
	err = enc.Encode(v)
	return
}

// gobLoad is a version of load using gob (no stats available).
func gobLoad(_ context.Context, r wire.Reader, v interface{}) (_ state.Stats, err error) {
	dec := gob.NewDecoder(r)
	err = dec.Decode(v)
	return
}

var allAlgos = map[string]struct {
	Save   func(context.Context, wire.Writer, interface{}) (state.Stats, error)
	Load   func(context.Context, wire.Reader, interface{}) (state.Stats, error)
	MaxPtr int
}{
	"state": {
		Save: state.Save,
		Load: state.Load,
	},
	"gob": {
		Save: gobSave,
		Load: gobLoad,
	},
}

func BenchmarkEncoding(b *testing.B) {
	for objName, objInfo := range allObjects {
		for algoName, algoInfo := range allAlgos {
			b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
				b.StopTimer()
				n, v := buildObjects(b.N, objInfo.New)
				b.ReportAllocs()
				b.StartTimer()
				for i := 0; i < n; i++ {
					if _, err := algoInfo.Save(context.Background(), discard{}, v); err != nil {
						b.Errorf("save failed: %v", err)
					}
				}
				b.StopTimer()
			})
		}
	}
}

func BenchmarkDecoding(b *testing.B) {
	for objName, objInfo := range allObjects {
		for algoName, algoInfo := range allAlgos {
			b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
				b.StopTimer()
				n, v := buildObjects(b.N, objInfo.New)
				buf := new(bytes.Buffer)
				if _, err := algoInfo.Save(context.Background(), buf, v); err != nil {
					b.Errorf("save failed: %v", err)
				}
				b.ReportAllocs()
				b.StartTimer()
				var r bytes.Reader
				for i := 0; i < n; i++ {
					r.Reset(buf.Bytes())
					if _, err := algoInfo.Load(context.Background(), &r, v); err != nil {
						b.Errorf("load failed: %v", err)
					}
				}
				b.StopTimer()
			})
		}
	}
}