summaryrefslogtreecommitdiffhomepage
path: root/pkg/test/testutil/sh.go
diff options
context:
space:
mode:
authorEtienne Perot <eperot@google.com>2021-01-08 14:49:23 -0800
committergVisor bot <gvisor-bot@google.com>2021-01-08 14:51:50 -0800
commit11787a601e2aba4d022aadd468f729963b9a09e6 (patch)
treec9029299ee056c12505a9bfae592454d4aa3618e /pkg/test/testutil/sh.go
parent5c13c2152ecc313353b745bdfb82ee601e38a867 (diff)
Create console test library.
This creates a TTY pair and runs `/bin/sh` in interactive mode within it. It provides useful helper functions to interact with the shell and read the output of commands run within it. This is meant to be used for testing upcoming changes allowing `runsc exec` to work in `-detach=false -tty=true` mode. PiperOrigin-RevId: 350841006
Diffstat (limited to 'pkg/test/testutil/sh.go')
-rw-r--r--pkg/test/testutil/sh.go515
1 files changed, 515 insertions, 0 deletions
diff --git a/pkg/test/testutil/sh.go b/pkg/test/testutil/sh.go
new file mode 100644
index 000000000..1c77562be
--- /dev/null
+++ b/pkg/test/testutil/sh.go
@@ -0,0 +1,515 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/kr/pty"
+)
+
+// Prompt is used as shell prompt.
+// It is meant to be unique enough to not be seen in command outputs.
+const Prompt = "PROMPT> "
+
+// Simplistic shell string escape.
+func shellEscape(s string) string {
+ // specialChars is used to determine whether s needs quoting at all.
+ const specialChars = "\\'\"`${[|&;<>()*?! \t\n"
+ // If s needs quoting, escapedChars is the set of characters that are
+ // escaped with a backslash.
+ const escapedChars = "\\\"$`"
+ if len(s) == 0 {
+ return "''"
+ }
+ if !strings.ContainsAny(s, specialChars) {
+ return s
+ }
+ var b bytes.Buffer
+ b.WriteString("\"")
+ for _, c := range s {
+ if strings.ContainsAny(string(c), escapedChars) {
+ b.WriteString("\\")
+ }
+ b.WriteRune(c)
+ }
+ b.WriteString("\"")
+ return b.String()
+}
+
+type byteOrError struct {
+ b byte
+ err error
+}
+
+// Shell manages a /bin/sh invocation with convenience functions to handle I/O.
+// The shell is run in its own interactive TTY and should present its prompt.
+type Shell struct {
+ // cmd is a reference to the underlying sh process.
+ cmd *exec.Cmd
+ // cmdFinished is closed when cmd exits.
+ cmdFinished chan struct{}
+
+ // echo is whether the shell will echo input back to us.
+ // This helps setting expectations of getting feedback of written bytes.
+ echo bool
+ // Control characters we expect to see in the shell.
+ controlCharIntr string
+ controlCharEOF string
+
+ // ptyMaster and ptyReplica are the TTY pair associated with the shell.
+ ptyMaster *os.File
+ ptyReplica *os.File
+ // readCh is a channel where everything read from ptyMaster is written.
+ readCh chan byteOrError
+
+ // logger is used for logging. It may be nil.
+ logger Logger
+}
+
+// cleanup kills the shell process and closes the TTY.
+// Users of this library get a reference to this function with NewShell.
+func (s *Shell) cleanup() {
+ s.logf("cleanup", "Shell cleanup started.")
+ if s.cmd.ProcessState == nil {
+ if err := s.cmd.Process.Kill(); err != nil {
+ s.logf("cleanup", "cannot kill shell process: %v", err)
+ }
+ // We don't log the error returned by Wait because the monitorExit
+ // goroutine will already do so.
+ s.cmd.Wait()
+ }
+ s.ptyReplica.Close()
+ s.ptyMaster.Close()
+ // Wait for monitorExit goroutine to write exit status to the debug log.
+ <-s.cmdFinished
+ // Empty out everything in the readCh, but don't wait too long for it.
+ var extraBytes bytes.Buffer
+ unreadTimeout := time.After(100 * time.Millisecond)
+unreadLoop:
+ for {
+ select {
+ case r, ok := <-s.readCh:
+ if !ok {
+ break unreadLoop
+ } else if r.err == nil {
+ extraBytes.WriteByte(r.b)
+ }
+ case <-unreadTimeout:
+ break unreadLoop
+ }
+ }
+ if extraBytes.Len() > 0 {
+ s.logIO("unread", extraBytes.Bytes(), nil)
+ }
+ s.logf("cleanup", "Shell cleanup complete.")
+}
+
+// logIO logs byte I/O to both standard logging and the test log, if provided.
+func (s *Shell) logIO(prefix string, b []byte, err error) {
+ var sb strings.Builder
+ if len(b) > 0 {
+ sb.WriteString(fmt.Sprintf("%q", b))
+ } else {
+ sb.WriteString("(nothing)")
+ }
+ if err != nil {
+ sb.WriteString(fmt.Sprintf(" [error: %v]", err))
+ }
+ s.logf(prefix, "%s", sb.String())
+}
+
+// logf logs something to both standard logging and the test log, if provided.
+func (s *Shell) logf(prefix, format string, values ...interface{}) {
+ if s.logger != nil {
+ s.logger.Logf("[%s] %s", prefix, fmt.Sprintf(format, values...))
+ }
+}
+
+// monitorExit waits for the shell process to exit and logs the exit result.
+func (s *Shell) monitorExit() {
+ if err := s.cmd.Wait(); err != nil {
+ s.logf("cmd", "shell process terminated: %v", err)
+ } else {
+ s.logf("cmd", "shell process terminated successfully")
+ }
+ close(s.cmdFinished)
+}
+
+// reader continuously reads the shell output and populates readCh.
+func (s *Shell) reader(ctx context.Context) {
+ b := make([]byte, 4096)
+ defer close(s.readCh)
+ for {
+ select {
+ case <-s.cmdFinished:
+ // Shell process terminated; stop trying to read.
+ return
+ case <-ctx.Done():
+ // Shell process will also have terminated in this case;
+ // stop trying to read.
+ // We don't print an error here because doing so would print this in the
+ // normal case where the context passed to NewShell is canceled at the
+ // end of a successful test.
+ return
+ default:
+ // Shell still running, try reading.
+ }
+ if got, err := s.ptyMaster.Read(b); err != nil {
+ s.readCh <- byteOrError{err: err}
+ if err == io.EOF {
+ return
+ }
+ } else {
+ for i := 0; i < got; i++ {
+ s.readCh <- byteOrError{b: b[i]}
+ }
+ }
+ }
+}
+
+// readByte reads a single byte, respecting the context.
+func (s *Shell) readByte(ctx context.Context) (byte, error) {
+ select {
+ case <-ctx.Done():
+ return 0, ctx.Err()
+ case r := <-s.readCh:
+ return r.b, r.err
+ }
+}
+
+// readLoop reads as many bytes as possible until the context expires, b is
+// full, or a short time passes. It returns how many bytes it has successfully
+// read.
+func (s *Shell) readLoop(ctx context.Context, b []byte) (int, error) {
+ soonCtx, soonCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer soonCancel()
+ var i int
+ for i = 0; i < len(b) && soonCtx.Err() == nil; i++ {
+ next, err := s.readByte(soonCtx)
+ if err != nil {
+ if i > 0 {
+ s.logIO("read", b[:i-1], err)
+ } else {
+ s.logIO("read", nil, err)
+ }
+ return i, err
+ }
+ b[i] = next
+ }
+ s.logIO("read", b[:i], soonCtx.Err())
+ return i, soonCtx.Err()
+}
+
+// readLine reads a single line. Strips out all \r characters for convenience.
+// Upon error, it will still return what it has read so far.
+// It will also exit quickly if the line content it has read so far (without a
+// line break) matches `prompt`.
+func (s *Shell) readLine(ctx context.Context, prompt string) ([]byte, error) {
+ soonCtx, soonCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer soonCancel()
+ var lineData bytes.Buffer
+ var b byte
+ var err error
+ for soonCtx.Err() == nil && b != '\n' {
+ b, err = s.readByte(soonCtx)
+ if err != nil {
+ data := lineData.Bytes()
+ s.logIO("read", data, err)
+ return data, err
+ }
+ if b != '\r' {
+ lineData.WriteByte(b)
+ }
+ if bytes.Equal(lineData.Bytes(), []byte(prompt)) {
+ // Assume that there will not be any further output if we get the prompt.
+ // This avoids waiting for the read deadline just to read the prompt.
+ break
+ }
+ }
+ data := lineData.Bytes()
+ s.logIO("read", data, soonCtx.Err())
+ return data, soonCtx.Err()
+}
+
+// Expect verifies that the next `len(want)` bytes we read match `want`.
+func (s *Shell) Expect(ctx context.Context, want []byte) error {
+ errPrefix := fmt.Sprintf("want(%q)", want)
+ b := make([]byte, len(want))
+ got, err := s.readLoop(ctx, b)
+ if err != nil {
+ if ctx.Err() != nil {
+ return fmt.Errorf("%s: context done (%w), got: %q", errPrefix, err, b[:got])
+ }
+ return fmt.Errorf("%s: %w", errPrefix, err)
+ }
+ if got < len(want) {
+ return fmt.Errorf("%s: short read (read %d bytes, expected %d): %q", errPrefix, got, len(want), b[:got])
+ }
+ if !bytes.Equal(b, want) {
+ return fmt.Errorf("got %q want %q", b, want)
+ }
+ return nil
+}
+
+// ExpectString verifies that the next `len(want)` bytes we read match `want`.
+func (s *Shell) ExpectString(ctx context.Context, want string) error {
+ return s.Expect(ctx, []byte(want))
+}
+
+// ExpectPrompt verifies that the next few bytes we read are the shell prompt.
+func (s *Shell) ExpectPrompt(ctx context.Context) error {
+ return s.ExpectString(ctx, Prompt)
+}
+
+// ExpectEmptyLine verifies that the next few bytes we read are an empty line,
+// as defined by any number of carriage or line break characters.
+func (s *Shell) ExpectEmptyLine(ctx context.Context) error {
+ line, err := s.readLine(ctx, Prompt)
+ if err != nil {
+ return fmt.Errorf("cannot read line: %w", err)
+ }
+ if strings.Trim(string(line), "\r\n") != "" {
+ return fmt.Errorf("line was not empty: %q", line)
+ }
+ return nil
+}
+
+// ExpectLine verifies that the next `len(want)` bytes we read match `want`,
+// followed by carriage returns or newline characters.
+func (s *Shell) ExpectLine(ctx context.Context, want string) error {
+ if err := s.ExpectString(ctx, want); err != nil {
+ return err
+ }
+ if err := s.ExpectEmptyLine(ctx); err != nil {
+ return fmt.Errorf("ExpectLine(%q): no line break: %w", want, err)
+ }
+ return nil
+}
+
+// Write writes `b` to the shell and verifies that all of them get written.
+func (s *Shell) Write(b []byte) error {
+ written, err := s.ptyMaster.Write(b)
+ s.logIO("write", b[:written], err)
+ if err != nil {
+ return fmt.Errorf("write(%q): %w", b, err)
+ }
+ if written != len(b) {
+ return fmt.Errorf("write(%q): wrote %d of %d bytes (%q)", b, written, len(b), b[:written])
+ }
+ return nil
+}
+
+// WriteLine writes `line` (to which \n will be appended) to the shell.
+// If the shell is in `echo` mode, it will also check that we got these bytes
+// back to read.
+func (s *Shell) WriteLine(ctx context.Context, line string) error {
+ if err := s.Write([]byte(line + "\n")); err != nil {
+ return err
+ }
+ if s.echo {
+ // We expect to see everything we've typed.
+ if err := s.ExpectLine(ctx, line); err != nil {
+ return fmt.Errorf("echo: %w", err)
+ }
+ }
+ return nil
+}
+
+// StartCommand is a convenience wrapper for WriteLine that mimics entering a
+// command line and pressing Enter. It does some basic shell argument escaping.
+func (s *Shell) StartCommand(ctx context.Context, cmd ...string) error {
+ escaped := make([]string, len(cmd))
+ for i, arg := range cmd {
+ escaped[i] = shellEscape(arg)
+ }
+ return s.WriteLine(ctx, strings.Join(escaped, " "))
+}
+
+// GetCommandOutput gets all following bytes until the prompt is encountered.
+// This is useful for matching the output of a command.
+// All \r are removed for ease of matching.
+func (s *Shell) GetCommandOutput(ctx context.Context) ([]byte, error) {
+ return s.ReadUntil(ctx, Prompt)
+}
+
+// ReadUntil gets all following bytes until a certain line is encountered.
+// This final line is not returned as part of the output, but everything before
+// it (including the \n) is included.
+// This is useful for matching the output of a command.
+// All \r are removed for ease of matching.
+func (s *Shell) ReadUntil(ctx context.Context, finalLine string) ([]byte, error) {
+ var output bytes.Buffer
+ for ctx.Err() == nil {
+ line, err := s.readLine(ctx, finalLine)
+ if err != nil {
+ return nil, err
+ }
+ if bytes.Equal(line, []byte(finalLine)) {
+ break
+ }
+ // readLine ensures that `line` either matches `finalLine` or contains \n.
+ // Thus we can be confident that `line` has a \n here.
+ output.Write(line)
+ }
+ return output.Bytes(), ctx.Err()
+}
+
+// RunCommand is a convenience wrapper for StartCommand + GetCommandOutput.
+func (s *Shell) RunCommand(ctx context.Context, cmd ...string) ([]byte, error) {
+ if err := s.StartCommand(ctx, cmd...); err != nil {
+ return nil, err
+ }
+ return s.GetCommandOutput(ctx)
+}
+
+// RefreshSTTY interprets output from `stty -a` to check whether we are in echo
+// mode and other settings.
+// It will assume that any line matching `expectPrompt` means the end of
+// the `stty -a` output.
+// Why do this rather than using `tcgets`? Because this function can be used in
+// conjunction with sub-shell processes that can allocate their own TTYs.
+func (s *Shell) RefreshSTTY(ctx context.Context, expectPrompt string) error {
+ // Temporarily assume we will not get any output.
+ // If echo is actually on, we'll get the "stty -a" line as if it was command
+ // output. This is OK because we parse the output generously.
+ s.echo = false
+ if err := s.WriteLine(ctx, "stty -a"); err != nil {
+ return fmt.Errorf("could not run `stty -a`: %w", err)
+ }
+ sttyOutput, err := s.ReadUntil(ctx, expectPrompt)
+ if err != nil {
+ return fmt.Errorf("cannot get `stty -a` output: %w", err)
+ }
+
+ // Set default control characters in case we can't see them in the output.
+ s.controlCharIntr = "^C"
+ s.controlCharEOF = "^D"
+ // stty output has two general notations:
+ // `a = b;` (for control characters), and `option` vs `-option` (for boolean
+ // options). We parse both kinds here.
+ // For `a = b;`, `controlChar` contains `a`, and `previousToken` is used to
+ // set `controlChar` to `previousToken` when we see an "=" token.
+ var previousToken, controlChar string
+ for _, token := range strings.Fields(string(sttyOutput)) {
+ if controlChar != "" {
+ value := strings.TrimSuffix(token, ";")
+ switch controlChar {
+ case "intr":
+ s.controlCharIntr = value
+ case "eof":
+ s.controlCharEOF = value
+ }
+ controlChar = ""
+ } else {
+ switch token {
+ case "=":
+ controlChar = previousToken
+ case "-echo":
+ s.echo = false
+ case "echo":
+ s.echo = true
+ }
+ }
+ previousToken = token
+ }
+ s.logf("stty", "refreshed settings: echo=%v, intr=%q, eof=%q", s.echo, s.controlCharIntr, s.controlCharEOF)
+ return nil
+}
+
+// sendControlCode sends `code` to the shell and expects to see `repr`.
+// If `expectLinebreak` is true, it also expects to see a linebreak.
+func (s *Shell) sendControlCode(ctx context.Context, code byte, repr string, expectLinebreak bool) error {
+ if err := s.Write([]byte{code}); err != nil {
+ return fmt.Errorf("cannot send %q: %w", code, err)
+ }
+ if err := s.ExpectString(ctx, repr); err != nil {
+ return fmt.Errorf("did not see %s: %w", repr, err)
+ }
+ if expectLinebreak {
+ if err := s.ExpectEmptyLine(ctx); err != nil {
+ return fmt.Errorf("linebreak after %s: %v", repr, err)
+ }
+ }
+ return nil
+}
+
+// SendInterrupt sends the \x03 (Ctrl+C) control character to the shell.
+func (s *Shell) SendInterrupt(ctx context.Context, expectLinebreak bool) error {
+ return s.sendControlCode(ctx, 0x03, s.controlCharIntr, expectLinebreak)
+}
+
+// SendEOF sends the \x04 (Ctrl+D) control character to the shell.
+func (s *Shell) SendEOF(ctx context.Context, expectLinebreak bool) error {
+ return s.sendControlCode(ctx, 0x04, s.controlCharEOF, expectLinebreak)
+}
+
+// NewShell returns a new managed sh process along with a cleanup function.
+// The caller is expected to call this function once it no longer needs the
+// shell.
+// The optional passed-in logger will be used for logging.
+func NewShell(ctx context.Context, logger Logger) (*Shell, func(), error) {
+ ptyMaster, ptyReplica, err := pty.Open()
+ if err != nil {
+ return nil, nil, fmt.Errorf("cannot create PTY: %w", err)
+ }
+ cmd := exec.CommandContext(ctx, "/bin/sh", "--noprofile", "--norc", "-i")
+ cmd.Stdin = ptyReplica
+ cmd.Stdout = ptyReplica
+ cmd.Stderr = ptyReplica
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setsid: true,
+ Setctty: true,
+ Ctty: 0,
+ }
+ cmd.Env = append(cmd.Env, fmt.Sprintf("PS1=%s", Prompt))
+ if err := cmd.Start(); err != nil {
+ return nil, nil, fmt.Errorf("cannot start shell: %w", err)
+ }
+ s := &Shell{
+ cmd: cmd,
+ cmdFinished: make(chan struct{}),
+ ptyMaster: ptyMaster,
+ ptyReplica: ptyReplica,
+ readCh: make(chan byteOrError, 1<<20),
+ logger: logger,
+ }
+ s.logf("creation", "Shell spawned.")
+ go s.monitorExit()
+ go s.reader(ctx)
+ setupCtx, setupCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer setupCancel()
+ // We expect to see the prompt immediately on startup,
+ // since the shell is started in interactive mode.
+ if err := s.ExpectPrompt(setupCtx); err != nil {
+ s.cleanup()
+ return nil, nil, fmt.Errorf("did not get initial prompt: %w", err)
+ }
+ s.logf("creation", "Initial prompt observed.")
+ // Get initial TTY settings.
+ if err := s.RefreshSTTY(setupCtx, Prompt); err != nil {
+ s.cleanup()
+ return nil, nil, fmt.Errorf("cannot get initial STTY settings: %w", err)
+ }
+ return s, s.cleanup, nil
+}