// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package p9

import (
	"errors"
	"fmt"
	"io"
	"io/ioutil"

	"golang.org/x/sys/unix"
	"gvisor.dev/gvisor/pkg/fd"
	"gvisor.dev/gvisor/pkg/log"
	"gvisor.dev/gvisor/pkg/sync"
	"gvisor.dev/gvisor/pkg/unet"
)

// ErrSocket is returned in cases of a socket issue.
//
// This may be treated differently than other errors.
type ErrSocket struct {
	// error is the socket error.
	error
}

// ErrMessageTooLarge indicates the size was larger than reasonable.
type ErrMessageTooLarge struct {
	size  uint32
	msize uint32
}

// Error returns a sensible error.
func (e *ErrMessageTooLarge) Error() string {
	return fmt.Sprintf("message too large for fixed buffer: size is %d, limit is %d", e.size, e.msize)
}

// ErrNoValidMessage indicates no valid message could be decoded.
var ErrNoValidMessage = errors.New("buffer contained no valid message")

const (
	// headerLength is the number of bytes required for a header.
	headerLength uint32 = 7

	// maximumLength is the largest possible message.
	maximumLength uint32 = 1 << 20

	// DefaultMessageSize is a sensible default.
	DefaultMessageSize uint32 = 64 << 10

	// initialBufferLength is the initial data buffer we allocate.
	initialBufferLength uint32 = 64
)

var dataPool = sync.Pool{
	New: func() interface{} {
		// These buffers are used for decoding without a payload.
		// We need to return a pointer to avoid unnecessary allocations
		// (see https://staticcheck.io/docs/checks#SA6002).
		b := make([]byte, initialBufferLength)
		return &b
	},
}

// send sends the given message over the socket.
func send(s *unet.Socket, tag Tag, m message) error {
	data := dataPool.Get().(*[]byte)
	dataBuf := buffer{data: (*data)[:0]}

	if log.IsLogging(log.Debug) {
		log.Debugf("send [FD %d] [Tag %06d] %s", s.FD(), tag, m.String())
	}

	// Encode the message. The buffer will grow automatically.
	m.encode(&dataBuf)

	// Get our vectors to send.
	var hdr [headerLength]byte
	vecs := make([][]byte, 0, 3)
	vecs = append(vecs, hdr[:])
	if len(dataBuf.data) > 0 {
		vecs = append(vecs, dataBuf.data)
	}
	totalLength := headerLength + uint32(len(dataBuf.data))

	// Is there a payload?
	if payloader, ok := m.(payloader); ok {
		p := payloader.Payload()
		if len(p) > 0 {
			vecs = append(vecs, p)
			totalLength += uint32(len(p))
		}
	}

	// Construct the header.
	headerBuf := buffer{data: hdr[:0]}
	headerBuf.Write32(totalLength)
	headerBuf.WriteMsgType(m.Type())
	headerBuf.WriteTag(tag)

	// Pack any files if necessary.
	w := s.Writer(true)
	if filer, ok := m.(filer); ok {
		if f := filer.FilePayload(); f != nil {
			defer f.Close()
			// Pack the file into the message.
			w.PackFDs(f.FD())
		}
	}

	for n := 0; n < int(totalLength); {
		cur, err := w.WriteVec(vecs)
		if err != nil {
			return ErrSocket{err}
		}
		n += cur

		// Consume iovecs.
		for consumed := 0; consumed < cur; {
			if len(vecs[0]) <= cur-consumed {
				consumed += len(vecs[0])
				vecs = vecs[1:]
			} else {
				vecs[0] = vecs[0][cur-consumed:]
				break
			}
		}

		if n > 0 && n < int(totalLength) {
			// Don't resend any control message.
			w.UnpackFDs()
		}
	}

	// All set.
	dataPool.Put(&dataBuf.data)
	return nil
}

// lookupTagAndType looks up an existing message or creates a new one.
//
// This is called by recv after decoding the header. Any error returned will be
// propagating back to the caller. You may use messageByType directly as a
// lookupTagAndType function (by design).
type lookupTagAndType func(tag Tag, t MsgType) (message, error)

// recv decodes a message from the socket.
//
// This is done in two parts, and is thus not safe for multiple callers.
//
// On a socket error, the special error type ErrSocket is returned.
//
// The tag value NoTag will always be returned if err is non-nil.
func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, error) {
	// Read a header.
	//
	// Since the send above is atomic, we must always receive control
	// messages along with the header. This means we need to be careful
	// about closing FDs during errors to prevent leaks.
	var hdr [headerLength]byte
	r := s.Reader(true)
	r.EnableFDs(1)

	n, err := r.ReadVec([][]byte{hdr[:]})
	if err != nil && (n == 0 || err != io.EOF) {
		r.CloseFDs()
		return NoTag, nil, ErrSocket{err}
	}

	fds, err := r.ExtractFDs()
	if err != nil {
		return NoTag, nil, ErrSocket{err}
	}
	defer func() {
		// Close anything left open. The case where
		// fds are caught and used is handled below,
		// and the fds variable will be set to nil.
		for _, fd := range fds {
			unix.Close(fd)
		}
	}()
	r.EnableFDs(0)

	// Continuing reading for a short header.
	for n < int(headerLength) {
		cur, err := r.ReadVec([][]byte{hdr[n:]})
		if err != nil && (cur == 0 || err != io.EOF) {
			return NoTag, nil, ErrSocket{err}
		}
		n += cur
	}

	// Decode the header.
	headerBuf := buffer{data: hdr[:]}
	size := headerBuf.Read32()
	t := headerBuf.ReadMsgType()
	tag := headerBuf.ReadTag()
	if size < headerLength {
		// The message is too small.
		//
		// See above: it's probably screwed.
		return NoTag, nil, ErrSocket{ErrNoValidMessage}
	}
	if size > maximumLength || size > msize {
		// The message is too big.
		return NoTag, nil, ErrSocket{&ErrMessageTooLarge{size, msize}}
	}
	remaining := size - headerLength

	// Find our message to decode.
	m, err := lookup(tag, t)
	if err != nil {
		// Throw away the contents of this message.
		if remaining > 0 {
			io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)})
		}
		return tag, nil, err
	}

	// Not yet initialized.
	var dataBuf buffer
	var vecs [][]byte

	appendBuffer := func(size int) *[]byte {
		// Pull a data buffer from the pool.
		datap := dataPool.Get().(*[]byte)
		data := *datap
		if size > len(data) {
			// Create a larger data buffer.
			data = make([]byte, size)
			datap = &data
		} else {
			// Limit the data buffer.
			data = data[:size]
		}
		dataBuf = buffer{data: data}
		vecs = append(vecs, data)
		return datap
	}

	// Read the rest of the payload.
	//
	// This requires some special care to ensure that the vectors all line
	// up the way they should. We do this to minimize copying data around.
	if payloader, ok := m.(payloader); ok {
		fixedSize := payloader.FixedSize()

		// Do we need more than there is?
		if fixedSize > remaining {
			// This is not a valid message.
			if remaining > 0 {
				io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)})
			}
			return NoTag, nil, ErrNoValidMessage
		}

		if fixedSize != 0 {
			datap := appendBuffer(int(fixedSize))
			defer dataPool.Put(datap)
		}

		// Include the payload.
		p := payloader.Payload()
		if p == nil || len(p) != int(remaining-fixedSize) {
			p = make([]byte, remaining-fixedSize)
			payloader.SetPayload(p)
		}
		if len(p) > 0 {
			vecs = append(vecs, p)
		}
	} else if remaining != 0 {
		datap := appendBuffer(int(remaining))
		defer dataPool.Put(datap)
	}

	if len(vecs) > 0 {
		// Read the rest of the message.
		//
		// No need to handle a control message.
		r := s.Reader(true)
		for n := 0; n < int(remaining); {
			cur, err := r.ReadVec(vecs)
			if err != nil && (cur == 0 || err != io.EOF) {
				return NoTag, nil, ErrSocket{err}
			}
			n += cur

			// Consume iovecs.
			for consumed := 0; consumed < cur; {
				if len(vecs[0]) <= cur-consumed {
					consumed += len(vecs[0])
					vecs = vecs[1:]
				} else {
					vecs[0] = vecs[0][cur-consumed:]
					break
				}
			}
		}
	}

	// Decode the message data.
	m.decode(&dataBuf)
	if dataBuf.isOverrun() {
		// No need to drain the socket.
		return NoTag, nil, ErrNoValidMessage
	}

	// Save the file, if any came out.
	if filer, ok := m.(filer); ok && len(fds) > 0 {
		// Set the file object.
		filer.SetFilePayload(fd.New(fds[0]))

		// Close the rest. We support only one.
		for i := 1; i < len(fds); i++ {
			unix.Close(fds[i])
		}

		// Don't close in the defer.
		fds = nil
	}

	if log.IsLogging(log.Debug) {
		log.Debugf("recv [FD %d] [Tag %06d] %s", s.FD(), tag, m.String())
	}

	// All set.
	return tag, m, nil
}