Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add duration field & expose max_digits of FloatField #211

Merged
merged 6 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions docs/fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,20 @@ DateTimeField supports int, float, string (isoformat), date object and of course
!!! Note
`auto_now` and `auto_now_add` set the `read_only` flag by default. You can explicitly set `read_only` to `False` to be still able to update the field manually.

#### DurationField

```python
import datetime
import edgy


class MyModel(edgy.Model):
worked: datetime.timedelta = edgy.DurationField()
...

```

devkral marked this conversation as resolved.
Show resolved Hide resolved

#### DecimalField

```python
Expand All @@ -364,7 +378,6 @@ import edgy
class MyModel(edgy.Model):
price: decimal.Decimal = edgy.DecimalField(max_digits=5, decimal_places=2, null=True)
...

```

##### Parameters
Expand Down Expand Up @@ -440,10 +453,15 @@ import edgy
class MyModel(edgy.Model):
price: float = edgy.FloatField(null=True)
...

```

Derives from the same as [IntegerField](#integerfield) and validates the decimal float.
Derives from the same as [IntegerField](#integerfield) and validates the float.

##### Parameters

* `max_digits` - An integer indicating the total maximum digits.
In contrast to DecimalField it is database-only and can be used for higher/lower precision fields.
It is also available under the name `precision` with a higher priority. Optional.

#### ForeignKey

Expand Down
13 changes: 13 additions & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ hide:

# Release Notes

## Unreleased

### Added

- Add DurationField.
- Allow passing `max_digits` to FloatField.

### Fixed

- Triggering load on non-existent field when reflecting.
- InspectDB mapping was incorrect.


## 0.19.1

### Fixed
Expand Down
2 changes: 2 additions & 0 deletions edgy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DateField,
DateTimeField,
DecimalField,
DurationField,
EmailField,
ExcludeField,
FloatField,
Expand Down Expand Up @@ -62,6 +63,7 @@
"DatabaseURL",
"DateField",
"DateTimeField",
"DurationField",
"DecimalField",
"EdgyExtra",
"EdgySettings",
Expand Down
2 changes: 2 additions & 0 deletions edgy/core/db/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DateField,
DateTimeField,
DecimalField,
DurationField,
EmailField,
FloatField,
IntegerField,
Expand Down Expand Up @@ -45,6 +46,7 @@
"CompositeField",
"DateField",
"DateTimeField",
"DurationField",
"DecimalField",
"EmailField",
"ExcludeField",
Expand Down
29 changes: 25 additions & 4 deletions edgy/core/db/fields/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pydantic
import sqlalchemy
from pydantic import EmailStr
from sqlalchemy.dialects import oracle

from edgy.core.db.context_vars import CURRENT_INSTANCE, CURRENT_PHASE, EXPLICIT_SPECIFIED_VALUES
from edgy.core.db.fields._internal import IPAddress
Expand Down Expand Up @@ -278,21 +279,31 @@ class FloatField(FieldFactory, float):
def __new__( # type: ignore
cls,
*,
max_digits: Optional[int] = None,
ge: Union[int, float, decimal.Decimal, None] = None,
gt: Union[int, float, decimal.Decimal, None] = None,
le: Union[int, float, decimal.Decimal, None] = None,
lt: Union[int, float, decimal.Decimal, None] = None,
**kwargs: Any,
) -> BaseFieldType:
# pydantic doesn't support max_digits for float, so rename it
if max_digits is not None:
kwargs.setdefault("precision", max_digits)
del max_digits
kwargs = {
**kwargs,
**{key: value for key, value in locals().items() if key not in CLASS_DEFAULTS},
}
return super().__new__(cls, **kwargs)

@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Float(asdecimal=False)
def get_column_type(cls, precision: Optional[int] = None, **kwargs: Any) -> Any:
if precision is None:
return sqlalchemy.Float(asdecimal=False)
return sqlalchemy.Float(precision=precision, asdecimal=False).with_variant(
oracle.FLOAT(binary_precision=round(precision / 0.30103), asdecimal=False), # type: ignore
"oracle",
)


class BigIntegerField(IntegerField):
Expand Down Expand Up @@ -532,12 +543,22 @@ def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Date()


class DurationField(FieldFactory, datetime.timedelta):
"""Representation of a time field"""

field_type = datetime.timedelta

@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Interval()


class TimeField(FieldFactory, datetime.time):
"""Representation of a time field"""

field_type = datetime.time

def __new__(cls, **kwargs: Any) -> BaseFieldType: # type: ignore
def __new__(cls, with_timezone: bool = False, **kwargs: Any) -> BaseFieldType: # type: ignore
kwargs = {
**kwargs,
**{k: v for k, v in locals().items() if k not in CLASS_DEFAULTS},
Expand Down Expand Up @@ -573,7 +594,7 @@ def __new__(cls, *, max_length: Optional[int] = None, **kwargs: Any) -> BaseFiel

@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.LargeBinary(kwargs.get("max_length"))
return sqlalchemy.LargeBinary(length=kwargs.get("max_length"))


class UUIDField(FieldFactory, uuid.UUID):
Expand Down
2 changes: 1 addition & 1 deletion edgy/core/db/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def __getattr__(self, name: str) -> Any:
name not in self.__dict__
and behavior != "passdown"
and not self._loaded_or_deleted
and field is not None
and (field is not None or self.__reflected__)
and name not in self.identifying_db_fields
and self.can_load
):
Expand Down
69 changes: 41 additions & 28 deletions edgy/utils/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,24 @@
printer = Print()

SQL_GENERIC_TYPES = {
sqltypes.BigInteger: edgy.BigIntegerField,
sqltypes.Integer: edgy.IntegerField,
sqltypes.JSON: edgy.JSONField,
sqltypes.Date: edgy.DateField,
sqltypes.String: edgy.CharField,
sqltypes.Unicode: edgy.CharField,
sqltypes.BINARY: edgy.BinaryField,
sqltypes.Boolean: edgy.BooleanField,
sqltypes.Enum: edgy.ChoiceField,
sqltypes.DateTime: edgy.DateTimeField,
sqltypes.Numeric: edgy.DecimalField,
sqltypes.Float: edgy.FloatField,
sqltypes.Double: edgy.FloatField,
sqltypes.SmallInteger: edgy.SmallIntegerField,
sqltypes.Text: edgy.TextField,
sqltypes.Time: edgy.TimeField,
sqltypes.Uuid: edgy.UUIDField,
sqltypes.BigInteger: edgy.fields.BigIntegerField,
sqltypes.Integer: edgy.fields.IntegerField,
sqltypes.JSON: edgy.fields.JSONField,
sqltypes.Date: edgy.fields.DateField,
sqltypes.String: edgy.fields.CharField,
sqltypes.Unicode: edgy.fields.CharField,
sqltypes.LargeBinary: edgy.fields.BinaryField,
sqltypes.Boolean: edgy.fields.BooleanField,
sqltypes.Enum: edgy.fields.ChoiceField,
sqltypes.DateTime: edgy.fields.DateTimeField,
sqltypes.Interval: edgy.fields.DurationField,
sqltypes.Numeric: edgy.fields.DecimalField,
sqltypes.Float: edgy.fields.FloatField,
sqltypes.Double: edgy.fields.FloatField,
sqltypes.SmallInteger: edgy.fields.SmallIntegerField,
sqltypes.Text: edgy.fields.TextField,
sqltypes.Time: edgy.fields.TimeField,
sqltypes.Uuid: edgy.fields.UUIDField,
}

DB_MODULE = "edgy"
Expand All @@ -51,7 +52,7 @@ class InspectDB:
Class that builds the inspection of a database.
"""

def __init__(self, database: str, schema: Optional[str]) -> None:
def __init__(self, database: str, schema: Optional[str] = None) -> None:
"""
Creates an instance of an InspectDB and triggers the proccess.
"""
Expand Down Expand Up @@ -184,22 +185,34 @@ def get_field_type(self, column: sqlalchemy.Column, is_fk: bool = False) -> Any:
field_params["max_digits"] = real_field.precision
field_params["decimal_places"] = real_field.scale

if field_type == "FloatField":
# Note: precision is maybe set to None when reflecting.
precision = getattr(real_field, "precision", None)
if precision is None:
# Oracle
precision = getattr(real_field, "binary_precision", None)
if precision is not None:
# invert calculation of binary_precision
precision = round(precision * 0.30103)
if precision is not None:
field_params["max_digits"] = precision

if field_type == "BinaryField":
field_params["sql_nullable"] = getattr(real_field, "none_as_null", False)
field_params["max_length"] = getattr(real_field, "length", None)

return field_type, field_params

@classmethod
def get_meta(
cls, table: dict[str, Any], unique_constraints: set[str], _indexes: set[str]
cls, table_detail: dict[str, Any], unique_constraints: set[str], _indexes: set[str]
) -> NoReturn:
"""
Produces the Meta class.
"""
unique_together: list[edgy.UniqueConstraint] = []
unique_indexes: list[edgy.Index] = []
indexes = list(table["indexes"])
constraints = list(table["constraints"])
indexes = list(table_detail["indexes"])
constraints = list(table_detail["constraints"])

# Handle the unique together
for constraint in constraints:
Expand All @@ -225,7 +238,7 @@ def get_meta(
meta += [
" class Meta:\n",
" registry = registry\n",
" tablename = '{}'\n".format(table["tablename"]),
" tablename = '{}'\n".format(table_detail["tablename"]),
]

if unique_together:
Expand All @@ -241,7 +254,7 @@ def get_meta(

@classmethod
def write_output(
cls, tables: list[Any], connection_string: str, schema: Union[str, None] = None
cls, table_details: list[Any], connection_string: str, schema: Union[str, None] = None
) -> NoReturn:
"""
Writes to stdout and runs some internal validations.
Expand Down Expand Up @@ -275,17 +288,17 @@ def write_output(
yield registry

# Start writing the classes
for table in tables:
for table_detail in table_details:
unique_constraints: set[str] = set()
indexes: set[str] = set()

yield "\n"
yield "\n"
yield "\n"
yield "class {}({}.ReflectModel):\n".format(table["class_name"], DB_MODULE)
yield "class {}({}.ReflectModel):\n".format(table_detail["class_name"], DB_MODULE)
# yield " ...\n"

sqla_table: sqlalchemy.Table = table["table"]
sqla_table: sqlalchemy.Table = table_detail["table"]
columns = list(sqla_table.columns)

# Get the column information
Expand Down Expand Up @@ -338,4 +351,4 @@ def write_output(
yield f" {field_description}"

yield "\n"
yield from cls.get_meta(table, unique_constraints, indexes)
yield from cls.get_meta(table_detail, unique_constraints, indexes)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class User(edgy.Model):
name: str = edgy.CharField(max_length=255, secret=True)
name: str = edgy.CharField(max_length=255)
created_at: datetime.datetime = edgy.DateTimeField(auto_now_add=True)
updated_at: datetime.datetime = edgy.DateTimeField(auto_now=True)

Expand Down
File renamed without changes.
58 changes: 58 additions & 0 deletions tests/fields/test_duration_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from datetime import timedelta

import pytest

import edgy
from edgy.testclient import DatabaseTestClient
from tests.settings import DATABASE_URL

database = DatabaseTestClient(DATABASE_URL)
models = edgy.Registry(database=edgy.Database(database, force_rollback=True))

pytestmark = pytest.mark.anyio


class User(edgy.Model):
name = edgy.CharField(max_length=100)
language = edgy.CharField(max_length=200, null=True)
age: timedelta = edgy.fields.DurationField(null=True)

class Meta:
registry = models


@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
async with database:
await models.create_all()
yield
if not database.drop:
await models.drop_all()


@pytest.fixture(autouse=True, scope="function")
async def rollback_transactions():
async with models.database:
yield


async def test_model_save():
user = await User.query.create(name="Jane", age=timedelta(days=365 * 20))

assert user.age == timedelta(days=365 * 20)
await user.save()

user = await User.query.get(pk=user.pk)

assert user.age == timedelta(days=365 * 20)


async def test_model_save_without():
user = await User.query.create(name="Jane")

assert user.age is None
await user.save()

user = await User.query.get(pk=user.pk)

assert user.age is None
Loading