Skip to content

Commit

Permalink
Add basic tests for CLI app initialization and execution
Browse files Browse the repository at this point in the history
  • Loading branch information
hslatman committed Sep 30, 2024
1 parent 7a54e1c commit 11f58c3
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 33 deletions.
74 changes: 41 additions & 33 deletions cmd/step/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"errors"
"fmt"
"io"
"os"
"reflect"
"regexp"
Expand All @@ -15,10 +16,10 @@ import (
"github.com/smallstep/cli-utils/command"
"github.com/smallstep/cli-utils/step"
"github.com/smallstep/cli-utils/ui"
"github.com/smallstep/cli-utils/usage"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"

"github.com/smallstep/cli-utils/usage"
"github.com/smallstep/cli/command/version"
"github.com/smallstep/cli/internal/plugin"
"github.com/smallstep/cli/utils"
Expand Down Expand Up @@ -66,6 +67,42 @@ func main() {

defer panicHandler()

// create new instance of app
app := newApp(os.Stdout, os.Stderr)

if err := app.Run(os.Args); err != nil {
var messenger interface {
Message() string
}
if errors.As(err, &messenger) {
if os.Getenv("STEPDEBUG") == "1" {
fmt.Fprintf(os.Stderr, "%+v\n\n%s", err, messenger.Message())
} else {
fmt.Fprintln(os.Stderr, messenger.Message())
fmt.Fprintln(os.Stderr, "Re-run with STEPDEBUG=1 for more info.")
}
} else {
if os.Getenv("STEPDEBUG") == "1" {
fmt.Fprintf(os.Stderr, "%+v\n", err)
} else {
fmt.Fprintln(os.Stderr, err)
}
}
//nolint:gocritic // ignore exitAfterDefer error because the defer is required for recovery.
os.Exit(1)
}
}

func newApp(stdout, stderr io.Writer) *cli.App {
// Define default file writers and prompters for go.step.sm/crypto
pemutil.WriteFile = utils.WriteFile
pemutil.PromptPassword = func(msg string) ([]byte, error) {
return ui.PromptPassword(msg)
}
jose.PromptPassword = func(msg string) ([]byte, error) {
return ui.PromptPassword(msg)
}

// Override global framework components
cli.VersionPrinter = func(c *cli.Context) {
version.Command(c)
Expand Down Expand Up @@ -111,39 +148,10 @@ func main() {
}

// All non-successful output should be written to stderr
app.Writer = os.Stdout
app.ErrWriter = os.Stderr

// Define default file writers and prompters for go.step.sm/crypto
pemutil.WriteFile = utils.WriteFile
pemutil.PromptPassword = func(msg string) ([]byte, error) {
return ui.PromptPassword(msg)
}
jose.PromptPassword = func(msg string) ([]byte, error) {
return ui.PromptPassword(msg)
}
app.Writer = stdout
app.ErrWriter = stderr

if err := app.Run(os.Args); err != nil {
var messenger interface {
Message() string
}
if errors.As(err, &messenger) {
if os.Getenv("STEPDEBUG") == "1" {
fmt.Fprintf(os.Stderr, "%+v\n\n%s", err, messenger.Message())
} else {
fmt.Fprintln(os.Stderr, messenger.Message())
fmt.Fprintln(os.Stderr, "Re-run with STEPDEBUG=1 for more info.")
}
} else {
if os.Getenv("STEPDEBUG") == "1" {
fmt.Fprintf(os.Stderr, "%+v\n", err)
} else {
fmt.Fprintln(os.Stderr, err)
}
}
//nolint:gocritic // ignore exitAfterDefer error because the defer is required for recovery.
os.Exit(1)
}
return app
}

func panicHandler() {
Expand Down
46 changes: 46 additions & 0 deletions cmd/step/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package main

import (
"bytes"
"regexp"
"testing"

"github.com/stretchr/testify/require"
)

func TestAppHasAllCommands(t *testing.T) {
app := newApp(&bytes.Buffer{}, &bytes.Buffer{})
require.NotNil(t, app)

require.Equal(t, "step", app.Name)
require.Equal(t, "step", app.HelpName)

var names = make([]string, 0, len(app.Commands))
for _, c := range app.Commands {
names = append(names, c.Name)
}
require.Equal(t, []string{
"help", "api", "path", "base64", "fileserver",
"certificate", "completion", "context", "crl",
"crypto", "oauth", "version", "ca", "beta", "ssh",
}, names)
}

const ansi = "[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))"

var ansiRegex = regexp.MustCompile(ansi)

func TestAppRuns(t *testing.T) {
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}

app := newApp(stdout, stderr)
require.NotNil(t, app)

err := app.Run([]string{"step"})
require.NoError(t, err)
require.Empty(t, stderr.Bytes())

output := ansiRegex.ReplaceAllString(stdout.String(), "")
require.Contains(t, output, "step -- plumbing for distributed systems")
}

0 comments on commit 11f58c3

Please sign in to comment.