From e79519c5b754fc040daedf17f126464d6ff970dd Mon Sep 17 00:00:00 2001 From: Max Cai Date: Tue, 2 Apr 2024 08:29:20 -0400 Subject: [PATCH] use env variable to control state command --- internal/wrapper/checkargs.go | 43 +++++++++++++++++++++++------- internal/wrapper/checkargs_test.go | 33 +++++++++++++---------- internal/wrapper/wrapper.go | 5 ++-- 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/internal/wrapper/checkargs.go b/internal/wrapper/checkargs.go index b177cef..321c81a 100644 --- a/internal/wrapper/checkargs.go +++ b/internal/wrapper/checkargs.go @@ -2,36 +2,48 @@ package wrapper import ( "errors" + "os" + "strings" "github.com/Masterminds/semver/v3" ) -func checkStateCommand(args []string, version *semver.Version) ([]string, error) { +func checkStateCommand(args []string, version *semver.Version) error { versionImport, _ := semver.NewConstraint(">= 1.5.0") versionMoved, _ := semver.NewConstraint(">= 1.1.0") + versionRemoved, _ := semver.NewConstraint(">= 1.7.0") + STATE_COMMAND_VAR := "TF_DEMUX_ALLOW_STATE_COMMANDS" if checkArgsExists(args, "import") >= 0 && versionImport.Check(version) { - force_pos := checkArgsExists(args, "--force") - if force_pos > 0 { - return append(args[:force_pos], args[force_pos+1:]...), nil + if allowStateCommand(STATE_COMMAND_VAR) { + return nil } else { - return args, errors.New("--force flag is required for the 'import' command. Consider using Terraform configuration import block instead") + return errors.New("need set TF_DEMUX_ALLOW_STATE_COMMANDS=true for the 'import' command. Consider using Terraform configuration import block instead") } } if checkArgsExists(args, "state") >= 0 && checkArgsExists(args, "mv") >= 0 && versionMoved.Check(version) { - force_pos := checkArgsExists(args, "--force") - if force_pos > 0 { - return append(args[:force_pos], args[force_pos+1:]...), nil + if allowStateCommand(STATE_COMMAND_VAR) { + return nil } else { - return args, errors.New("--force flag is required for the 'state mv' command. Consider using Terraform configuration moved block instead") + return errors.New("need set TF_DEMUX_ALLOW_STATE_COMMANDS=true for the 'state mv' command. Consider using Terraform configuration moved block instead") } } - return args, nil + if checkArgsExists(args, "state") >= 0 && + checkArgsExists(args, "rm") >= 0 && + versionRemoved.Check(version) { + if allowStateCommand(STATE_COMMAND_VAR) { + return nil + } else { + return errors.New("need set TF_DEMUX_ALLOW_STATE_COMMANDS=true for the 'state rm' command. Consider using Terraform configuration removed block instead") + } + } + + return nil } func checkArgsExists(args []string, cmd string) int { @@ -42,3 +54,14 @@ func checkArgsExists(args []string, cmd string) int { } return -1 } + +func allowStateCommand(envVarName string) bool { + validValues := []string{"1", "true", "yes"} + value := strings.ToLower(os.Getenv(envVarName)) + for _, valid := range validValues { + if value == valid { + return true + } + } + return false +} diff --git a/internal/wrapper/checkargs_test.go b/internal/wrapper/checkargs_test.go index 9ba6327..d9c6345 100644 --- a/internal/wrapper/checkargs_test.go +++ b/internal/wrapper/checkargs_test.go @@ -1,46 +1,51 @@ package wrapper import ( - "slices" + "os" "testing" "github.com/Masterminds/semver/v3" ) func TestCheckStateCommand(t *testing.T) { - t.Run("Valid state import command with --force flag on 1.5.0", func(t *testing.T) { + STATE_COMMAND_VAR := "TF_DEMUX_ALLOW_STATE_COMMANDS" + t.Run("Valid state import command with TF_DEMUX_ALLOW_STATE_COMMANDS on 1.5.0", func(t *testing.T) { args := []string{"import", "--force"} version, _ := semver.NewVersion("1.5.0") - result, err := checkStateCommand(args, version) - if err != nil || !slices.Equal(result, []string{"import"}) { - t.Errorf("Expected no error, got: %v, %v", err, result) + os.Setenv(STATE_COMMAND_VAR, "true") + err := checkStateCommand(args, version) + if err != nil { + t.Errorf("Expected no error, got: %v", err) } }) - t.Run("Valid state import command without --force flag on 1.4.7", func(t *testing.T) { + t.Run("Valid state import command without TF_DEMUX_ALLOW_STATE_COMMANDS on 1.4.7", func(t *testing.T) { args := []string{"import"} version, _ := semver.NewVersion("1.4.7") - result, err := checkStateCommand(args, version) - if err != nil || !slices.Equal(result, []string{"import"}) { + os.Setenv(STATE_COMMAND_VAR, "true") + err := checkStateCommand(args, version) + if err != nil { t.Errorf("Expected no error, got: %v", err) } }) - t.Run("Invalid state import command without --force flag on 1.5.0", func(t *testing.T) { + t.Run("Invalid state import command without TF_DEMUX_ALLOW_STATE_COMMANDS on 1.5.0", func(t *testing.T) { args := []string{"import"} version, _ := semver.NewVersion("1.6.0") - _, err := checkStateCommand(args, version) + os.Setenv(STATE_COMMAND_VAR, "") + err := checkStateCommand(args, version) if err == nil { t.Errorf("Expected error, got: %v", err) } }) - t.Run("Valid state mv command with --force flag on 1.6.0", func(t *testing.T) { + t.Run("Valid state mv command with TF_DEMUX_ALLOW_STATE_COMMANDS on 1.6.0", func(t *testing.T) { args := []string{"state", "mv", "--force"} version, _ := semver.NewVersion("1.6.0") - result, err := checkStateCommand(args, version) - if err != nil || !slices.Equal(result, []string{"state", "mv"}) { - t.Errorf("Expected no error, got: %v, %v", err, result) + os.Setenv(STATE_COMMAND_VAR, "true") + err := checkStateCommand(args, version) + if err != nil { + t.Errorf("Expected no error, got: %v", err) } }) } diff --git a/internal/wrapper/wrapper.go b/internal/wrapper/wrapper.go index f88cb20..25cfae2 100644 --- a/internal/wrapper/wrapper.go +++ b/internal/wrapper/wrapper.go @@ -53,8 +53,7 @@ func RunTerraform(args []string, arch string) (int, error) { log.Printf("version '%s' matches all constraints", matchingRelease.Version) - newArgs, err := checkStateCommand(args, matchingRelease.Version) - if err != nil { + if checkStateCommand(args, matchingRelease.Version) != nil { log.SetOutput(os.Stderr) log.Fatal("error: ", err) } @@ -65,7 +64,7 @@ func RunTerraform(args []string, arch string) (int, error) { return 1, err } - return runTerraform(executablePath, newArgs) + return runTerraform(executablePath, args) } func ensureCacheDirectory() (string, error) {