From 71964f00dc43d67688a240438284a0aa42077dff Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Tue, 12 Mar 2024 16:42:11 +0530 Subject: [PATCH] feat: fix the cycle detection logic Signed-off-by: Manan Gupta --- go/vt/sqlparser/ast_funcs.go | 10 ++++++ go/vt/vtgate/vschema_manager.go | 56 +++++++++++++++++++++------------ 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index ddb5251dbb3..4335e2432f9 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -2556,6 +2556,16 @@ func (ra ReferenceAction) IsRestrict() bool { } } +// IsCascade returns true if the reference action is of cascade type. +func (ra ReferenceAction) IsCascade() bool { + switch ra { + case Cascade: + return true + default: + return false + } +} + // IsLiteral returns true if the expression is of a literal type. func IsLiteral(expr Expr) bool { switch expr.(type) { diff --git a/go/vt/vtgate/vschema_manager.go b/go/vt/vtgate/vschema_manager.go index f215fd9df11..426270e6218 100644 --- a/go/vt/vtgate/vschema_manager.go +++ b/go/vt/vtgate/vschema_manager.go @@ -260,19 +260,6 @@ func (vm *VSchemaManager) updateFromSchema(vschema *vindexes.VSchema) { } } -type tableCol struct { - tableName sqlparser.TableName - colNames sqlparser.Columns -} - -var tableColHash = func(tc tableCol) string { - res := sqlparser.String(tc.tableName) - for _, colName := range tc.colNames { - res += "|" + sqlparser.String(colName) - } - return res -} - func markErrorIfCyclesInFk(vschema *vindexes.VSchema) { for ksName, ks := range vschema.Keyspaces { // Only check cyclic foreign keys for keyspaces that have @@ -280,19 +267,39 @@ func markErrorIfCyclesInFk(vschema *vindexes.VSchema) { if ks.ForeignKeyMode != vschemapb.Keyspace_managed { continue } + /* + 3 cases for creating the graph for cycle detection: + 1. ON DELETE RESTRICT ON UPDATE RESTRICT: This is the simplest case where no update/delete is required on the child table, we only need to verify whether a value exists or not. So we don't need to add any edge for this case. + 2. ON DELETE SET NULL, ON UPDATE SET NULL, ON UPDATE CASCADE: In this case having any update/delete on any of the columns in the parent side of the foreign key will make a corresponding delete/update on all the column in the child side of the foreign key. So we will add an edge from all the columns in the parent side to all the columns in the child side. + 3. ON DELETE CASCADE: This is a special case wherein a deletion on the parent table will affect all the columns in the child table irrespective of the columns involved in the foreign key! So, we'll add an edge from all the columns in the parent side of the foreign key to all the columns of the child table. + */ g := graph.NewGraph[string]() for _, table := range ks.Tables { for _, cfk := range table.ChildForeignKeys { childTable := cfk.Table - parentVertex := tableCol{ - tableName: table.GetTableName(), - colNames: cfk.ParentColumns, + + // Check for case 1. + if cfk.OnUpdate.IsRestrict() && cfk.OnDelete.IsRestrict() { + continue } - childVertex := tableCol{ - tableName: childTable.GetTableName(), - colNames: cfk.ChildColumns, + var parentVertices []string + var childVertices []string + for _, column := range cfk.ParentColumns { + parentVertices = append(parentVertices, sqlparser.String(sqlparser.NewColNameWithQualifier(column.String(), table.GetTableName()))) } - g.AddEdge(tableColHash(parentVertex), tableColHash(childVertex)) + + // Check for case 3. + if cfk.OnDelete.IsCascade() { + for _, column := range childTable.Columns { + childVertices = append(childVertices, sqlparser.String(sqlparser.NewColNameWithQualifier(column.Name.String(), childTable.GetTableName()))) + } + } else { + // Case 2. + for _, column := range cfk.ChildColumns { + childVertices = append(childVertices, sqlparser.String(sqlparser.NewColNameWithQualifier(column.String(), childTable.GetTableName()))) + } + } + addCrossEdges(g, parentVertices, childVertices) } } if g.HasCycles() { @@ -301,6 +308,15 @@ func markErrorIfCyclesInFk(vschema *vindexes.VSchema) { } } +// addCrossEdges adds the edges from all the vertices in the first list to all the vertices in the second list. +func addCrossEdges(g *graph.Graph[string], from []string, to []string) { + for _, fromStr := range from { + for _, toStr := range to { + g.AddEdge(fromStr, toStr) + } + } +} + func setColumns(ks *vindexes.KeyspaceSchema, tblName string, columns []vindexes.Column) *vindexes.Table { vTbl := ks.Tables[tblName] if vTbl == nil {