Skip to content

Commit

Permalink
Adding warning if you're about to commit an out-of-balance state (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
marstr authored Feb 2, 2019
1 parent d41ef19 commit 2db50bd
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 1 deletion.
114 changes: 113 additions & 1 deletion cmd/commit.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
package cmd

import (
"bufio"
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"

"github.com/marstr/envelopes"
Expand Down Expand Up @@ -59,6 +63,13 @@ const (
timeUsage = "The time and date when this transaction occurred."
)

const (
forceFlag = "force"
forceShorthand = "f"
forceDefault = false
forceUsage = "Ignore warnings, commit the transaction anyway."
)

var commitConfig = viper.New()

var commitCmd = &cobra.Command{
Expand All @@ -82,7 +93,7 @@ var commitCmd = &cobra.Command{
return cobra.NoArgs(cmd, args)
},
Run: func(_ *cobra.Command, _ []string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()

amount, err := envelopes.ParseBalance(commitConfig.GetString(amountFlag))
Expand All @@ -107,6 +118,35 @@ var commitCmd = &cobra.Command{
logrus.Fatal(err)
}

budgetBal := budget.RecursiveBalance()
var accountsBal envelopes.Balance
for _, entry := range accounts {
accountsBal += entry
}

if budgetBal != accountsBal {
logrus.Warnf(
"accounts (%s) and budget (%s) balance are not equal by %s.",
accountsBal,
budgetBal,
accountsBal-budgetBal)

if !commitConfig.GetBool(forceFlag) {
shouldContinue, err := promptToContinue(
ctx,
"proceed despite imbalance?",
os.Stdout,
os.Stdin)
if err != nil {
logrus.Fatal(err)
}

if !shouldContinue {
return
}
}
}

persister := persist.FileSystem{
Root: filepath.Join(targetDir, index.RepoName),
}
Expand Down Expand Up @@ -151,9 +191,81 @@ func init() {
commitCmd.PersistentFlags().StringP(commentFlag, commentShorthand, commentDefault, commentUsage)
commitCmd.PersistentFlags().StringP(timeFlag, timeShorthand, timeDefault, timeUsage)
commitCmd.PersistentFlags().StringP(amountFlag, amountShorthand, amountDefault, amountUsage)
commitCmd.PersistentFlags().BoolP(forceFlag, forceShorthand, forceDefault, forceUsage)

err := commitConfig.BindPFlags(commitCmd.PersistentFlags())
if err != nil {
logrus.Fatal(err)
}
}

func promptToContinue(ctx context.Context, message string, output io.Writer, input io.Reader) (bool, error) {
results := make(chan bool, 1)
errs := make(chan error, 1)

go func(ctx context.Context) {
for {
select {
case <-ctx.Done():
errs <- ctx.Err()
return
default:
// Intentionally Left Blank
}

_, err := fmt.Fprintf(output, "%s (y/N): ", message)
if err != nil {
errs <- err
return
}

// If `ctx` expires while we're waiting for user response here, this goroutine will leak. There are a lot of
// different ways to organize around this problem, but until there is a Reader that allows for
// cancellation in the standard library (or something we're willing to take a dependency on) there's not
// actually anyway to get around this leak.
//
// Given that this function is expected be executed in very short-lived programs, and realistically this
// will leak one or zero times for any-given process before immediately being terminated, I'm not worried
// about it.
reader := bufio.NewReader(input)
response, err := reader.ReadString('\n')
if err != nil {
errs <- err
return
}

response = strings.TrimSpace(response)

switch {
case strings.EqualFold(response, "yes"):
fallthrough
case strings.EqualFold(response, "y"):
results <- true
return
case strings.EqualFold(response, "quit"):
fallthrough
case strings.EqualFold(response, "q"):
fallthrough
case strings.EqualFold(response, ""):
fallthrough
case strings.EqualFold(response, "no"):
fallthrough
case strings.EqualFold(response, "n"):
results <- false
return
default:
// Intentionally Left Blank
// The loop should be re-executed until an answer in an expected format is provided.
}
}
}(ctx)

select {
case <-ctx.Done():
return false, ctx.Err()
case err := <-errs:
return false, err
case result := <-results:
return result, nil
}
}
139 changes: 139 additions & 0 deletions cmd/commit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package cmd

import (
"bytes"
"context"
"fmt"
"strings"
"testing"
"time"
)

func Test_promptToContinue(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()

t.Run("affirmative", getTestAffirmativePromptReponses(ctx))
t.Run("negative", getTestNegativePromptResponses(ctx))
t.Run("prompt", getTestPromptText(ctx))
}

func getTestAffirmativePromptReponses(ctx context.Context) func(*testing.T) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 2*time.Minute)

return func(t *testing.T) {
defer cancel()

testCases := []string{
"y",
"Y",
"yes",
"Yes",
"YES",
"YEs",
" yes",
"yes\n",
"yes",
"yes\r\n",
}

output, input := &bytes.Buffer{}, &bytes.Buffer{}
for _, tc := range testCases {
output.Reset()
input.Reset()

_, err := fmt.Fprintln(input, tc)
if err != nil {
t.Error(err)
continue
}

result, err := promptToContinue(ctx, "want to proceed?", output, input)
if err != nil {
t.Error(err)
} else if !result {
t.Logf("returned false for: %q", tc)
t.Fail()
}
}
}
}

func getTestNegativePromptResponses(ctx context.Context) func(t *testing.T) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 2*time.Minute)

return func(t *testing.T) {
defer cancel()
testCases := []string{
"n",
"N",
"no",
"No",
"q",
"Q",
"quit",
"QuIt",
"",
" ",
"\t",
"\r\n",
}

output, input := &bytes.Buffer{}, &bytes.Buffer{}

for _, tc := range testCases {
output.Reset()
input.Reset()

_, err := fmt.Fprintln(input, tc)
if err != nil {
t.Error(err)
continue
}

result, err := promptToContinue(ctx, "want to proceed?", output, input)
if err != nil {
t.Error(err)
} else if result {
t.Logf("returned true for %q", tc)
t.Fail()
}
}
}
}

func getTestPromptText(ctx context.Context) func(*testing.T) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 2*time.Minute)

return func(t *testing.T) {
defer cancel()

testCases := []string{
"want to proceed?",
}

input, output := &bytes.Buffer{}, &bytes.Buffer{}

for _, tc := range testCases {
input.Reset()
output.Reset()

fmt.Fprintln(input)

_, err := promptToContinue(ctx, tc, output, input)
if err != nil {
t.Error(err)
continue
}

want := fmt.Sprintf("%s (y/N): ", strings.TrimSpace(tc))
if got := output.String(); got != want {
t.Logf("got: %q\nwant: %q", got, want)
t.Fail()
}
}
}
}

0 comments on commit 2db50bd

Please sign in to comment.