diff --git a/go/vt/schemadiff/diff_test.go b/go/vt/schemadiff/diff_test.go index 291049a22ad..8676e1bab29 100644 --- a/go/vt/schemadiff/diff_test.go +++ b/go/vt/schemadiff/diff_test.go @@ -17,6 +17,7 @@ limitations under the License. package schemadiff import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -403,6 +404,7 @@ func TestDiffViews(t *testing.T) { } func TestDiffSchemas(t *testing.T) { + ctx := context.Background() tt := []struct { name string from string @@ -806,7 +808,7 @@ func TestDiffSchemas(t *testing.T) { } else { assert.NoError(t, err) - diffs, err := diff.OrderedDiffs() + diffs, err := diff.OrderedDiffs(ctx) assert.NoError(t, err) statements := []string{} cstatements := []string{} @@ -858,6 +860,7 @@ func TestDiffSchemas(t *testing.T) { } func TestSchemaApplyError(t *testing.T) { + ctx := context.Background() tt := []struct { name string from string @@ -900,7 +903,7 @@ func TestSchemaApplyError(t *testing.T) { { diff, err := schema1.SchemaDiff(schema2, hints) require.NoError(t, err) - diffs, err := diff.OrderedDiffs() + diffs, err := diff.OrderedDiffs(ctx) assert.NoError(t, err) assert.NotEmpty(t, diffs) _, err = schema1.Apply(diffs) @@ -911,7 +914,7 @@ func TestSchemaApplyError(t *testing.T) { { diff, err := schema2.SchemaDiff(schema1, hints) require.NoError(t, err) - diffs, err := diff.OrderedDiffs() + diffs, err := diff.OrderedDiffs(ctx) assert.NoError(t, err) assert.NotEmpty(t, diffs, "schema1: %v, schema2: %v", schema1.ToSQL(), schema2.ToSQL()) _, err = schema2.Apply(diffs) diff --git a/go/vt/schemadiff/schema_diff.go b/go/vt/schemadiff/schema_diff.go index e4e61786283..920584d20e2 100644 --- a/go/vt/schemadiff/schema_diff.go +++ b/go/vt/schemadiff/schema_diff.go @@ -17,6 +17,7 @@ limitations under the License. package schemadiff import ( + "context" "fmt" "math" "sort" @@ -109,35 +110,38 @@ Modified to have an early break // permutateDiffs calls `callback` with each permutation of a. If the function returns `true`, that means // the callback has returned `true` for an early break, thus possibly not all permutations have been evaluated. -func permutateDiffs(diffs []EntityDiff, callback func([]EntityDiff) (earlyBreak bool)) (earlyBreak bool) { +func permutateDiffs(ctx context.Context, diffs []EntityDiff, callback func([]EntityDiff) (earlyBreak bool)) (earlyBreak bool, err error) { if len(diffs) == 0 { - return false + return false, nil } // Sort by a heristic (DROPs first, ALTERs next, CREATEs last). This ordering is then used first in the permutation // search and serves as seed for the rest of permutations. sortDiffsHeuristically(diffs) - return permDiff(diffs, callback, 0) + return permDiff(ctx, diffs, callback, 0) } // permDiff is a recursive function to permutate given `a` and call `callback` for each permutation. // If `callback` returns `true`, then so does this function, and this indicates a request for an early // break, in which case this function will not be called again. -func permDiff(a []EntityDiff, callback func([]EntityDiff) (earlyBreak bool), i int) (earlyBreak bool) { +func permDiff(ctx context.Context, a []EntityDiff, callback func([]EntityDiff) (earlyBreak bool), i int) (earlyBreak bool, err error) { + if err := ctx.Err(); err != nil { + return true, err // early break + } if i > len(a) { - return callback(a) + return callback(a), nil } - if permDiff(a, callback, i+1) { - return true + if brk, err := permDiff(ctx, a, callback, i+1); brk { + return true, err } for j := i + 1; j < len(a); j++ { a[i], a[j] = a[j], a[i] - if permDiff(a, callback, i+1) { - return true + if brk, err := permDiff(ctx, a, callback, i+1); brk { + return true, err } a[i], a[j] = a[j], a[i] } - return false + return false, nil } // SchemaDiff is a rich diff between two schemas. It includes the following: @@ -269,7 +273,7 @@ func (d *SchemaDiff) HasSequentialExecutionDependencies() bool { // OrderedDiffs returns the list of diff in applicable order, if possible. This is a linearized representation // where diffs may be applied in-order one after another, keeping the schema in valid state at all times. -func (d *SchemaDiff) OrderedDiffs() ([]EntityDiff, error) { +func (d *SchemaDiff) OrderedDiffs(ctx context.Context) ([]EntityDiff, error) { lastGoodSchema := d.schema var orderedDiffs []EntityDiff m := d.r.Map() @@ -286,7 +290,7 @@ func (d *SchemaDiff) OrderedDiffs() ([]EntityDiff, error) { } // We will now permutate the diffs in this equivalence class, and hopefully find // a valid permutation (one where if we apply the diffs in-order, the schema remains valid throughout the process) - foundValidPathForClass := permutateDiffs(classDiffs, func(permutatedDiffs []EntityDiff) bool { + foundValidPathForClass, err := permutateDiffs(ctx, classDiffs, func(permutatedDiffs []EntityDiff) bool { permutationSchema := lastGoodSchema.copy() // We want to apply the changes one by one, and validate the schema after each change for i := range permutatedDiffs { @@ -301,6 +305,9 @@ func (d *SchemaDiff) OrderedDiffs() ([]EntityDiff, error) { lastGoodSchema = permutationSchema return true // early break! No need to keep searching }) + if err != nil { + return nil, err + } if !foundValidPathForClass { // In this equivalence class, there is no valid permutation. We cannot linearize the diffs. return nil, &ImpossibleApplyDiffOrderError{ diff --git a/go/vt/schemadiff/schema_diff_test.go b/go/vt/schemadiff/schema_diff_test.go index bd7073325ed..35c2c7d0871 100644 --- a/go/vt/schemadiff/schema_diff_test.go +++ b/go/vt/schemadiff/schema_diff_test.go @@ -17,6 +17,7 @@ limitations under the License. package schemadiff import ( + "context" "strings" "testing" @@ -25,6 +26,7 @@ import ( ) func TestPermutations(t *testing.T) { + ctx := context.Background() tt := []struct { name string fromQueries []string @@ -115,7 +117,7 @@ func TestPermutations(t *testing.T) { allDiffs := schemaDiff.UnorderedDiffs() originalSingleString := toSingleString(allDiffs) numEquals := 0 - earlyBreak := permutateDiffs(allDiffs, func(pdiffs []EntityDiff) (earlyBreak bool) { + earlyBreak, err := permutateDiffs(ctx, allDiffs, func(pdiffs []EntityDiff) (earlyBreak bool) { defer func() { iteration++ }() // cover all permutations singleString := toSingleString(pdiffs) @@ -125,6 +127,7 @@ func TestPermutations(t *testing.T) { } return false }) + assert.NoError(t, err) if len(allDiffs) > 0 { assert.Equal(t, numEquals, 1) } @@ -135,7 +138,7 @@ func TestPermutations(t *testing.T) { allPerms := map[string]bool{} allDiffs := schemaDiff.UnorderedDiffs() originalSingleString := toSingleString(allDiffs) - earlyBreak := permutateDiffs(allDiffs, func(pdiffs []EntityDiff) (earlyBreak bool) { + earlyBreak, err := permutateDiffs(ctx, allDiffs, func(pdiffs []EntityDiff) (earlyBreak bool) { // Single visit allPerms[toSingleString(pdiffs)] = true // First permutation should be the same as original @@ -143,6 +146,7 @@ func TestPermutations(t *testing.T) { // early break; this callback function should not be invoked again return true }) + assert.NoError(t, err) if len(allDiffs) > 0 { assert.True(t, earlyBreak) assert.Equal(t, 1, len(allPerms)) @@ -156,7 +160,20 @@ func TestPermutations(t *testing.T) { } } +func TestPermutationsContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + allDiffs := []EntityDiff{&DropViewEntityDiff{}} + earlyBreak, err := permutateDiffs(ctx, allDiffs, func(pdiffs []EntityDiff) (earlyBreak bool) { + return false + }) + assert.True(t, earlyBreak) // proves that termination was due to context cancel + assert.Error(t, err) // proves that termination was due to context cancel +} + func TestSchemaDiff(t *testing.T) { + ctx := context.Background() var ( createQueries = []string{ "create table t1 (id int primary key, info int not null);", @@ -681,7 +698,7 @@ func TestSchemaDiff(t *testing.T) { assert.Equalf(t, tc.expectDeps, len(deps), "found deps: %v", depsKeys) assert.Equal(t, tc.sequential, schemaDiff.HasSequentialExecutionDependencies()) - orderedDiffs, err := schemaDiff.OrderedDiffs() + orderedDiffs, err := schemaDiff.OrderedDiffs(ctx) if tc.conflictingDiffs > 0 { require.Greater(t, tc.conflictingDiffs, 1) // self integrity. If there's a conflict, then obviously there's at least two conflicting diffs (a single diff has nothing to conflict with) assert.Error(t, err)