Skip to content

Commit

Permalink
feat: Improve type checking support (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
jefersondaniel authored Apr 25, 2024
1 parent 4c187c3 commit d8bb03a
Show file tree
Hide file tree
Showing 13 changed files with 70 additions and 50 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ pip install pydantic-mongo
```python
from bson import ObjectId
from pydantic import BaseModel
from pydantic_mongo import AbstractRepository, ObjectIdField
from pydantic_mongo import AbstractRepository, PydanticObjectId
from pymongo import MongoClient
from typing import List
from typing import Optional, List
import os

class Foo(BaseModel):
Expand All @@ -31,7 +31,8 @@ class Bar(BaseModel):
banana: str = 'y'

class Spam(BaseModel):
id: ObjectIdField = None
# PydanticObjectId is an alias to Annotated[ObjectId, ObjectIdAnnotation]
id: Optional[PydanticObjectId] = None
foo: Foo
bars: List[Bar]

Expand Down Expand Up @@ -64,7 +65,7 @@ spam_repository.delete(spam)
# Find One By Id
result = spam_repository.find_one_by_id(spam.id)

# Find One By Id using string if the id attribute is a ObjectIdField
# Find One By Id using string if the id attribute is a PydanticObjectId
result = spam_repository.find_one_by_id(ObjectId('611827f2878b88b49ebb69fc'))
assert result.foo.count == 2

Expand Down
1 change: 1 addition & 0 deletions integration_test/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import re


def extract_python_snippets(content):
# Regular expression pattern for finding Python code blocks
pattern = r'```python(.*?)```'
Expand Down
2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy]
plugins = pydantic.mypy
4 changes: 1 addition & 3 deletions phulpyfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def integration_test(phulpy):

@task
def typecheck(phulpy):
result = system(
r'find ./pydantic_mongo -name "*.py" -exec mypy --ignore-missing-imports --follow-imports=skip --strict-optional {} \+'
)
result = system('mypy pydantic_mongo test --check-untyped-defs')
if result:
raise Exception("lint test failed")
9 changes: 7 additions & 2 deletions pydantic_mongo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from .abstract_repository import AbstractRepository
from .fields import ObjectIdField
from .fields import ObjectIdAnnotation, ObjectIdField, PydanticObjectId
from .version import __version__ # noqa: F401

__all__ = ["ObjectIdField", "AbstractRepository"]
__all__ = [
"AbstractRepository",
"ObjectIdField",
"ObjectIdAnnotation",
"PydanticObjectId",
]
31 changes: 18 additions & 13 deletions pydantic_mongo/abstract_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Type,
TypeVar,
Union,
cast,
)

from pydantic import BaseModel
Expand All @@ -26,10 +27,13 @@

T = TypeVar("T", bound=BaseModel)
OutputT = TypeVar("OutputT", bound=BaseModel)

Sort = Sequence[Tuple[str, int]]


class ModelWithId(BaseModel):
id: Any


class AbstractRepository(Generic[T]):
class Meta:
collection_name: str
Expand All @@ -53,8 +57,6 @@ def get_collection(self) -> Collection:
return self.__database[self.__collection_name]

def __validate(self):
if not issubclass(self.__document_class, BaseModel):
raise Exception("Document class should inherit BaseModel")
if "id" not in self.__document_class.model_fields:
raise Exception("Document class should have id field")
if not self.__collection_name:
Expand All @@ -67,10 +69,11 @@ def to_document(model: T) -> dict:
:param model:
:return: dict
"""
data = model.model_dump()
model_with_id = cast(ModelWithId, model)
data = model_with_id.model_dump()
data.pop("id")
if model.id:
data["_id"] = model.id
if model_with_id.id:
data["_id"] = model_with_id.id
return data

def __map_id(self, data: dict) -> dict:
Expand Down Expand Up @@ -109,15 +112,16 @@ def save(self, model: T) -> Union[InsertOneResult, UpdateResult]:
Save entity to database. It will update the entity if it has id, otherwise it will insert it.
"""
document = self.to_document(model)
model_with_id = cast(ModelWithId, model)

if model.id:
if model_with_id.id:
mongo_id = document.pop("_id")
return self.get_collection().update_one(
{"_id": mongo_id}, {"$set": document}, upsert=True
)

result = self.get_collection().insert_one(document)
model.id = result.inserted_id
model_with_id.id = result.inserted_id
return result

def save_many(self, models: Iterable[T]):
Expand All @@ -128,7 +132,8 @@ def save_many(self, models: Iterable[T]):
models_to_update = []

for model in models:
if model.id:
model_with_id = cast(ModelWithId, model)
if model_with_id.id:
models_to_update.append(model)
else:
models_to_insert.append(model)
Expand All @@ -138,7 +143,7 @@ def save_many(self, models: Iterable[T]):
)

for idx, inserted_id in enumerate(result.inserted_ids):
models_to_insert[idx].id = inserted_id
cast(ModelWithId, models_to_insert[idx]).id = inserted_id

if len(models_to_update) == 0:
return
Expand All @@ -152,7 +157,7 @@ def save_many(self, models: Iterable[T]):
self.get_collection().bulk_write(bulk_operations)

def delete(self, model: T):
return self.get_collection().delete_one({"_id": model.id})
return self.get_collection().delete_one({"_id": cast(ModelWithId, model).id})

def delete_by_id(self, _id: Any):
return self.get_collection().delete_one({"_id": _id})
Expand Down Expand Up @@ -198,7 +203,7 @@ def find_by_with_output_type(
cursor.limit(limit)
if skip:
cursor.skip(skip)
if sort:
if mapped_sort:
cursor.sort(mapped_sort)
return map(lambda doc: self.to_model_custom(output_type, doc), cursor)

Expand Down Expand Up @@ -285,7 +290,7 @@ def paginate_with_output_type(
)

return map(
lambda model: Edge[T](
lambda model: Edge[OutputT](
node=model,
cursor=encode_pagination_cursor(
get_pagination_cursor_payload(model, sort_keys)
Expand Down
13 changes: 12 additions & 1 deletion pydantic_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from bson import ObjectId
from pydantic_core import core_schema
from typing_extensions import Annotated


class ObjectIdField(ObjectId):
class ObjectIdAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: Any
Expand All @@ -31,3 +32,13 @@ def validate(cls, value):
raise ValueError("Invalid id")

return ObjectId(value)


# Deprecated, use PydanticObjectId instead.
class ObjectIdField(ObjectId):
@classmethod
def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: Any):
return ObjectIdAnnotation.__get_pydantic_core_schema__(_source_type, _handler)


PydanticObjectId = Annotated[ObjectId, ObjectIdAnnotation]
2 changes: 1 addition & 1 deletion pydantic_mongo/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Edge(BaseModel, Generic[DataT]):


def encode_pagination_cursor(data: List) -> str:
byte_data = bson.BSON.encode({"v": data})
byte_data: bytes = bson.BSON.encode({"v": data})
byte_data = zlib.compress(byte_data, 9)
return b64encode(byte_data).decode("utf-8")

Expand Down
2 changes: 1 addition & 1 deletion requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pytest==8.1.1
pytest-cov==4.1.0
pytest-mock==3.12.0
mongomock==4.1.2
pydantic>=2.0.2
pydantic==2.7.1
pymongo==4.6.3
mypy==1.10.0
mypy-extensions==1.0.0
Expand Down
8 changes: 4 additions & 4 deletions test/test_enhance_meta.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
from pydantic import BaseModel, Field

from pydantic_mongo import AbstractRepository, ObjectIdField
from pydantic_mongo import AbstractRepository, PydanticObjectId
from typing_extensions import Optional


class HamModel(BaseModel):
id: ObjectIdField = Field(default=None)
id: Optional[PydanticObjectId]
name: str


Expand All @@ -31,7 +31,7 @@ def test_repository_with_v2_meta(ham_repo):


def test_save_with_new_repo(clean_ham_collection, ham_repo):
m = HamModel(name="wilfred")
m = HamModel(id=None, name="wilfred")
assert m.id is None, "should have no id"
ham_repo.save(m)
assert m.id
6 changes: 4 additions & 2 deletions test/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pytest
from typing import Optional
from bson import ObjectId
from pydantic import BaseModel, ValidationError

from pydantic_mongo import ObjectIdField


class User(BaseModel):
id: ObjectIdField = None
id: ObjectIdField


class TestFields:
Expand All @@ -26,5 +27,6 @@ def test_modify_schema(self):
assert {
"title": "User",
"type": "object",
"properties": {"id": {"default": None, "title": "Id", "type": "string"}},
"properties": {"id": {"title": "Id", "type": "string"}},
"required": ["id"],
} == schema
7 changes: 3 additions & 4 deletions test/test_pagination.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
from typing import List
from typing import List, Optional

from bson import ObjectId
from pydantic import BaseModel
Expand All @@ -13,7 +12,7 @@

class Foo(BaseModel):
count: int
size: float = None
size: Optional[float] = None


class Bar(BaseModel):
Expand All @@ -22,7 +21,7 @@ class Bar(BaseModel):


class Spam(BaseModel):
id: str = None
id: Optional[str] = None
foo: Foo
bars: List[Bar]

Expand Down
26 changes: 11 additions & 15 deletions test/test_repository.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import List
from typing import List, Optional, cast

import mongomock
import pytest
from bson import ObjectId
from pydantic import BaseModel, Field

from pydantic_mongo import AbstractRepository, ObjectIdField
from pydantic_mongo import AbstractRepository, PydanticObjectId
from pydantic_mongo.errors import PaginationError


class Foo(BaseModel):
count: int
size: float = None
size: Optional[float] = None


class Bar(BaseModel):
Expand All @@ -20,9 +20,9 @@ class Bar(BaseModel):


class Spam(BaseModel):
id: ObjectIdField = None
foo: Foo = None
bars: List[Bar] = None
id: Optional[PydanticObjectId] = None
foo: Optional[Foo] = None
bars: Optional[List[Bar]] = None


class SpamRepository(AbstractRepository[Spam]):
Expand All @@ -49,7 +49,7 @@ def test_save(self, database):
"bars": [{"apple": "x", "banana": "y"}],
} == database["spams"].find()[0]

spam.foo.count = 2
cast(Foo, spam.foo).count = 2
spam_repository.save(spam)

assert {
Expand Down Expand Up @@ -147,6 +147,8 @@ def test_find_by_id(self, database):
spam_repository = SpamRepository(database=database)
result = spam_repository.find_one_by_id(spam_id)

assert result is not None
assert result.bars is not None
assert issubclass(Spam, type(result))
assert spam_id == result.id
assert "x" == result.bars[0].apple
Expand All @@ -171,6 +173,8 @@ def test_find_by(self, database):
result = spam_repository.find_by({})
results = [x for x in result]
assert 2 == len(results)
assert results[0].foo is not None
assert results[1].foo is not None
assert 2 == results[0].foo.count
assert 3 == results[1].foo.count

Expand All @@ -181,14 +185,6 @@ def test_find_by(self, database):
results = [x for x in result]
assert 0 == len(results)

def test_invalid_model_class(self, database):
class BrokenRepository(AbstractRepository[int]):
class Meta:
collection_name = "spams"

with pytest.raises(Exception):
BrokenRepository(database=database)

def test_invalid_model_id_field(self, database):
class NoIdModel(BaseModel):
something: str
Expand Down

0 comments on commit d8bb03a

Please sign in to comment.