Skip to content

Commit

Permalink
sqlmodel: several bugfixes
Browse files Browse the repository at this point in the history
These were found by testing a more comprehensive use case:

https://github.com/adsharma/fastapi-shopping/blob/main/models.py
  • Loading branch information
adsharma committed Jan 18, 2025
1 parent cf2897e commit 972163f
Showing 1 changed file with 35 additions and 15 deletions.
50 changes: 35 additions & 15 deletions fquery/sqlmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def many_to_one(key_column=None, back_populates=None):
if key_column is not None:
ret.metadata["SQL"]["key_column"] = key_column
if back_populates is not None:
ret.metadata["SQL"][back_populates] = back_populates
ret.metadata["SQL"]["back_populates"] = back_populates
return ret


Expand Down Expand Up @@ -134,8 +134,13 @@ def get_field_type(field, cls):
type_class = field.type
other_class = type_class.__args__[0]
if has_many_to_one_relationship:
type_class = get_type_hints(cls)[field.name]
return Optional[other_class.__sqlmodel__]
try:
type_class = get_type_hints(cls)[field.name]
except NameError:
# TODO: log exception?
pass
else:
return Optional[other_class.__sqlmodel__]
return field.type

def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls):
Expand All @@ -145,7 +150,16 @@ def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls):
if has_relationship:
if has_many_to_one_relationship:
type_class = field.type
other_class = type_class.__args__[0].__sqlmodel__
try:
type_class = get_type_hints(cls)[field.name]
except NameError:
# TODO: log exception?
pass
inner = type_class.__args__[0]
if isinstance(inner, ForwardRef):
# can't patch right now. Try at a later time via back_populates
return
other_class = inner.__sqlmodel__
old = other_class.__annotations__[back_populates]
# Should be sqlalchemy.orm.base.Mapped[typing.List[ForwardRef('T')]]
# replace it with Mapped[List[sqlmodel_cls]]
Expand All @@ -156,17 +170,23 @@ def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls):
List[sqlmodel_cls]
]
other_class.sqlmodel_rebuild()
else:
# Replace Optional['T'] with Optional[TSQLModel]
old = field.type
origin = get_origin(old)
inner = get_args(old)
if (
origin == Union
and len(inner)
and inner[0] == ForwardRef(cls.__name__)
):
sqlmodel_cls.__annotations__[field.name] = Optional[sqlmodel_cls]

# Replace Optional['T'] with Optional[TSQLModel]
old = field.type
origin = get_origin(old)
inner = get_args(old)
needs_rebuild = False
if origin == Union and len(inner) and inner[0] == ForwardRef(cls.__name__):
sqlmodel_cls.__annotations__[field.name] = Optional[sqlmodel_cls]
needs_rebuild = True

# Replace Optional[T] with Optional[TSQLModel] if T is a dataclass
if origin == Union and len(inner) and is_dataclass(inner[0]):
sqlmodel_cls.__annotations__[field.name] = Optional[inner[0].__sqlmodel__]
needs_rebuild = True

if needs_rebuild:
sqlmodel_cls.sqlmodel_rebuild()

def default_table_name(clsname: str) -> str:
return inflection.underscore(inflection.pluralize(clsname))
Expand Down

0 comments on commit 972163f

Please sign in to comment.