From d8bb03a9ad0c75a722a0cf92252fc461642adcc1 Mon Sep 17 00:00:00 2001 From: Jeferson Daniel Date: Thu, 25 Apr 2024 00:00:38 -0300 Subject: [PATCH] feat: Improve type checking support (#103) --- README.md | 9 ++++---- integration_test/test_readme.py | 1 + mypy.ini | 2 ++ phulpyfile.py | 4 +--- pydantic_mongo/__init__.py | 9 ++++++-- pydantic_mongo/abstract_repository.py | 31 ++++++++++++++++----------- pydantic_mongo/fields.py | 13 ++++++++++- pydantic_mongo/pagination.py | 2 +- requirements_test.txt | 2 +- test/test_enhance_meta.py | 8 +++---- test/test_fields.py | 6 ++++-- test/test_pagination.py | 7 +++--- test/test_repository.py | 26 ++++++++++------------ 13 files changed, 70 insertions(+), 50 deletions(-) create mode 100644 mypy.ini diff --git a/README.md b/README.md index 7deeba1..90e2e46 100644 --- a/README.md +++ b/README.md @@ -17,9 +17,9 @@ pip install pydantic-mongo ```python from bson import ObjectId from pydantic import BaseModel -from pydantic_mongo import AbstractRepository, ObjectIdField +from pydantic_mongo import AbstractRepository, PydanticObjectId from pymongo import MongoClient -from typing import List +from typing import Optional, List import os class Foo(BaseModel): @@ -31,7 +31,8 @@ class Bar(BaseModel): banana: str = 'y' class Spam(BaseModel): - id: ObjectIdField = None + # PydanticObjectId is an alias to Annotated[ObjectId, ObjectIdAnnotation] + id: Optional[PydanticObjectId] = None foo: Foo bars: List[Bar] @@ -64,7 +65,7 @@ spam_repository.delete(spam) # Find One By Id result = spam_repository.find_one_by_id(spam.id) -# Find One By Id using string if the id attribute is a ObjectIdField +# Find One By Id using string if the id attribute is a PydanticObjectId result = spam_repository.find_one_by_id(ObjectId('611827f2878b88b49ebb69fc')) assert result.foo.count == 2 diff --git a/integration_test/test_readme.py b/integration_test/test_readme.py index 94c64e5..703e0e0 100644 --- a/integration_test/test_readme.py +++ b/integration_test/test_readme.py @@ -3,6 +3,7 @@ import os import re + def extract_python_snippets(content): # Regular expression pattern for finding Python code blocks pattern = r'```python(.*?)```' diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..895701c --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +plugins = pydantic.mypy \ No newline at end of file diff --git a/phulpyfile.py b/phulpyfile.py index a9df97a..d0fc7fa 100644 --- a/phulpyfile.py +++ b/phulpyfile.py @@ -46,8 +46,6 @@ def integration_test(phulpy): @task def typecheck(phulpy): - result = system( - r'find ./pydantic_mongo -name "*.py" -exec mypy --ignore-missing-imports --follow-imports=skip --strict-optional {} \+' - ) + result = system('mypy pydantic_mongo test --check-untyped-defs') if result: raise Exception("lint test failed") diff --git a/pydantic_mongo/__init__.py b/pydantic_mongo/__init__.py index 26bfd07..83bfe65 100644 --- a/pydantic_mongo/__init__.py +++ b/pydantic_mongo/__init__.py @@ -1,5 +1,10 @@ from .abstract_repository import AbstractRepository -from .fields import ObjectIdField +from .fields import ObjectIdAnnotation, ObjectIdField, PydanticObjectId from .version import __version__ # noqa: F401 -__all__ = ["ObjectIdField", "AbstractRepository"] +__all__ = [ + "AbstractRepository", + "ObjectIdField", + "ObjectIdAnnotation", + "PydanticObjectId", +] diff --git a/pydantic_mongo/abstract_repository.py b/pydantic_mongo/abstract_repository.py index 0c4a4eb..112f961 100644 --- a/pydantic_mongo/abstract_repository.py +++ b/pydantic_mongo/abstract_repository.py @@ -9,6 +9,7 @@ Type, TypeVar, Union, + cast, ) from pydantic import BaseModel @@ -26,10 +27,13 @@ T = TypeVar("T", bound=BaseModel) OutputT = TypeVar("OutputT", bound=BaseModel) - Sort = Sequence[Tuple[str, int]] +class ModelWithId(BaseModel): + id: Any + + class AbstractRepository(Generic[T]): class Meta: collection_name: str @@ -53,8 +57,6 @@ def get_collection(self) -> Collection: return self.__database[self.__collection_name] def __validate(self): - if not issubclass(self.__document_class, BaseModel): - raise Exception("Document class should inherit BaseModel") if "id" not in self.__document_class.model_fields: raise Exception("Document class should have id field") if not self.__collection_name: @@ -67,10 +69,11 @@ def to_document(model: T) -> dict: :param model: :return: dict """ - data = model.model_dump() + model_with_id = cast(ModelWithId, model) + data = model_with_id.model_dump() data.pop("id") - if model.id: - data["_id"] = model.id + if model_with_id.id: + data["_id"] = model_with_id.id return data def __map_id(self, data: dict) -> dict: @@ -109,15 +112,16 @@ def save(self, model: T) -> Union[InsertOneResult, UpdateResult]: Save entity to database. It will update the entity if it has id, otherwise it will insert it. """ document = self.to_document(model) + model_with_id = cast(ModelWithId, model) - if model.id: + if model_with_id.id: mongo_id = document.pop("_id") return self.get_collection().update_one( {"_id": mongo_id}, {"$set": document}, upsert=True ) result = self.get_collection().insert_one(document) - model.id = result.inserted_id + model_with_id.id = result.inserted_id return result def save_many(self, models: Iterable[T]): @@ -128,7 +132,8 @@ def save_many(self, models: Iterable[T]): models_to_update = [] for model in models: - if model.id: + model_with_id = cast(ModelWithId, model) + if model_with_id.id: models_to_update.append(model) else: models_to_insert.append(model) @@ -138,7 +143,7 @@ def save_many(self, models: Iterable[T]): ) for idx, inserted_id in enumerate(result.inserted_ids): - models_to_insert[idx].id = inserted_id + cast(ModelWithId, models_to_insert[idx]).id = inserted_id if len(models_to_update) == 0: return @@ -152,7 +157,7 @@ def save_many(self, models: Iterable[T]): self.get_collection().bulk_write(bulk_operations) def delete(self, model: T): - return self.get_collection().delete_one({"_id": model.id}) + return self.get_collection().delete_one({"_id": cast(ModelWithId, model).id}) def delete_by_id(self, _id: Any): return self.get_collection().delete_one({"_id": _id}) @@ -198,7 +203,7 @@ def find_by_with_output_type( cursor.limit(limit) if skip: cursor.skip(skip) - if sort: + if mapped_sort: cursor.sort(mapped_sort) return map(lambda doc: self.to_model_custom(output_type, doc), cursor) @@ -285,7 +290,7 @@ def paginate_with_output_type( ) return map( - lambda model: Edge[T]( + lambda model: Edge[OutputT]( node=model, cursor=encode_pagination_cursor( get_pagination_cursor_payload(model, sort_keys) diff --git a/pydantic_mongo/fields.py b/pydantic_mongo/fields.py index 3e6ef2c..2d4b171 100644 --- a/pydantic_mongo/fields.py +++ b/pydantic_mongo/fields.py @@ -2,9 +2,10 @@ from bson import ObjectId from pydantic_core import core_schema +from typing_extensions import Annotated -class ObjectIdField(ObjectId): +class ObjectIdAnnotation: @classmethod def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: Any @@ -31,3 +32,13 @@ def validate(cls, value): raise ValueError("Invalid id") return ObjectId(value) + + +# Deprecated, use PydanticObjectId instead. +class ObjectIdField(ObjectId): + @classmethod + def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: Any): + return ObjectIdAnnotation.__get_pydantic_core_schema__(_source_type, _handler) + + +PydanticObjectId = Annotated[ObjectId, ObjectIdAnnotation] diff --git a/pydantic_mongo/pagination.py b/pydantic_mongo/pagination.py index 5d66752..fc243fb 100644 --- a/pydantic_mongo/pagination.py +++ b/pydantic_mongo/pagination.py @@ -16,7 +16,7 @@ class Edge(BaseModel, Generic[DataT]): def encode_pagination_cursor(data: List) -> str: - byte_data = bson.BSON.encode({"v": data}) + byte_data: bytes = bson.BSON.encode({"v": data}) byte_data = zlib.compress(byte_data, 9) return b64encode(byte_data).decode("utf-8") diff --git a/requirements_test.txt b/requirements_test.txt index ad30968..b369ae6 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -7,7 +7,7 @@ pytest==8.1.1 pytest-cov==4.1.0 pytest-mock==3.12.0 mongomock==4.1.2 -pydantic>=2.0.2 +pydantic==2.7.1 pymongo==4.6.3 mypy==1.10.0 mypy-extensions==1.0.0 diff --git a/test/test_enhance_meta.py b/test/test_enhance_meta.py index a993cfd..221e2a1 100644 --- a/test/test_enhance_meta.py +++ b/test/test_enhance_meta.py @@ -1,11 +1,11 @@ import pytest from pydantic import BaseModel, Field - -from pydantic_mongo import AbstractRepository, ObjectIdField +from pydantic_mongo import AbstractRepository, PydanticObjectId +from typing_extensions import Optional class HamModel(BaseModel): - id: ObjectIdField = Field(default=None) + id: Optional[PydanticObjectId] name: str @@ -31,7 +31,7 @@ def test_repository_with_v2_meta(ham_repo): def test_save_with_new_repo(clean_ham_collection, ham_repo): - m = HamModel(name="wilfred") + m = HamModel(id=None, name="wilfred") assert m.id is None, "should have no id" ham_repo.save(m) assert m.id diff --git a/test/test_fields.py b/test/test_fields.py index 6a7043e..6ff156d 100644 --- a/test/test_fields.py +++ b/test/test_fields.py @@ -1,4 +1,5 @@ import pytest +from typing import Optional from bson import ObjectId from pydantic import BaseModel, ValidationError @@ -6,7 +7,7 @@ class User(BaseModel): - id: ObjectIdField = None + id: ObjectIdField class TestFields: @@ -26,5 +27,6 @@ def test_modify_schema(self): assert { "title": "User", "type": "object", - "properties": {"id": {"default": None, "title": "Id", "type": "string"}}, + "properties": {"id": {"title": "Id", "type": "string"}}, + "required": ["id"], } == schema diff --git a/test/test_pagination.py b/test/test_pagination.py index e539f04..6ace3ce 100644 --- a/test/test_pagination.py +++ b/test/test_pagination.py @@ -1,5 +1,4 @@ -import datetime -from typing import List +from typing import List, Optional from bson import ObjectId from pydantic import BaseModel @@ -13,7 +12,7 @@ class Foo(BaseModel): count: int - size: float = None + size: Optional[float] = None class Bar(BaseModel): @@ -22,7 +21,7 @@ class Bar(BaseModel): class Spam(BaseModel): - id: str = None + id: Optional[str] = None foo: Foo bars: List[Bar] diff --git a/test/test_repository.py b/test/test_repository.py index c80e3a4..81d51bc 100644 --- a/test/test_repository.py +++ b/test/test_repository.py @@ -1,17 +1,17 @@ -from typing import List +from typing import List, Optional, cast import mongomock import pytest from bson import ObjectId from pydantic import BaseModel, Field -from pydantic_mongo import AbstractRepository, ObjectIdField +from pydantic_mongo import AbstractRepository, PydanticObjectId from pydantic_mongo.errors import PaginationError class Foo(BaseModel): count: int - size: float = None + size: Optional[float] = None class Bar(BaseModel): @@ -20,9 +20,9 @@ class Bar(BaseModel): class Spam(BaseModel): - id: ObjectIdField = None - foo: Foo = None - bars: List[Bar] = None + id: Optional[PydanticObjectId] = None + foo: Optional[Foo] = None + bars: Optional[List[Bar]] = None class SpamRepository(AbstractRepository[Spam]): @@ -49,7 +49,7 @@ def test_save(self, database): "bars": [{"apple": "x", "banana": "y"}], } == database["spams"].find()[0] - spam.foo.count = 2 + cast(Foo, spam.foo).count = 2 spam_repository.save(spam) assert { @@ -147,6 +147,8 @@ def test_find_by_id(self, database): spam_repository = SpamRepository(database=database) result = spam_repository.find_one_by_id(spam_id) + assert result is not None + assert result.bars is not None assert issubclass(Spam, type(result)) assert spam_id == result.id assert "x" == result.bars[0].apple @@ -171,6 +173,8 @@ def test_find_by(self, database): result = spam_repository.find_by({}) results = [x for x in result] assert 2 == len(results) + assert results[0].foo is not None + assert results[1].foo is not None assert 2 == results[0].foo.count assert 3 == results[1].foo.count @@ -181,14 +185,6 @@ def test_find_by(self, database): results = [x for x in result] assert 0 == len(results) - def test_invalid_model_class(self, database): - class BrokenRepository(AbstractRepository[int]): - class Meta: - collection_name = "spams" - - with pytest.raises(Exception): - BrokenRepository(database=database) - def test_invalid_model_id_field(self, database): class NoIdModel(BaseModel): something: str