Skip to content

Commit

Permalink
Use local Terraform state only when lineage match (#1588)
Browse files Browse the repository at this point in the history
## Changes
DABs deployments should be isolated if `root_path` and workspace host
are different. This PR fixes a bug where local terraform state gets
piggybacked if the same cwd is used to deploy two isolated deployments
for the same bundle target. This can happen if:
1. A user switches to a different identity on the same machine. 
2. The workspace host URL the bundle/target points to is changed.
3. A user changes the `root_path` while doing bundle development.

To solve this problem we rely on the lineage field available in the
terraform state, which is a uuid identifying unique terraform
deployments. There's a 1:1 mapping between a terraform deployment and a
bundle deployment.

For more details on how lineage works in terraform, see:
https://developer.hashicorp.com/terraform/language/state/backends#manual-state-pull-push

## Tests
Manually verified that changing the identity no longer results in the
incorrect terraform state being used. Also, new unit tests are added.
  • Loading branch information
shreyas-goenka authored Jul 18, 2024
1 parent af0114a commit 5b65358
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 180 deletions.
115 changes: 75 additions & 40 deletions bundle/deploy/terraform/state_pull.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package terraform

import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"io/fs"
Expand All @@ -12,10 +12,14 @@ import (
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/deploy"
"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/filer"
"github.com/databricks/cli/libs/log"
)

type tfState struct {
Serial int64 `json:"serial"`
Lineage string `json:"lineage"`
}

type statePull struct {
filerFactory deploy.FilerFactory
}
Expand All @@ -24,74 +28,105 @@ func (l *statePull) Name() string {
return "terraform:state-pull"
}

func (l *statePull) remoteState(ctx context.Context, f filer.Filer) (*bytes.Buffer, error) {
// Download state file from filer to local cache directory.
remote, err := f.Read(ctx, TerraformStateFileName)
func (l *statePull) remoteState(ctx context.Context, b *bundle.Bundle) (*tfState, []byte, error) {
f, err := l.filerFactory(b)
if err != nil {
// On first deploy this state file doesn't yet exist.
if errors.Is(err, fs.ErrNotExist) {
return nil, nil
}
return nil, err
return nil, nil, err
}

defer remote.Close()

var buf bytes.Buffer
_, err = io.Copy(&buf, remote)
r, err := f.Read(ctx, TerraformStateFileName)
if err != nil {
return nil, err
return nil, nil, err
}
defer r.Close()

return &buf, nil
}
content, err := io.ReadAll(r)
if err != nil {
return nil, nil, err
}

func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
f, err := l.filerFactory(b)
state := &tfState{}
err = json.Unmarshal(content, state)
if err != nil {
return diag.FromErr(err)
return nil, nil, err
}

return state, content, nil
}

func (l *statePull) localState(ctx context.Context, b *bundle.Bundle) (*tfState, error) {
dir, err := Dir(ctx, b)
if err != nil {
return diag.FromErr(err)
return nil, err
}

// Download state file from filer to local cache directory.
log.Infof(ctx, "Opening remote state file")
remote, err := l.remoteState(ctx, f)
content, err := os.ReadFile(filepath.Join(dir, TerraformStateFileName))
if err != nil {
log.Infof(ctx, "Unable to open remote state file: %s", err)
return diag.FromErr(err)
return nil, err
}
if remote == nil {
log.Infof(ctx, "Remote state file does not exist")
return nil

state := &tfState{}
err = json.Unmarshal(content, state)
if err != nil {
return nil, err
}

// Expect the state file to live under dir.
local, err := os.OpenFile(filepath.Join(dir, TerraformStateFileName), os.O_CREATE|os.O_RDWR, 0600)
return state, nil
}

func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
dir, err := Dir(ctx, b)
if err != nil {
return diag.FromErr(err)
}
defer local.Close()

if !IsLocalStateStale(local, bytes.NewReader(remote.Bytes())) {
log.Infof(ctx, "Local state is the same or newer, ignoring remote state")
localStatePath := filepath.Join(dir, TerraformStateFileName)

// Case: Remote state file does not exist. In this case we fallback to using the
// local Terraform state. This allows users to change the "root_path" their bundle is
// configured with.
remoteState, remoteContent, err := l.remoteState(ctx, b)
if errors.Is(err, fs.ErrNotExist) {
log.Infof(ctx, "Remote state file does not exist. Using local Terraform state.")
return nil
}
if err != nil {
return diag.Errorf("failed to read remote state file: %v", err)
}

// Truncating the file before writing
local.Truncate(0)
local.Seek(0, 0)
// Expected invariant: remote state file should have a lineage UUID. Error
// if that's not the case.
if remoteState.Lineage == "" {
return diag.Errorf("remote state file does not have a lineage")
}

// Write file to disk.
log.Infof(ctx, "Writing remote state file to local cache directory")
_, err = io.Copy(local, bytes.NewReader(remote.Bytes()))
// Case: Local state file does not exist. In this case we should rely on the remote state file.
localState, err := l.localState(ctx, b)
if errors.Is(err, fs.ErrNotExist) {
log.Infof(ctx, "Local state file does not exist. Using remote Terraform state.")
err := os.WriteFile(localStatePath, remoteContent, 0600)
return diag.FromErr(err)
}
if err != nil {
return diag.Errorf("failed to read local state file: %v", err)
}

// If the lineage does not match, the Terraform state files do not correspond to the same deployment.
if localState.Lineage != remoteState.Lineage {
log.Infof(ctx, "Remote and local state lineages do not match. Using remote Terraform state. Invalidating local Terraform state.")
err := os.WriteFile(localStatePath, remoteContent, 0600)
return diag.FromErr(err)
}

// If the remote state is newer than the local state, we should use the remote state.
if remoteState.Serial > localState.Serial {
log.Infof(ctx, "Remote state is newer than local state. Using remote Terraform state.")
err := os.WriteFile(localStatePath, remoteContent, 0600)
return diag.FromErr(err)
}

// default: local state is newer or equal to remote state in terms of serial sequence.
// It is also of the same lineage. Keep using the local state.
return nil
}

Expand Down
177 changes: 107 additions & 70 deletions bundle/deploy/terraform/state_pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/mock"
)

func mockStateFilerForPull(t *testing.T, contents map[string]int, merr error) filer.Filer {
func mockStateFilerForPull(t *testing.T, contents map[string]any, merr error) filer.Filer {
buf, err := json.Marshal(contents)
assert.NoError(t, err)

Expand All @@ -41,86 +41,123 @@ func statePullTestBundle(t *testing.T) *bundle.Bundle {
}
}

func TestStatePullLocalMissingRemoteMissing(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, nil, os.ErrNotExist)),
}
func TestStatePullLocalErrorWhenRemoteHasNoLineage(t *testing.T) {
m := &statePull{}

ctx := context.Background()
b := statePullTestBundle(t)
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())
t.Run("no local state", func(t *testing.T) {
// setup remote state.
m.filerFactory = identityFiler(mockStateFilerForPull(t, map[string]any{"serial": 5}, nil))

// Confirm that no local state file has been written.
_, err := os.Stat(localStateFile(t, ctx, b))
assert.ErrorIs(t, err, fs.ErrNotExist)
}
ctx := context.Background()
b := statePullTestBundle(t)
diags := bundle.Apply(ctx, b, m)
assert.EqualError(t, diags.Error(), "remote state file does not have a lineage")
})

func TestStatePullLocalMissingRemotePresent(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5}, nil)),
}
t.Run("local state with lineage", func(t *testing.T) {
// setup remote state.
m.filerFactory = identityFiler(mockStateFilerForPull(t, map[string]any{"serial": 5}, nil))

ctx := context.Background()
b := statePullTestBundle(t)
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())
ctx := context.Background()
b := statePullTestBundle(t)
writeLocalState(t, ctx, b, map[string]any{"serial": 5, "lineage": "aaaa"})

// Confirm that the local state file has been updated.
localState := readLocalState(t, ctx, b)
assert.Equal(t, map[string]int{"serial": 5}, localState)
diags := bundle.Apply(ctx, b, m)
assert.EqualError(t, diags.Error(), "remote state file does not have a lineage")
})
}

func TestStatePullLocalStale(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5}, nil)),
}

ctx := context.Background()
b := statePullTestBundle(t)
func TestStatePullLocal(t *testing.T) {
tcases := []struct {
name string

// Write a stale local state file.
writeLocalState(t, ctx, b, map[string]int{"serial": 4})
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())
// remote state before applying the pull mutators
remote map[string]any

// Confirm that the local state file has been updated.
localState := readLocalState(t, ctx, b)
assert.Equal(t, map[string]int{"serial": 5}, localState)
}
// local state before applying the pull mutators
local map[string]any

func TestStatePullLocalEqual(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5, "some_other_key": 123}, nil)),
// expected local state after applying the pull mutators
expected map[string]any
}{
{
name: "remote missing, local missing",
remote: nil,
local: nil,
expected: nil,
},
{
name: "remote missing, local present",
remote: nil,
local: map[string]any{"serial": 5, "lineage": "aaaa"},
// fallback to local state, since remote state is missing.
expected: map[string]any{"serial": float64(5), "lineage": "aaaa"},
},
{
name: "local stale",
remote: map[string]any{"serial": 10, "lineage": "aaaa", "some_other_key": 123},
local: map[string]any{"serial": 5, "lineage": "aaaa"},
// use remote, since remote is newer.
expected: map[string]any{"serial": float64(10), "lineage": "aaaa", "some_other_key": float64(123)},
},
{
name: "local equal",
remote: map[string]any{"serial": 5, "lineage": "aaaa", "some_other_key": 123},
local: map[string]any{"serial": 5, "lineage": "aaaa"},
// use local state, since they are equal in terms of serial sequence.
expected: map[string]any{"serial": float64(5), "lineage": "aaaa"},
},
{
name: "local newer",
remote: map[string]any{"serial": 5, "lineage": "aaaa", "some_other_key": 123},
local: map[string]any{"serial": 6, "lineage": "aaaa"},
// use local state, since local is newer.
expected: map[string]any{"serial": float64(6), "lineage": "aaaa"},
},
{
name: "remote and local have different lineages",
remote: map[string]any{"serial": 5, "lineage": "aaaa"},
local: map[string]any{"serial": 10, "lineage": "bbbb"},
// use remote, since lineages do not match.
expected: map[string]any{"serial": float64(5), "lineage": "aaaa"},
},
{
name: "local is missing lineage",
remote: map[string]any{"serial": 5, "lineage": "aaaa"},
local: map[string]any{"serial": 10},
// use remote, since local does not have lineage.
expected: map[string]any{"serial": float64(5), "lineage": "aaaa"},
},
}

ctx := context.Background()
b := statePullTestBundle(t)

// Write a local state file with the same serial as the remote.
writeLocalState(t, ctx, b, map[string]int{"serial": 5})
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())

// Confirm that the local state file has not been updated.
localState := readLocalState(t, ctx, b)
assert.Equal(t, map[string]int{"serial": 5}, localState)
}

func TestStatePullLocalNewer(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5, "some_other_key": 123}, nil)),
for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {
m := &statePull{}
if tc.remote == nil {
// nil represents no remote state file.
m.filerFactory = identityFiler(mockStateFilerForPull(t, nil, os.ErrNotExist))
} else {
m.filerFactory = identityFiler(mockStateFilerForPull(t, tc.remote, nil))
}

ctx := context.Background()
b := statePullTestBundle(t)
if tc.local != nil {
writeLocalState(t, ctx, b, tc.local)
}

diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())

if tc.expected == nil {
// nil represents no local state file is expected.
_, err := os.Stat(localStateFile(t, ctx, b))
assert.ErrorIs(t, err, fs.ErrNotExist)
} else {
localState := readLocalState(t, ctx, b)
assert.Equal(t, tc.expected, localState)

}
})
}

ctx := context.Background()
b := statePullTestBundle(t)

// Write a local state file with a newer serial as the remote.
writeLocalState(t, ctx, b, map[string]int{"serial": 6})
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())

// Confirm that the local state file has not been updated.
localState := readLocalState(t, ctx, b)
assert.Equal(t, map[string]int{"serial": 6}, localState)
}
2 changes: 1 addition & 1 deletion bundle/deploy/terraform/state_push_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestStatePush(t *testing.T) {
b := statePushTestBundle(t)

// Write a stale local state file.
writeLocalState(t, ctx, b, map[string]int{"serial": 4})
writeLocalState(t, ctx, b, map[string]any{"serial": 4})
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())
}
Loading

0 comments on commit 5b65358

Please sign in to comment.