From d11929500151250aff222416d23ade17b1b0b5ac Mon Sep 17 00:00:00 2001 From: Manan Gupta <35839558+GuptaManan100@users.noreply.github.com> Date: Wed, 13 Mar 2024 21:34:53 +0530 Subject: [PATCH] Fix cycle detection for foreign keys (#15458) Signed-off-by: Manan Gupta --- go/vt/graph/graph.go | 10 +- go/vt/graph/graph_test.go | 6 +- go/vt/sqlparser/ast_funcs.go | 10 ++ go/vt/vterrors/code.go | 2 +- go/vt/vtgate/vschema_manager.go | 61 ++++++---- go/vt/vtgate/vschema_manager_test.go | 169 ++++++++++++++++++++++++++- 6 files changed, 223 insertions(+), 35 deletions(-) diff --git a/go/vt/graph/graph.go b/go/vt/graph/graph.go index cc5f837d6f7..1938cf4bf1c 100644 --- a/go/vt/graph/graph.go +++ b/go/vt/graph/graph.go @@ -83,10 +83,10 @@ func (gr *Graph[C]) Empty() bool { // HasCycles checks whether the given graph has a cycle or not. // We are using a well-known DFS based colouring algorithm to check for cycles. // Look at https://cp-algorithms.com/graph/finding-cycle.html for more details on the algorithm. -func (gr *Graph[C]) HasCycles() bool { +func (gr *Graph[C]) HasCycles() (bool, []C) { // If the graph is empty, then we don't need to check anything. if gr.Empty() { - return false + return false, nil } // Initialize the coloring map. // 0 represents white. @@ -96,12 +96,12 @@ func (gr *Graph[C]) HasCycles() bool { for vertex := range gr.edges { // If any vertex is still white, we initiate a new DFS. if color[vertex] == white { - if hasCycle, _ := gr.hasCyclesDfs(color, vertex); hasCycle { - return true + if hasCycle, cycle := gr.hasCyclesDfs(color, vertex); hasCycle { + return true, cycle } } } - return false + return false, nil } // GetCycles returns all known cycles in the graph. diff --git a/go/vt/graph/graph_test.go b/go/vt/graph/graph_test.go index 3231998039e..3f762552556 100644 --- a/go/vt/graph/graph_test.go +++ b/go/vt/graph/graph_test.go @@ -82,7 +82,8 @@ func TestIntegerGraph(t *testing.T) { } require.Equal(t, tt.wantedGraph, graph.PrintGraph()) require.Equal(t, tt.wantEmpty, graph.Empty()) - require.Equal(t, tt.wantHasCycles, graph.HasCycles()) + hasCycle, _ := graph.HasCycles() + require.Equal(t, tt.wantHasCycles, hasCycle) }) } } @@ -155,7 +156,8 @@ F - A`, } require.Equal(t, tt.wantedGraph, graph.PrintGraph()) require.Equal(t, tt.wantEmpty, graph.Empty()) - require.Equal(t, tt.wantHasCycles, graph.HasCycles()) + hasCycle, _ := graph.HasCycles() + require.Equal(t, tt.wantHasCycles, hasCycle) if tt.wantCycles == nil { tt.wantCycles = map[string][]string{} } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index ddb5251dbb3..4335e2432f9 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -2556,6 +2556,16 @@ func (ra ReferenceAction) IsRestrict() bool { } } +// IsCascade returns true if the reference action is of cascade type. +func (ra ReferenceAction) IsCascade() bool { + switch ra { + case Cascade: + return true + default: + return false + } +} + // IsLiteral returns true if the expression is of a literal type. func IsLiteral(expr Expr) bool { switch expr.(type) { diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go index 574ac7c2cdf..4ade1b22368 100644 --- a/go/vt/vterrors/code.go +++ b/go/vt/vterrors/code.go @@ -89,7 +89,7 @@ var ( VT09016 = errorWithState("VT09016", vtrpcpb.Code_FAILED_PRECONDITION, RowIsReferenced2, "Cannot delete or update a parent row: a foreign key constraint fails", "SET DEFAULT is not supported by InnoDB") VT09017 = errorWithoutState("VT09017", vtrpcpb.Code_FAILED_PRECONDITION, "%s", "Invalid syntax for the statement type.") VT09018 = errorWithoutState("VT09018", vtrpcpb.Code_FAILED_PRECONDITION, "%s", "Invalid syntax for the vindex function statement.") - VT09019 = errorWithoutState("VT09019", vtrpcpb.Code_FAILED_PRECONDITION, "keyspace '%s' has cyclic foreign keys", "Vitess doesn't support cyclic foreign keys.") + VT09019 = errorWithoutState("VT09019", vtrpcpb.Code_FAILED_PRECONDITION, "keyspace '%s' has cyclic foreign keys. Cycle exists between %v", "Vitess doesn't support cyclic foreign keys.") VT09020 = errorWithoutState("VT09020", vtrpcpb.Code_FAILED_PRECONDITION, "can not use multiple vindex hints for table %s", "Vitess does not allow using multiple vindex hints on the same table.") VT09021 = errorWithState("VT09021", vtrpcpb.Code_FAILED_PRECONDITION, KeyDoesNotExist, "Vindex '%s' does not exist in table '%s'", "Vindex hints have to reference an existing vindex, and no such vindex could be found for the given table.") diff --git a/go/vt/vtgate/vschema_manager.go b/go/vt/vtgate/vschema_manager.go index f215fd9df11..ad054807045 100644 --- a/go/vt/vtgate/vschema_manager.go +++ b/go/vt/vtgate/vschema_manager.go @@ -260,19 +260,6 @@ func (vm *VSchemaManager) updateFromSchema(vschema *vindexes.VSchema) { } } -type tableCol struct { - tableName sqlparser.TableName - colNames sqlparser.Columns -} - -var tableColHash = func(tc tableCol) string { - res := sqlparser.String(tc.tableName) - for _, colName := range tc.colNames { - res += "|" + sqlparser.String(colName) - } - return res -} - func markErrorIfCyclesInFk(vschema *vindexes.VSchema) { for ksName, ks := range vschema.Keyspaces { // Only check cyclic foreign keys for keyspaces that have @@ -280,23 +267,53 @@ func markErrorIfCyclesInFk(vschema *vindexes.VSchema) { if ks.ForeignKeyMode != vschemapb.Keyspace_managed { continue } + /* + 3 cases for creating the graph for cycle detection: + 1. ON DELETE RESTRICT ON UPDATE RESTRICT: This is the simplest case where no update/delete is required on the child table, we only need to verify whether a value exists or not. So we don't need to add any edge for this case. + 2. ON DELETE SET NULL, ON UPDATE SET NULL, ON UPDATE CASCADE: In this case having any update/delete on any of the columns in the parent side of the foreign key will make a corresponding delete/update on all the column in the child side of the foreign key. So we will add an edge from all the columns in the parent side to all the columns in the child side. + 3. ON DELETE CASCADE: This is a special case wherein a deletion on the parent table will affect all the columns in the child table irrespective of the columns involved in the foreign key! So, we'll add an edge from all the columns in the parent side of the foreign key to all the columns of the child table. + */ g := graph.NewGraph[string]() for _, table := range ks.Tables { for _, cfk := range table.ChildForeignKeys { + // Check for case 1. + if cfk.OnUpdate.IsRestrict() && cfk.OnDelete.IsRestrict() { + continue + } + childTable := cfk.Table - parentVertex := tableCol{ - tableName: table.GetTableName(), - colNames: cfk.ParentColumns, + var parentVertices []string + var childVertices []string + for _, column := range cfk.ParentColumns { + parentVertices = append(parentVertices, sqlparser.String(sqlparser.NewColNameWithQualifier(column.String(), table.GetTableName()))) } - childVertex := tableCol{ - tableName: childTable.GetTableName(), - colNames: cfk.ChildColumns, + + // Check for case 3. + if cfk.OnDelete.IsCascade() { + for _, column := range childTable.Columns { + childVertices = append(childVertices, sqlparser.String(sqlparser.NewColNameWithQualifier(column.Name.String(), childTable.GetTableName()))) + } + } else { + // Case 2. + for _, column := range cfk.ChildColumns { + childVertices = append(childVertices, sqlparser.String(sqlparser.NewColNameWithQualifier(column.String(), childTable.GetTableName()))) + } } - g.AddEdge(tableColHash(parentVertex), tableColHash(childVertex)) + addCrossEdges(g, parentVertices, childVertices) } } - if g.HasCycles() { - ks.Error = vterrors.VT09019(ksName) + hasCycle, cycle := g.HasCycles() + if hasCycle { + ks.Error = vterrors.VT09019(ksName, cycle) + } + } +} + +// addCrossEdges adds the edges from all the vertices in the first list to all the vertices in the second list. +func addCrossEdges(g *graph.Graph[string], from []string, to []string) { + for _, fromStr := range from { + for _, toStr := range to { + g.AddEdge(fromStr, toStr) } } } diff --git a/go/vt/vtgate/vschema_manager_test.go b/go/vt/vtgate/vschema_manager_test.go index f810d7c42af..c7ee5f34e8d 100644 --- a/go/vt/vtgate/vschema_manager_test.go +++ b/go/vt/vtgate/vschema_manager_test.go @@ -449,7 +449,7 @@ func TestMarkErrorIfCyclesInFk(t *testing.T) { errWanted string }{ { - name: "Has a cycle", + name: "Has a direct cycle", getVschema: func() *vindexes.VSchema { vschema := &vindexes.VSchema{ Keyspaces: map[string]*vindexes.KeyspaceSchema{ @@ -472,13 +472,44 @@ func TestMarkErrorIfCyclesInFk(t *testing.T) { }, }, } - _ = vschema.AddForeignKey("ks", "t2", createFkDefinition([]string{"col"}, "t1", []string{"col"}, sqlparser.Cascade, sqlparser.Cascade)) - _ = vschema.AddForeignKey("ks", "t3", createFkDefinition([]string{"col"}, "t2", []string{"col"}, sqlparser.Cascade, sqlparser.Cascade)) - _ = vschema.AddForeignKey("ks", "t1", createFkDefinition([]string{"col"}, "t3", []string{"col"}, sqlparser.Cascade, sqlparser.Cascade)) + _ = vschema.AddForeignKey("ks", "t2", createFkDefinition([]string{"col"}, "t1", []string{"col"}, sqlparser.SetNull, sqlparser.SetNull)) + _ = vschema.AddForeignKey("ks", "t3", createFkDefinition([]string{"col"}, "t2", []string{"col"}, sqlparser.SetNull, sqlparser.SetNull)) + _ = vschema.AddForeignKey("ks", "t1", createFkDefinition([]string{"col"}, "t3", []string{"col"}, sqlparser.SetNull, sqlparser.SetNull)) return vschema }, errWanted: "VT09019: keyspace 'ks' has cyclic foreign keys", }, + { + name: "Has a direct cycle but there is a restrict constraint in between", + getVschema: func() *vindexes.VSchema { + vschema := &vindexes.VSchema{ + Keyspaces: map[string]*vindexes.KeyspaceSchema{ + ksName: { + ForeignKeyMode: vschemapb.Keyspace_managed, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: keyspace, + }, + "t2": { + Name: sqlparser.NewIdentifierCS("t2"), + Keyspace: keyspace, + }, + "t3": { + Name: sqlparser.NewIdentifierCS("t3"), + Keyspace: keyspace, + }, + }, + }, + }, + } + _ = vschema.AddForeignKey("ks", "t2", createFkDefinition([]string{"col"}, "t1", []string{"col"}, sqlparser.SetNull, sqlparser.SetNull)) + _ = vschema.AddForeignKey("ks", "t3", createFkDefinition([]string{"col"}, "t2", []string{"col"}, sqlparser.Restrict, sqlparser.Restrict)) + _ = vschema.AddForeignKey("ks", "t1", createFkDefinition([]string{"col"}, "t3", []string{"col"}, sqlparser.SetNull, sqlparser.SetNull)) + return vschema + }, + errWanted: "", + }, { name: "No cycle", getVschema: func() *vindexes.VSchema { @@ -508,6 +539,134 @@ func TestMarkErrorIfCyclesInFk(t *testing.T) { return vschema }, errWanted: "", + }, { + name: "Self-referencing foreign key with delete cascade", + getVschema: func() *vindexes.VSchema { + vschema := &vindexes.VSchema{ + Keyspaces: map[string]*vindexes.KeyspaceSchema{ + ksName: { + ForeignKeyMode: vschemapb.Keyspace_managed, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: keyspace, + Columns: []vindexes.Column{ + { + Name: sqlparser.NewIdentifierCI("id"), + }, + { + Name: sqlparser.NewIdentifierCI("manager_id"), + }, + }, + }, + }, + }, + }, + } + _ = vschema.AddForeignKey("ks", "t1", createFkDefinition([]string{"manager_id"}, "t1", []string{"id"}, sqlparser.SetNull, sqlparser.Cascade)) + return vschema + }, + errWanted: "VT09019: keyspace 'ks' has cyclic foreign keys. Cycle exists between [ks.t1.id ks.t1.id]", + }, { + name: "Self-referencing foreign key without delete cascade", + getVschema: func() *vindexes.VSchema { + vschema := &vindexes.VSchema{ + Keyspaces: map[string]*vindexes.KeyspaceSchema{ + ksName: { + ForeignKeyMode: vschemapb.Keyspace_managed, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: keyspace, + Columns: []vindexes.Column{ + { + Name: sqlparser.NewIdentifierCI("id"), + }, + { + Name: sqlparser.NewIdentifierCI("manager_id"), + }, + }, + }, + }, + }, + }, + } + _ = vschema.AddForeignKey("ks", "t1", createFkDefinition([]string{"manager_id"}, "t1", []string{"id"}, sqlparser.SetNull, sqlparser.SetNull)) + return vschema + }, + errWanted: "", + }, { + name: "Has an indirect cycle because of cascades", + getVschema: func() *vindexes.VSchema { + vschema := &vindexes.VSchema{ + Keyspaces: map[string]*vindexes.KeyspaceSchema{ + ksName: { + ForeignKeyMode: vschemapb.Keyspace_managed, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: keyspace, + Columns: []vindexes.Column{ + { + Name: sqlparser.NewIdentifierCI("a"), + }, + { + Name: sqlparser.NewIdentifierCI("b"), + }, + { + Name: sqlparser.NewIdentifierCI("c"), + }, + }, + }, + "t2": { + Name: sqlparser.NewIdentifierCS("t2"), + Keyspace: keyspace, + Columns: []vindexes.Column{ + { + Name: sqlparser.NewIdentifierCI("d"), + }, + { + Name: sqlparser.NewIdentifierCI("e"), + }, + { + Name: sqlparser.NewIdentifierCI("f"), + }, + }, + }, + }, + }, + }, + } + _ = vschema.AddForeignKey("ks", "t2", createFkDefinition([]string{"f"}, "t1", []string{"a"}, sqlparser.SetNull, sqlparser.Cascade)) + _ = vschema.AddForeignKey("ks", "t1", createFkDefinition([]string{"b"}, "t2", []string{"e"}, sqlparser.SetNull, sqlparser.Cascade)) + return vschema + }, + errWanted: "VT09019: keyspace 'ks' has cyclic foreign keys", + }, { + name: "Cycle part of a multi-column foreign key", + getVschema: func() *vindexes.VSchema { + vschema := &vindexes.VSchema{ + Keyspaces: map[string]*vindexes.KeyspaceSchema{ + ksName: { + ForeignKeyMode: vschemapb.Keyspace_managed, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: keyspace, + }, + "t2": { + Name: sqlparser.NewIdentifierCS("t2"), + Keyspace: keyspace, + }, + }, + }, + }, + } + _ = vschema.AddForeignKey("ks", "t2", createFkDefinition([]string{"e", "f"}, "t1", []string{"a", "b"}, sqlparser.SetNull, sqlparser.SetNull)) + _ = vschema.AddForeignKey("ks", "t1", createFkDefinition([]string{"b"}, "t2", []string{"e"}, sqlparser.SetNull, sqlparser.SetNull)) + return vschema + }, + errWanted: "VT09019: keyspace 'ks' has cyclic foreign keys", }, } for _, tt := range tests { @@ -515,7 +674,7 @@ func TestMarkErrorIfCyclesInFk(t *testing.T) { vschema := tt.getVschema() markErrorIfCyclesInFk(vschema) if tt.errWanted != "" { - require.EqualError(t, vschema.Keyspaces[ksName].Error, tt.errWanted) + require.ErrorContains(t, vschema.Keyspaces[ksName].Error, tt.errWanted) return } require.NoError(t, vschema.Keyspaces[ksName].Error)