Skip to content

Commit

Permalink
schemadiff: optimize permutation evaluation (#16435)
Browse files Browse the repository at this point in the history
Signed-off-by: Shlomi Noach <[email protected]>
  • Loading branch information
shlomi-noach authored Jul 22, 2024
1 parent efd8292 commit 485d736
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 19 deletions.
54 changes: 41 additions & 13 deletions go/vt/schemadiff/schema_diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,28 +88,50 @@ Modified to have an early break

// permutateDiffs calls `callback` with each permutation of a. If the function returns `true`, that means
// the callback has returned `true` for an early break, thus possibly not all permutations have been evaluated.
func permutateDiffs(ctx context.Context, diffs []EntityDiff, hints *DiffHints, callback func([]EntityDiff, *DiffHints) (earlyBreak bool)) (earlyBreak bool, err error) {
// callback's `errorIndex` indicates the first index at which the permutation has error, or -1 if there's no such error.
// When `errorIndex` is non negative, then the algorithm skips any further recursive dives following `i`.
func permutateDiffs(
ctx context.Context,
diffs []EntityDiff,
hints *DiffHints,
callback func([]EntityDiff, *DiffHints) (earlyBreak bool, errorIndex int),
) (earlyBreak bool, err error) {
if len(diffs) == 0 {
return false, nil
}
// Sort by a heuristic (DROPs first, ALTERs next, CREATEs last). This ordering is then used first in the permutation
// search and serves as seed for the rest of permutations.

return permDiff(ctx, diffs, hints, callback, 0)
earlyBreak, _, err = permDiff(ctx, diffs, hints, callback, 0)
return earlyBreak, err
}

// permDiff is a recursive function to permutate given `a` and call `callback` for each permutation.
// If `callback` returns `true`, then so does this function, and this indicates a request for an early
// break, in which case this function will not be called again.
func permDiff(ctx context.Context, a []EntityDiff, hints *DiffHints, callback func([]EntityDiff, *DiffHints) (earlyBreak bool), i int) (earlyBreak bool, err error) {
func permDiff(
ctx context.Context,
a []EntityDiff,
hints *DiffHints,
callback func([]EntityDiff, *DiffHints) (earlyBreak bool, errorIndex int),
i int,
) (earlyBreak bool, errorIndex int, err error) {
if err := ctx.Err(); err != nil {
return true, err // early break
return true, -1, err // early break (due to context)
}
if i > len(a) {
return callback(a, hints), nil
earlyBreak, errorIndex := callback(a, hints)
return earlyBreak, errorIndex, nil
}
if brk, err := permDiff(ctx, a, hints, callback, i+1); brk {
return true, err
earlyBreak, errorIndex, err = permDiff(ctx, a, hints, callback, i+1)
if errorIndex >= 0 && i > errorIndex {
// Means the current permutation failed at `errorIndex`, and we're beyond that point. There's no
// point in continuing to permutate the rest of the array.
return false, errorIndex, err
}
if earlyBreak {
// Found a valid permutation, no need to continue
return true, -1, err
}
for j := i + 1; j < len(a); j++ {
// An optimization: we don't really need all possible permutations. We can skip some of the recursive search.
Expand Down Expand Up @@ -150,12 +172,18 @@ func permDiff(ctx context.Context, a []EntityDiff, hints *DiffHints, callback fu
}
// End of optimization
a[i], a[j] = a[j], a[i]
if brk, err := permDiff(ctx, a, hints, callback, i+1); brk {
return true, err
earlyBreak, errorIndex, err = permDiff(ctx, a, hints, callback, i+1)
if errorIndex >= 0 && i > errorIndex {
// Means the current permutation failed at `errorIndex`, and we're beyond that point. There's no
// point in continuing to permutate the rest of the array.
return false, errorIndex, err
}
if earlyBreak {
return true, -1, err
}
a[i], a[j] = a[j], a[i]
}
return false, nil
return false, -1, nil
}

// SchemaDiff is a rich diff between two schemas. It includes the following:
Expand Down Expand Up @@ -316,7 +344,7 @@ func (d *SchemaDiff) OrderedDiffs(ctx context.Context) ([]EntityDiff, error) {
// We will now permutate the diffs in this equivalence class, and hopefully find
// a valid permutation (one where if we apply the diffs in-order, the schema remains valid throughout the process)
tryPermutateDiffs := func(hints *DiffHints) (bool, error) {
return permutateDiffs(ctx, classDiffs, hints, func(permutatedDiffs []EntityDiff, hints *DiffHints) bool {
return permutateDiffs(ctx, classDiffs, hints, func(permutatedDiffs []EntityDiff, hints *DiffHints) (bool, int) {
permutationSchema := lastGoodSchema.copy()
// We want to apply the changes one by one, and validate the schema after each change
for i := range permutatedDiffs {
Expand All @@ -338,14 +366,14 @@ func (d *SchemaDiff) OrderedDiffs(ctx context.Context) ([]EntityDiff, error) {
}
if err := permutationSchema.apply(permutatedDiffs[i:i+1], applyHints); err != nil {
// permutation is invalid
return false // continue searching
return false, i // let the algorithm know there's no point in pursuing any path after `i`
}
}

// Good news, we managed to apply all of the permutations!
orderedDiffs = append(orderedDiffs, permutatedDiffs...)
lastGoodSchema = permutationSchema
return true // early break! No need to keep searching
return true, -1 // early break! No need to keep searching
})
}
// We prefer stricter strategy, because that gives best chance of finding a valid path.
Expand Down
50 changes: 44 additions & 6 deletions go/vt/schemadiff/schema_diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package schemadiff

import (
"context"
"fmt"
"os"
"strings"
"testing"

Expand Down Expand Up @@ -195,7 +197,7 @@ func TestPermutations(t *testing.T) {
allDiffs := schemaDiff.UnorderedDiffs()
originalSingleString := toSingleString(allDiffs)
numEquals := 0
earlyBreak, err := permutateDiffs(ctx, allDiffs, hints, func(pdiffs []EntityDiff, hints *DiffHints) (earlyBreak bool) {
earlyBreak, err := permutateDiffs(ctx, allDiffs, hints, func(pdiffs []EntityDiff, hints *DiffHints) (earlyBreak bool, errorIndex int) {
defer func() { iteration++ }()
// cover all permutations
singleString := toSingleString(pdiffs)
Expand All @@ -204,7 +206,7 @@ func TestPermutations(t *testing.T) {
if originalSingleString == singleString {
numEquals++
}
return false
return false, -1
})
assert.NoError(t, err)
if len(allDiffs) > 0 {
Expand All @@ -218,13 +220,13 @@ func TestPermutations(t *testing.T) {
allPerms := map[string]bool{}
allDiffs := schemaDiff.UnorderedDiffs()
originalSingleString := toSingleString(allDiffs)
earlyBreak, err := permutateDiffs(ctx, allDiffs, hints, func(pdiffs []EntityDiff, hints *DiffHints) (earlyBreak bool) {
earlyBreak, err := permutateDiffs(ctx, allDiffs, hints, func(pdiffs []EntityDiff, hints *DiffHints) (earlyBreak bool, errorIndex int) {
// Single visit
allPerms[toSingleString(pdiffs)] = true
// First permutation should be the same as original
require.Equal(t, originalSingleString, toSingleString(pdiffs))
// early break; this callback function should not be invoked again
return true
return true, -1
})
assert.NoError(t, err)
if len(allDiffs) > 0 {
Expand All @@ -246,8 +248,8 @@ func TestPermutationsContext(t *testing.T) {

hints := &DiffHints{RangeRotationStrategy: RangeRotationDistinctStatements}
allDiffs := []EntityDiff{&DropViewEntityDiff{}}
earlyBreak, err := permutateDiffs(ctx, allDiffs, hints, func(pdiffs []EntityDiff, hints *DiffHints) (earlyBreak bool) {
return false
earlyBreak, err := permutateDiffs(ctx, allDiffs, hints, func(pdiffs []EntityDiff, hints *DiffHints) (earlyBreak bool, errorIndex int) {
return false, -1
})
assert.True(t, earlyBreak) // proves that termination was due to context cancel
assert.Error(t, err) // proves that termination was due to context cancel
Expand Down Expand Up @@ -1322,3 +1324,39 @@ func TestSchemaDiff(t *testing.T) {

}
}

// TestDiffFiles diffs two schema files on the local file system. It requires the $TEST_SCHEMADIFF_DIFF_FILES
// environment variable to be set to a comma-separated list of two file paths, e.g. "/tmp/from.sql,/tmp/to.sql".
// If the variable is unspecified, the test is skipped. It is useful for ad-hoc testing of schema diffs.
func TestDiffFiles(t *testing.T) {
ctx := context.Background()

envName := "TEST_SCHEMADIFF_DIFF_FILES"
filesVar := os.Getenv(envName)
if filesVar == "" {
t.Skipf("no diff files specified in $%s", envName)
}
files := strings.Split(filesVar, ",")
require.Len(t, files, 2, "expecting two files in $%s: <file1>,<file2>", envName)
fromSchemaSQL, err := os.ReadFile(files[0])
require.NoError(t, err)
toSchemaSQL, err := os.ReadFile(files[1])
require.NoError(t, err)

env := NewTestEnv()
fromSchema, err := NewSchemaFromSQL(env, string(fromSchemaSQL))
require.NoError(t, err)
toSchema, err := NewSchemaFromSQL(env, string(toSchemaSQL))
require.NoError(t, err)

hints := &DiffHints{RangeRotationStrategy: RangeRotationDistinctStatements}
schemaDiff, err := fromSchema.SchemaDiff(toSchema, hints)
require.NoError(t, err)
t.Logf("diff length: %v", len(schemaDiff.UnorderedDiffs()))
orderedDiffs, err := schemaDiff.OrderedDiffs(ctx)
require.NoError(t, err)
t.Logf("ordered diffs length: %v", len(orderedDiffs))
for _, diff := range orderedDiffs {
fmt.Printf("%s;\n", diff.CanonicalStatementString())
}
}

0 comments on commit 485d736

Please sign in to comment.