diff --git a/cmd/commit.go b/cmd/commit.go index e249fa0..5af0bf5 100644 --- a/cmd/commit.go +++ b/cmd/commit.go @@ -16,10 +16,14 @@ package cmd import ( + "bufio" "context" "errors" "fmt" + "io" + "os" "path/filepath" + "strings" "time" "github.com/marstr/envelopes" @@ -59,6 +63,13 @@ const ( timeUsage = "The time and date when this transaction occurred." ) +const ( + forceFlag = "force" + forceShorthand = "f" + forceDefault = false + forceUsage = "Ignore warnings, commit the transaction anyway." +) + var commitConfig = viper.New() var commitCmd = &cobra.Command{ @@ -82,7 +93,7 @@ var commitCmd = &cobra.Command{ return cobra.NoArgs(cmd, args) }, Run: func(_ *cobra.Command, _ []string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() amount, err := envelopes.ParseBalance(commitConfig.GetString(amountFlag)) @@ -107,6 +118,35 @@ var commitCmd = &cobra.Command{ logrus.Fatal(err) } + budgetBal := budget.RecursiveBalance() + var accountsBal envelopes.Balance + for _, entry := range accounts { + accountsBal += entry + } + + if budgetBal != accountsBal { + logrus.Warnf( + "accounts (%s) and budget (%s) balance are not equal by %s.", + accountsBal, + budgetBal, + accountsBal-budgetBal) + + if !commitConfig.GetBool(forceFlag) { + shouldContinue, err := promptToContinue( + ctx, + "proceed despite imbalance?", + os.Stdout, + os.Stdin) + if err != nil { + logrus.Fatal(err) + } + + if !shouldContinue { + return + } + } + } + persister := persist.FileSystem{ Root: filepath.Join(targetDir, index.RepoName), } @@ -151,9 +191,81 @@ func init() { commitCmd.PersistentFlags().StringP(commentFlag, commentShorthand, commentDefault, commentUsage) commitCmd.PersistentFlags().StringP(timeFlag, timeShorthand, timeDefault, timeUsage) commitCmd.PersistentFlags().StringP(amountFlag, amountShorthand, amountDefault, amountUsage) + commitCmd.PersistentFlags().BoolP(forceFlag, forceShorthand, forceDefault, forceUsage) err := commitConfig.BindPFlags(commitCmd.PersistentFlags()) if err != nil { logrus.Fatal(err) } } + +func promptToContinue(ctx context.Context, message string, output io.Writer, input io.Reader) (bool, error) { + results := make(chan bool, 1) + errs := make(chan error, 1) + + go func(ctx context.Context) { + for { + select { + case <-ctx.Done(): + errs <- ctx.Err() + return + default: + // Intentionally Left Blank + } + + _, err := fmt.Fprintf(output, "%s (y/N): ", message) + if err != nil { + errs <- err + return + } + + // If `ctx` expires while we're waiting for user response here, this goroutine will leak. There are a lot of + // different ways to organize around this problem, but until there is a Reader that allows for + // cancellation in the standard library (or something we're willing to take a dependency on) there's not + // actually anyway to get around this leak. + // + // Given that this function is expected be executed in very short-lived programs, and realistically this + // will leak one or zero times for any-given process before immediately being terminated, I'm not worried + // about it. + reader := bufio.NewReader(input) + response, err := reader.ReadString('\n') + if err != nil { + errs <- err + return + } + + response = strings.TrimSpace(response) + + switch { + case strings.EqualFold(response, "yes"): + fallthrough + case strings.EqualFold(response, "y"): + results <- true + return + case strings.EqualFold(response, "quit"): + fallthrough + case strings.EqualFold(response, "q"): + fallthrough + case strings.EqualFold(response, ""): + fallthrough + case strings.EqualFold(response, "no"): + fallthrough + case strings.EqualFold(response, "n"): + results <- false + return + default: + // Intentionally Left Blank + // The loop should be re-executed until an answer in an expected format is provided. + } + } + }(ctx) + + select { + case <-ctx.Done(): + return false, ctx.Err() + case err := <-errs: + return false, err + case result := <-results: + return result, nil + } +} diff --git a/cmd/commit_test.go b/cmd/commit_test.go new file mode 100644 index 0000000..c907af1 --- /dev/null +++ b/cmd/commit_test.go @@ -0,0 +1,139 @@ +package cmd + +import ( + "bytes" + "context" + "fmt" + "strings" + "testing" + "time" +) + +func Test_promptToContinue(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + t.Run("affirmative", getTestAffirmativePromptReponses(ctx)) + t.Run("negative", getTestNegativePromptResponses(ctx)) + t.Run("prompt", getTestPromptText(ctx)) +} + +func getTestAffirmativePromptReponses(ctx context.Context) func(*testing.T) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 2*time.Minute) + + return func(t *testing.T) { + defer cancel() + + testCases := []string{ + "y", + "Y", + "yes", + "Yes", + "YES", + "YEs", + " yes", + "yes\n", + "yes", + "yes\r\n", + } + + output, input := &bytes.Buffer{}, &bytes.Buffer{} + for _, tc := range testCases { + output.Reset() + input.Reset() + + _, err := fmt.Fprintln(input, tc) + if err != nil { + t.Error(err) + continue + } + + result, err := promptToContinue(ctx, "want to proceed?", output, input) + if err != nil { + t.Error(err) + } else if !result { + t.Logf("returned false for: %q", tc) + t.Fail() + } + } + } +} + +func getTestNegativePromptResponses(ctx context.Context) func(t *testing.T) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 2*time.Minute) + + return func(t *testing.T) { + defer cancel() + testCases := []string{ + "n", + "N", + "no", + "No", + "q", + "Q", + "quit", + "QuIt", + "", + " ", + "\t", + "\r\n", + } + + output, input := &bytes.Buffer{}, &bytes.Buffer{} + + for _, tc := range testCases { + output.Reset() + input.Reset() + + _, err := fmt.Fprintln(input, tc) + if err != nil { + t.Error(err) + continue + } + + result, err := promptToContinue(ctx, "want to proceed?", output, input) + if err != nil { + t.Error(err) + } else if result { + t.Logf("returned true for %q", tc) + t.Fail() + } + } + } +} + +func getTestPromptText(ctx context.Context) func(*testing.T) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 2*time.Minute) + + return func(t *testing.T) { + defer cancel() + + testCases := []string{ + "want to proceed?", + } + + input, output := &bytes.Buffer{}, &bytes.Buffer{} + + for _, tc := range testCases { + input.Reset() + output.Reset() + + fmt.Fprintln(input) + + _, err := promptToContinue(ctx, tc, output, input) + if err != nil { + t.Error(err) + continue + } + + want := fmt.Sprintf("%s (y/N): ", strings.TrimSpace(tc)) + if got := output.String(); got != want { + t.Logf("got: %q\nwant: %q", got, want) + t.Fail() + } + } + } +}