From d06d5f8c4e8a13720b19a55fef12e3c73034f571 Mon Sep 17 00:00:00 2001 From: JargeZ Date: Wed, 14 Aug 2024 16:27:55 +1000 Subject: [PATCH] Rebuild serializer on model refs update --- src/drf_pydantic/base_model.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/drf_pydantic/base_model.py b/src/drf_pydantic/base_model.py index c3669d5..5f09f45 100644 --- a/src/drf_pydantic/base_model.py +++ b/src/drf_pydantic/base_model.py @@ -9,7 +9,7 @@ from rest_framework import serializers from typing_extensions import dataclass_transform -from drf_pydantic.parse import create_serializer_from_model +from drf_pydantic.parse import SERIALIZER_REGISTRY, create_serializer_from_model @dataclass_transform(kw_only_default=True, field_specifiers=(pydantic.Field,)) @@ -50,3 +50,27 @@ def __new__( class BaseModel(pydantic.BaseModel, metaclass=ModelMetaclass): # Populated by the metaclass or manually set by the user drf_serializer: ClassVar[type[serializers.Serializer]] + + @classmethod + def model_rebuild( + cls, + *, + force: bool = False, + raise_errors: bool = True, + _parent_namespace_depth: int = 2, + _types_namespace: Optional[dict[str, Any]] = None, + ) -> bool | None: + ret = super().model_rebuild( + force=force, + raise_errors=raise_errors, + _parent_namespace_depth=_parent_namespace_depth, + _types_namespace=_types_namespace, + ) + + if not hasattr(cls, "drf_serializer") or getattr(cls, "drf_serializer") in ( + getattr(base, "drf_serializer", None) for base in cls.__mro__[1:] + ): + SERIALIZER_REGISTRY.pop(cls) + cls.drf_serializer = create_serializer_from_model(cls) + + return ret