// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package urpc provides a minimal RPC package based on unet.
//
// RPC requests are _not_ concurrent and methods must be explicitly
// registered. However, files may be send as part of the payload.
package urpc

import (
	"bytes"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"os"
	"reflect"
	"runtime"

	"gvisor.dev/gvisor/pkg/fd"
	"gvisor.dev/gvisor/pkg/log"
	"gvisor.dev/gvisor/pkg/sync"
	"gvisor.dev/gvisor/pkg/unet"
)

// maxFiles determines the maximum file payload.
const maxFiles = 32

// ErrTooManyFiles is returned when too many file descriptors are mapped.
var ErrTooManyFiles = errors.New("too many files")

// ErrUnknownMethod is returned when a method is not known.
var ErrUnknownMethod = errors.New("unknown method")

// errStopped is an internal error indicating the server has been stopped.
var errStopped = errors.New("stopped")

// RemoteError is an error returned by the remote invocation.
//
// This indicates that the RPC transport was correct, but that the called
// function itself returned an error.
type RemoteError struct {
	// Message is the result of calling Error() on the remote error.
	Message string
}

// Error returns the remote error string.
func (r RemoteError) Error() string {
	return r.Message
}

// FilePayload may be _embedded_ in another type in order to send or receive a
// file as a result of an RPC. These are not actually serialized, rather they
// are sent via an accompanying SCM_RIGHTS message (plumbed through the unet
// package).
//
// When embedding a FilePayload in an argument struct, the argument type _must_
// be a pointer to the struct rather than the struct type itself. This is
// because the urpc package defines pointer methods on FilePayload.
type FilePayload struct {
	Files []*os.File `json:"-"`
}

// ReleaseFD releases the FD at the specified index.
func (f *FilePayload) ReleaseFD(index int) (*fd.FD, error) {
	return fd.NewFromFile(f.Files[index])
}

// filePayload returns the file. It may be nil.
func (f *FilePayload) filePayload() []*os.File {
	return f.Files
}

// setFilePayload sets the payload.
func (f *FilePayload) setFilePayload(fs []*os.File) {
	f.Files = fs
}

// closeAll closes a slice of files.
func closeAll(files []*os.File) {
	for _, f := range files {
		f.Close()
	}
}

// filePayloader is implemented only by FilePayload and will be implicitly
// implemented by types that have the FilePayload embedded. Note that there is
// no way to implement these methods other than by embedding FilePayload, due
// to the way unexported method names are mangled.
type filePayloader interface {
	filePayload() []*os.File
	setFilePayload([]*os.File)
}

// clientCall is the client=>server method call on the client side.
type clientCall struct {
	Method string      `json:"method"`
	Arg    interface{} `json:"arg"`
}

// serverCall is the client=>server method call on the server side.
type serverCall struct {
	Method string          `json:"method"`
	Arg    json.RawMessage `json:"arg"`
}

// callResult is the server=>client method call result.
type callResult struct {
	Success bool        `json:"success"`
	Err     string      `json:"err"`
	Result  interface{} `json:"result"`
}

// registeredMethod is method registered with the server.
type registeredMethod struct {
	// fn is the underlying function.
	fn reflect.Value

	// rcvr is the receiver value.
	rcvr reflect.Value

	// argType is a typed argument.
	argType reflect.Type

	// resultType is also a type result.
	resultType reflect.Type
}

// clientState is client metadata.
//
// The following are valid states:
//
// idle - not processing any requests, no close request.
// processing - actively processing, no close request.
// closeRequested - actively processing, pending close.
// closed - client connection has been closed.
//
// The following transitions are possible:
//
// idle -> processing, closed
// processing -> idle, closeRequested
// closeRequested -> closed
//
type clientState int

// See clientState.
const (
	idle clientState = iota
	processing
	closeRequested
	closed
)

// Server is an RPC server.
type Server struct {
	// mu protects all fields, except wg.
	mu sync.Mutex

	// methods is the set of server methods.
	methods map[string]registeredMethod

	// stoppers are all registered stoppers.
	stoppers []Stopper

	// clients is a map of clients.
	clients map[*unet.Socket]clientState

	// wg is a wait group for all outstanding clients.
	wg sync.WaitGroup

	// afterRPCCallback is called after each RPC is successfully completed.
	afterRPCCallback func()
}

// NewServer returns a new server.
func NewServer() *Server {
	return NewServerWithCallback(nil)
}

// NewServerWithCallback returns a new server, who upon completion of each RPC
// calls the given function.
func NewServerWithCallback(afterRPCCallback func()) *Server {
	return &Server{
		methods:          make(map[string]registeredMethod),
		clients:          make(map[*unet.Socket]clientState),
		afterRPCCallback: afterRPCCallback,
	}
}

// Stopper is an optional interface, that when implemented, allows an object
// to have a callback executed when the server is shutting down.
type Stopper interface {
	Stop()
}

// Register registers the given object as an RPC receiver.
//
// This functions is the same way as the built-in RPC package, but it does not
// tolerate any object with non-conforming methods. Any non-confirming methods
// will lead to an immediate panic, instead of being skipped or an error.
// Panics will also be generated by anonymous objects and duplicate entries.
func (s *Server) Register(obj interface{}) {
	s.mu.Lock()
	defer s.mu.Unlock()

	typ := reflect.TypeOf(obj)
	stopper, hasStop := obj.(Stopper)

	// If we got a pointer, deref it to the underlying object. We need this to
	// obtain the name of the underlying type.
	typDeref := typ
	if typ.Kind() == reflect.Ptr {
		typDeref = typ.Elem()
	}

	for m := 0; m < typ.NumMethod(); m++ {
		method := typ.Method(m)

		if typDeref.Name() == "" {
			// Can't be anonymous.
			panic("type not named.")
		}
		if hasStop && method.Name == "Stop" {
			s.stoppers = append(s.stoppers, stopper)
			continue // Legal stop method.
		}

		prettyName := typDeref.Name() + "." + method.Name
		if _, ok := s.methods[prettyName]; ok {
			// Duplicate entry.
			panic(fmt.Sprintf("method %s is duplicated.", prettyName))
		}

		if method.PkgPath != "" {
			// Must be exported.
			panic(fmt.Sprintf("method %s is not exported.", prettyName))
		}
		mtype := method.Type
		if mtype.NumIn() != 3 {
			// Need exactly two arguments (+ receiver).
			panic(fmt.Sprintf("method %s has wrong number of arguments.", prettyName))
		}
		argType := mtype.In(1)
		if argType.Kind() != reflect.Ptr {
			// Need arg pointer.
			panic(fmt.Sprintf("method %s has non-pointer first argument.", prettyName))
		}
		resultType := mtype.In(2)
		if resultType.Kind() != reflect.Ptr {
			// Need result pointer.
			panic(fmt.Sprintf("method %s has non-pointer second argument.", prettyName))
		}
		if mtype.NumOut() != 1 {
			// Need single return.
			panic(fmt.Sprintf("method %s has wrong number of returns.", prettyName))
		}
		if returnType := mtype.Out(0); returnType != reflect.TypeOf((*error)(nil)).Elem() {
			// Need error return.
			panic(fmt.Sprintf("method %s has non-error return value.", prettyName))
		}

		// Register the method.
		s.methods[prettyName] = registeredMethod{
			fn:         method.Func,
			rcvr:       reflect.ValueOf(obj),
			argType:    argType,
			resultType: resultType,
		}
	}
}

// lookup looks up the given method.
func (s *Server) lookup(method string) (registeredMethod, bool) {
	s.mu.Lock()
	defer s.mu.Unlock()
	rm, ok := s.methods[method]
	return rm, ok
}

// handleOne handles a single call.
func (s *Server) handleOne(client *unet.Socket) error {
	// Unmarshal the call.
	var c serverCall
	newFs, err := unmarshal(client, &c)
	if err != nil {
		// Client is dead.
		return err
	}
	if s.afterRPCCallback != nil {
		defer s.afterRPCCallback()
	}

	// Explicitly close all these files after the call.
	//
	// This is also explicitly a reference to the files after the call,
	// which means they are kept open for the duration of the call.
	defer closeAll(newFs)

	// Start the request.
	if !s.clientBeginRequest(client) {
		// Client is dead; don't process this call.
		return errStopped
	}
	defer s.clientEndRequest(client)

	// Lookup the method.
	rm, ok := s.lookup(c.Method)
	if !ok {
		// Try to serialize the error.
		return marshal(client, &callResult{Err: ErrUnknownMethod.Error()}, nil)
	}

	// Unmarshal the arguments now that we know the type.
	na := reflect.New(rm.argType.Elem())
	if err := json.Unmarshal(c.Arg, na.Interface()); err != nil {
		return marshal(client, &callResult{Err: err.Error()}, nil)
	}

	// Set the file payload as an argument.
	if fp, ok := na.Interface().(filePayloader); ok {
		fp.setFilePayload(newFs)
	}

	// Call the method.
	re := reflect.New(rm.resultType.Elem())
	rValues := rm.fn.Call([]reflect.Value{rm.rcvr, na, re})
	if errVal := rValues[0].Interface(); errVal != nil {
		return marshal(client, &callResult{Err: errVal.(error).Error()}, nil)
	}

	// Set the resulting payload.
	var fs []*os.File
	if fp, ok := re.Interface().(filePayloader); ok {
		fs = fp.filePayload()
		if len(fs) > maxFiles {
			// Ugh. Send an error to the client, despite success.
			return marshal(client, &callResult{Err: ErrTooManyFiles.Error()}, nil)
		}
	}

	// Marshal the result.
	return marshal(client, &callResult{Success: true, Result: re.Interface()}, fs)
}

// clientBeginRequest begins a request.
//
// If true is returned, the request may be processed. If false is returned,
// then the server has been stopped and the request should be skipped.
func (s *Server) clientBeginRequest(client *unet.Socket) bool {
	s.mu.Lock()
	defer s.mu.Unlock()
	switch state := s.clients[client]; state {
	case idle:
		// Mark as processing.
		s.clients[client] = processing
		return true
	case closed:
		// Whoops, how did this happen? Must have closed immediately
		// following the deserialization. Don't let the RPC actually go
		// through, since we won't be able to serialize a proper
		// response.
		return false
	default:
		// Should not happen.
		panic(fmt.Sprintf("expected idle or closed, got %d", state))
	}
}

// clientEndRequest ends a request.
func (s *Server) clientEndRequest(client *unet.Socket) {
	s.mu.Lock()
	defer s.mu.Unlock()
	switch state := s.clients[client]; state {
	case processing:
		// Return to idle.
		s.clients[client] = idle
	case closeRequested:
		// Close the connection.
		client.Close()
		s.clients[client] = closed
	default:
		// Should not happen.
		panic(fmt.Sprintf("expected processing or requestClose, got %d", state))
	}
}

// clientRegister registers a connection.
//
// See Stop for more context.
func (s *Server) clientRegister(client *unet.Socket) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.clients[client] = idle
	s.wg.Add(1)
}

// clientUnregister unregisters and closes a connection if necessary.
//
// See Stop for more context.
func (s *Server) clientUnregister(client *unet.Socket) {
	s.mu.Lock()
	defer s.mu.Unlock()
	switch state := s.clients[client]; state {
	case idle:
		// Close the connection.
		client.Close()
	case closed:
		// Already done.
	default:
		// Should not happen.
		panic(fmt.Sprintf("expected idle or closed, got %d", state))
	}
	delete(s.clients, client)
	s.wg.Done()
}

// handleRegistered handles calls from a registered client.
func (s *Server) handleRegistered(client *unet.Socket) error {
	for {
		// Handle one call.
		if err := s.handleOne(client); err != nil {
			// Client is dead.
			return err
		}
	}
}

// Handle synchronously handles a single client over a connection.
func (s *Server) Handle(client *unet.Socket) error {
	s.clientRegister(client)
	defer s.clientUnregister(client)
	return s.handleRegistered(client)
}

// StartHandling creates a goroutine that handles a single client over a
// connection.
func (s *Server) StartHandling(client *unet.Socket) {
	s.clientRegister(client)
	go func() { // S/R-SAFE: out of scope
		defer s.clientUnregister(client)
		s.handleRegistered(client)
	}()
}

// Stop safely terminates outstanding clients.
//
// No new requests should be initiated after calling Stop. Existing clients
// will be closed after completing any pending RPCs. This method will block
// until all clients have disconnected.
func (s *Server) Stop() {
	// Wait for all outstanding requests.
	defer s.wg.Wait()

	// Call any Stop callbacks.
	for _, stopper := range s.stoppers {
		stopper.Stop()
	}

	// Close all known clients.
	s.mu.Lock()
	defer s.mu.Unlock()
	for client, state := range s.clients {
		switch state {
		case idle:
			// Close connection now.
			client.Close()
			s.clients[client] = closed
		case processing:
			// Request close when done.
			s.clients[client] = closeRequested
		}
	}
}

// Client is a urpc client.
type Client struct {
	// mu protects all members.
	//
	// It also enforces single-call semantics.
	mu sync.Mutex

	// Socket is the underlying socket for this client.
	//
	// This _must_ be provided and must be closed manually by calling
	// Close.
	Socket *unet.Socket
}

// NewClient returns a new client.
func NewClient(socket *unet.Socket) *Client {
	return &Client{
		Socket: socket,
	}
}

// marshal sends the given FD and json struct.
func marshal(s *unet.Socket, v interface{}, fs []*os.File) error {
	// Marshal to a buffer.
	data, err := json.Marshal(v)
	if err != nil {
		log.Warningf("urpc: error marshalling %s: %s", fmt.Sprintf("%v", v), err.Error())
		return err
	}

	// Write to the socket.
	w := s.Writer(true)
	if fs != nil {
		var fds []int
		for _, f := range fs {
			fds = append(fds, int(f.Fd()))
		}
		w.PackFDs(fds...)
	}

	// Send.
	for n := 0; n < len(data); {
		cur, err := w.WriteVec([][]byte{data[n:]})
		if n == 0 && cur < len(data) {
			// Don't send FDs anymore. This call is only made on
			// the first successful call to WriteVec, assuming cur
			// is not sufficient to fill the entire buffer.
			w.PackFDs()
		}
		n += cur
		if err != nil {
			log.Warningf("urpc: error writing %v: %s", data[n:], err.Error())
			return err
		}
	}

	// We're done sending the fds to the client. Explicitly prevent fs from
	// being GCed until here. Urpc rpcs often unlink the file to send, relying
	// on the kernel to automatically delete it once the last reference is
	// dropped. Until we successfully call sendmsg(2), fs may contain the last
	// references to these files. Without this explicit reference to fs here,
	// the go runtime is free to assume we're done with fs after the fd
	// collection loop above, since it just sees us copying ints.
	runtime.KeepAlive(fs)

	log.Debugf("urpc: successfully marshalled %d bytes.", len(data))
	return nil
}

// unmarhsal receives an FD (optional) and unmarshals the given struct.
func unmarshal(s *unet.Socket, v interface{}) ([]*os.File, error) {
	// Receive a single byte.
	r := s.Reader(true)
	r.EnableFDs(maxFiles)
	firstByte := make([]byte, 1)

	// Extract any FDs that may be there.
	if _, err := r.ReadVec([][]byte{firstByte}); err != nil {
		return nil, err
	}
	fds, err := r.ExtractFDs()
	if err != nil {
		log.Warningf("urpc: error extracting fds: %s", err.Error())
		return nil, err
	}
	var fs []*os.File
	for _, fd := range fds {
		fs = append(fs, os.NewFile(uintptr(fd), "urpc"))
	}

	// Read the rest.
	d := json.NewDecoder(io.MultiReader(bytes.NewBuffer(firstByte), s))
	// urpc internally decodes / re-encodes the data with interface{} as the
	// intermediate type. We have to unmarshal integers to json.Number type
	// instead of the default float type for those intermediate values, such
	// that when they get re-encoded, their values are not printed out in
	// floating-point formats such as 1e9, which could not be decoded to
	// explicitly typed intergers later.
	d.UseNumber()
	if err := d.Decode(v); err != nil {
		log.Warningf("urpc: error decoding: %s", err.Error())
		for _, f := range fs {
			f.Close()
		}
		return nil, err
	}

	// All set.
	log.Debugf("urpc: unmarshal success.")
	return fs, nil
}

// Call calls a function.
func (c *Client) Call(method string, arg interface{}, result interface{}) error {
	c.mu.Lock()
	defer c.mu.Unlock()

	// If arg is a FilePayload, not a *FilePayload, files won't actually be
	// sent, so error out.
	if _, ok := arg.(FilePayload); ok {
		return fmt.Errorf("argument is a FilePayload, but should be a *FilePayload")
	}

	// Are there files to send?
	var fs []*os.File
	if fp, ok := arg.(filePayloader); ok {
		fs = fp.filePayload()
		if len(fs) > maxFiles {
			return ErrTooManyFiles
		}
	}

	// Marshal the data.
	if err := marshal(c.Socket, &clientCall{Method: method, Arg: arg}, fs); err != nil {
		return err
	}

	// Wait for the response.
	callR := callResult{Result: result}
	newFs, err := unmarshal(c.Socket, &callR)
	if err != nil {
		return fmt.Errorf("urpc method %q failed: %v", method, err)
	}

	// Set the file payload.
	if fp, ok := result.(filePayloader); ok {
		fp.setFilePayload(newFs)
	} else {
		closeAll(newFs)
	}

	// Did an error occur?
	if !callR.Success {
		return RemoteError{Message: callR.Err}
	}

	// All set.
	return nil
}

// Close closes the underlying socket.
//
// Further calls to the client may result in undefined behavior.
func (c *Client) Close() error {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.Socket.Close()
}