diff --git a/cherry/models/models.py b/cherry/models/models.py index 774c118..72a486b 100644 --- a/cherry/models/models.py +++ b/cherry/models/models.py @@ -810,25 +810,6 @@ def _get_field_type_by_model(cls, model: "Model") -> Tuple[str, RelationshipFiel RelationshipField, ) and isinstance(model, field.field_info.related_model): return field_name, field.field_info - # if ( - # isinstance( - # field.field_info, - # ForeignKeyField, - # ) - # and field.field_info.related_field_name - # and field.field_info.related_field - # and isinstance(model, field.field_info.related_field.related_model) - # ) or ( - # isinstance( - # field.field_info, - # (ReverseRelationshipField, ManyToManyField), - # ) - # and isinstance(model, field.field_info.related_field.related_model) - # ): - # return ( - # field.field_info.related_field_name, # type: ignore - # field.field_info.related_field, - # ) raise FieldTypeError( f"There are no related fields associated with {cls}.{model}", ) @@ -855,6 +836,7 @@ def _generate_sqlalchemy_column(cls): elif isinstance(field_info, ForeignKeyField): if not hasattr(field_info, "related_field_name"): for rname, rfield in field_info.related_model.__fields__.items(): + has_flag = False if ( isinstance(rfield.field_info, ReverseRelationshipField) and rfield.field_info.related_model is cls @@ -863,11 +845,25 @@ def _generate_sqlalchemy_column(cls): or rfield.field_info.related_field_name is None ) ): + has_flag = True field_info.related_field_name = rname field_info.related_field = rfield.field_info rfield.field_info.related_field_name = model_field.name rfield.field_info.related_field = field_info break + if not has_flag: + field_info.related_field_name = None + field_info.related_field = None + if ( + not hasattr(field_info, "related_field") + and field_info.related_field_name is not None + ): + field_info.related_field = cast( + ReverseRelationshipField, + field_info.related_model.__fields__[ + field_info.related_field_name + ].field_info, + ) if not hasattr(field_info, "foreign_key"): if len(field_info.related_model.__meta__.primary_key) != 1: raise PrimaryKeyMultipleError( @@ -956,6 +952,13 @@ def _generate_sqlalchemy_column(cls): field_info, ), ) + if not hasattr(field_info, "related_field"): + field_info.related_field = cast( + ForeignKeyField, + field_info.related_model.__fields__[ + field_info.related_field_name + ].field_info, + ) else: raise RelationSolveError( ( @@ -1000,6 +1003,20 @@ def _generate_sqlalchemy_column(cls): rfield.field_info.related_field_name = model_field.name rfield.field_info.related_field = field_info break + if not hasattr(field_info, "related_field_name"): + raise RelationSolveError( + ( + "There are no related fields associated with" + f" {cls}.{model_field.name}" + ), + ) + if not hasattr(field_info, "related_field"): + field_info.related_field = cast( + ManyToManyField, + field_info.related_model.__fields__[ + field_info.related_field_name + ].field_info, + ) cls.__meta__.many_to_many_fields[model_field.name] = field_info setattr( cls,