// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package testutil

import (
	"bytes"
	"context"
	"fmt"
	"io"
	"os"
	"os/exec"
	"strings"
	"time"

	"github.com/kr/pty"
	"golang.org/x/sys/unix"
)

// Prompt is used as shell prompt.
// It is meant to be unique enough to not be seen in command outputs.
const Prompt = "PROMPT> "

// Simplistic shell string escape.
func shellEscape(s string) string {
	// specialChars is used to determine whether s needs quoting at all.
	const specialChars = "\\'\"`${[|&;<>()*?! \t\n"
	// If s needs quoting, escapedChars is the set of characters that are
	// escaped with a backslash.
	const escapedChars = "\\\"$`"
	if len(s) == 0 {
		return "''"
	}
	if !strings.ContainsAny(s, specialChars) {
		return s
	}
	var b bytes.Buffer
	b.WriteString("\"")
	for _, c := range s {
		if strings.ContainsAny(string(c), escapedChars) {
			b.WriteString("\\")
		}
		b.WriteRune(c)
	}
	b.WriteString("\"")
	return b.String()
}

type byteOrError struct {
	b   byte
	err error
}

// Shell manages a /bin/sh invocation with convenience functions to handle I/O.
// The shell is run in its own interactive TTY and should present its prompt.
type Shell struct {
	// cmd is a reference to the underlying sh process.
	cmd *exec.Cmd
	// cmdFinished is closed when cmd exits.
	cmdFinished chan struct{}

	// echo is whether the shell will echo input back to us.
	// This helps setting expectations of getting feedback of written bytes.
	echo bool
	// Control characters we expect to see in the shell.
	controlCharIntr string
	controlCharEOF  string

	// ptyMaster and ptyReplica are the TTY pair associated with the shell.
	ptyMaster  *os.File
	ptyReplica *os.File
	// readCh is a channel where everything read from ptyMaster is written.
	readCh chan byteOrError

	// logger is used for logging. It may be nil.
	logger Logger
}

// cleanup kills the shell process and closes the TTY.
// Users of this library get a reference to this function with NewShell.
func (s *Shell) cleanup() {
	s.logf("cleanup", "Shell cleanup started.")
	if s.cmd.ProcessState == nil {
		if err := s.cmd.Process.Kill(); err != nil {
			s.logf("cleanup", "cannot kill shell process: %v", err)
		}
		// We don't log the error returned by Wait because the monitorExit
		// goroutine will already do so.
		s.cmd.Wait()
	}
	s.ptyReplica.Close()
	s.ptyMaster.Close()
	// Wait for monitorExit goroutine to write exit status to the debug log.
	<-s.cmdFinished
	// Empty out everything in the readCh, but don't wait too long for it.
	var extraBytes bytes.Buffer
	unreadTimeout := time.After(100 * time.Millisecond)
unreadLoop:
	for {
		select {
		case r, ok := <-s.readCh:
			if !ok {
				break unreadLoop
			} else if r.err == nil {
				extraBytes.WriteByte(r.b)
			}
		case <-unreadTimeout:
			break unreadLoop
		}
	}
	if extraBytes.Len() > 0 {
		s.logIO("unread", extraBytes.Bytes(), nil)
	}
	s.logf("cleanup", "Shell cleanup complete.")
}

// logIO logs byte I/O to both standard logging and the test log, if provided.
func (s *Shell) logIO(prefix string, b []byte, err error) {
	var sb strings.Builder
	if len(b) > 0 {
		sb.WriteString(fmt.Sprintf("%q", b))
	} else {
		sb.WriteString("(nothing)")
	}
	if err != nil {
		sb.WriteString(fmt.Sprintf(" [error: %v]", err))
	}
	s.logf(prefix, "%s", sb.String())
}

// logf logs something to both standard logging and the test log, if provided.
func (s *Shell) logf(prefix, format string, values ...interface{}) {
	if s.logger != nil {
		s.logger.Logf("[%s] %s", prefix, fmt.Sprintf(format, values...))
	}
}

// monitorExit waits for the shell process to exit and logs the exit result.
func (s *Shell) monitorExit() {
	if err := s.cmd.Wait(); err != nil {
		s.logf("cmd", "shell process terminated: %v", err)
	} else {
		s.logf("cmd", "shell process terminated successfully")
	}
	close(s.cmdFinished)
}

// reader continuously reads the shell output and populates readCh.
func (s *Shell) reader(ctx context.Context) {
	b := make([]byte, 4096)
	defer close(s.readCh)
	for {
		select {
		case <-s.cmdFinished:
			// Shell process terminated; stop trying to read.
			return
		case <-ctx.Done():
			// Shell process will also have terminated in this case;
			// stop trying to read.
			// We don't print an error here because doing so would print this in the
			// normal case where the context passed to NewShell is canceled at the
			// end of a successful test.
			return
		default:
			// Shell still running, try reading.
		}
		if got, err := s.ptyMaster.Read(b); err != nil {
			s.readCh <- byteOrError{err: err}
			if err == io.EOF {
				return
			}
		} else {
			for i := 0; i < got; i++ {
				s.readCh <- byteOrError{b: b[i]}
			}
		}
	}
}

// readByte reads a single byte, respecting the context.
func (s *Shell) readByte(ctx context.Context) (byte, error) {
	select {
	case <-ctx.Done():
		return 0, ctx.Err()
	case r := <-s.readCh:
		return r.b, r.err
	}
}

// readLoop reads as many bytes as possible until the context expires, b is
// full, or a short time passes. It returns how many bytes it has successfully
// read.
func (s *Shell) readLoop(ctx context.Context, b []byte) (int, error) {
	soonCtx, soonCancel := context.WithTimeout(ctx, 5*time.Second)
	defer soonCancel()
	var i int
	for i = 0; i < len(b) && soonCtx.Err() == nil; i++ {
		next, err := s.readByte(soonCtx)
		if err != nil {
			if i > 0 {
				s.logIO("read", b[:i-1], err)
			} else {
				s.logIO("read", nil, err)
			}
			return i, err
		}
		b[i] = next
	}
	s.logIO("read", b[:i], soonCtx.Err())
	return i, soonCtx.Err()
}

// readLine reads a single line. Strips out all \r characters for convenience.
// Upon error, it will still return what it has read so far.
// It will also exit quickly if the line content it has read so far (without a
// line break) matches `prompt`.
func (s *Shell) readLine(ctx context.Context, prompt string) ([]byte, error) {
	soonCtx, soonCancel := context.WithTimeout(ctx, 5*time.Second)
	defer soonCancel()
	var lineData bytes.Buffer
	var b byte
	var err error
	for soonCtx.Err() == nil && b != '\n' {
		b, err = s.readByte(soonCtx)
		if err != nil {
			data := lineData.Bytes()
			s.logIO("read", data, err)
			return data, err
		}
		if b != '\r' {
			lineData.WriteByte(b)
		}
		if bytes.Equal(lineData.Bytes(), []byte(prompt)) {
			// Assume that there will not be any further output if we get the prompt.
			// This avoids waiting for the read deadline just to read the prompt.
			break
		}
	}
	data := lineData.Bytes()
	s.logIO("read", data, soonCtx.Err())
	return data, soonCtx.Err()
}

// Expect verifies that the next `len(want)` bytes we read match `want`.
func (s *Shell) Expect(ctx context.Context, want []byte) error {
	errPrefix := fmt.Sprintf("want(%q)", want)
	b := make([]byte, len(want))
	got, err := s.readLoop(ctx, b)
	if err != nil {
		if ctx.Err() != nil {
			return fmt.Errorf("%s: context done (%w), got: %q", errPrefix, err, b[:got])
		}
		return fmt.Errorf("%s: %w", errPrefix, err)
	}
	if got < len(want) {
		return fmt.Errorf("%s: short read (read %d bytes, expected %d): %q", errPrefix, got, len(want), b[:got])
	}
	if !bytes.Equal(b, want) {
		return fmt.Errorf("got %q want %q", b, want)
	}
	return nil
}

// ExpectString verifies that the next `len(want)` bytes we read match `want`.
func (s *Shell) ExpectString(ctx context.Context, want string) error {
	return s.Expect(ctx, []byte(want))
}

// ExpectPrompt verifies that the next few bytes we read are the shell prompt.
func (s *Shell) ExpectPrompt(ctx context.Context) error {
	return s.ExpectString(ctx, Prompt)
}

// ExpectEmptyLine verifies that the next few bytes we read are an empty line,
// as defined by any number of carriage or line break characters.
func (s *Shell) ExpectEmptyLine(ctx context.Context) error {
	line, err := s.readLine(ctx, Prompt)
	if err != nil {
		return fmt.Errorf("cannot read line: %w", err)
	}
	if strings.Trim(string(line), "\r\n") != "" {
		return fmt.Errorf("line was not empty: %q", line)
	}
	return nil
}

// ExpectLine verifies that the next `len(want)` bytes we read match `want`,
// followed by carriage returns or newline characters.
func (s *Shell) ExpectLine(ctx context.Context, want string) error {
	if err := s.ExpectString(ctx, want); err != nil {
		return err
	}
	if err := s.ExpectEmptyLine(ctx); err != nil {
		return fmt.Errorf("ExpectLine(%q): no line break: %w", want, err)
	}
	return nil
}

// Write writes `b` to the shell and verifies that all of them get written.
func (s *Shell) Write(b []byte) error {
	written, err := s.ptyMaster.Write(b)
	s.logIO("write", b[:written], err)
	if err != nil {
		return fmt.Errorf("write(%q): %w", b, err)
	}
	if written != len(b) {
		return fmt.Errorf("write(%q): wrote %d of %d bytes (%q)", b, written, len(b), b[:written])
	}
	return nil
}

// WriteLine writes `line` (to which \n will be appended) to the shell.
// If the shell is in `echo` mode, it will also check that we got these bytes
// back to read.
func (s *Shell) WriteLine(ctx context.Context, line string) error {
	if err := s.Write([]byte(line + "\n")); err != nil {
		return err
	}
	if s.echo {
		// We expect to see everything we've typed.
		if err := s.ExpectLine(ctx, line); err != nil {
			return fmt.Errorf("echo: %w", err)
		}
	}
	return nil
}

// StartCommand is a convenience wrapper for WriteLine that mimics entering a
// command line and pressing Enter. It does some basic shell argument escaping.
func (s *Shell) StartCommand(ctx context.Context, cmd ...string) error {
	escaped := make([]string, len(cmd))
	for i, arg := range cmd {
		escaped[i] = shellEscape(arg)
	}
	return s.WriteLine(ctx, strings.Join(escaped, " "))
}

// GetCommandOutput gets all following bytes until the prompt is encountered.
// This is useful for matching the output of a command.
// All \r are removed for ease of matching.
func (s *Shell) GetCommandOutput(ctx context.Context) ([]byte, error) {
	return s.ReadUntil(ctx, Prompt)
}

// ReadUntil gets all following bytes until a certain line is encountered.
// This final line is not returned as part of the output, but everything before
// it (including the \n) is included.
// This is useful for matching the output of a command.
// All \r are removed for ease of matching.
func (s *Shell) ReadUntil(ctx context.Context, finalLine string) ([]byte, error) {
	var output bytes.Buffer
	for ctx.Err() == nil {
		line, err := s.readLine(ctx, finalLine)
		if err != nil {
			return nil, err
		}
		if bytes.Equal(line, []byte(finalLine)) {
			break
		}
		// readLine ensures that `line` either matches `finalLine` or contains \n.
		// Thus we can be confident that `line` has a \n here.
		output.Write(line)
	}
	return output.Bytes(), ctx.Err()
}

// RunCommand is a convenience wrapper for StartCommand + GetCommandOutput.
func (s *Shell) RunCommand(ctx context.Context, cmd ...string) ([]byte, error) {
	if err := s.StartCommand(ctx, cmd...); err != nil {
		return nil, err
	}
	return s.GetCommandOutput(ctx)
}

// RefreshSTTY interprets output from `stty -a` to check whether we are in echo
// mode and other settings.
// It will assume that any line matching `expectPrompt` means the end of
// the `stty -a` output.
// Why do this rather than using `tcgets`? Because this function can be used in
// conjunction with sub-shell processes that can allocate their own TTYs.
func (s *Shell) RefreshSTTY(ctx context.Context, expectPrompt string) error {
	// Temporarily assume we will not get any output.
	// If echo is actually on, we'll get the "stty -a" line as if it was command
	// output. This is OK because we parse the output generously.
	s.echo = false
	if err := s.WriteLine(ctx, "stty -a"); err != nil {
		return fmt.Errorf("could not run `stty -a`: %w", err)
	}
	sttyOutput, err := s.ReadUntil(ctx, expectPrompt)
	if err != nil {
		return fmt.Errorf("cannot get `stty -a` output: %w", err)
	}

	// Set default control characters in case we can't see them in the output.
	s.controlCharIntr = "^C"
	s.controlCharEOF = "^D"
	// stty output has two general notations:
	// `a = b;` (for control characters), and `option` vs `-option` (for boolean
	// options). We parse both kinds here.
	// For `a = b;`, `controlChar` contains `a`, and `previousToken` is used to
	// set `controlChar` to `previousToken` when we see an "=" token.
	var previousToken, controlChar string
	for _, token := range strings.Fields(string(sttyOutput)) {
		if controlChar != "" {
			value := strings.TrimSuffix(token, ";")
			switch controlChar {
			case "intr":
				s.controlCharIntr = value
			case "eof":
				s.controlCharEOF = value
			}
			controlChar = ""
		} else {
			switch token {
			case "=":
				controlChar = previousToken
			case "-echo":
				s.echo = false
			case "echo":
				s.echo = true
			}
		}
		previousToken = token
	}
	s.logf("stty", "refreshed settings: echo=%v, intr=%q, eof=%q", s.echo, s.controlCharIntr, s.controlCharEOF)
	return nil
}

// sendControlCode sends `code` to the shell and expects to see `repr`.
// If `expectLinebreak` is true, it also expects to see a linebreak.
func (s *Shell) sendControlCode(ctx context.Context, code byte, repr string, expectLinebreak bool) error {
	if err := s.Write([]byte{code}); err != nil {
		return fmt.Errorf("cannot send %q: %w", code, err)
	}
	if err := s.ExpectString(ctx, repr); err != nil {
		return fmt.Errorf("did not see %s: %w", repr, err)
	}
	if expectLinebreak {
		if err := s.ExpectEmptyLine(ctx); err != nil {
			return fmt.Errorf("linebreak after %s: %v", repr, err)
		}
	}
	return nil
}

// SendInterrupt sends the \x03 (Ctrl+C) control character to the shell.
func (s *Shell) SendInterrupt(ctx context.Context, expectLinebreak bool) error {
	return s.sendControlCode(ctx, 0x03, s.controlCharIntr, expectLinebreak)
}

// SendEOF sends the \x04 (Ctrl+D) control character to the shell.
func (s *Shell) SendEOF(ctx context.Context, expectLinebreak bool) error {
	return s.sendControlCode(ctx, 0x04, s.controlCharEOF, expectLinebreak)
}

// NewShell returns a new managed sh process along with a cleanup function.
// The caller is expected to call this function once it no longer needs the
// shell.
// The optional passed-in logger will be used for logging.
func NewShell(ctx context.Context, logger Logger) (*Shell, func(), error) {
	ptyMaster, ptyReplica, err := pty.Open()
	if err != nil {
		return nil, nil, fmt.Errorf("cannot create PTY: %w", err)
	}
	cmd := exec.CommandContext(ctx, "/bin/sh", "--noprofile", "--norc", "-i")
	cmd.Stdin = ptyReplica
	cmd.Stdout = ptyReplica
	cmd.Stderr = ptyReplica
	cmd.SysProcAttr = &unix.SysProcAttr{
		Setsid:  true,
		Setctty: true,
		Ctty:    0,
	}
	cmd.Env = append(cmd.Env, fmt.Sprintf("PS1=%s", Prompt))
	if err := cmd.Start(); err != nil {
		return nil, nil, fmt.Errorf("cannot start shell: %w", err)
	}
	s := &Shell{
		cmd:         cmd,
		cmdFinished: make(chan struct{}),
		ptyMaster:   ptyMaster,
		ptyReplica:  ptyReplica,
		readCh:      make(chan byteOrError, 1<<20),
		logger:      logger,
	}
	s.logf("creation", "Shell spawned.")
	go s.monitorExit()
	go s.reader(ctx)
	setupCtx, setupCancel := context.WithTimeout(ctx, 5*time.Second)
	defer setupCancel()
	// We expect to see the prompt immediately on startup,
	// since the shell is started in interactive mode.
	if err := s.ExpectPrompt(setupCtx); err != nil {
		s.cleanup()
		return nil, nil, fmt.Errorf("did not get initial prompt: %w", err)
	}
	s.logf("creation", "Initial prompt observed.")
	// Get initial TTY settings.
	if err := s.RefreshSTTY(setupCtx, Prompt); err != nil {
		s.cleanup()
		return nil, nil, fmt.Errorf("cannot get initial STTY settings: %w", err)
	}
	return s, s.cleanup, nil
}