Skip to content

Commit

Permalink
add duration field & expose max_digits/precision of FloatField (#211)
Browse files Browse the repository at this point in the history
Changes:

- add duration field
- move field tests to fields
- remove duplicate test file in run_sync
- add duration field test
- expand tests for InspectDB
- expand tests for reflection
- test harder the correct reflection of fields
- expose max_digits for FloatField
  • Loading branch information
devkral authored Oct 22, 2024
1 parent d6f6987 commit 0124e17
Show file tree
Hide file tree
Showing 14 changed files with 375 additions and 160 deletions.
26 changes: 23 additions & 3 deletions docs/fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,22 @@ DateTimeField supports int, float, string (isoformat), date object and of course
!!! Note
`auto_now` and `auto_now_add` set the `read_only` flag by default. You can explicitly set `read_only` to `False` to be still able to update the field manually.

#### DurationField

A DurationField can save the amount of time of a process. This is useful in case there is no clear start/stop timepoints.
For example the time worked on a project.


```python
import datetime
import edgy

class Project(edgy.Model):
worked: datetime.timedelta = edgy.DurationField(default=datetime.timedelta())
estimated_time: datetime.timedelta = edgy.DurationField()
...
```

#### DecimalField

```python
Expand All @@ -364,7 +380,6 @@ import edgy
class MyModel(edgy.Model):
price: decimal.Decimal = edgy.DecimalField(max_digits=5, decimal_places=2, null=True)
...

```

##### Parameters
Expand Down Expand Up @@ -440,10 +455,15 @@ import edgy
class MyModel(edgy.Model):
price: float = edgy.FloatField(null=True)
...

```

Derives from the same as [IntegerField](#integerfield) and validates the decimal float.
Derives from the same as [IntegerField](#integerfield) and validates the float.

##### Parameters

* `max_digits` - An integer indicating the total maximum digits.
In contrast to DecimalField it is database-only and can be used for higher/lower precision fields.
It is also available under the name `precision` with a higher priority. Optional.

#### ForeignKey

Expand Down
13 changes: 13 additions & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ hide:

# Release Notes

## Unreleased

### Added

- Add DurationField.
- Allow passing `max_digits` to FloatField.

### Fixed

- Triggering load on non-existent field when reflecting.
- InspectDB mapping was incorrect.


## 0.19.1

### Fixed
Expand Down
2 changes: 2 additions & 0 deletions edgy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DateField,
DateTimeField,
DecimalField,
DurationField,
EmailField,
ExcludeField,
FloatField,
Expand Down Expand Up @@ -62,6 +63,7 @@
"DatabaseURL",
"DateField",
"DateTimeField",
"DurationField",
"DecimalField",
"EdgyExtra",
"EdgySettings",
Expand Down
2 changes: 2 additions & 0 deletions edgy/core/db/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DateField,
DateTimeField,
DecimalField,
DurationField,
EmailField,
FloatField,
IntegerField,
Expand Down Expand Up @@ -45,6 +46,7 @@
"CompositeField",
"DateField",
"DateTimeField",
"DurationField",
"DecimalField",
"EmailField",
"ExcludeField",
Expand Down
29 changes: 25 additions & 4 deletions edgy/core/db/fields/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pydantic
import sqlalchemy
from pydantic import EmailStr
from sqlalchemy.dialects import oracle

from edgy.core.db.context_vars import CURRENT_INSTANCE, CURRENT_PHASE, EXPLICIT_SPECIFIED_VALUES
from edgy.core.db.fields._internal import IPAddress
Expand Down Expand Up @@ -278,21 +279,31 @@ class FloatField(FieldFactory, float):
def __new__( # type: ignore
cls,
*,
max_digits: Optional[int] = None,
ge: Union[int, float, decimal.Decimal, None] = None,
gt: Union[int, float, decimal.Decimal, None] = None,
le: Union[int, float, decimal.Decimal, None] = None,
lt: Union[int, float, decimal.Decimal, None] = None,
**kwargs: Any,
) -> BaseFieldType:
# pydantic doesn't support max_digits for float, so rename it
if max_digits is not None:
kwargs.setdefault("precision", max_digits)
del max_digits
kwargs = {
**kwargs,
**{key: value for key, value in locals().items() if key not in CLASS_DEFAULTS},
}
return super().__new__(cls, **kwargs)

@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Float(asdecimal=False)
def get_column_type(cls, precision: Optional[int] = None, **kwargs: Any) -> Any:
if precision is None:
return sqlalchemy.Float(asdecimal=False)
return sqlalchemy.Float(precision=precision, asdecimal=False).with_variant(
oracle.FLOAT(binary_precision=round(precision / 0.30103), asdecimal=False), # type: ignore
"oracle",
)


class BigIntegerField(IntegerField):
Expand Down Expand Up @@ -532,12 +543,22 @@ def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Date()


class DurationField(FieldFactory, datetime.timedelta):
"""Representation of a time field"""

field_type = datetime.timedelta

@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Interval()


class TimeField(FieldFactory, datetime.time):
"""Representation of a time field"""

field_type = datetime.time

def __new__(cls, **kwargs: Any) -> BaseFieldType: # type: ignore
def __new__(cls, with_timezone: bool = False, **kwargs: Any) -> BaseFieldType: # type: ignore
kwargs = {
**kwargs,
**{k: v for k, v in locals().items() if k not in CLASS_DEFAULTS},
Expand Down Expand Up @@ -573,7 +594,7 @@ def __new__(cls, *, max_length: Optional[int] = None, **kwargs: Any) -> BaseFiel

@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.LargeBinary(kwargs.get("max_length"))
return sqlalchemy.LargeBinary(length=kwargs.get("max_length"))


class UUIDField(FieldFactory, uuid.UUID):
Expand Down
2 changes: 1 addition & 1 deletion edgy/core/db/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def __getattr__(self, name: str) -> Any:
name not in self.__dict__
and behavior != "passdown"
and not self._loaded_or_deleted
and field is not None
and (field is not None or self.__reflected__)
and name not in self.identifying_db_fields
and self.can_load
):
Expand Down
69 changes: 41 additions & 28 deletions edgy/utils/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,24 @@
printer = Print()

SQL_GENERIC_TYPES = {
sqltypes.BigInteger: edgy.BigIntegerField,
sqltypes.Integer: edgy.IntegerField,
sqltypes.JSON: edgy.JSONField,
sqltypes.Date: edgy.DateField,
sqltypes.String: edgy.CharField,
sqltypes.Unicode: edgy.CharField,
sqltypes.BINARY: edgy.BinaryField,
sqltypes.Boolean: edgy.BooleanField,
sqltypes.Enum: edgy.ChoiceField,
sqltypes.DateTime: edgy.DateTimeField,
sqltypes.Numeric: edgy.DecimalField,
sqltypes.Float: edgy.FloatField,
sqltypes.Double: edgy.FloatField,
sqltypes.SmallInteger: edgy.SmallIntegerField,
sqltypes.Text: edgy.TextField,
sqltypes.Time: edgy.TimeField,
sqltypes.Uuid: edgy.UUIDField,
sqltypes.BigInteger: edgy.fields.BigIntegerField,
sqltypes.Integer: edgy.fields.IntegerField,
sqltypes.JSON: edgy.fields.JSONField,
sqltypes.Date: edgy.fields.DateField,
sqltypes.String: edgy.fields.CharField,
sqltypes.Unicode: edgy.fields.CharField,
sqltypes.LargeBinary: edgy.fields.BinaryField,
sqltypes.Boolean: edgy.fields.BooleanField,
sqltypes.Enum: edgy.fields.ChoiceField,
sqltypes.DateTime: edgy.fields.DateTimeField,
sqltypes.Interval: edgy.fields.DurationField,
sqltypes.Numeric: edgy.fields.DecimalField,
sqltypes.Float: edgy.fields.FloatField,
sqltypes.Double: edgy.fields.FloatField,
sqltypes.SmallInteger: edgy.fields.SmallIntegerField,
sqltypes.Text: edgy.fields.TextField,
sqltypes.Time: edgy.fields.TimeField,
sqltypes.Uuid: edgy.fields.UUIDField,
}

DB_MODULE = "edgy"
Expand All @@ -51,7 +52,7 @@ class InspectDB:
Class that builds the inspection of a database.
"""

def __init__(self, database: str, schema: Optional[str]) -> None:
def __init__(self, database: str, schema: Optional[str] = None) -> None:
"""
Creates an instance of an InspectDB and triggers the proccess.
"""
Expand Down Expand Up @@ -184,22 +185,34 @@ def get_field_type(self, column: sqlalchemy.Column, is_fk: bool = False) -> Any:
field_params["max_digits"] = real_field.precision
field_params["decimal_places"] = real_field.scale

if field_type == "FloatField":
# Note: precision is maybe set to None when reflecting.
precision = getattr(real_field, "precision", None)
if precision is None:
# Oracle
precision = getattr(real_field, "binary_precision", None)
if precision is not None:
# invert calculation of binary_precision
precision = round(precision * 0.30103)
if precision is not None:
field_params["max_digits"] = precision

if field_type == "BinaryField":
field_params["sql_nullable"] = getattr(real_field, "none_as_null", False)
field_params["max_length"] = getattr(real_field, "length", None)

return field_type, field_params

@classmethod
def get_meta(
cls, table: dict[str, Any], unique_constraints: set[str], _indexes: set[str]
cls, table_detail: dict[str, Any], unique_constraints: set[str], _indexes: set[str]
) -> NoReturn:
"""
Produces the Meta class.
"""
unique_together: list[edgy.UniqueConstraint] = []
unique_indexes: list[edgy.Index] = []
indexes = list(table["indexes"])
constraints = list(table["constraints"])
indexes = list(table_detail["indexes"])
constraints = list(table_detail["constraints"])

# Handle the unique together
for constraint in constraints:
Expand All @@ -225,7 +238,7 @@ def get_meta(
meta += [
" class Meta:\n",
" registry = registry\n",
" tablename = '{}'\n".format(table["tablename"]),
" tablename = '{}'\n".format(table_detail["tablename"]),
]

if unique_together:
Expand All @@ -241,7 +254,7 @@ def get_meta(

@classmethod
def write_output(
cls, tables: list[Any], connection_string: str, schema: Union[str, None] = None
cls, table_details: list[Any], connection_string: str, schema: Union[str, None] = None
) -> NoReturn:
"""
Writes to stdout and runs some internal validations.
Expand Down Expand Up @@ -275,17 +288,17 @@ def write_output(
yield registry

# Start writing the classes
for table in tables:
for table_detail in table_details:
unique_constraints: set[str] = set()
indexes: set[str] = set()

yield "\n"
yield "\n"
yield "\n"
yield "class {}({}.ReflectModel):\n".format(table["class_name"], DB_MODULE)
yield "class {}({}.ReflectModel):\n".format(table_detail["class_name"], DB_MODULE)
# yield " ...\n"

sqla_table: sqlalchemy.Table = table["table"]
sqla_table: sqlalchemy.Table = table_detail["table"]
columns = list(sqla_table.columns)

# Get the column information
Expand Down Expand Up @@ -338,4 +351,4 @@ def write_output(
yield f" {field_description}"

yield "\n"
yield from cls.get_meta(table, unique_constraints, indexes)
yield from cls.get_meta(table_detail, unique_constraints, indexes)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class User(edgy.Model):
name: str = edgy.CharField(max_length=255, secret=True)
name: str = edgy.CharField(max_length=255)
created_at: datetime.datetime = edgy.DateTimeField(auto_now_add=True)
updated_at: datetime.datetime = edgy.DateTimeField(auto_now=True)

Expand Down
File renamed without changes.
58 changes: 58 additions & 0 deletions tests/fields/test_duration_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from datetime import timedelta

import pytest

import edgy
from edgy.testclient import DatabaseTestClient
from tests.settings import DATABASE_URL

database = DatabaseTestClient(DATABASE_URL)
models = edgy.Registry(database=edgy.Database(database, force_rollback=True))

pytestmark = pytest.mark.anyio


class User(edgy.Model):
name = edgy.CharField(max_length=100)
language = edgy.CharField(max_length=200, null=True)
age: timedelta = edgy.fields.DurationField(null=True)

class Meta:
registry = models


@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
async with database:
await models.create_all()
yield
if not database.drop:
await models.drop_all()


@pytest.fixture(autouse=True, scope="function")
async def rollback_transactions():
async with models.database:
yield


async def test_model_save():
user = await User.query.create(name="Jane", age=timedelta(days=365 * 20))

assert user.age == timedelta(days=365 * 20)
await user.save()

user = await User.query.get(pk=user.pk)

assert user.age == timedelta(days=365 * 20)


async def test_model_save_without():
user = await User.query.create(name="Jane")

assert user.age is None
await user.save()

user = await User.query.get(pk=user.pk)

assert user.age is None
Loading

0 comments on commit 0124e17

Please sign in to comment.