From c7e93667ac0357430630369ead9fe0889d485e00 Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 22 Oct 2024 11:36:07 +0200 Subject: [PATCH] Changes: - update release notes - expand tests for InspectDB - expand tests for reflection - test harder the correct reflection of fields - expose max_digits for FloatField --- docs/fields.md | 7 +- docs/release-notes.md | 13 ++ edgy/core/db/fields/core.py | 14 +- edgy/core/db/models/base.py | 2 +- edgy/utils/inspect.py | 32 ++-- tests/fields/test_fields.py | 7 + tests/reflection/test_table_reflection.py | 80 +++++---- .../test_table_reflection_special_fields.py | 154 ++++++++++++++++++ 8 files changed, 259 insertions(+), 50 deletions(-) create mode 100644 tests/reflection/test_table_reflection_special_fields.py diff --git a/docs/fields.md b/docs/fields.md index 96c6eb55..722380a6 100644 --- a/docs/fields.md +++ b/docs/fields.md @@ -453,10 +453,13 @@ import edgy class MyModel(edgy.Model): price: float = edgy.FloatField(null=True) ... - ``` -Derives from the same as [IntegerField](#integerfield) and validates the decimal float. +Derives from the same as [IntegerField](#integerfield) and validates the float. + +##### Parameters + +* `max_digits` - An integer indicating the total maximum digits. In contrast to DecimalField it is database-only and can be used for higher/lower precision fields. Optional. #### ForeignKey diff --git a/docs/release-notes.md b/docs/release-notes.md index 3fa5d0a1..78153c2b 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -6,6 +6,19 @@ hide: # Release Notes +## Unreleased + +### Added + +- Add DurationField. +- Allow passing `max_digits` to FloatField. + +### Fixed + +- Triggering load on non-existent field when reflecting. +- InspectDB mapping was incorrect. + + ## 0.19.1 ### Fixed diff --git a/edgy/core/db/fields/core.py b/edgy/core/db/fields/core.py index 0aad00b5..957cbbc7 100644 --- a/edgy/core/db/fields/core.py +++ b/edgy/core/db/fields/core.py @@ -12,6 +12,7 @@ import pydantic import sqlalchemy from pydantic import EmailStr +from sqlalchemy.dialects import oracle from edgy.core.db.context_vars import CURRENT_INSTANCE, CURRENT_PHASE, EXPLICIT_SPECIFIED_VALUES from edgy.core.db.fields._internal import IPAddress @@ -278,12 +279,16 @@ class FloatField(FieldFactory, float): def __new__( # type: ignore cls, *, + max_digits: Optional[int] = None, ge: Union[int, float, decimal.Decimal, None] = None, gt: Union[int, float, decimal.Decimal, None] = None, le: Union[int, float, decimal.Decimal, None] = None, lt: Union[int, float, decimal.Decimal, None] = None, **kwargs: Any, ) -> BaseFieldType: + # pydantic doesn't support max_digits for float, so rename it + column_max_digits = max_digits + del max_digits kwargs = { **kwargs, **{key: value for key, value in locals().items() if key not in CLASS_DEFAULTS}, @@ -291,8 +296,13 @@ def __new__( # type: ignore return super().__new__(cls, **kwargs) @classmethod - def get_column_type(cls, **kwargs: Any) -> Any: - return sqlalchemy.Float(asdecimal=False) + def get_column_type(cls, column_max_digits: Optional[int] = None, **kwargs: Any) -> Any: + if column_max_digits is None: + return sqlalchemy.Float(asdecimal=False) + return sqlalchemy.Float(precision=column_max_digits, asdecimal=False).with_variant( + oracle.FLOAT(binary_precision=round(column_max_digits / 0.30103), asdecimal=False), + "oracle", + ) class BigIntegerField(IntegerField): diff --git a/edgy/core/db/models/base.py b/edgy/core/db/models/base.py index 94bb1b3f..305d8fd4 100644 --- a/edgy/core/db/models/base.py +++ b/edgy/core/db/models/base.py @@ -484,7 +484,7 @@ def __getattr__(self, name: str) -> Any: name not in self.__dict__ and behavior != "passdown" and not self._loaded_or_deleted - and field is not None + and (field is not None or self.__reflected__) and name not in self.identifying_db_fields and self.can_load ): diff --git a/edgy/utils/inspect.py b/edgy/utils/inspect.py index 6db72dd1..bc56dbfb 100644 --- a/edgy/utils/inspect.py +++ b/edgy/utils/inspect.py @@ -52,7 +52,7 @@ class InspectDB: Class that builds the inspection of a database. """ - def __init__(self, database: str, schema: Optional[str]) -> None: + def __init__(self, database: str, schema: Optional[str] = None) -> None: """ Creates an instance of an InspectDB and triggers the proccess. """ @@ -185,6 +185,18 @@ def get_field_type(self, column: sqlalchemy.Column, is_fk: bool = False) -> Any: field_params["max_digits"] = real_field.precision field_params["decimal_places"] = real_field.scale + if field_type == "FloatField": + # Note: precision is maybe set to None when reflecting. + precision = getattr(real_field, "precision", None) + if precision is None: + # Oracle + precision = getattr(real_field, "binary_precision", None) + if precision is not None: + # invert calculation of binary_precision + precision = round(precision * 0.30103) + if precision is not None: + field_params["max_digits"] = precision + if field_type == "BinaryField": field_params["max_length"] = getattr(real_field, "length", None) @@ -192,15 +204,15 @@ def get_field_type(self, column: sqlalchemy.Column, is_fk: bool = False) -> Any: @classmethod def get_meta( - cls, table: dict[str, Any], unique_constraints: set[str], _indexes: set[str] + cls, table_detail: dict[str, Any], unique_constraints: set[str], _indexes: set[str] ) -> NoReturn: """ Produces the Meta class. """ unique_together: list[edgy.UniqueConstraint] = [] unique_indexes: list[edgy.Index] = [] - indexes = list(table["indexes"]) - constraints = list(table["constraints"]) + indexes = list(table_detail["indexes"]) + constraints = list(table_detail["constraints"]) # Handle the unique together for constraint in constraints: @@ -226,7 +238,7 @@ def get_meta( meta += [ " class Meta:\n", " registry = registry\n", - " tablename = '{}'\n".format(table["tablename"]), + " tablename = '{}'\n".format(table_detail["tablename"]), ] if unique_together: @@ -242,7 +254,7 @@ def get_meta( @classmethod def write_output( - cls, tables: list[Any], connection_string: str, schema: Union[str, None] = None + cls, table_details: list[Any], connection_string: str, schema: Union[str, None] = None ) -> NoReturn: """ Writes to stdout and runs some internal validations. @@ -276,17 +288,17 @@ def write_output( yield registry # Start writing the classes - for table in tables: + for table_detail in table_details: unique_constraints: set[str] = set() indexes: set[str] = set() yield "\n" yield "\n" yield "\n" - yield "class {}({}.ReflectModel):\n".format(table["class_name"], DB_MODULE) + yield "class {}({}.ReflectModel):\n".format(table_detail["class_name"], DB_MODULE) # yield " ...\n" - sqla_table: sqlalchemy.Table = table["table"] + sqla_table: sqlalchemy.Table = table_detail["table"] columns = list(sqla_table.columns) # Get the column information @@ -339,4 +351,4 @@ def write_output( yield f" {field_description}" yield "\n" - yield from cls.get_meta(table, unique_constraints, indexes) + yield from cls.get_meta(table_detail, unique_constraints, indexes) diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index 8ddf31b5..d1aa5d74 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -165,6 +165,13 @@ def test_can_create_float_field(): assert field.null is True +def test_can_create_max_digits_float_field(): + field = FloatField(max_digits=10, null=True) + + assert isinstance(field, BaseField) + assert field.column_type.precision == 10 + + def test_can_create_boolean_field(): field = BooleanField(default=False) diff --git a/tests/reflection/test_table_reflection.py b/tests/reflection/test_table_reflection.py index 55702d9e..266f1b08 100644 --- a/tests/reflection/test_table_reflection.py +++ b/tests/reflection/test_table_reflection.py @@ -1,25 +1,42 @@ -import asyncio -import functools -import random -import string +from contextlib import redirect_stdout +from io import StringIO import pytest import edgy from edgy.core.db.datastructures import Index from edgy.testclient import DatabaseTestClient +from edgy.utils.inspect import InspectDB from tests.settings import DATABASE_URL pytestmark = pytest.mark.anyio database = DatabaseTestClient(DATABASE_URL) -models = edgy.Registry(database=edgy.Database(database, force_rollback=True)) +models = edgy.Registry(database) +second = edgy.Registry(database=edgy.Database(database, force_rollback=False)) +expected_result1 = """ +class Users(edgy.ReflectModel): + name = edgy.CharField(max_length=255, null=False) + title = edgy.CharField(max_length=255, null=True) + id = edgy.BigIntegerField(null=False, primary_key=True) + + class Meta: + registry = registry + tablename = 'users' +""".strip() + +expected_result2 = """ +class Hubusers(edgy.ReflectModel): + name = edgy.CharField(max_length=255, null=False) + title = edgy.CharField(max_length=255, null=True) + description = edgy.CharField(max_length=255, null=True) + id = edgy.BigIntegerField(null=False, primary_key=True) -def get_random_string(length): - letters = string.ascii_lowercase - result_str = "".join(random.choice(letters) for i in range(length)) - return result_str + class Meta: + registry = registry + tablename = 'hubusers' +""".strip() class User(edgy.Model): @@ -51,7 +68,7 @@ class ReflectedUser(edgy.ReflectModel): class Meta: tablename = "hubusers" - registry = models + registry = second class NewReflectedUser(edgy.ReflectModel): @@ -60,38 +77,19 @@ class NewReflectedUser(edgy.ReflectModel): class Meta: tablename = "hubusers" - registry = models + registry = second -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture(autouse=True, scope="function") async def create_test_database(): - async with database: + async with models: await models.create_all() - yield + async with second: + yield if not database.drop: await models.drop_all() -@pytest.fixture(autouse=True, scope="function") -async def rollback_transactions(): - async with models.database: - yield - - -def async_adapter(wrapped_func): - """ - Decorator used to run async test cases. - """ - - @functools.wraps(wrapped_func) - def run_sync(*args, **kwargs): - loop = asyncio.get_event_loop() - task = wrapped_func(*args, **kwargs) - return loop.run_until_complete(task) - - return run_sync - - async def test_can_reflect_existing_table(): await HubUser.query.create(name="Test", title="a title", description="desc") @@ -134,7 +132,7 @@ async def test_can_reflect_existing_table_with_not_all_fields_and_create_record( user = users[1] assert user.name == "Test2" - assert not hasattr(user, "description") + assert "description" not in user.__dict__ async def test_can_reflect_and_edit_existing_table(): @@ -163,3 +161,15 @@ async def test_can_reflect_and_edit_existing_table(): assert user.name == "edgy" assert user.description == "updated" + + +async def test_create_correct_inspect_db(): + inflected = InspectDB(str(models.database.url)) + out = StringIO() + with redirect_stdout(out): + inflected.inspect() + out.seek(0) + generated = out.read() + # indexes are not sorted and appear in any order so they are removed + assert expected_result1 in generated + assert expected_result2 in generated diff --git a/tests/reflection/test_table_reflection_special_fields.py b/tests/reflection/test_table_reflection_special_fields.py new file mode 100644 index 00000000..a13f9ea7 --- /dev/null +++ b/tests/reflection/test_table_reflection_special_fields.py @@ -0,0 +1,154 @@ +import sys +from contextlib import redirect_stdout +from datetime import timedelta +from io import StringIO +from uuid import uuid4 + +import pytest +import sqlalchemy + +import edgy +from edgy.core.db.datastructures import Index +from edgy.testclient import DatabaseTestClient +from edgy.utils.inspect import InspectDB +from tests.settings import DATABASE_URL + +pytestmark = pytest.mark.anyio + +database = DatabaseTestClient(DATABASE_URL) +models = edgy.Registry(database) +second = edgy.Registry(database=edgy.Database(database, force_rollback=False)) +# not connected at all +third = edgy.Registry(database=edgy.Database(database, force_rollback=False)) + +expected_result1 = """ +class Products(edgy.ReflectModel): + name = edgy.CharField(max_length=255, null=False) + title = edgy.CharField(max_length=255, null=True) + price = edgy.FloatField(null=False) + uuid = edgy.UUIDField(null=False) + duration = edgy.DurationField(null=False) + extra = edgy.JSONField(null=False) + id = edgy.BigIntegerField(null=False, primary_key=True) + + class Meta: + registry = registry + tablename = 'products' +""".strip() +expected_result_full_info = """ +class Products(edgy.ReflectModel): + name = edgy.CharField(max_length=255, null=False, index=True) + title = edgy.CharField(max_length=255, null=True) + price = edgy.FloatField(max_digits=4, null=False) + uuid = edgy.UUIDField(null=False) + duration = edgy.DurationField(null=False) + extra = edgy.JSONField(null=False) + id = edgy.BigIntegerField(null=False, primary_key=True) + + class Meta: + registry = registry + tablename = 'products' +""".strip() + + +class Product(edgy.Model): + name = edgy.fields.CharField(max_length=255, index=True) + title = edgy.fields.CharField(max_length=255, null=True) + price = edgy.fields.FloatField(max_digits=4) + uuid = edgy.fields.UUIDField(default=uuid4) + duration = edgy.fields.DurationField() + extra = edgy.fields.JSONField(default=dict) + + class Meta: + registry = models + indexes = [Index(fields=["name", "title"], name="idx_name_title")] + + +class ProductThird(edgy.Model): + name = edgy.fields.CharField(max_length=255, index=True) + title = edgy.fields.CharField(max_length=255, null=True) + price = edgy.fields.FloatField(max_digits=4) + uuid = edgy.fields.UUIDField() + duration = edgy.fields.DurationField() + extra = edgy.fields.JSONField() + + class Meta: + tablename = "products" + registry = third + + +class ReflectedProduct(edgy.ReflectModel): + name = edgy.fields.CharField(max_length=50) + + class Meta: + tablename = "products" + registry = second + + +@pytest.fixture(autouse=True, scope="function") +async def create_test_database(): + async with models: + await models.create_all() + async with second: + yield + if not database.drop: + await models.drop_all() + + +async def test_can_reflect_correct_columns(): + second.invalidate_models() + assert ReflectedProduct.table.c.price.type.as_generic().__class__ == sqlalchemy.Float + assert ReflectedProduct.table.c.uuid.type.as_generic().__class__ == sqlalchemy.Uuid + assert ReflectedProduct.table.c.duration.type.as_generic().__class__ == sqlalchemy.Interval + # now the tables should be initialized + assert second.metadata.tables["products"].c.uuid.type.as_generic().__class__ == sqlalchemy.Uuid + assert ( + second.metadata.tables["products"].c.duration.type.as_generic().__class__ + == sqlalchemy.Interval + ) + + +async def test_create_correct_inspect_db(): + inflected = InspectDB(str(models.database.url)) + out = StringIO() + with redirect_stdout(out): + inflected.inspect() + out.seek(0) + generated = out.read() + generated = "\n".join(generated.splitlines()[:-1]) + # remove indexes as they tend to be instable (last line) + assert generated.strip().endswith(expected_result1) + + +async def test_create_correct_inspect_db_with_full_info_avail(): + # Here we generate from an original metadata a file + # this will however not happen often + tables, _ = InspectDB.generate_table_information(third.metadata) + + out = StringIO() + with redirect_stdout(out): + for line in InspectDB.write_output(tables, str(database.url), schema=None): + sys.stdout.writelines(line) # type: ignore + out.seek(0) + generated = out.read() + generated = "\n".join(generated.splitlines()[:-1]) + # remove indexes as they tend to be instable (last line) + assert generated.strip().endswith(expected_result_full_info) + + +async def test_can_read_update_fields(): + await Product.query.create( + name="Ice cream", title="yummy ice cream", duration=timedelta(hours=3), price=1.4 + ) + + product = await ReflectedProduct.query.get() + assert product.name == "Ice cream" + assert product.title == "yummy ice cream" + assert product.duration == timedelta(hours=3) + assert product.extra == {} + product.name = "Chocolate" + await product.save() + + # check first table + old_product = await Product.query.get(pk=product) + assert old_product.name == "Chocolate"