Skip to content

Commit

Permalink
Add support for json field
Browse files Browse the repository at this point in the history
  • Loading branch information
TheSuperiorStanislav committed Mar 25, 2024
1 parent 706a17b commit e0d1bf4
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 50 deletions.
19 changes: 18 additions & 1 deletion saritasa_sqlalchemy_tools/auto_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pydantic
import pydantic_core
import sqlalchemy.dialects.postgresql.ranges
import sqlalchemy.dialects.postgresql
import sqlalchemy.orm

from . import models
Expand Down Expand Up @@ -292,6 +292,7 @@ def _get_db_types_mapping(
sqlalchemy.Numeric: cls._generate_numeric_field,
sqlalchemy.Interval: cls._generate_interval_field,
sqlalchemy.ARRAY: cls._generate_array_field,
sqlalchemy.dialects.postgresql.JSON: cls._generate_postgres_json_field, # noqa: E501
}

@classmethod
Expand Down Expand Up @@ -531,6 +532,22 @@ def _generate_array_field(
else list[list_type] # type: ignore
), pydantic_core.PydanticUndefined

@classmethod
def _generate_postgres_json_field(
cls,
model: models.SQLAlchemyModel,
field: str,
model_attribute: models.ModelAttribute,
model_type: models.ModelType,
extra_field_config: MetaExtraFieldConfig,
) -> PydanticFieldConfig:
"""Generate postgres json field."""
return (
dict[str, str | int | float] | None
if model_attribute.nullable
else dict[str, str | int | float]
), pydantic_core.PydanticUndefined


ModelAutoSchemaT = typing.TypeVar(
"ModelAutoSchemaT",
Expand Down
2 changes: 2 additions & 0 deletions tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class TestModelFactory(
date_time = factory.Faker("date_time")
date = factory.Faker("date_between")
timedelta = factory.Faker("time_delta")
json_field = factory.Faker("pydict", allowed_types=[str, int, float])

class Meta:
model = models.TestModel
Expand Down Expand Up @@ -90,6 +91,7 @@ class SoftDeleteTestModelFactory(
date_time = factory.Faker("date_time")
date = factory.Faker("date_between")
timedelta = factory.Faker("time_delta")
json_field = factory.Faker("pydict", allowed_types=[str, int, float])

class Meta:
model = models.SoftDeleteTestModel
Expand Down
15 changes: 15 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import enum
import typing

import sqlalchemy.dialects.postgresql
import sqlalchemy.orm

import saritasa_sqlalchemy_tools
Expand Down Expand Up @@ -193,6 +194,20 @@ class TextEnum(enum.StrEnum):
)
)

json_field: sqlalchemy.orm.Mapped[dict[str, str | int | float]] = (
sqlalchemy.orm.mapped_column(
sqlalchemy.dialects.postgresql.JSON(),
nullable=False,
)
)

json_field_nullable: sqlalchemy.orm.Mapped[
dict[str, str | int | float]
] = sqlalchemy.orm.mapped_column(
sqlalchemy.dialects.postgresql.JSON(),
nullable=True,
)

@property
def custom_property(self) -> str:
"""Implement property."""
Expand Down
118 changes: 69 additions & 49 deletions tests/test_auto_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,43 @@
from . import models, repositories


@pytest.mark.parametrize(
"field",
[
"id",
"created",
"modified",
"text",
"text_nullable",
"text_enum",
"text_enum_nullable",
"number",
"number_nullable",
"small_number",
"small_number_nullable",
"decimal_number",
"decimal_number_nullable",
"boolean",
"boolean_nullable",
"text_list",
"text_list_nullable",
"date_time",
"date_time_nullable",
"date",
"date_nullable",
"timedelta",
"timedelta_nullable",
"related_model_id",
"related_model_id_nullable",
"custom_property",
"custom_property_nullable",
"json_field",
"json_field_nullable",
],
)
async def test_auto_schema_generation(
test_model: models.TestModel,
field: str,
) -> None:
"""Test schema generation picks correct types from model for schema."""

Expand All @@ -23,47 +58,41 @@ class Meta:
from_attributes=True,
validate_assignment=True,
)
fields = (
"id",
"created",
"modified",
"text",
"text_nullable",
"text_enum",
"text_enum_nullable",
"number",
"number_nullable",
"small_number",
"small_number_nullable",
"decimal_number",
"decimal_number_nullable",
"boolean",
"boolean_nullable",
"text_list",
"text_list_nullable",
"date_time",
"date_time_nullable",
"date",
"date_nullable",
"timedelta",
"timedelta_nullable",
"related_model_id",
"related_model_id_nullable",
"custom_property",
"custom_property_nullable",
)
fields = (field,)

schema = AutoSchema.get_schema()
model = schema.model_validate(test_model)
for field in AutoSchema.Meta.fields:
assert getattr(model, field) == getattr(test_model, field)
if "nullable" not in field and "property" not in field:
with pytest.raises(pydantic.ValidationError):
setattr(model, field, None)
assert getattr(model, field) == getattr(test_model, field)
if "nullable" not in field and "property" not in field:
with pytest.raises(pydantic.ValidationError):
setattr(model, field, None)


@pytest.mark.parametrize(
[
"field",
"field_type",
],
[
["text", str | None],
["text_enum", models.TestModel.TextEnum | None],
["number", int | None],
["small_number", int | None],
["decimal_number", decimal.Decimal | None],
["boolean", bool | None],
["text_list", list[str] | None],
["date_time", datetime.datetime | None],
["date", datetime.date | None],
["timedelta", datetime.timedelta | None],
["json_field", dict[str, typing.Any] | None],
["custom_property", str | None],
["related_model_id", int | None],
],
)
async def test_auto_schema_type_override_generation(
test_model: models.TestModel,
field: str,
field_type: type,
) -> None:
"""Test that in auto schema generation you can override type.
Expand All @@ -76,25 +105,16 @@ class AutoSchema(saritasa_sqlalchemy_tools.ModelAutoSchema):
class Meta:
model = models.TestModel
fields = (
("text", str | None),
("text_enum", models.TestModel.TextEnum | None),
("number", int | None),
("small_number", int | None),
("decimal_number", decimal.Decimal | None),
("boolean", bool | None),
("text_list", list[str] | None),
("date_time", datetime.datetime | None),
("date", datetime.date | None),
("timedelta", datetime.timedelta | None),
("custom_property", str | None),
("related_model_id", int | None),
(
field,
field_type,
),
)

schema = AutoSchema.get_schema()
model = schema.model_validate(test_model)
for field, _ in AutoSchema.Meta.fields:
if "property" not in field:
setattr(model, field, None)
if "property" not in field:
setattr(model, field, None)


async def test_auto_schema_type_invalid_field_config(
Expand Down

0 comments on commit e0d1bf4

Please sign in to comment.