Skip to content

Commit

Permalink
context-based termination
Browse files Browse the repository at this point in the history
Signed-off-by: Shlomi Noach <[email protected]>
  • Loading branch information
shlomi-noach committed Oct 11, 2023
1 parent af40f85 commit 25c92fe
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 18 deletions.
9 changes: 6 additions & 3 deletions go/vt/schemadiff/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package schemadiff

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -403,6 +404,7 @@ func TestDiffViews(t *testing.T) {
}

func TestDiffSchemas(t *testing.T) {
ctx := context.Background()
tt := []struct {
name string
from string
Expand Down Expand Up @@ -806,7 +808,7 @@ func TestDiffSchemas(t *testing.T) {
} else {
assert.NoError(t, err)

diffs, err := diff.OrderedDiffs()
diffs, err := diff.OrderedDiffs(ctx)
assert.NoError(t, err)
statements := []string{}
cstatements := []string{}
Expand Down Expand Up @@ -858,6 +860,7 @@ func TestDiffSchemas(t *testing.T) {
}

func TestSchemaApplyError(t *testing.T) {
ctx := context.Background()
tt := []struct {
name string
from string
Expand Down Expand Up @@ -900,7 +903,7 @@ func TestSchemaApplyError(t *testing.T) {
{
diff, err := schema1.SchemaDiff(schema2, hints)
require.NoError(t, err)
diffs, err := diff.OrderedDiffs()
diffs, err := diff.OrderedDiffs(ctx)
assert.NoError(t, err)
assert.NotEmpty(t, diffs)
_, err = schema1.Apply(diffs)
Expand All @@ -911,7 +914,7 @@ func TestSchemaApplyError(t *testing.T) {
{
diff, err := schema2.SchemaDiff(schema1, hints)
require.NoError(t, err)
diffs, err := diff.OrderedDiffs()
diffs, err := diff.OrderedDiffs(ctx)
assert.NoError(t, err)
assert.NotEmpty(t, diffs, "schema1: %v, schema2: %v", schema1.ToSQL(), schema2.ToSQL())
_, err = schema2.Apply(diffs)
Expand Down
31 changes: 19 additions & 12 deletions go/vt/schemadiff/schema_diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package schemadiff

import (
"context"
"fmt"
"math"
"sort"
Expand Down Expand Up @@ -109,35 +110,38 @@ Modified to have an early break

// permutateDiffs calls `callback` with each permutation of a. If the function returns `true`, that means
// the callback has returned `true` for an early break, thus possibly not all permutations have been evaluated.
func permutateDiffs(diffs []EntityDiff, callback func([]EntityDiff) (earlyBreak bool)) (earlyBreak bool) {
func permutateDiffs(ctx context.Context, diffs []EntityDiff, callback func([]EntityDiff) (earlyBreak bool)) (earlyBreak bool, err error) {
if len(diffs) == 0 {
return false
return false, nil
}
// Sort by a heristic (DROPs first, ALTERs next, CREATEs last). This ordering is then used first in the permutation
// search and serves as seed for the rest of permutations.
sortDiffsHeuristically(diffs)

return permDiff(diffs, callback, 0)
return permDiff(ctx, diffs, callback, 0)
}

// permDiff is a recursive function to permutate given `a` and call `callback` for each permutation.
// If `callback` returns `true`, then so does this function, and this indicates a request for an early
// break, in which case this function will not be called again.
func permDiff(a []EntityDiff, callback func([]EntityDiff) (earlyBreak bool), i int) (earlyBreak bool) {
func permDiff(ctx context.Context, a []EntityDiff, callback func([]EntityDiff) (earlyBreak bool), i int) (earlyBreak bool, err error) {
if err := ctx.Err(); err != nil {
return true, err // early break
}
if i > len(a) {
return callback(a)
return callback(a), nil
}
if permDiff(a, callback, i+1) {
return true
if brk, err := permDiff(ctx, a, callback, i+1); brk {
return true, err
}
for j := i + 1; j < len(a); j++ {
a[i], a[j] = a[j], a[i]
if permDiff(a, callback, i+1) {
return true
if brk, err := permDiff(ctx, a, callback, i+1); brk {
return true, err
}
a[i], a[j] = a[j], a[i]
}
return false
return false, nil
}

// SchemaDiff is a rich diff between two schemas. It includes the following:
Expand Down Expand Up @@ -269,7 +273,7 @@ func (d *SchemaDiff) HasSequentialExecutionDependencies() bool {

// OrderedDiffs returns the list of diff in applicable order, if possible. This is a linearized representation
// where diffs may be applied in-order one after another, keeping the schema in valid state at all times.
func (d *SchemaDiff) OrderedDiffs() ([]EntityDiff, error) {
func (d *SchemaDiff) OrderedDiffs(ctx context.Context) ([]EntityDiff, error) {
lastGoodSchema := d.schema
var orderedDiffs []EntityDiff
m := d.r.Map()
Expand All @@ -286,7 +290,7 @@ func (d *SchemaDiff) OrderedDiffs() ([]EntityDiff, error) {
}
// We will now permutate the diffs in this equivalence class, and hopefully find
// a valid permutation (one where if we apply the diffs in-order, the schema remains valid throughout the process)
foundValidPathForClass := permutateDiffs(classDiffs, func(permutatedDiffs []EntityDiff) bool {
foundValidPathForClass, err := permutateDiffs(ctx, classDiffs, func(permutatedDiffs []EntityDiff) bool {
permutationSchema := lastGoodSchema.copy()
// We want to apply the changes one by one, and validate the schema after each change
for i := range permutatedDiffs {
Expand All @@ -301,6 +305,9 @@ func (d *SchemaDiff) OrderedDiffs() ([]EntityDiff, error) {
lastGoodSchema = permutationSchema
return true // early break! No need to keep searching
})
if err != nil {
return nil, err
}
if !foundValidPathForClass {
// In this equivalence class, there is no valid permutation. We cannot linearize the diffs.
return nil, &ImpossibleApplyDiffOrderError{
Expand Down
23 changes: 20 additions & 3 deletions go/vt/schemadiff/schema_diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package schemadiff

import (
"context"
"strings"
"testing"

Expand All @@ -25,6 +26,7 @@ import (
)

func TestPermutations(t *testing.T) {
ctx := context.Background()
tt := []struct {
name string
fromQueries []string
Expand Down Expand Up @@ -115,7 +117,7 @@ func TestPermutations(t *testing.T) {
allDiffs := schemaDiff.UnorderedDiffs()
originalSingleString := toSingleString(allDiffs)
numEquals := 0
earlyBreak := permutateDiffs(allDiffs, func(pdiffs []EntityDiff) (earlyBreak bool) {
earlyBreak, err := permutateDiffs(ctx, allDiffs, func(pdiffs []EntityDiff) (earlyBreak bool) {
defer func() { iteration++ }()
// cover all permutations
singleString := toSingleString(pdiffs)
Expand All @@ -125,6 +127,7 @@ func TestPermutations(t *testing.T) {
}
return false
})
assert.NoError(t, err)
if len(allDiffs) > 0 {
assert.Equal(t, numEquals, 1)
}
Expand All @@ -135,14 +138,15 @@ func TestPermutations(t *testing.T) {
allPerms := map[string]bool{}
allDiffs := schemaDiff.UnorderedDiffs()
originalSingleString := toSingleString(allDiffs)
earlyBreak := permutateDiffs(allDiffs, func(pdiffs []EntityDiff) (earlyBreak bool) {
earlyBreak, err := permutateDiffs(ctx, allDiffs, func(pdiffs []EntityDiff) (earlyBreak bool) {
// Single visit
allPerms[toSingleString(pdiffs)] = true
// First permutation should be the same as original
require.Equal(t, originalSingleString, toSingleString(pdiffs))
// early break; this callback function should not be invoked again
return true
})
assert.NoError(t, err)
if len(allDiffs) > 0 {
assert.True(t, earlyBreak)
assert.Equal(t, 1, len(allPerms))
Expand All @@ -156,7 +160,20 @@ func TestPermutations(t *testing.T) {
}
}

func TestPermutationsContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()

allDiffs := []EntityDiff{&DropViewEntityDiff{}}
earlyBreak, err := permutateDiffs(ctx, allDiffs, func(pdiffs []EntityDiff) (earlyBreak bool) {
return false
})
assert.True(t, earlyBreak) // proves that termination was due to context cancel
assert.Error(t, err) // proves that termination was due to context cancel
}

func TestSchemaDiff(t *testing.T) {
ctx := context.Background()
var (
createQueries = []string{
"create table t1 (id int primary key, info int not null);",
Expand Down Expand Up @@ -681,7 +698,7 @@ func TestSchemaDiff(t *testing.T) {
assert.Equalf(t, tc.expectDeps, len(deps), "found deps: %v", depsKeys)
assert.Equal(t, tc.sequential, schemaDiff.HasSequentialExecutionDependencies())

orderedDiffs, err := schemaDiff.OrderedDiffs()
orderedDiffs, err := schemaDiff.OrderedDiffs(ctx)
if tc.conflictingDiffs > 0 {
require.Greater(t, tc.conflictingDiffs, 1) // self integrity. If there's a conflict, then obviously there's at least two conflicting diffs (a single diff has nothing to conflict with)
assert.Error(t, err)
Expand Down

0 comments on commit 25c92fe

Please sign in to comment.