From 6b965d357e573e2f70e4c644b12d915553447abf Mon Sep 17 00:00:00 2001 From: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> Date: Thu, 7 Mar 2024 14:22:06 +0200 Subject: [PATCH 1/9] graph: GetCycleVertices() Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- go/vt/graph/graph.go | 31 +++++++++++++++++++++++++++++-- go/vt/graph/graph_test.go | 17 ++++++++++------- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/go/vt/graph/graph.go b/go/vt/graph/graph.go index 54668027008..703d5d4bcfa 100644 --- a/go/vt/graph/graph.go +++ b/go/vt/graph/graph.go @@ -24,13 +24,15 @@ import ( // Graph is a generic graph implementation. type Graph[C comparable] struct { - edges map[C][]C + edges map[C][]C + orderedEdged []C } // NewGraph creates a new graph for the given comparable type. func NewGraph[C comparable]() *Graph[C] { return &Graph[C]{ - edges: map[C][]C{}, + edges: map[C][]C{}, + orderedEdged: []C{}, } } @@ -41,6 +43,7 @@ func (gr *Graph[C]) AddVertex(vertex C) { return } gr.edges[vertex] = []C{} + gr.orderedEdged = append(gr.orderedEdged, vertex) } // AddEdge adds an edge to the given Graph. @@ -94,6 +97,30 @@ func (gr *Graph[C]) HasCycles() bool { return false } +// HasCycles checks whether the given graph has a cycle or not. +// We are using a well-known DFS based colouring algorithm to check for cycles. +// Look at https://cp-algorithms.com/graph/finding-cycle.html for more details on the algorithm. +func (gr *Graph[C]) GetCycleVertices() (vertices []C) { + // If the graph is empty, then we don't need to check anything. + if gr.Empty() { + return nil + } + // Initialize the coloring map. + // 0 represents white. + // 1 represents grey. + // 2 represents black. + color := map[C]int{} + for _, vertex := range gr.orderedEdged { + // If any vertex is still white, we initiate a new DFS. + if color[vertex] == 0 { + if gr.hasCyclesDfs(color, vertex) { + vertices = append(vertices, vertex) + } + } + } + return vertices +} + // hasCyclesDfs is a utility function for checking for cycles in a graph. // It runs a dfs from the given vertex marking each vertex as grey. During the dfs, // if we encounter a grey vertex, we know we have a cycle. We mark the visited vertices black diff --git a/go/vt/graph/graph_test.go b/go/vt/graph/graph_test.go index bc334c7d225..64287490bd6 100644 --- a/go/vt/graph/graph_test.go +++ b/go/vt/graph/graph_test.go @@ -90,11 +90,12 @@ func TestIntegerGraph(t *testing.T) { // TestStringGraph tests that a graph with strings can be created and all graph functions work as intended. func TestStringGraph(t *testing.T) { testcases := []struct { - name string - edges [][2]string - wantedGraph string - wantEmpty bool - wantHasCycles bool + name string + edges [][2]string + wantedGraph string + wantEmpty bool + wantHasCycles bool + wantCycleVertices []string }{ { name: "empty graph", @@ -135,8 +136,9 @@ C - D - E E - F F - A`, - wantEmpty: false, - wantHasCycles: true, + wantEmpty: false, + wantHasCycles: true, + wantCycleVertices: []string{"A", "D"}, }, } for _, tt := range testcases { @@ -148,6 +150,7 @@ F - A`, require.Equal(t, tt.wantedGraph, graph.PrintGraph()) require.Equal(t, tt.wantEmpty, graph.Empty()) require.Equal(t, tt.wantHasCycles, graph.HasCycles()) + require.Equal(t, tt.wantCycleVertices, graph.GetCycleVertices()) }) } } From 7154e1e3fb6253df2bbc3a52da618fe34a485d44 Mon Sep 17 00:00:00 2001 From: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:17:41 +0200 Subject: [PATCH 2/9] track complete list of cycles Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- go/vt/graph/graph.go | 40 ++++++++++++++++++++++++--------------- go/vt/graph/graph_test.go | 23 +++++++++++++++++----- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/go/vt/graph/graph.go b/go/vt/graph/graph.go index 703d5d4bcfa..21280875808 100644 --- a/go/vt/graph/graph.go +++ b/go/vt/graph/graph.go @@ -18,10 +18,17 @@ package graph import ( "fmt" + "maps" "slices" "strings" ) +const ( + white int = iota + grey + black +) + // Graph is a generic graph implementation. type Graph[C comparable] struct { edges map[C][]C @@ -88,8 +95,8 @@ func (gr *Graph[C]) HasCycles() bool { color := map[C]int{} for vertex := range gr.edges { // If any vertex is still white, we initiate a new DFS. - if color[vertex] == 0 { - if gr.hasCyclesDfs(color, vertex) { + if color[vertex] == white { + if hasCycle, _ := gr.hasCyclesDfs(color, vertex); hasCycle { return true } } @@ -100,11 +107,12 @@ func (gr *Graph[C]) HasCycles() bool { // HasCycles checks whether the given graph has a cycle or not. // We are using a well-known DFS based colouring algorithm to check for cycles. // Look at https://cp-algorithms.com/graph/finding-cycle.html for more details on the algorithm. -func (gr *Graph[C]) GetCycleVertices() (vertices []C) { +func (gr *Graph[C]) GetCycleVertices() (vertices map[C][]C) { // If the graph is empty, then we don't need to check anything. if gr.Empty() { return nil } + vertices = make(map[C][]C) // Initialize the coloring map. // 0 represents white. // 1 represents grey. @@ -112,9 +120,10 @@ func (gr *Graph[C]) GetCycleVertices() (vertices []C) { color := map[C]int{} for _, vertex := range gr.orderedEdged { // If any vertex is still white, we initiate a new DFS. - if color[vertex] == 0 { - if gr.hasCyclesDfs(color, vertex) { - vertices = append(vertices, vertex) + if color[vertex] == white { + color := maps.Clone(color) + if hasCycle, cycle := gr.hasCyclesDfs(color, vertex); hasCycle { + vertices[vertex] = cycle } } } @@ -125,22 +134,23 @@ func (gr *Graph[C]) GetCycleVertices() (vertices []C) { // It runs a dfs from the given vertex marking each vertex as grey. During the dfs, // if we encounter a grey vertex, we know we have a cycle. We mark the visited vertices black // on finishing the dfs. -func (gr *Graph[C]) hasCyclesDfs(color map[C]int, vertex C) bool { +func (gr *Graph[C]) hasCyclesDfs(color map[C]int, vertex C) (bool, []C) { // Mark the vertex grey. - color[vertex] = 1 + color[vertex] = grey + result := []C{vertex} // Go over all the edges. for _, end := range gr.edges[vertex] { // If we encounter a white vertex, we continue the dfs. - if color[end] == 0 { - if gr.hasCyclesDfs(color, end) { - return true + if color[end] == white { + if hasCycle, cycle := gr.hasCyclesDfs(color, end); hasCycle { + return true, append(result, cycle...) } - } else if color[end] == 1 { + } else if color[end] == grey { // We encountered a grey vertex, we have a cycle. - return true + return true, append(result, end) } } // Mark the vertex black before finishing - color[vertex] = 2 - return false + color[vertex] = black + return false, nil } diff --git a/go/vt/graph/graph_test.go b/go/vt/graph/graph_test.go index 64287490bd6..b5acfc5e6d4 100644 --- a/go/vt/graph/graph_test.go +++ b/go/vt/graph/graph_test.go @@ -95,7 +95,7 @@ func TestStringGraph(t *testing.T) { wantedGraph string wantEmpty bool wantHasCycles bool - wantCycleVertices []string + wantCycleVertices map[string][]string }{ { name: "empty graph", @@ -136,9 +136,15 @@ C - D - E E - F F - A`, - wantEmpty: false, - wantHasCycles: true, - wantCycleVertices: []string{"A", "D"}, + wantEmpty: false, + wantHasCycles: true, + wantCycleVertices: map[string][]string{ + "A": {"A", "B", "E", "F", "A"}, + "B": {"B", "E", "F", "A", "B"}, + "D": {"D", "E", "F", "A", "B", "E"}, + "E": {"E", "F", "A", "B", "E"}, + "F": {"F", "A", "B", "E", "F"}, + }, }, } for _, tt := range testcases { @@ -150,7 +156,14 @@ F - A`, require.Equal(t, tt.wantedGraph, graph.PrintGraph()) require.Equal(t, tt.wantEmpty, graph.Empty()) require.Equal(t, tt.wantHasCycles, graph.HasCycles()) - require.Equal(t, tt.wantCycleVertices, graph.GetCycleVertices()) + if tt.wantCycleVertices == nil { + tt.wantCycleVertices = map[string][]string{} + } + actualCycleVertices := graph.GetCycleVertices() + if actualCycleVertices == nil { + actualCycleVertices = map[string][]string{} + } + require.Equal(t, tt.wantCycleVertices, actualCycleVertices) }) } } From d5fceb4a56d20ba899fece8db4c24c7780520c68 Mon Sep 17 00:00:00 2001 From: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:43:16 +0200 Subject: [PATCH 3/9] Smart ForeignKeyLoopError self-identifies whether a table is part of the loop or just references a loop Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- go/vt/schemadiff/errors.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/go/vt/schemadiff/errors.go b/go/vt/schemadiff/errors.go index 5268db76ff3..942488f39a3 100644 --- a/go/vt/schemadiff/errors.go +++ b/go/vt/schemadiff/errors.go @@ -294,8 +294,23 @@ type ForeignKeyLoopError struct { func (e *ForeignKeyLoopError) Error() string { tableIsInsideLoop := false - escaped := make([]string, len(e.Loop)) - for i, t := range e.Loop { + loop := e.Loop + // The tables in the loop could be e.g.: + // t1->t2->a->b->c->a + // In such case, the loop is a->b->c->a. The last item is always the head & tail of the loop. + // We want to distinguish between the case where the table is inside the loop and the case where it's outside, + // so we remove the prefix of the loop that doesn't participate in the actual cycle. + if len(loop) > 0 { + last := loop[len(loop)-1] + for i := range loop { + if loop[i] == last { + loop = loop[i:] + break + } + } + } + escaped := make([]string, len(loop)) + for i, t := range loop { escaped[i] = sqlescape.EscapeID(t) if t == e.Table { tableIsInsideLoop = true From 07956ad0c0821027988188b467f40003568e720d Mon Sep 17 00:00:00 2001 From: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:55:35 +0200 Subject: [PATCH 4/9] better naming and comments Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- go/vt/graph/graph.go | 7 +++++-- go/vt/graph/graph_test.go | 26 +++++++++++++------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/go/vt/graph/graph.go b/go/vt/graph/graph.go index 21280875808..c5095638ed2 100644 --- a/go/vt/graph/graph.go +++ b/go/vt/graph/graph.go @@ -104,10 +104,11 @@ func (gr *Graph[C]) HasCycles() bool { return false } -// HasCycles checks whether the given graph has a cycle or not. +// GetCycles returns all known cycles in the graph. +// It returns a map of vertices to the cycle they are part of. // We are using a well-known DFS based colouring algorithm to check for cycles. // Look at https://cp-algorithms.com/graph/finding-cycle.html for more details on the algorithm. -func (gr *Graph[C]) GetCycleVertices() (vertices map[C][]C) { +func (gr *Graph[C]) GetCycles() (vertices map[C][]C) { // If the graph is empty, then we don't need to check anything. if gr.Empty() { return nil @@ -121,6 +122,8 @@ func (gr *Graph[C]) GetCycleVertices() (vertices map[C][]C) { for _, vertex := range gr.orderedEdged { // If any vertex is still white, we initiate a new DFS. if color[vertex] == white { + // We clone the colors because we wnt full coverage for all vertices. + // Otherwise, the algorithm is optimal and stop more-or-less after the first cycle. color := maps.Clone(color) if hasCycle, cycle := gr.hasCyclesDfs(color, vertex); hasCycle { vertices[vertex] = cycle diff --git a/go/vt/graph/graph_test.go b/go/vt/graph/graph_test.go index b5acfc5e6d4..3231998039e 100644 --- a/go/vt/graph/graph_test.go +++ b/go/vt/graph/graph_test.go @@ -90,12 +90,12 @@ func TestIntegerGraph(t *testing.T) { // TestStringGraph tests that a graph with strings can be created and all graph functions work as intended. func TestStringGraph(t *testing.T) { testcases := []struct { - name string - edges [][2]string - wantedGraph string - wantEmpty bool - wantHasCycles bool - wantCycleVertices map[string][]string + name string + edges [][2]string + wantedGraph string + wantEmpty bool + wantHasCycles bool + wantCycles map[string][]string }{ { name: "empty graph", @@ -138,7 +138,7 @@ E - F F - A`, wantEmpty: false, wantHasCycles: true, - wantCycleVertices: map[string][]string{ + wantCycles: map[string][]string{ "A": {"A", "B", "E", "F", "A"}, "B": {"B", "E", "F", "A", "B"}, "D": {"D", "E", "F", "A", "B", "E"}, @@ -156,14 +156,14 @@ F - A`, require.Equal(t, tt.wantedGraph, graph.PrintGraph()) require.Equal(t, tt.wantEmpty, graph.Empty()) require.Equal(t, tt.wantHasCycles, graph.HasCycles()) - if tt.wantCycleVertices == nil { - tt.wantCycleVertices = map[string][]string{} + if tt.wantCycles == nil { + tt.wantCycles = map[string][]string{} } - actualCycleVertices := graph.GetCycleVertices() - if actualCycleVertices == nil { - actualCycleVertices = map[string][]string{} + actualCycles := graph.GetCycles() + if actualCycles == nil { + actualCycles = map[string][]string{} } - require.Equal(t, tt.wantCycleVertices, actualCycleVertices) + require.Equal(t, tt.wantCycles, actualCycles) }) } } From cc828bceacc56a8715fa9a3c2b571fb3ece60469 Mon Sep 17 00:00:00 2001 From: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> Date: Sun, 10 Mar 2024 11:48:03 +0200 Subject: [PATCH 5/9] schemadiff: support valid foreign key cycles Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- go/vt/schemadiff/diff_test.go | 51 +++++++++- go/vt/schemadiff/schema.go | 137 +++++++++++++++++---------- go/vt/schemadiff/schema_diff.go | 18 +++- go/vt/schemadiff/schema_diff_test.go | 42 +++++++- go/vt/schemadiff/schema_test.go | 98 ++++++++++++------- go/vt/schemadiff/types.go | 10 ++ go/vt/schemadiff/view_test.go | 2 +- 7 files changed, 263 insertions(+), 95 deletions(-) diff --git a/go/vt/schemadiff/diff_test.go b/go/vt/schemadiff/diff_test.go index fbe7238e3fd..65a8581c02c 100644 --- a/go/vt/schemadiff/diff_test.go +++ b/go/vt/schemadiff/diff_test.go @@ -313,7 +313,7 @@ func TestDiffTables(t *testing.T) { for _, ts := range tt { t.Run(ts.name, func(t *testing.T) { var fromCreateTable *sqlparser.CreateTable - hints := &DiffHints{} + hints := EmptyDiffHints() if ts.hints != nil { hints = ts.hints } @@ -448,7 +448,7 @@ func TestDiffViews(t *testing.T) { name: "none", }, } - hints := &DiffHints{} + hints := EmptyDiffHints() env := NewTestEnv() for _, ts := range tt { t.Run(ts.name, func(t *testing.T) { @@ -545,6 +545,7 @@ func TestDiffSchemas(t *testing.T) { cdiffs []string expectError string tableRename int + fkStrategy int }{ { name: "identical tables", @@ -799,6 +800,45 @@ func TestDiffSchemas(t *testing.T) { "CREATE TABLE `t5` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f5` (`i`),\n\tCONSTRAINT `f5` FOREIGN KEY (`i`) REFERENCES `t7` (`id`)\n)", }, }, + { + name: "create tables with foreign keys, with invalid fk reference", + from: "create table t (id int primary key)", + to: ` + create table t (id int primary key); + create table t11 (id int primary key, i int, constraint f1101a foreign key (i) references t12 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1201a foreign key (i) references t9 (id) on delete set null); + `, + expectError: "table `t12` foreign key references nonexistent table `t9`", + }, + { + name: "create tables with foreign keys, with invalid fk reference", + from: "create table t (id int primary key)", + to: ` + create table t (id int primary key); + create table t11 (id int primary key, i int, constraint f1101b foreign key (i) references t12 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1201b foreign key (i) references t9 (id) on delete set null); + `, + expectError: "table `t12` foreign key references nonexistent table `t9`", + fkStrategy: ForeignKeyCheckStrategyIgnore, + }, + { + name: "create tables with foreign keys, with valid cycle", + from: "create table t (id int primary key)", + to: ` + create table t (id int primary key); + create table t11 (id int primary key, i int, constraint f1101c foreign key (i) references t12 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1201c foreign key (i) references t11 (id) on delete set null); + `, + diffs: []string{ + "create table t11 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f1101 (i),\n\tconstraint f1101 foreign key (i) references t12 (id) on delete restrict\n)", + "create table t12 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f1201 (i),\n\tconstraint f1201 foreign key (i) references t11 (id) on delete set null\n)", + }, + cdiffs: []string{ + "CREATE TABLE `t11` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f1101` (`i`),\n\tCONSTRAINT `f1101` FOREIGN KEY (`i`) REFERENCES `t12` (`id`) ON DELETE RESTRICT\n)", + "CREATE TABLE `t12` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f1201` (`i`),\n\tCONSTRAINT `f1201` FOREIGN KEY (`i`) REFERENCES `t11` (`id`) ON DELETE SET NULL\n)", + }, + fkStrategy: ForeignKeyCheckStrategyIgnore, + }, { name: "drop tables with foreign keys, expect specific order", from: "create table t7(id int primary key); create table t5 (id int primary key, i int, constraint f5 foreign key (i) references t7(id)); create table t4 (id int primary key, i int, constraint f4 foreign key (i) references t7(id));", @@ -932,14 +972,15 @@ func TestDiffSchemas(t *testing.T) { for _, ts := range tt { t.Run(ts.name, func(t *testing.T) { hints := &DiffHints{ - TableRenameStrategy: ts.tableRename, + TableRenameStrategy: ts.tableRename, + ForeignKeyCheckStrategy: ts.fkStrategy, } diff, err := DiffSchemasSQL(env, ts.from, ts.to, hints) if ts.expectError != "" { require.Error(t, err) assert.Contains(t, err.Error(), ts.expectError) } else { - assert.NoError(t, err) + require.NoError(t, err) diffs, err := diff.OrderedDiffs(ctx) assert.NoError(t, err) @@ -1024,7 +1065,7 @@ func TestSchemaApplyError(t *testing.T) { to: "create table t(id int); create view v1 as select * from t; create view v2 as select * from t", }, } - hints := &DiffHints{} + hints := EmptyDiffHints() env := NewTestEnv() for _, ts := range tt { t.Run(ts.name, func(t *testing.T) { diff --git a/go/vt/schemadiff/schema.go b/go/vt/schemadiff/schema.go index e3782fdbf0b..2506597af3d 100644 --- a/go/vt/schemadiff/schema.go +++ b/go/vt/schemadiff/schema.go @@ -18,10 +18,14 @@ package schemadiff import ( "errors" + "slices" "sort" "strings" + "golang.org/x/exp/maps" + "vitess.io/vitess/go/mysql/capabilities" + "vitess.io/vitess/go/vt/graph" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -72,7 +76,7 @@ func NewSchemaFromEntities(env *Environment, entities []Entity) (*Schema, error) return nil, &UnsupportedEntityError{Entity: c.Name(), Statement: c.Create().CanonicalStatementString()} } } - err := schema.normalize() + err := schema.normalize(EmptyDiffHints()) return schema, err } @@ -135,42 +139,6 @@ func getForeignKeyParentTableNames(createTable *sqlparser.CreateTable) (names [] return names } -// findForeignKeyLoop is a stateful recursive function that determines whether a given table participates in a foreign -// key loop or derives from one. It returns a list of table names that form a loop, or nil if no loop is found. -// The function updates and checks the stateful map s.foreignKeyLoopMap to avoid re-analyzing the same table twice. -func (s *Schema) findForeignKeyLoop(tableName string, seen []string) (loop []string) { - if loop := s.foreignKeyLoopMap[tableName]; loop != nil { - return loop - } - t := s.Table(tableName) - if t == nil { - return nil - } - seen = append(seen, tableName) - for i, seenTable := range seen { - if i == len(seen)-1 { - // as we've just appended the table name to the end of the slice, we should skip it. - break - } - if seenTable == tableName { - // This table alreay appears in `seen`. - // We only return the suffix of `seen` that starts (and now ends) with this table. - return seen[i:] - } - } - for _, referencedTableName := range getForeignKeyParentTableNames(t.CreateTable) { - if loop := s.findForeignKeyLoop(referencedTableName, seen); loop != nil { - // Found loop. Update cache. - // It's possible for one table to participate in more than one foreign key loop, but - // we suffice with one loop, since we already only ever report one foreign key error - // per table. - s.foreignKeyLoopMap[tableName] = loop - return loop - } - } - return nil -} - // getViewDependentTableNames analyzes a CREATE VIEW definition and extracts all tables/views read by this view func getViewDependentTableNames(createView *sqlparser.CreateView) (names []string) { _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { @@ -191,7 +159,7 @@ func getViewDependentTableNames(createView *sqlparser.CreateView) (names []strin // normalize is called as part of Schema creation process. The user may only get a hold of normalized schema. // It validates some cross-entity constraints, and orders entity based on dependencies (e.g. tables, views that read from tables, 2nd level views, etc.) -func (s *Schema) normalize() error { +func (s *Schema) normalize(hints *DiffHints) error { var errs error s.named = make(map[string]Entity, len(s.tables)+len(s.views)) @@ -284,8 +252,10 @@ func (s *Schema) normalize() error { } referencedEntity, ok := s.named[referencedTableName] if !ok { - errs = errors.Join(errs, addEntityFkError(t, &ForeignKeyNonexistentReferencedTableError{Table: name, ReferencedTable: referencedTableName})) - continue + if hints.ForeignKeyCheckStrategy == ForeignKeyCheckStrategyStrict { + errs = errors.Join(errs, addEntityFkError(t, &ForeignKeyNonexistentReferencedTableError{Table: name, ReferencedTable: referencedTableName})) + continue + } } if _, ok := referencedEntity.(*CreateViewEntity); ok { errs = errors.Join(errs, addEntityFkError(t, &ForeignKeyReferencesViewError{Table: name, ReferencedView: referencedTableName})) @@ -310,6 +280,77 @@ func (s *Schema) normalize() error { s.foreignKeyParents = append(s.foreignKeyParents, t) } } + if len(dependencyLevels) != len(s.tables) { + // We have leftover tables. This can happen if there's foreign key loops + for _, t := range s.tables { + if _, ok := dependencyLevels[t.Name()]; ok { + // known table + continue + } + // Table is part of a loop or references a loop + s.sorted = append(s.sorted, t) + dependencyLevels[t.Name()] = iterationLevel // all in same level + } + + // Now, let's see if the loop is valid or invalid. For example: + // users.avatar_id -> avatars.id + // avatars.creator_id -> users.id + // is a valid loop, because even though the two tables reference each other, the loop ends in different columns. + type tableCol struct { + tableName sqlparser.TableName + colNames sqlparser.Columns + } + var tableColHash = func(tc tableCol) string { + res := sqlparser.String(tc.tableName) + for _, colName := range tc.colNames { + res += "|" + sqlparser.String(colName) + } + return res + } + var decodeTableColHash = func(hash string) (tableName string, colNames []string) { + tokens := strings.Split(hash, "|") + return tokens[0], tokens[1:] + } + g := graph.NewGraph[string]() + for _, table := range s.tables { + for _, cfk := range table.TableSpec.Constraints { + check, ok := cfk.Details.(*sqlparser.ForeignKeyDefinition) + if !ok { + // Not a foreign key + continue + } + + parentVertex := tableCol{ + tableName: check.ReferenceDefinition.ReferencedTable, + colNames: check.ReferenceDefinition.ReferencedColumns, + } + childVertex := tableCol{ + tableName: table.Table, + colNames: check.Source, + } + g.AddEdge(tableColHash(parentVertex), tableColHash(childVertex)) + } + } + cycles := g.GetCycles() // map of table name to cycle + // golang maps have undefined iteration order. For consistent output, we sort the keys. + vertices := maps.Keys(cycles) + slices.Sort(vertices) + for _, vertex := range vertices { + cycle := cycles[vertex] + if len(cycle) == 0 { + continue + } + cycleTables := make([]string, len(cycle)) + for i := range cycle { + // Reduce tablename|colname(s) to just tablename + cycleTables[i], _ = decodeTableColHash(cycle[i]) + } + tableName := cycleTables[0] + s.foreignKeyLoopMap[tableName] = cycleTables + errs = errors.Join(errs, addEntityFkError(s.named[tableName], &ForeignKeyLoopError{Table: tableName, Loop: cycleTables})) + } + } + // We now iterate all views. We iterate "dependency levels": // - first we want all views that only depend on tables. These are 1st level views. // - then we only want views that depend on 1st level views or on tables. These are 2nd level views. @@ -347,14 +388,6 @@ func (s *Schema) normalize() error { } if len(s.sorted) != len(s.tables)+len(s.views) { - - for _, t := range s.tables { - if _, ok := dependencyLevels[t.Name()]; !ok { - if loop := s.findForeignKeyLoop(t.Name(), nil); loop != nil { - errs = errors.Join(errs, addEntityFkError(t, &ForeignKeyLoopError{Table: t.Name(), Loop: loop})) - } - } - } // We have leftover tables or views. This can happen if the schema definition is invalid: // - a table's foreign key references a nonexistent table // - two or more tables have circular FK dependency @@ -724,7 +757,7 @@ func (s *Schema) copy() *Schema { // apply attempts to apply given list of diffs to this object. // These diffs are CREATE/DROP/ALTER TABLE/VIEW. -func (s *Schema) apply(diffs []EntityDiff) error { +func (s *Schema) apply(diffs []EntityDiff, hints *DiffHints) error { for _, diff := range diffs { switch diff := diff.(type) { case *CreateTableEntityDiff: @@ -834,7 +867,7 @@ func (s *Schema) apply(diffs []EntityDiff) error { return &UnsupportedApplyOperationError{Statement: diff.CanonicalStatementString()} } } - if err := s.normalize(); err != nil { + if err := s.normalize(hints); err != nil { return err } return nil @@ -845,7 +878,7 @@ func (s *Schema) apply(diffs []EntityDiff) error { // The operation does not modify this object. Instead, if successful, a new (modified) Schema is returned. func (s *Schema) Apply(diffs []EntityDiff) (*Schema, error) { dup := s.copy() - if err := dup.apply(diffs); err != nil { + if err := dup.apply(diffs, EmptyDiffHints()); err != nil { return nil, err } return dup, nil @@ -861,7 +894,7 @@ func (s *Schema) SchemaDiff(other *Schema, hints *DiffHints) (*SchemaDiff, error if err != nil { return nil, err } - schemaDiff := NewSchemaDiff(s) + schemaDiff := NewSchemaDiff(s, hints) schemaDiff.loadDiffs(diffs) // Utility function to see whether the given diff has dependencies on diffs that operate on any of the given named entities, diff --git a/go/vt/schemadiff/schema_diff.go b/go/vt/schemadiff/schema_diff.go index d2f5e012220..3fbc1e6c9d3 100644 --- a/go/vt/schemadiff/schema_diff.go +++ b/go/vt/schemadiff/schema_diff.go @@ -165,6 +165,7 @@ func permDiff(ctx context.Context, a []EntityDiff, callback func([]EntityDiff) ( // Operations on SchemaDiff are not concurrency-safe. type SchemaDiff struct { schema *Schema + hints *DiffHints diffs []EntityDiff diffMap map[string]EntityDiff // key is diff's CanonicalStatementString() @@ -173,9 +174,10 @@ type SchemaDiff struct { r *mathutil.EquivalenceRelation // internal structure to help determine diffs } -func NewSchemaDiff(schema *Schema) *SchemaDiff { +func NewSchemaDiff(schema *Schema, hints *DiffHints) *SchemaDiff { return &SchemaDiff{ schema: schema, + hints: hints, dependencies: make(map[string]*DiffDependency), diffMap: make(map[string]EntityDiff), r: mathutil.NewEquivalenceRelation(), @@ -318,7 +320,7 @@ func (d *SchemaDiff) OrderedDiffs(ctx context.Context) ([]EntityDiff, error) { // We want to apply the changes one by one, and validate the schema after each change for i := range permutatedDiffs { // apply inline - if err := permutationSchema.apply(permutatedDiffs[i : i+1]); err != nil { + if err := permutationSchema.apply(permutatedDiffs[i:i+1], d.hints); err != nil { // permutation is invalid return false // continue searching } @@ -341,6 +343,18 @@ func (d *SchemaDiff) OrderedDiffs(ctx context.Context) ([]EntityDiff, error) { // Done taking care of this equivalence class. } + if d.hints.ForeignKeyCheckStrategy != ForeignKeyCheckStrategyStrict { + // We may have allowed invalid foreign key dependencies along the way. But we must then validate the final schema + // to ensure that all foreign keys are valid. + hints := *d.hints + hints.ForeignKeyCheckStrategy = ForeignKeyCheckStrategyStrict + if err := lastGoodSchema.normalize(&hints); err != nil { + return nil, &ImpossibleApplyDiffOrderError{ + UnorderedDiffs: d.UnorderedDiffs(), + ConflictingDiffs: d.UnorderedDiffs(), + } + } + } return orderedDiffs, nil } diff --git a/go/vt/schemadiff/schema_diff_test.go b/go/vt/schemadiff/schema_diff_test.go index 4fbc31a6492..f363236c784 100644 --- a/go/vt/schemadiff/schema_diff_test.go +++ b/go/vt/schemadiff/schema_diff_test.go @@ -272,6 +272,9 @@ func TestSchemaDiff(t *testing.T) { entityOrder []string // names of tables/views in expected diff order mysqlServerVersion string instantCapability InstantDDLCapability + fkStrategy int + expectError string + expectOrderedError string }{ { name: "no change", @@ -624,6 +627,33 @@ func TestSchemaDiff(t *testing.T) { sequential: true, instantCapability: InstantDDLCapabilityIrrelevant, }, + { + name: "create two tables valid fk cycle", + toQueries: append( + createQueries, + "create table t11 (id int primary key, i int, constraint f1101 foreign key (i) references t12 (id) on delete restrict);", + "create table t12 (id int primary key, i int, constraint f1201 foreign key (i) references t11 (id) on delete set null);", + ), + expectDiffs: 2, + expectDeps: 2, + sequential: true, + fkStrategy: ForeignKeyCheckStrategyStrict, + expectOrderedError: "no valid applicable order for diffs", + }, + { + name: "create two tables valid fk cycle, fk ignore", + toQueries: append( + createQueries, + "create table t12 (id int primary key, i int, constraint f1201 foreign key (i) references t11 (id) on delete set null);", + "create table t11 (id int primary key, i int, constraint f1101 foreign key (i) references t12 (id) on delete restrict);", + ), + expectDiffs: 2, + expectDeps: 2, + entityOrder: []string{"t11", "t12"}, // Note that the tables were reordered lexicographically + sequential: true, + instantCapability: InstantDDLCapabilityIrrelevant, + fkStrategy: ForeignKeyCheckStrategyIgnore, + }, { name: "add FK", toQueries: []string{ @@ -934,7 +964,13 @@ func TestSchemaDiff(t *testing.T) { require.NoError(t, err) require.NotNil(t, toSchema) - schemaDiff, err := fromSchema.SchemaDiff(toSchema, baseHints) + hints := *baseHints + hints.ForeignKeyCheckStrategy = tc.fkStrategy + schemaDiff, err := fromSchema.SchemaDiff(toSchema, &hints) + if tc.expectError != "" { + assert.ErrorContains(t, err, tc.expectError) + return + } require.NoError(t, err) allDiffs := schemaDiff.UnorderedDiffs() @@ -953,6 +989,10 @@ func TestSchemaDiff(t *testing.T) { assert.Equal(t, tc.sequential, schemaDiff.HasSequentialExecutionDependencies()) orderedDiffs, err := schemaDiff.OrderedDiffs(ctx) + if tc.expectOrderedError != "" { + assert.ErrorContains(t, err, tc.expectOrderedError) + return + } if tc.conflictingDiffs > 0 { assert.Error(t, err) impossibleOrderErr, ok := err.(*ImpossibleApplyDiffOrderError) diff --git a/go/vt/schemadiff/schema_test.go b/go/vt/schemadiff/schema_test.go index a979e521216..8e2cc651dd6 100644 --- a/go/vt/schemadiff/schema_test.go +++ b/go/vt/schemadiff/schema_test.go @@ -346,55 +346,80 @@ func TestInvalidSchema(t *testing.T) { }, { // t12<->t11 - schema: "create table t11 (id int primary key, i int, constraint f11 foreign key (i) references t12 (id) on delete restrict); create table t12 (id int primary key, i int, constraint f12 foreign key (i) references t11 (id) on delete restrict)", + schema: ` + create table t11 (id int primary key, i int, constraint f1103 foreign key (i) references t12 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1203 foreign key (i) references t11 (id) on delete restrict) + `, + }, + { + // t12<->t11 + schema: ` + create table t11 (id int primary key, i int, constraint f1101 foreign key (i) references t12 (i) on delete restrict); + create table t12 (id int primary key, i int, constraint f1201 foreign key (i) references t11 (i) on delete set null) + `, expectErr: errors.Join( &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t11", "t12", "t11"}}, + &ForeignKeyLoopError{Table: "t12", Loop: []string{"t12", "t11", "t12"}}, ), expectLoopTables: 2, }, { // t10, t12<->t11 - schema: "create table t10(id int primary key); create table t11 (id int primary key, i int, constraint f11 foreign key (i) references t12 (id) on delete restrict); create table t12 (id int primary key, i int, constraint f12 foreign key (i) references t11 (id) on delete restrict)", - expectErr: errors.Join( - &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t11", "t12", "t11"}}, - ), - expectLoopTables: 2, + schema: ` + create table t10(id int primary key); + create table t11 (id int primary key, i int, constraint f1102 foreign key (i) references t12 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1202 foreign key (i) references t11 (id) on delete restrict) + `, }, { // t10, t12<->t11<-t13 - schema: "create table t10(id int primary key); create table t11 (id int primary key, i int, constraint f11 foreign key (i) references t12 (id) on delete restrict); create table t12 (id int primary key, i int, constraint f12 foreign key (i) references t11 (id) on delete restrict); create table t13 (id int primary key, i int, constraint f13 foreign key (i) references t11 (id) on delete restrict)", - expectErr: errors.Join( - &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t13", Loop: []string{"t11", "t12", "t11"}}, - ), - expectLoopTables: 3, + schema: ` + create table t10(id int primary key); + create table t11 (id int primary key, i int, constraint f1104 foreign key (i) references t12 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1204 foreign key (i) references t11 (id) on delete restrict); + create table t13 (id int primary key, i int, constraint f13 foreign key (i) references t11 (id) on delete restrict)`, }, { // t10 // ^ // | //t12<->t11<-t13 - schema: "create table t10(id int primary key); create table t11 (id int primary key, i int, i10 int, constraint f11 foreign key (i) references t12 (id) on delete restrict, constraint f1110 foreign key (i10) references t10 (id) on delete restrict); create table t12 (id int primary key, i int, constraint f12 foreign key (i) references t11 (id) on delete restrict); create table t13 (id int primary key, i int, constraint f13 foreign key (i) references t11 (id) on delete restrict)", + schema: ` + create table t10(id int primary key); + create table t11 (id int primary key, i int, i10 int, constraint f111205 foreign key (i) references t12 (id) on delete restrict, constraint f111005 foreign key (i10) references t10 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1205 foreign key (id) references t11 (i) on delete restrict); + create table t13 (id int primary key, i int, constraint f1305 foreign key (i) references t11 (id) on delete restrict) + `, expectErr: errors.Join( &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t13", Loop: []string{"t11", "t12", "t11"}}, + &ForeignKeyLoopError{Table: "t12", Loop: []string{"t12", "t11", "t12"}}, ), - expectLoopTables: 3, + expectLoopTables: 2, }, { // t10, t12<->t11<-t13<-t14 - schema: "create table t10(id int primary key); create table t11 (id int primary key, i int, i10 int, constraint f11 foreign key (i) references t12 (id) on delete restrict, constraint f1110 foreign key (i10) references t10 (id) on delete restrict); create table t12 (id int primary key, i int, constraint f12 foreign key (i) references t11 (id) on delete restrict); create table t13 (id int primary key, i int, constraint f13 foreign key (i) references t11 (id) on delete restrict); create table t14 (id int primary key, i int, constraint f14 foreign key (i) references t13 (id) on delete restrict)", + schema: ` + create table t10(id int primary key); + create table t11 (id int primary key, i int, i10 int, constraint f1106 foreign key (i) references t12 (id) on delete restrict, constraint f111006 foreign key (i10) references t10 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1206 foreign key (i) references t11 (id) on delete restrict); + create table t13 (id int primary key, i int, constraint f1306 foreign key (i) references t11 (id) on delete restrict); + create table t14 (id int primary key, i int, constraint f1406 foreign key (i) references t13 (id) on delete restrict) + `, + }, + { + // t10, t12<-t11<-t13<-t12 + schema: ` + create table t10(id int primary key); + create table t11 (id int primary key, i int, key i_idx (i), i10 int, constraint f1107 foreign key (i) references t12 (id), constraint f111007 foreign key (i10) references t10 (id)); + create table t12 (id int primary key, i int, key i_idx (i), constraint f1207 foreign key (id) references t13 (i)); + create table t13 (id int primary key, i int, key i_idx (i), constraint f1307 foreign key (i) references t11 (i)); + `, expectErr: errors.Join( - &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t13", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t14", Loop: []string{"t11", "t12", "t11"}}, + &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t13", "t12", "t11"}}, + &ForeignKeyLoopError{Table: "t12", Loop: []string{"t12", "t11", "t13", "t12"}}, + &ForeignKeyLoopError{Table: "t13", Loop: []string{"t13", "t12", "t11", "t13"}}, ), - expectLoopTables: 4, + expectLoopTables: 3, }, { schema: "create table t11 (id int primary key, i int, key ix(i), constraint f11 foreign key (i) references t11(id2) on delete restrict)", @@ -492,7 +517,7 @@ func TestInvalidTableForeignKeyReference(t *testing.T) { // Even though there's errors, we still expect the schema to have been created. assert.NotNil(t, s) // Even though t11 caused an error, we still expect the schema to have parsed all tables. - assert.Equal(t, 3, len(s.Entities())) + assert.Equalf(t, 3, len(s.Entities()), "found: %+v", s.EntityNames()) t11 := s.Table("t11") assert.NotNil(t, t11) // validate t11 table definition is complete, even though it was invalid. @@ -506,10 +531,19 @@ func TestInvalidTableForeignKeyReference(t *testing.T) { "create table t12 (id int primary key, i int, constraint f13 foreign key (i) references t13(id) on delete restrict)", } _, err := NewSchemaFromQueries(NewTestEnv(), fkQueries) + assert.NoError(t, err) + } + { + fkQueries := []string{ + "create table t13 (id int primary key, i int, constraint f11 foreign key (i) references t11(i) on delete restrict)", + "create table t11 (id int primary key, i int, constraint f12 foreign key (i) references t12(i) on delete restrict)", + "create table t12 (id int primary key, i int, constraint f13 foreign key (i) references t13(i) on delete restrict)", + } + _, err := NewSchemaFromQueries(NewTestEnv(), fkQueries) assert.Error(t, err) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t13", "t11"}}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t12", Loop: []string{"t11", "t12", "t13", "t11"}}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t13", Loop: []string{"t11", "t12", "t13", "t11"}}).Error()) + assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t13", "t12", "t11"}}).Error()) + assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t12", Loop: []string{"t12", "t11", "t13", "t12"}}).Error()) + assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t13", Loop: []string{"t13", "t12", "t11", "t13"}}).Error()) } { fkQueries := []string{ @@ -520,8 +554,6 @@ func TestInvalidTableForeignKeyReference(t *testing.T) { _, err := NewSchemaFromQueries(NewTestEnv(), fkQueries) assert.Error(t, err) assert.ErrorContains(t, err, (&ForeignKeyNonexistentReferencedTableError{Table: "t11", ReferencedTable: "t0"}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyDependencyUnresolvedError{Table: "t12"}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyDependencyUnresolvedError{Table: "t13"}).Error()) } { fkQueries := []string{ @@ -532,8 +564,6 @@ func TestInvalidTableForeignKeyReference(t *testing.T) { _, err := NewSchemaFromQueries(NewTestEnv(), fkQueries) assert.Error(t, err) assert.ErrorContains(t, err, (&ForeignKeyNonexistentReferencedTableError{Table: "t11", ReferencedTable: "t0"}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t12", Loop: []string{"t12", "t13", "t12"}}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t13", Loop: []string{"t12", "t13", "t12"}}).Error()) } } @@ -943,7 +973,7 @@ func TestMassiveSchema(t *testing.T) { }) t.Run("evaluating diff", func(t *testing.T) { - schemaDiff, err := schema0.SchemaDiff(schema1, &DiffHints{}) + schemaDiff, err := schema0.SchemaDiff(schema1, EmptyDiffHints()) require.NoError(t, err) diffs := schemaDiff.UnorderedDiffs() require.NotEmpty(t, diffs) diff --git a/go/vt/schemadiff/types.go b/go/vt/schemadiff/types.go index a4edb09ec9b..e5e17229cd6 100644 --- a/go/vt/schemadiff/types.go +++ b/go/vt/schemadiff/types.go @@ -124,6 +124,11 @@ const ( EnumReorderStrategyReject ) +const ( + ForeignKeyCheckStrategyStrict int = iota + ForeignKeyCheckStrategyIgnore +) + // DiffHints is an assortment of rules for diffing entities type DiffHints struct { StrictIndexOrdering bool @@ -137,6 +142,11 @@ type DiffHints struct { TableQualifierHint int AlterTableAlgorithmStrategy int EnumReorderStrategy int + ForeignKeyCheckStrategy int +} + +func EmptyDiffHints() *DiffHints { + return &DiffHints{} } const ( diff --git a/go/vt/schemadiff/view_test.go b/go/vt/schemadiff/view_test.go index e5be9055970..2aade1dc3e8 100644 --- a/go/vt/schemadiff/view_test.go +++ b/go/vt/schemadiff/view_test.go @@ -145,7 +145,7 @@ func TestCreateViewDiff(t *testing.T) { cdiff: "ALTER ALGORITHM = TEMPTABLE VIEW `v1` AS SELECT `a` FROM `t`", }, } - hints := &DiffHints{} + hints := EmptyDiffHints() env := NewTestEnv() for _, ts := range tt { t.Run(ts.name, func(t *testing.T) { From f81cbdbc2039f22b14b5c90785586c1780021553 Mon Sep 17 00:00:00 2001 From: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> Date: Sun, 10 Mar 2024 12:20:31 +0200 Subject: [PATCH 6/9] Add test that adds a valid foreign key cycle Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- go/vt/schemadiff/diff_test.go | 8 ++++---- go/vt/schemadiff/schema_diff_test.go | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/go/vt/schemadiff/diff_test.go b/go/vt/schemadiff/diff_test.go index 65a8581c02c..3fe94e3b0b5 100644 --- a/go/vt/schemadiff/diff_test.go +++ b/go/vt/schemadiff/diff_test.go @@ -830,12 +830,12 @@ func TestDiffSchemas(t *testing.T) { create table t12 (id int primary key, i int, constraint f1201c foreign key (i) references t11 (id) on delete set null); `, diffs: []string{ - "create table t11 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f1101 (i),\n\tconstraint f1101 foreign key (i) references t12 (id) on delete restrict\n)", - "create table t12 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f1201 (i),\n\tconstraint f1201 foreign key (i) references t11 (id) on delete set null\n)", + "create table t11 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f1101c (i),\n\tconstraint f1101c foreign key (i) references t12 (id) on delete restrict\n)", + "create table t12 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f1201c (i),\n\tconstraint f1201c foreign key (i) references t11 (id) on delete set null\n)", }, cdiffs: []string{ - "CREATE TABLE `t11` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f1101` (`i`),\n\tCONSTRAINT `f1101` FOREIGN KEY (`i`) REFERENCES `t12` (`id`) ON DELETE RESTRICT\n)", - "CREATE TABLE `t12` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f1201` (`i`),\n\tCONSTRAINT `f1201` FOREIGN KEY (`i`) REFERENCES `t11` (`id`) ON DELETE SET NULL\n)", + "CREATE TABLE `t11` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f1101c` (`i`),\n\tCONSTRAINT `f1101c` FOREIGN KEY (`i`) REFERENCES `t12` (`id`) ON DELETE RESTRICT\n)", + "CREATE TABLE `t12` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f1201c` (`i`),\n\tCONSTRAINT `f1201c` FOREIGN KEY (`i`) REFERENCES `t11` (`id`) ON DELETE SET NULL\n)", }, fkStrategy: ForeignKeyCheckStrategyIgnore, }, diff --git a/go/vt/schemadiff/schema_diff_test.go b/go/vt/schemadiff/schema_diff_test.go index f363236c784..9f1aea50efd 100644 --- a/go/vt/schemadiff/schema_diff_test.go +++ b/go/vt/schemadiff/schema_diff_test.go @@ -680,6 +680,20 @@ func TestSchemaDiff(t *testing.T) { entityOrder: []string{"tp", "t2"}, instantCapability: InstantDDLCapabilityImpossible, }, + { + name: "add two valid fk cycle references", + toQueries: []string{ + "create table t1 (id int primary key, info int not null, i int, constraint f1 foreign key (i) references t2 (id) on delete restrict);", + "create table t2 (id int primary key, ts timestamp, i int, constraint f2 foreign key (i) references t1 (id) on delete set null);", + "create view v1 as select id from t1", + }, + expectDiffs: 2, + expectDeps: 2, + sequential: false, + fkStrategy: ForeignKeyCheckStrategyStrict, + entityOrder: []string{"t1", "t2"}, + instantCapability: InstantDDLCapabilityImpossible, + }, { name: "add FK, unrelated alter", toQueries: []string{ From e7e3b0a426c41b237ddbf4994015efadf0ac7221 Mon Sep 17 00:00:00 2001 From: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> Date: Sun, 10 Mar 2024 12:23:53 +0200 Subject: [PATCH 7/9] add test: new table and new foreign key Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- go/vt/schemadiff/schema_diff_test.go | 30 ++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/go/vt/schemadiff/schema_diff_test.go b/go/vt/schemadiff/schema_diff_test.go index 9f1aea50efd..5aff4a0b408 100644 --- a/go/vt/schemadiff/schema_diff_test.go +++ b/go/vt/schemadiff/schema_diff_test.go @@ -694,6 +694,36 @@ func TestSchemaDiff(t *testing.T) { entityOrder: []string{"t1", "t2"}, instantCapability: InstantDDLCapabilityImpossible, }, + { + name: "add a table and a valid fk cycle references", + toQueries: []string{ + "create table t0 (id int primary key, info int not null, i int, constraint f1 foreign key (i) references t2 (id) on delete restrict);", + "create table t1 (id int primary key, info int not null);", + "create table t2 (id int primary key, ts timestamp, i int, constraint f2 foreign key (i) references t0 (id) on delete set null);", + "create view v1 as select id from t1", + }, + expectDiffs: 2, + expectDeps: 2, + sequential: true, + fkStrategy: ForeignKeyCheckStrategyStrict, + entityOrder: []string{"t0", "t2"}, + instantCapability: InstantDDLCapabilityImpossible, + }, + { + name: "add a table and a valid fk cycle references, lelxicographically desc", + toQueries: []string{ + "create table t1 (id int primary key, info int not null);", + "create table t2 (id int primary key, ts timestamp, i int, constraint f2 foreign key (i) references t9 (id) on delete set null);", + "create table t9 (id int primary key, info int not null, i int, constraint f1 foreign key (i) references t2 (id) on delete restrict);", + "create view v1 as select id from t1", + }, + expectDiffs: 2, + expectDeps: 2, + sequential: true, + fkStrategy: ForeignKeyCheckStrategyStrict, + entityOrder: []string{"t9", "t2"}, + instantCapability: InstantDDLCapabilityImpossible, + }, { name: "add FK, unrelated alter", toQueries: []string{ From 27cf79cf16fee4a9ba51928196f9398465d465d1 Mon Sep 17 00:00:00 2001 From: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> Date: Mon, 11 Mar 2024 09:08:43 +0200 Subject: [PATCH 8/9] fix var name Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- go/vt/graph/graph.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/go/vt/graph/graph.go b/go/vt/graph/graph.go index c5095638ed2..cc5f837d6f7 100644 --- a/go/vt/graph/graph.go +++ b/go/vt/graph/graph.go @@ -31,15 +31,15 @@ const ( // Graph is a generic graph implementation. type Graph[C comparable] struct { - edges map[C][]C - orderedEdged []C + edges map[C][]C + orderedVertices []C } // NewGraph creates a new graph for the given comparable type. func NewGraph[C comparable]() *Graph[C] { return &Graph[C]{ - edges: map[C][]C{}, - orderedEdged: []C{}, + edges: map[C][]C{}, + orderedVertices: []C{}, } } @@ -50,7 +50,7 @@ func (gr *Graph[C]) AddVertex(vertex C) { return } gr.edges[vertex] = []C{} - gr.orderedEdged = append(gr.orderedEdged, vertex) + gr.orderedVertices = append(gr.orderedVertices, vertex) } // AddEdge adds an edge to the given Graph. @@ -119,7 +119,7 @@ func (gr *Graph[C]) GetCycles() (vertices map[C][]C) { // 1 represents grey. // 2 represents black. color := map[C]int{} - for _, vertex := range gr.orderedEdged { + for _, vertex := range gr.orderedVertices { // If any vertex is still white, we initiate a new DFS. if color[vertex] == white { // We clone the colors because we wnt full coverage for all vertices. From d3b8c8e2b1777877bb1e2297d39b2f8dd9773331 Mon Sep 17 00:00:00 2001 From: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> Date: Mon, 11 Mar 2024 12:11:31 +0200 Subject: [PATCH 9/9] Intorduce ForeignKeyTableColumns, formalize ForeignKeyLoopError.Loop, include column names in error message Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- go/vt/schemadiff/errors.go | 10 +++++----- go/vt/schemadiff/schema.go | 13 +++++-------- go/vt/schemadiff/schema_test.go | 32 ++++++++++++++------------------ go/vt/schemadiff/types.go | 21 +++++++++++++++++++++ 4 files changed, 45 insertions(+), 31 deletions(-) diff --git a/go/vt/schemadiff/errors.go b/go/vt/schemadiff/errors.go index 942488f39a3..dc73acdb9a0 100644 --- a/go/vt/schemadiff/errors.go +++ b/go/vt/schemadiff/errors.go @@ -288,7 +288,7 @@ func (e *ForeignKeyDependencyUnresolvedError) Error() string { type ForeignKeyLoopError struct { Table string - Loop []string + Loop []*ForeignKeyTableColumns } func (e *ForeignKeyLoopError) Error() string { @@ -303,16 +303,16 @@ func (e *ForeignKeyLoopError) Error() string { if len(loop) > 0 { last := loop[len(loop)-1] for i := range loop { - if loop[i] == last { + if loop[i].Table == last.Table { loop = loop[i:] break } } } escaped := make([]string, len(loop)) - for i, t := range loop { - escaped[i] = sqlescape.EscapeID(t) - if t == e.Table { + for i, fk := range loop { + escaped[i] = fk.Escaped() + if fk.Table == e.Table { tableIsInsideLoop = true } } diff --git a/go/vt/schemadiff/schema.go b/go/vt/schemadiff/schema.go index 2506597af3d..8081c6eaeea 100644 --- a/go/vt/schemadiff/schema.go +++ b/go/vt/schemadiff/schema.go @@ -41,7 +41,6 @@ type Schema struct { foreignKeyParents []*CreateTableEntity // subset of tables foreignKeyChildren []*CreateTableEntity // subset of tables - foreignKeyLoopMap map[string][]string // map of table name that either participate, or directly or indirectly reference foreign key loops env *Environment } @@ -56,7 +55,6 @@ func newEmptySchema(env *Environment) *Schema { foreignKeyParents: []*CreateTableEntity{}, foreignKeyChildren: []*CreateTableEntity{}, - foreignKeyLoopMap: map[string][]string{}, env: env, } @@ -307,9 +305,9 @@ func (s *Schema) normalize(hints *DiffHints) error { } return res } - var decodeTableColHash = func(hash string) (tableName string, colNames []string) { + var decodeTableColHash = func(hash string) *ForeignKeyTableColumns { tokens := strings.Split(hash, "|") - return tokens[0], tokens[1:] + return &ForeignKeyTableColumns{tokens[0], tokens[1:]} } g := graph.NewGraph[string]() for _, table := range s.tables { @@ -340,13 +338,12 @@ func (s *Schema) normalize(hints *DiffHints) error { if len(cycle) == 0 { continue } - cycleTables := make([]string, len(cycle)) + cycleTables := make([]*ForeignKeyTableColumns, len(cycle)) for i := range cycle { // Reduce tablename|colname(s) to just tablename - cycleTables[i], _ = decodeTableColHash(cycle[i]) + cycleTables[i] = decodeTableColHash(cycle[i]) } - tableName := cycleTables[0] - s.foreignKeyLoopMap[tableName] = cycleTables + tableName := cycleTables[0].Table errs = errors.Join(errs, addEntityFkError(s.named[tableName], &ForeignKeyLoopError{Table: tableName, Loop: cycleTables})) } } diff --git a/go/vt/schemadiff/schema_test.go b/go/vt/schemadiff/schema_test.go index 8e2cc651dd6..19a1b95e186 100644 --- a/go/vt/schemadiff/schema_test.go +++ b/go/vt/schemadiff/schema_test.go @@ -310,9 +310,8 @@ func TestTableForeignKeyOrdering(t *testing.T) { func TestInvalidSchema(t *testing.T) { tt := []struct { - schema string - expectErr error - expectLoopTables int + schema string + expectErr error }{ { schema: "create table t11 (id int primary key, i int, key ix(i), constraint f11 foreign key (i) references t11(id) on delete restrict)", @@ -358,10 +357,9 @@ func TestInvalidSchema(t *testing.T) { create table t12 (id int primary key, i int, constraint f1201 foreign key (i) references t11 (i) on delete set null) `, expectErr: errors.Join( - &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t12", "t11", "t12"}}, + &ForeignKeyLoopError{Table: "t11", Loop: []*ForeignKeyTableColumns{{"t11", []string{"i"}}, {"t12", []string{"i"}}, {"t11", []string{"i"}}}}, + &ForeignKeyLoopError{Table: "t12", Loop: []*ForeignKeyTableColumns{{"t12", []string{"i"}}, {"t11", []string{"i"}}, {"t12", []string{"i"}}}}, ), - expectLoopTables: 2, }, { // t10, t12<->t11 @@ -391,10 +389,9 @@ func TestInvalidSchema(t *testing.T) { create table t13 (id int primary key, i int, constraint f1305 foreign key (i) references t11 (id) on delete restrict) `, expectErr: errors.Join( - &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t12", "t11", "t12"}}, + &ForeignKeyLoopError{Table: "t11", Loop: []*ForeignKeyTableColumns{{"t11", []string{"i"}}, {"t12", []string{"id"}}, {"t11", []string{"i"}}}}, + &ForeignKeyLoopError{Table: "t12", Loop: []*ForeignKeyTableColumns{{"t12", []string{"id"}}, {"t11", []string{"i"}}, {"t12", []string{"id"}}}}, ), - expectLoopTables: 2, }, { // t10, t12<->t11<-t13<-t14 @@ -415,11 +412,10 @@ func TestInvalidSchema(t *testing.T) { create table t13 (id int primary key, i int, key i_idx (i), constraint f1307 foreign key (i) references t11 (i)); `, expectErr: errors.Join( - &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t13", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t12", "t11", "t13", "t12"}}, - &ForeignKeyLoopError{Table: "t13", Loop: []string{"t13", "t12", "t11", "t13"}}, + &ForeignKeyLoopError{Table: "t11", Loop: []*ForeignKeyTableColumns{{"t11", []string{"i"}}, {"t13", []string{"i"}}, {"t12", []string{"id"}}, {"t11", []string{"i"}}}}, + &ForeignKeyLoopError{Table: "t12", Loop: []*ForeignKeyTableColumns{{"t12", []string{"id"}}, {"t11", []string{"i"}}, {"t13", []string{"i"}}, {"t12", []string{"id"}}}}, + &ForeignKeyLoopError{Table: "t13", Loop: []*ForeignKeyTableColumns{{"t13", []string{"i"}}, {"t12", []string{"id"}}, {"t11", []string{"i"}}, {"t13", []string{"i"}}}}, ), - expectLoopTables: 3, }, { schema: "create table t11 (id int primary key, i int, key ix(i), constraint f11 foreign key (i) references t11(id2) on delete restrict)", @@ -493,14 +489,13 @@ func TestInvalidSchema(t *testing.T) { for _, ts := range tt { t.Run(ts.schema, func(t *testing.T) { - s, err := NewSchemaFromSQL(NewTestEnv(), ts.schema) + _, err := NewSchemaFromSQL(NewTestEnv(), ts.schema) if ts.expectErr == nil { assert.NoError(t, err) } else { assert.Error(t, err) assert.EqualError(t, err, ts.expectErr.Error()) } - assert.Equal(t, ts.expectLoopTables, len(s.foreignKeyLoopMap)) }) } } @@ -541,9 +536,10 @@ func TestInvalidTableForeignKeyReference(t *testing.T) { } _, err := NewSchemaFromQueries(NewTestEnv(), fkQueries) assert.Error(t, err) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t13", "t12", "t11"}}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t12", Loop: []string{"t12", "t11", "t13", "t12"}}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t13", Loop: []string{"t13", "t12", "t11", "t13"}}).Error()) + + assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t11", Loop: []*ForeignKeyTableColumns{{"t11", []string{"i"}}, {"t13", []string{"i"}}, {"t12", []string{"i"}}, {"t11", []string{"i"}}}}).Error()) + assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t12", Loop: []*ForeignKeyTableColumns{{"t12", []string{"i"}}, {"t11", []string{"i"}}, {"t13", []string{"i"}}, {"t12", []string{"i"}}}}).Error()) + assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t13", Loop: []*ForeignKeyTableColumns{{"t13", []string{"i"}}, {"t12", []string{"i"}}, {"t11", []string{"i"}}, {"t13", []string{"i"}}}}).Error()) } { fkQueries := []string{ diff --git a/go/vt/schemadiff/types.go b/go/vt/schemadiff/types.go index e5e17229cd6..b42408376b8 100644 --- a/go/vt/schemadiff/types.go +++ b/go/vt/schemadiff/types.go @@ -17,6 +17,9 @@ limitations under the License. package schemadiff import ( + "strings" + + "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/vt/sqlparser" ) @@ -154,3 +157,21 @@ const ( ApplyDiffsInOrder = "ApplyDiffsInOrder" ApplyDiffsSequential = "ApplyDiffsSequential" ) + +type ForeignKeyTableColumns struct { + Table string + Columns []string +} + +func (f ForeignKeyTableColumns) Escaped() string { + var b strings.Builder + b.WriteString(sqlescape.EscapeID(f.Table)) + b.WriteString(" (") + escapedColumns := make([]string, len(f.Columns)) + for i, column := range f.Columns { + escapedColumns[i] = sqlescape.EscapeID(column) + } + b.WriteString(strings.Join(escapedColumns, ", ")) + b.WriteString(")") + return b.String() +}