diff --git a/go/vt/schemadiff/schema.go b/go/vt/schemadiff/schema.go index 405ad6c7f45..311ccd94896 100644 --- a/go/vt/schemadiff/schema.go +++ b/go/vt/schemadiff/schema.go @@ -214,6 +214,18 @@ func (s *Schema) normalize() error { return true } + // Utility map and function to only record one foreign-key error per table. We make this limitation + // because the search algorithm below could review the same table twice, thus potentially unnecessarily duplicating + // found errors. + entityFkErrors := map[string]error{} + addEntityFkError := func(e Entity, err error) error { + if _, ok := entityFkErrors[e.Name()]; ok { + // error already recorded for this entity + return nil + } + entityFkErrors[e.Name()] = err + return err + } // We now iterate all tables. We iterate "dependency levels": // - first we want all tables that don't have foreign keys or which only reference themselves // - then we only want tables that reference 1st level tables. these are 2nd level tables @@ -241,10 +253,12 @@ func (s *Schema) normalize() error { } referencedEntity, ok := s.named[referencedTableName] if !ok { - return &ForeignKeyNonexistentReferencedTableError{Table: name, ReferencedTable: referencedTableName} + errs = errors.Join(errs, addEntityFkError(t, &ForeignKeyNonexistentReferencedTableError{Table: name, ReferencedTable: referencedTableName})) + continue } if _, ok := referencedEntity.(*CreateViewEntity); ok { - return &ForeignKeyReferencesViewError{Table: name, ReferencedView: referencedTableName} + errs = errors.Join(errs, addEntityFkError(t, &ForeignKeyReferencesViewError{Table: name, ReferencedView: referencedTableName})) + continue } fkParents[referencedTableName] = true @@ -310,7 +324,8 @@ func (s *Schema) normalize() error { if _, ok := dependencyLevels[t.Name()]; !ok { // We _know_ that in this iteration, at least one foreign key is not found. // We return the first one. - return &ForeignKeyDependencyUnresolvedError{Table: t.Name()} + errs = errors.Join(errs, addEntityFkError(t, &ForeignKeyDependencyUnresolvedError{Table: t.Name()})) + s.sorted = append(s.sorted, t) } } for _, v := range s.views { @@ -364,7 +379,12 @@ func (s *Schema) normalize() error { continue } referencedTableName := check.ReferenceDefinition.ReferencedTable.Name.String() - referencedTable := s.Table(referencedTableName) // we know this exists because we validated foreign key dependencies earlier on + referencedTable := s.Table(referencedTableName) + if referencedTable == nil { + // This can happen because earlier, when we validated existence of reference table, we took note + // of nonexisting tables, but kept on going. + continue + } referencedColumns := map[string]*sqlparser.ColumnDefinition{} for _, col := range referencedTable.CreateTable.TableSpec.Columns { diff --git a/go/vt/schemadiff/schema_test.go b/go/vt/schemadiff/schema_test.go index 67de705d05c..8c410f54381 100644 --- a/go/vt/schemadiff/schema_test.go +++ b/go/vt/schemadiff/schema_test.go @@ -17,6 +17,7 @@ limitations under the License. package schemadiff import ( + "errors" "fmt" "math/rand" "sort" @@ -27,7 +28,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/errors" + vterrors "vitess.io/vitess/go/errors" "vitess.io/vitess/go/vt/sqlparser" ) @@ -161,7 +162,7 @@ func TestNewSchemaFromQueriesLoop(t *testing.T) { ) _, err := NewSchemaFromQueries(queries) require.Error(t, err) - err = errors.UnwrapFirst(err) + err = vterrors.UnwrapFirst(err) assert.EqualError(t, err, (&ViewDependencyUnresolvedError{View: "v7"}).Error()) } @@ -339,8 +340,11 @@ func TestInvalidSchema(t *testing.T) { expectErr: &ForeignKeyReferencesViewError{Table: "t11", ReferencedView: "v"}, }, { - schema: "create table t11 (id int primary key, i int, constraint f11 foreign key (i) references t12 (id) on delete restrict); create table t12 (id int primary key, i int, constraint f12 foreign key (i) references t11 (id) on delete restrict)", - expectErr: &ForeignKeyDependencyUnresolvedError{Table: "t11"}, + schema: "create table t11 (id int primary key, i int, constraint f11 foreign key (i) references t12 (id) on delete restrict); create table t12 (id int primary key, i int, constraint f12 foreign key (i) references t11 (id) on delete restrict)", + expectErr: errors.Join( + &ForeignKeyDependencyUnresolvedError{Table: "t11"}, + &ForeignKeyDependencyUnresolvedError{Table: "t12"}, + ), }, { schema: "create table t11 (id int primary key, i int, key ix(i), constraint f11 foreign key (i) references t11(id2) on delete restrict)", @@ -396,11 +400,20 @@ func TestInvalidSchema(t *testing.T) { func TestInvalidTableForeignKeyReference(t *testing.T) { { fkQueries := []string{ + "create table t10 (id int primary key)", "create table t11 (id int primary key, i int, constraint f12 foreign key (i) references t12(id) on delete restrict)", "create table t15(id int, primary key(id))", } - _, err := NewSchemaFromQueries(fkQueries) + s, err := NewSchemaFromQueries(fkQueries) assert.Error(t, err) + // Even though there's errors, we still expect the schema to have been created. + assert.NotNil(t, s) + // Even though t11 caused an error, we still expect the schema to have parsed all tables. + assert.Equal(t, 3, len(s.Entities())) + t11 := s.Table("t11") + assert.NotNil(t, t11) + // validate t11 table definition is complete, even though it was invalid. + assert.Equal(t, "create table t11 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f12 (i),\n\tconstraint f12 foreign key (i) references t12 (id) on delete restrict\n)", t11.Create().StatementString()) assert.EqualError(t, err, (&ForeignKeyNonexistentReferencedTableError{Table: "t11", ReferencedTable: "t12"}).Error()) } { @@ -411,7 +424,9 @@ func TestInvalidTableForeignKeyReference(t *testing.T) { } _, err := NewSchemaFromQueries(fkQueries) assert.Error(t, err) - assert.EqualError(t, err, (&ForeignKeyDependencyUnresolvedError{Table: "t11"}).Error()) + assert.ErrorContains(t, err, (&ForeignKeyDependencyUnresolvedError{Table: "t11"}).Error()) + assert.ErrorContains(t, err, (&ForeignKeyDependencyUnresolvedError{Table: "t12"}).Error()) + assert.ErrorContains(t, err, (&ForeignKeyDependencyUnresolvedError{Table: "t13"}).Error()) } } @@ -716,7 +731,7 @@ func TestViewReferences(t *testing.T) { require.NotNil(t, schema) } else { require.Error(t, err) - err = errors.UnwrapFirst(err) + err = vterrors.UnwrapFirst(err) require.Equal(t, ts.expectErr, err, "received error: %v", err) } })