diff --git a/src/drf_pydantic/base_model.py b/src/drf_pydantic/base_model.py index c3669d5..c116cf2 100644 --- a/src/drf_pydantic/base_model.py +++ b/src/drf_pydantic/base_model.py @@ -9,7 +9,7 @@ from rest_framework import serializers from typing_extensions import dataclass_transform -from drf_pydantic.parse import create_serializer_from_model +from drf_pydantic.parse import SERIALIZER_REGISTRY, create_serializer_from_model @dataclass_transform(kw_only_default=True, field_specifiers=(pydantic.Field,)) @@ -50,3 +50,25 @@ def __new__( class BaseModel(pydantic.BaseModel, metaclass=ModelMetaclass): # Populated by the metaclass or manually set by the user drf_serializer: ClassVar[type[serializers.Serializer]] + + @classmethod + def model_rebuild( + cls, + *, + force: bool = False, + raise_errors: bool = True, + _parent_namespace_depth: int = 2, + _types_namespace: Optional[dict[str, Any]] = None, + ) -> bool | None: + ret = super().model_rebuild( + force=force, + raise_errors=raise_errors, + _parent_namespace_depth=_parent_namespace_depth, + _types_namespace=_types_namespace, + ) + + if cls in SERIALIZER_REGISTRY: + SERIALIZER_REGISTRY.pop(cls) + cls.drf_serializer = create_serializer_from_model(cls) + + return ret diff --git a/tests/test_models.py b/tests/test_models.py index b311f23..01f2213 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -118,6 +118,30 @@ class Person(BaseModel): assert isinstance(job.fields["salary"], serializers.FloatField) +def test_nested_recursive_model(): + class Task(BaseModel): + title: str + parent: "Task" + subtasks: list["Task"] + + Task.model_rebuild() + + serializer = Task.drf_serializer() + + # Parent model + assert serializer.__class__.__name__ == "TaskSerializer" + assert len(serializer.fields) == 3 + assert isinstance(serializer.fields["title"], serializers.CharField) + assert isinstance(serializer.fields["parent"], serializers.Serializer) + assert isinstance(serializer.fields["subtasks"], serializers.ListField) + + parent: serializers.Serializer = serializer.fields["parent"] + assert parent.__class__.__name__ == "TaskSerializer" + assert len(parent.fields) == 3 + assert isinstance(parent.fields["title"], serializers.CharField) + assert isinstance(parent.fields["parent"], serializers.Serializer) + + def test_list_of_nested_models(): class Apartment(BaseModel): floor: int