Skip to content

Commit

Permalink
Changes:
Browse files Browse the repository at this point in the history
- update release notes
- expand tests for InspectDB
- expand tests for reflection
- test harder the correct reflection of fields
- expose max_digits for FloatField
  • Loading branch information
devkral committed Oct 22, 2024
1 parent cb970f1 commit c7e9366
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 50 deletions.
7 changes: 5 additions & 2 deletions docs/fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,13 @@ import edgy
class MyModel(edgy.Model):
price: float = edgy.FloatField(null=True)
...

```

Derives from the same as [IntegerField](#integerfield) and validates the decimal float.
Derives from the same as [IntegerField](#integerfield) and validates the float.

##### Parameters

* `max_digits` - An integer indicating the total maximum digits. In contrast to DecimalField it is database-only and can be used for higher/lower precision fields. Optional.

#### ForeignKey

Expand Down
13 changes: 13 additions & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ hide:

# Release Notes

## Unreleased

### Added

- Add DurationField.
- Allow passing `max_digits` to FloatField.

### Fixed

- Triggering load on non-existent field when reflecting.
- InspectDB mapping was incorrect.


## 0.19.1

### Fixed
Expand Down
14 changes: 12 additions & 2 deletions edgy/core/db/fields/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pydantic
import sqlalchemy
from pydantic import EmailStr
from sqlalchemy.dialects import oracle

from edgy.core.db.context_vars import CURRENT_INSTANCE, CURRENT_PHASE, EXPLICIT_SPECIFIED_VALUES
from edgy.core.db.fields._internal import IPAddress
Expand Down Expand Up @@ -278,21 +279,30 @@ class FloatField(FieldFactory, float):
def __new__( # type: ignore
cls,
*,
max_digits: Optional[int] = None,
ge: Union[int, float, decimal.Decimal, None] = None,
gt: Union[int, float, decimal.Decimal, None] = None,
le: Union[int, float, decimal.Decimal, None] = None,
lt: Union[int, float, decimal.Decimal, None] = None,
**kwargs: Any,
) -> BaseFieldType:
# pydantic doesn't support max_digits for float, so rename it
column_max_digits = max_digits
del max_digits
kwargs = {
**kwargs,
**{key: value for key, value in locals().items() if key not in CLASS_DEFAULTS},
}
return super().__new__(cls, **kwargs)

@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Float(asdecimal=False)
def get_column_type(cls, column_max_digits: Optional[int] = None, **kwargs: Any) -> Any:
if column_max_digits is None:
return sqlalchemy.Float(asdecimal=False)
return sqlalchemy.Float(precision=column_max_digits, asdecimal=False).with_variant(
oracle.FLOAT(binary_precision=round(column_max_digits / 0.30103), asdecimal=False),
"oracle",
)


class BigIntegerField(IntegerField):
Expand Down
2 changes: 1 addition & 1 deletion edgy/core/db/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def __getattr__(self, name: str) -> Any:
name not in self.__dict__
and behavior != "passdown"
and not self._loaded_or_deleted
and field is not None
and (field is not None or self.__reflected__)
and name not in self.identifying_db_fields
and self.can_load
):
Expand Down
32 changes: 22 additions & 10 deletions edgy/utils/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class InspectDB:
Class that builds the inspection of a database.
"""

def __init__(self, database: str, schema: Optional[str]) -> None:
def __init__(self, database: str, schema: Optional[str] = None) -> None:
"""
Creates an instance of an InspectDB and triggers the proccess.
"""
Expand Down Expand Up @@ -185,22 +185,34 @@ def get_field_type(self, column: sqlalchemy.Column, is_fk: bool = False) -> Any:
field_params["max_digits"] = real_field.precision
field_params["decimal_places"] = real_field.scale

if field_type == "FloatField":
# Note: precision is maybe set to None when reflecting.
precision = getattr(real_field, "precision", None)
if precision is None:
# Oracle
precision = getattr(real_field, "binary_precision", None)
if precision is not None:
# invert calculation of binary_precision
precision = round(precision * 0.30103)
if precision is not None:
field_params["max_digits"] = precision

if field_type == "BinaryField":
field_params["max_length"] = getattr(real_field, "length", None)

return field_type, field_params

@classmethod
def get_meta(
cls, table: dict[str, Any], unique_constraints: set[str], _indexes: set[str]
cls, table_detail: dict[str, Any], unique_constraints: set[str], _indexes: set[str]
) -> NoReturn:
"""
Produces the Meta class.
"""
unique_together: list[edgy.UniqueConstraint] = []
unique_indexes: list[edgy.Index] = []
indexes = list(table["indexes"])
constraints = list(table["constraints"])
indexes = list(table_detail["indexes"])
constraints = list(table_detail["constraints"])

# Handle the unique together
for constraint in constraints:
Expand All @@ -226,7 +238,7 @@ def get_meta(
meta += [
" class Meta:\n",
" registry = registry\n",
" tablename = '{}'\n".format(table["tablename"]),
" tablename = '{}'\n".format(table_detail["tablename"]),
]

if unique_together:
Expand All @@ -242,7 +254,7 @@ def get_meta(

@classmethod
def write_output(
cls, tables: list[Any], connection_string: str, schema: Union[str, None] = None
cls, table_details: list[Any], connection_string: str, schema: Union[str, None] = None
) -> NoReturn:
"""
Writes to stdout and runs some internal validations.
Expand Down Expand Up @@ -276,17 +288,17 @@ def write_output(
yield registry

# Start writing the classes
for table in tables:
for table_detail in table_details:
unique_constraints: set[str] = set()
indexes: set[str] = set()

yield "\n"
yield "\n"
yield "\n"
yield "class {}({}.ReflectModel):\n".format(table["class_name"], DB_MODULE)
yield "class {}({}.ReflectModel):\n".format(table_detail["class_name"], DB_MODULE)
# yield " ...\n"

sqla_table: sqlalchemy.Table = table["table"]
sqla_table: sqlalchemy.Table = table_detail["table"]
columns = list(sqla_table.columns)

# Get the column information
Expand Down Expand Up @@ -339,4 +351,4 @@ def write_output(
yield f" {field_description}"

yield "\n"
yield from cls.get_meta(table, unique_constraints, indexes)
yield from cls.get_meta(table_detail, unique_constraints, indexes)
7 changes: 7 additions & 0 deletions tests/fields/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ def test_can_create_float_field():
assert field.null is True


def test_can_create_max_digits_float_field():
field = FloatField(max_digits=10, null=True)

assert isinstance(field, BaseField)
assert field.column_type.precision == 10


def test_can_create_boolean_field():
field = BooleanField(default=False)

Expand Down
80 changes: 45 additions & 35 deletions tests/reflection/test_table_reflection.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,42 @@
import asyncio
import functools
import random
import string
from contextlib import redirect_stdout
from io import StringIO

import pytest

import edgy
from edgy.core.db.datastructures import Index
from edgy.testclient import DatabaseTestClient
from edgy.utils.inspect import InspectDB
from tests.settings import DATABASE_URL

pytestmark = pytest.mark.anyio

database = DatabaseTestClient(DATABASE_URL)
models = edgy.Registry(database=edgy.Database(database, force_rollback=True))
models = edgy.Registry(database)
second = edgy.Registry(database=edgy.Database(database, force_rollback=False))

expected_result1 = """
class Users(edgy.ReflectModel):
name = edgy.CharField(max_length=255, null=False)
title = edgy.CharField(max_length=255, null=True)
id = edgy.BigIntegerField(null=False, primary_key=True)
class Meta:
registry = registry
tablename = 'users'
""".strip()

expected_result2 = """
class Hubusers(edgy.ReflectModel):
name = edgy.CharField(max_length=255, null=False)
title = edgy.CharField(max_length=255, null=True)
description = edgy.CharField(max_length=255, null=True)
id = edgy.BigIntegerField(null=False, primary_key=True)
def get_random_string(length):
letters = string.ascii_lowercase
result_str = "".join(random.choice(letters) for i in range(length))
return result_str
class Meta:
registry = registry
tablename = 'hubusers'
""".strip()


class User(edgy.Model):
Expand Down Expand Up @@ -51,7 +68,7 @@ class ReflectedUser(edgy.ReflectModel):

class Meta:
tablename = "hubusers"
registry = models
registry = second


class NewReflectedUser(edgy.ReflectModel):
Expand All @@ -60,38 +77,19 @@ class NewReflectedUser(edgy.ReflectModel):

class Meta:
tablename = "hubusers"
registry = models
registry = second


@pytest.fixture(autouse=True, scope="module")
@pytest.fixture(autouse=True, scope="function")
async def create_test_database():
async with database:
async with models:
await models.create_all()
yield
async with second:
yield
if not database.drop:
await models.drop_all()


@pytest.fixture(autouse=True, scope="function")
async def rollback_transactions():
async with models.database:
yield


def async_adapter(wrapped_func):
"""
Decorator used to run async test cases.
"""

@functools.wraps(wrapped_func)
def run_sync(*args, **kwargs):
loop = asyncio.get_event_loop()
task = wrapped_func(*args, **kwargs)
return loop.run_until_complete(task)

return run_sync


async def test_can_reflect_existing_table():
await HubUser.query.create(name="Test", title="a title", description="desc")

Expand Down Expand Up @@ -134,7 +132,7 @@ async def test_can_reflect_existing_table_with_not_all_fields_and_create_record(
user = users[1]

assert user.name == "Test2"
assert not hasattr(user, "description")
assert "description" not in user.__dict__


async def test_can_reflect_and_edit_existing_table():
Expand Down Expand Up @@ -163,3 +161,15 @@ async def test_can_reflect_and_edit_existing_table():

assert user.name == "edgy"
assert user.description == "updated"


async def test_create_correct_inspect_db():
inflected = InspectDB(str(models.database.url))
out = StringIO()
with redirect_stdout(out):
inflected.inspect()
out.seek(0)
generated = out.read()
# indexes are not sorted and appear in any order so they are removed
assert expected_result1 in generated
assert expected_result2 in generated
Loading

0 comments on commit c7e9366

Please sign in to comment.