From 5b65358146627b63005e1ba9233bce603822dd84 Mon Sep 17 00:00:00 2001 From: shreyas-goenka <88374338+shreyas-goenka@users.noreply.github.com> Date: Thu, 18 Jul 2024 15:17:59 +0530 Subject: [PATCH] Use local Terraform state only when lineage match (#1588) ## Changes DABs deployments should be isolated if `root_path` and workspace host are different. This PR fixes a bug where local terraform state gets piggybacked if the same cwd is used to deploy two isolated deployments for the same bundle target. This can happen if: 1. A user switches to a different identity on the same machine. 2. The workspace host URL the bundle/target points to is changed. 3. A user changes the `root_path` while doing bundle development. To solve this problem we rely on the lineage field available in the terraform state, which is a uuid identifying unique terraform deployments. There's a 1:1 mapping between a terraform deployment and a bundle deployment. For more details on how lineage works in terraform, see: https://developer.hashicorp.com/terraform/language/state/backends#manual-state-pull-push ## Tests Manually verified that changing the identity no longer results in the incorrect terraform state being used. Also, new unit tests are added. --- bundle/deploy/terraform/state_pull.go | 115 ++++++++----- bundle/deploy/terraform/state_pull_test.go | 177 +++++++++++++-------- bundle/deploy/terraform/state_push_test.go | 2 +- bundle/deploy/terraform/state_test.go | 6 +- bundle/deploy/terraform/util.go | 33 ---- bundle/deploy/terraform/util_test.go | 33 ---- 6 files changed, 186 insertions(+), 180 deletions(-) diff --git a/bundle/deploy/terraform/state_pull.go b/bundle/deploy/terraform/state_pull.go index cc7d342747..9a5b910076 100644 --- a/bundle/deploy/terraform/state_pull.go +++ b/bundle/deploy/terraform/state_pull.go @@ -1,8 +1,8 @@ package terraform import ( - "bytes" "context" + "encoding/json" "errors" "io" "io/fs" @@ -12,10 +12,14 @@ import ( "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/deploy" "github.com/databricks/cli/libs/diag" - "github.com/databricks/cli/libs/filer" "github.com/databricks/cli/libs/log" ) +type tfState struct { + Serial int64 `json:"serial"` + Lineage string `json:"lineage"` +} + type statePull struct { filerFactory deploy.FilerFactory } @@ -24,74 +28,105 @@ func (l *statePull) Name() string { return "terraform:state-pull" } -func (l *statePull) remoteState(ctx context.Context, f filer.Filer) (*bytes.Buffer, error) { - // Download state file from filer to local cache directory. - remote, err := f.Read(ctx, TerraformStateFileName) +func (l *statePull) remoteState(ctx context.Context, b *bundle.Bundle) (*tfState, []byte, error) { + f, err := l.filerFactory(b) if err != nil { - // On first deploy this state file doesn't yet exist. - if errors.Is(err, fs.ErrNotExist) { - return nil, nil - } - return nil, err + return nil, nil, err } - defer remote.Close() - - var buf bytes.Buffer - _, err = io.Copy(&buf, remote) + r, err := f.Read(ctx, TerraformStateFileName) if err != nil { - return nil, err + return nil, nil, err } + defer r.Close() - return &buf, nil -} + content, err := io.ReadAll(r) + if err != nil { + return nil, nil, err + } -func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics { - f, err := l.filerFactory(b) + state := &tfState{} + err = json.Unmarshal(content, state) if err != nil { - return diag.FromErr(err) + return nil, nil, err } + return state, content, nil +} + +func (l *statePull) localState(ctx context.Context, b *bundle.Bundle) (*tfState, error) { dir, err := Dir(ctx, b) if err != nil { - return diag.FromErr(err) + return nil, err } - // Download state file from filer to local cache directory. - log.Infof(ctx, "Opening remote state file") - remote, err := l.remoteState(ctx, f) + content, err := os.ReadFile(filepath.Join(dir, TerraformStateFileName)) if err != nil { - log.Infof(ctx, "Unable to open remote state file: %s", err) - return diag.FromErr(err) + return nil, err } - if remote == nil { - log.Infof(ctx, "Remote state file does not exist") - return nil + + state := &tfState{} + err = json.Unmarshal(content, state) + if err != nil { + return nil, err } - // Expect the state file to live under dir. - local, err := os.OpenFile(filepath.Join(dir, TerraformStateFileName), os.O_CREATE|os.O_RDWR, 0600) + return state, nil +} + +func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics { + dir, err := Dir(ctx, b) if err != nil { return diag.FromErr(err) } - defer local.Close() - if !IsLocalStateStale(local, bytes.NewReader(remote.Bytes())) { - log.Infof(ctx, "Local state is the same or newer, ignoring remote state") + localStatePath := filepath.Join(dir, TerraformStateFileName) + + // Case: Remote state file does not exist. In this case we fallback to using the + // local Terraform state. This allows users to change the "root_path" their bundle is + // configured with. + remoteState, remoteContent, err := l.remoteState(ctx, b) + if errors.Is(err, fs.ErrNotExist) { + log.Infof(ctx, "Remote state file does not exist. Using local Terraform state.") return nil } + if err != nil { + return diag.Errorf("failed to read remote state file: %v", err) + } - // Truncating the file before writing - local.Truncate(0) - local.Seek(0, 0) + // Expected invariant: remote state file should have a lineage UUID. Error + // if that's not the case. + if remoteState.Lineage == "" { + return diag.Errorf("remote state file does not have a lineage") + } - // Write file to disk. - log.Infof(ctx, "Writing remote state file to local cache directory") - _, err = io.Copy(local, bytes.NewReader(remote.Bytes())) + // Case: Local state file does not exist. In this case we should rely on the remote state file. + localState, err := l.localState(ctx, b) + if errors.Is(err, fs.ErrNotExist) { + log.Infof(ctx, "Local state file does not exist. Using remote Terraform state.") + err := os.WriteFile(localStatePath, remoteContent, 0600) + return diag.FromErr(err) + } if err != nil { + return diag.Errorf("failed to read local state file: %v", err) + } + + // If the lineage does not match, the Terraform state files do not correspond to the same deployment. + if localState.Lineage != remoteState.Lineage { + log.Infof(ctx, "Remote and local state lineages do not match. Using remote Terraform state. Invalidating local Terraform state.") + err := os.WriteFile(localStatePath, remoteContent, 0600) + return diag.FromErr(err) + } + + // If the remote state is newer than the local state, we should use the remote state. + if remoteState.Serial > localState.Serial { + log.Infof(ctx, "Remote state is newer than local state. Using remote Terraform state.") + err := os.WriteFile(localStatePath, remoteContent, 0600) return diag.FromErr(err) } + // default: local state is newer or equal to remote state in terms of serial sequence. + // It is also of the same lineage. Keep using the local state. return nil } diff --git a/bundle/deploy/terraform/state_pull_test.go b/bundle/deploy/terraform/state_pull_test.go index 26297bfcbe..39937a3cc2 100644 --- a/bundle/deploy/terraform/state_pull_test.go +++ b/bundle/deploy/terraform/state_pull_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/mock" ) -func mockStateFilerForPull(t *testing.T, contents map[string]int, merr error) filer.Filer { +func mockStateFilerForPull(t *testing.T, contents map[string]any, merr error) filer.Filer { buf, err := json.Marshal(contents) assert.NoError(t, err) @@ -41,86 +41,123 @@ func statePullTestBundle(t *testing.T) *bundle.Bundle { } } -func TestStatePullLocalMissingRemoteMissing(t *testing.T) { - m := &statePull{ - identityFiler(mockStateFilerForPull(t, nil, os.ErrNotExist)), - } +func TestStatePullLocalErrorWhenRemoteHasNoLineage(t *testing.T) { + m := &statePull{} - ctx := context.Background() - b := statePullTestBundle(t) - diags := bundle.Apply(ctx, b, m) - assert.NoError(t, diags.Error()) + t.Run("no local state", func(t *testing.T) { + // setup remote state. + m.filerFactory = identityFiler(mockStateFilerForPull(t, map[string]any{"serial": 5}, nil)) - // Confirm that no local state file has been written. - _, err := os.Stat(localStateFile(t, ctx, b)) - assert.ErrorIs(t, err, fs.ErrNotExist) -} + ctx := context.Background() + b := statePullTestBundle(t) + diags := bundle.Apply(ctx, b, m) + assert.EqualError(t, diags.Error(), "remote state file does not have a lineage") + }) -func TestStatePullLocalMissingRemotePresent(t *testing.T) { - m := &statePull{ - identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5}, nil)), - } + t.Run("local state with lineage", func(t *testing.T) { + // setup remote state. + m.filerFactory = identityFiler(mockStateFilerForPull(t, map[string]any{"serial": 5}, nil)) - ctx := context.Background() - b := statePullTestBundle(t) - diags := bundle.Apply(ctx, b, m) - assert.NoError(t, diags.Error()) + ctx := context.Background() + b := statePullTestBundle(t) + writeLocalState(t, ctx, b, map[string]any{"serial": 5, "lineage": "aaaa"}) - // Confirm that the local state file has been updated. - localState := readLocalState(t, ctx, b) - assert.Equal(t, map[string]int{"serial": 5}, localState) + diags := bundle.Apply(ctx, b, m) + assert.EqualError(t, diags.Error(), "remote state file does not have a lineage") + }) } -func TestStatePullLocalStale(t *testing.T) { - m := &statePull{ - identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5}, nil)), - } - - ctx := context.Background() - b := statePullTestBundle(t) +func TestStatePullLocal(t *testing.T) { + tcases := []struct { + name string - // Write a stale local state file. - writeLocalState(t, ctx, b, map[string]int{"serial": 4}) - diags := bundle.Apply(ctx, b, m) - assert.NoError(t, diags.Error()) + // remote state before applying the pull mutators + remote map[string]any - // Confirm that the local state file has been updated. - localState := readLocalState(t, ctx, b) - assert.Equal(t, map[string]int{"serial": 5}, localState) -} + // local state before applying the pull mutators + local map[string]any -func TestStatePullLocalEqual(t *testing.T) { - m := &statePull{ - identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5, "some_other_key": 123}, nil)), + // expected local state after applying the pull mutators + expected map[string]any + }{ + { + name: "remote missing, local missing", + remote: nil, + local: nil, + expected: nil, + }, + { + name: "remote missing, local present", + remote: nil, + local: map[string]any{"serial": 5, "lineage": "aaaa"}, + // fallback to local state, since remote state is missing. + expected: map[string]any{"serial": float64(5), "lineage": "aaaa"}, + }, + { + name: "local stale", + remote: map[string]any{"serial": 10, "lineage": "aaaa", "some_other_key": 123}, + local: map[string]any{"serial": 5, "lineage": "aaaa"}, + // use remote, since remote is newer. + expected: map[string]any{"serial": float64(10), "lineage": "aaaa", "some_other_key": float64(123)}, + }, + { + name: "local equal", + remote: map[string]any{"serial": 5, "lineage": "aaaa", "some_other_key": 123}, + local: map[string]any{"serial": 5, "lineage": "aaaa"}, + // use local state, since they are equal in terms of serial sequence. + expected: map[string]any{"serial": float64(5), "lineage": "aaaa"}, + }, + { + name: "local newer", + remote: map[string]any{"serial": 5, "lineage": "aaaa", "some_other_key": 123}, + local: map[string]any{"serial": 6, "lineage": "aaaa"}, + // use local state, since local is newer. + expected: map[string]any{"serial": float64(6), "lineage": "aaaa"}, + }, + { + name: "remote and local have different lineages", + remote: map[string]any{"serial": 5, "lineage": "aaaa"}, + local: map[string]any{"serial": 10, "lineage": "bbbb"}, + // use remote, since lineages do not match. + expected: map[string]any{"serial": float64(5), "lineage": "aaaa"}, + }, + { + name: "local is missing lineage", + remote: map[string]any{"serial": 5, "lineage": "aaaa"}, + local: map[string]any{"serial": 10}, + // use remote, since local does not have lineage. + expected: map[string]any{"serial": float64(5), "lineage": "aaaa"}, + }, } - ctx := context.Background() - b := statePullTestBundle(t) - - // Write a local state file with the same serial as the remote. - writeLocalState(t, ctx, b, map[string]int{"serial": 5}) - diags := bundle.Apply(ctx, b, m) - assert.NoError(t, diags.Error()) - - // Confirm that the local state file has not been updated. - localState := readLocalState(t, ctx, b) - assert.Equal(t, map[string]int{"serial": 5}, localState) -} - -func TestStatePullLocalNewer(t *testing.T) { - m := &statePull{ - identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5, "some_other_key": 123}, nil)), + for _, tc := range tcases { + t.Run(tc.name, func(t *testing.T) { + m := &statePull{} + if tc.remote == nil { + // nil represents no remote state file. + m.filerFactory = identityFiler(mockStateFilerForPull(t, nil, os.ErrNotExist)) + } else { + m.filerFactory = identityFiler(mockStateFilerForPull(t, tc.remote, nil)) + } + + ctx := context.Background() + b := statePullTestBundle(t) + if tc.local != nil { + writeLocalState(t, ctx, b, tc.local) + } + + diags := bundle.Apply(ctx, b, m) + assert.NoError(t, diags.Error()) + + if tc.expected == nil { + // nil represents no local state file is expected. + _, err := os.Stat(localStateFile(t, ctx, b)) + assert.ErrorIs(t, err, fs.ErrNotExist) + } else { + localState := readLocalState(t, ctx, b) + assert.Equal(t, tc.expected, localState) + + } + }) } - - ctx := context.Background() - b := statePullTestBundle(t) - - // Write a local state file with a newer serial as the remote. - writeLocalState(t, ctx, b, map[string]int{"serial": 6}) - diags := bundle.Apply(ctx, b, m) - assert.NoError(t, diags.Error()) - - // Confirm that the local state file has not been updated. - localState := readLocalState(t, ctx, b) - assert.Equal(t, map[string]int{"serial": 6}, localState) } diff --git a/bundle/deploy/terraform/state_push_test.go b/bundle/deploy/terraform/state_push_test.go index e054773f31..ac74f345d2 100644 --- a/bundle/deploy/terraform/state_push_test.go +++ b/bundle/deploy/terraform/state_push_test.go @@ -55,7 +55,7 @@ func TestStatePush(t *testing.T) { b := statePushTestBundle(t) // Write a stale local state file. - writeLocalState(t, ctx, b, map[string]int{"serial": 4}) + writeLocalState(t, ctx, b, map[string]any{"serial": 4}) diags := bundle.Apply(ctx, b, m) assert.NoError(t, diags.Error()) } diff --git a/bundle/deploy/terraform/state_test.go b/bundle/deploy/terraform/state_test.go index ff32506255..73d7cb0dee 100644 --- a/bundle/deploy/terraform/state_test.go +++ b/bundle/deploy/terraform/state_test.go @@ -26,19 +26,19 @@ func localStateFile(t *testing.T, ctx context.Context, b *bundle.Bundle) string return filepath.Join(dir, TerraformStateFileName) } -func readLocalState(t *testing.T, ctx context.Context, b *bundle.Bundle) map[string]int { +func readLocalState(t *testing.T, ctx context.Context, b *bundle.Bundle) map[string]any { f, err := os.Open(localStateFile(t, ctx, b)) require.NoError(t, err) defer f.Close() - var contents map[string]int + var contents map[string]any dec := json.NewDecoder(f) err = dec.Decode(&contents) require.NoError(t, err) return contents } -func writeLocalState(t *testing.T, ctx context.Context, b *bundle.Bundle, contents map[string]int) { +func writeLocalState(t *testing.T, ctx context.Context, b *bundle.Bundle, contents map[string]any) { f, err := os.Create(localStateFile(t, ctx, b)) require.NoError(t, err) defer f.Close() diff --git a/bundle/deploy/terraform/util.go b/bundle/deploy/terraform/util.go index 1a8a83ac73..64d667b5f6 100644 --- a/bundle/deploy/terraform/util.go +++ b/bundle/deploy/terraform/util.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "io" "os" "path/filepath" @@ -22,10 +21,6 @@ type resourcesState struct { const SupportedStateVersion = 4 -type serialState struct { - Serial int `json:"serial"` -} - type stateResource struct { Type string `json:"type"` Name string `json:"name"` @@ -41,34 +36,6 @@ type stateInstanceAttributes struct { ID string `json:"id"` } -func IsLocalStateStale(local io.Reader, remote io.Reader) bool { - localState, err := loadState(local) - if err != nil { - return true - } - - remoteState, err := loadState(remote) - if err != nil { - return false - } - - return localState.Serial < remoteState.Serial -} - -func loadState(input io.Reader) (*serialState, error) { - content, err := io.ReadAll(input) - if err != nil { - return nil, err - } - var s serialState - err = json.Unmarshal(content, &s) - if err != nil { - return nil, err - } - - return &s, nil -} - func ParseResourcesState(ctx context.Context, b *bundle.Bundle) (*resourcesState, error) { cacheDir, err := Dir(ctx, b) if err != nil { diff --git a/bundle/deploy/terraform/util_test.go b/bundle/deploy/terraform/util_test.go index 8949ebca82..251a7c256a 100644 --- a/bundle/deploy/terraform/util_test.go +++ b/bundle/deploy/terraform/util_test.go @@ -2,48 +2,15 @@ package terraform import ( "context" - "fmt" "os" "path/filepath" - "strings" "testing" - "testing/iotest" "github.com/databricks/cli/bundle" "github.com/databricks/cli/bundle/config" "github.com/stretchr/testify/assert" ) -func TestLocalStateIsNewer(t *testing.T) { - local := strings.NewReader(`{"serial": 5}`) - remote := strings.NewReader(`{"serial": 4}`) - assert.False(t, IsLocalStateStale(local, remote)) -} - -func TestLocalStateIsOlder(t *testing.T) { - local := strings.NewReader(`{"serial": 5}`) - remote := strings.NewReader(`{"serial": 6}`) - assert.True(t, IsLocalStateStale(local, remote)) -} - -func TestLocalStateIsTheSame(t *testing.T) { - local := strings.NewReader(`{"serial": 5}`) - remote := strings.NewReader(`{"serial": 5}`) - assert.False(t, IsLocalStateStale(local, remote)) -} - -func TestLocalStateMarkStaleWhenFailsToLoad(t *testing.T) { - local := iotest.ErrReader(fmt.Errorf("Random error")) - remote := strings.NewReader(`{"serial": 5}`) - assert.True(t, IsLocalStateStale(local, remote)) -} - -func TestLocalStateMarkNonStaleWhenRemoteFailsToLoad(t *testing.T) { - local := strings.NewReader(`{"serial": 5}`) - remote := iotest.ErrReader(fmt.Errorf("Random error")) - assert.False(t, IsLocalStateStale(local, remote)) -} - func TestParseResourcesStateWithNoFile(t *testing.T) { b := &bundle.Bundle{ RootPath: t.TempDir(),