Skip to content

Commit

Permalink
Add support for date range fields
Browse files Browse the repository at this point in the history
  • Loading branch information
TheSuperiorStanislav committed Apr 9, 2024
1 parent bd544d4 commit ed88460
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 8 deletions.
13 changes: 11 additions & 2 deletions saritasa_sqlalchemy_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,29 @@
)

with contextlib.suppress(ImportError):
from .testing import AsyncSQLAlchemyModelFactory, AsyncSQLAlchemyOptions
from .testing import (
AsyncSQLAlchemyModelFactory,
AsyncSQLAlchemyOptions,
DateRangeFactory,
)

with contextlib.suppress(ImportError):
from .alembic import AlembicMigrations

with contextlib.suppress(ImportError):
from .auto_schema import (
from .schema import (
ModelAutoSchema,
ModelAutoSchemaError,
ModelAutoSchemaT,
PostgresRange,
PostgresRangeTypeT,
)

__all__ = (
"AlembicMigrations",
"AsyncSQLAlchemyModelFactory",
"AsyncSQLAlchemyOptions",
"DateRangeFactory",
"Session",
"SessionFactory",
"get_async_db_session",
Expand Down Expand Up @@ -111,4 +118,6 @@
"ModelAutoSchema",
"ModelAutoSchemaError",
"ModelAutoSchemaT",
"PostgresRange",
"PostgresRangeTypeT",
)
8 changes: 8 additions & 0 deletions saritasa_sqlalchemy_tools/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .auto_schema import (
ModelAutoSchema,
ModelAutoSchemaError,
ModelAutoSchemaT,
UnableProcessTypeError,
UnableToExtractEnumClassError,
)
from .fields import PostgresRange, PostgresRangeTypeT
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import sqlalchemy.dialects.postgresql
import sqlalchemy.orm

from . import models
from .. import models
from . import fields

PydanticFieldConfig: typing.TypeAlias = tuple[
types.UnionType
Expand Down Expand Up @@ -295,6 +296,7 @@ def _get_db_types_mapping(
sqlalchemy.Interval: cls._generate_interval_field,
sqlalchemy.ARRAY: cls._generate_array_field,
sqlalchemy.dialects.postgresql.JSON: cls._generate_postgres_json_field, # noqa: E501
sqlalchemy.dialects.postgresql.ranges.DATERANGE: cls._generate_date_range, # noqa: E501
}

@classmethod
Expand Down Expand Up @@ -550,6 +552,22 @@ def _generate_postgres_json_field(
else dict[str, str | int | float]
), pydantic_core.PydanticUndefined

@classmethod
def _generate_date_range(
cls,
model: models.SQLAlchemyModel,
field: str,
model_attribute: models.ModelAttribute,
model_type: models.ModelType,
extra_field_config: MetaExtraFieldConfig,
) -> PydanticFieldConfig:
"""Generate date range field."""
return (
fields.PostgresRange[datetime.date] | None
if model_attribute.nullable
else fields.PostgresRange[datetime.date]
), pydantic_core.PydanticUndefined


ModelAutoSchemaT = typing.TypeVar(
"ModelAutoSchemaT",
Expand Down
50 changes: 50 additions & 0 deletions saritasa_sqlalchemy_tools/schema/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import datetime
import typing

import pydantic
import sqlalchemy.dialects.postgresql

PostgresRangeTypeT = typing.TypeVar("PostgresRangeTypeT", bound=typing.Any)


class PostgresRange(pydantic.BaseModel, typing.Generic[PostgresRangeTypeT]):
"""Representation of sqlalchemy.dialects.postgresql.Range."""

model_config = pydantic.ConfigDict(from_attributes=True)

lower: PostgresRangeTypeT | None
upper: PostgresRangeTypeT | None
bounds: sqlalchemy.dialects.postgresql.ranges._BoundsType = "[]"

def to_postgres(
self,
) -> sqlalchemy.dialects.postgresql.Range[PostgresRangeTypeT]:
"""Convert to postgres range."""
return sqlalchemy.dialects.postgresql.Range(
lower=self.lower,
upper=self.upper,
bounds=self.bounds,
)

def model_post_init(self, __context: dict[str, typing.Any]) -> None:
"""Adjust limit depending on bounds.
Postgres always keeps and returns ranges in `[)` no matter how you
save it. Thats why we need to correct it for frontend.
"""
match self.bounds:
case "[]":
return # pragma: no cover
case "(]":
if self.lower: # pragma: no cover
self.lower = self.lower + datetime.timedelta(days=1)
case "[)":
if self.upper:
self.upper = self.upper - datetime.timedelta(days=1)
case "()": # pragma: no cover
if self.lower: # pragma: no cover
self.lower = self.lower + datetime.timedelta(days=1)
if self.upper: # pragma: no cover
self.upper = self.upper - datetime.timedelta(days=1)
self.bounds = "[]"
1 change: 1 addition & 0 deletions saritasa_sqlalchemy_tools/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .factories import AsyncSQLAlchemyModelFactory, AsyncSQLAlchemyOptions
from .fields import DateRangeFactory
45 changes: 45 additions & 0 deletions saritasa_sqlalchemy_tools/testing/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import datetime

import factory
import faker
import sqlalchemy.dialects.postgresql


class DateRangeFactory(factory.LazyAttribute):
"""Generate date range."""

def __init__(
self,
start_date: str = "-3d",
end_date: str = "+3d",
) -> None:
self.start_date = start_date
self.end_date = end_date
super().__init__(
function=self._generate_date_range,
)

def _generate_date_range(
self,
*args, # noqa: ANN002
**kwargs,
) -> sqlalchemy.dialects.postgresql.Range[datetime.date]:
"""Generate range."""
fake = faker.Faker()
lower = fake.date_between(
start_date=self.start_date,
end_date=self.end_date,
)
upper = fake.date_between(
start_date=lower,
end_date=self.end_date,
)
# Need to make sure that the dates are not the same
if upper == lower:
upper += datetime.timedelta(days=1)

return sqlalchemy.dialects.postgresql.Range(
lower=lower,
upper=upper,
bounds="[)",
)
23 changes: 22 additions & 1 deletion tests/alembic/versions/0001_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Revision ID: 0001
Revises:
Create Date: 2024-04-04 12:15:38.086722
Create Date: 2024-04-09 14:46:52.693175
"""

Expand All @@ -19,6 +19,7 @@

def upgrade() -> None:
"""Apply migrations to database."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"related_model",
sa.Column("test_model_id", sa.Integer(), nullable=True),
Expand Down Expand Up @@ -83,6 +84,16 @@ def upgrade() -> None:
postgresql.JSON(astext_type=sa.Text()),
nullable=True,
),
sa.Column(
"date_range",
postgresql.DATERANGE(),
nullable=False,
),
sa.Column(
"date_range_nullable",
postgresql.DATERANGE(),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("deleted", sa.DateTime(), nullable=True),
sa.Column(
Expand Down Expand Up @@ -141,6 +152,16 @@ def upgrade() -> None:
postgresql.JSON(astext_type=sa.Text()),
nullable=True,
),
sa.Column(
"date_range",
postgresql.DATERANGE(),
nullable=False,
),
sa.Column(
"date_range_nullable",
postgresql.DATERANGE(),
nullable=True,
),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"created",
Expand Down
2 changes: 1 addition & 1 deletion tests/alembic/versions/0002_init_alter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Revision ID: 0002
Revises: 0001
Create Date: 2024-04-04 12:16:43.007252
Create Date: 2024-04-09 14:48:14.415826
"""

Expand Down
2 changes: 2 additions & 0 deletions tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class TestModelFactory(
date = factory.Faker("date_between")
timedelta = factory.Faker("time_delta")
json_field = factory.Faker("pydict", allowed_types=[str, int, float])
date_range = saritasa_sqlalchemy_tools.DateRangeFactory()

class Meta:
model = models.TestModel
Expand Down Expand Up @@ -96,6 +97,7 @@ class SoftDeleteTestModelFactory(
date = factory.Faker("date_between")
timedelta = factory.Faker("time_delta")
json_field = factory.Faker("pydict", allowed_types=[str, int, float])
date_range = saritasa_sqlalchemy_tools.DateRangeFactory()

class Meta:
model = models.SoftDeleteTestModel
Expand Down
13 changes: 13 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,19 @@ class TextEnum(enum.StrEnum):
nullable=True,
)

date_range: sqlalchemy.orm.Mapped[
sqlalchemy.dialects.postgresql.Range[datetime.date]
] = sqlalchemy.orm.mapped_column(
sqlalchemy.dialects.postgresql.DATERANGE,
)

date_range_nullable: sqlalchemy.orm.Mapped[
sqlalchemy.dialects.postgresql.Range[datetime.date]
] = sqlalchemy.orm.mapped_column(
sqlalchemy.dialects.postgresql.DATERANGE,
nullable=True,
)

@property
def custom_property(self) -> str:
"""Implement property."""
Expand Down
19 changes: 16 additions & 3 deletions tests/test_auto_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

from . import models, repositories

SPECIAL_POSTGRES_TYPES = {
saritasa_sqlalchemy_tools.PostgresRange[datetime.date],
}


@pytest.mark.parametrize(
"field",
Expand Down Expand Up @@ -43,6 +47,8 @@
"custom_property_nullable",
"json_field",
"json_field_nullable",
"date_range",
"date_range_nullable",
],
)
async def test_auto_schema_generation(
Expand All @@ -62,7 +68,10 @@ class Meta:

schema = AutoSchema.get_schema()
model = schema.model_validate(test_model)
assert getattr(model, field) == getattr(test_model, field)
value = getattr(model, field)
if value.__class__ in SPECIAL_POSTGRES_TYPES:
value = value.to_postgres()
assert value == getattr(test_model, field)
if "nullable" not in field and "property" not in field:
with pytest.raises(pydantic.ValidationError):
setattr(model, field, None)
Expand All @@ -87,6 +96,10 @@ class Meta:
["json_field", dict[str, typing.Any] | None],
["custom_property", str | None],
["related_model_id", int | None],
[
"date_range",
saritasa_sqlalchemy_tools.PostgresRange[datetime.date] | None,
],
],
)
async def test_auto_schema_type_override_generation(
Expand Down Expand Up @@ -133,7 +146,7 @@ class Meta:
fields = (("id", int, 1),)

with pytest.raises(
saritasa_sqlalchemy_tools.auto_schema.UnableProcessTypeError,
saritasa_sqlalchemy_tools.schema.UnableProcessTypeError,
match=re.escape(
"Can't process the following field ('id', <class 'int'>, 1)",
),
Expand All @@ -160,7 +173,7 @@ class Meta:
)

with pytest.raises(
saritasa_sqlalchemy_tools.auto_schema.UnableProcessTypeError,
saritasa_sqlalchemy_tools.schema.UnableProcessTypeError,
match=re.escape(
"Schema generation is not supported for relationship "
"fields(related_model), please use auto-schema or pydantic class",
Expand Down

0 comments on commit ed88460

Please sign in to comment.