Skip to content

Commit

Permalink
improve pw field (#194)
Browse files Browse the repository at this point in the history
* Changes:

- improve password field
- fix __set__ called on update insert

* Changes:

- manipulate dict directly instead of looping
- fix password field test
- update release notes
  • Loading branch information
devkral authored Oct 10, 2024
1 parent 2777041 commit 2f543ec
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 48 deletions.
33 changes: 28 additions & 5 deletions docs/fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class MyModel(edgy.Model):

##### Parameters:

* `max_length` - An integer indicating the total length of string.
* `max_length` - An integer indicating the total length of string. Required. Set to None for creating a field without a string length restriction.
* `min_length` - An integer indicating the minimum length of string.

#### ChoiceField
Expand Down Expand Up @@ -353,11 +353,14 @@ import edgy
class MyModel(edgy.Model):
email: str = edgy.EmailField(max_length=60, null=True)
...

```

Derives from the same as [CharField](#charfield) and validates the email value.

##### Parameters

- `max_length` - Integer/None. Default: 255.

#### ExcludeField

Remove inherited fields by masking them from the model.
Expand Down Expand Up @@ -640,15 +643,32 @@ Similar to [CharField](#charfield) but has no `max_length` restrictions.

```python
import edgy
import secrets

hasher = Hasher()

class MyModel(edgy.Model):
data: str = edgy.PasswordField(null=False, max_length=255)
pw: str = edgy.PasswordField(null=False, derive_fn=hasher.derive)
token: str = edgy.PasswordField(null=False, default=secrets.token_hex)
...

# we can check if the pw matches by providing a tuple
obj = await MyModel.query.create(pw=("foobar", "foobar"))
# now let's check the pw
hasher.compare_pw(obj.pw, "foobar")
obj.token == "<token>"
```

Similar to [CharField](#charfield) and it can be used to represent a password text.
Similar to [CharField](#charfield) and it can be used to represent a password text. The secret parameter defaults to `True`.

##### Parameters

- `max_length` - Integer/None. Default: 255.
- `derive_fn` - Callable. Default: None. When provided it automatically hashes an incoming string. Should be a good key deriving function.
- `keep_original` - Boolean. Default: `True` when `derive_fn` is provided `False` otherwise. When True, an attribute named: `<fieldname>_original` is added
whenever a password is manually set. It contains the password in plaintext. After saving/loading the attribute is set to `None`.

Ideally the key derivation function includes the parameters (and derive algorithm) used for deriving in the hash so a compare_pw function can reproduce the result.

#### TimeField

Expand Down Expand Up @@ -680,11 +700,14 @@ import edgy
class MyModel(edgy.Model):
url: str = fields.URLField(null=True, max_length=1024)
...

```

Derives from the same as [CharField](#charfield) and validates the value of an URL.

##### Parameters

- `max_length` - Integer/None. Default: 255.

#### UUIDField

```python
Expand Down
1 change: 1 addition & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ hide:

- `model_dump_json` returns right result.
- `show_pk=False` can now be used to disable the inclusion of pk fields regardless of `__show_pk__`.
- `__setattr__` is called after insert/update. We have transform_input already.

## 0.17.3

Expand Down
2 changes: 1 addition & 1 deletion edgy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.17.3"
__version__ = "0.17.4"

from .cli.base import Migrate
from .conf import settings
Expand Down
94 changes: 57 additions & 37 deletions edgy/core/db/fields/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import enum
import ipaddress
import uuid
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from enum import EnumMeta
from functools import partial
from re import Pattern
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, cast

import pydantic
import sqlalchemy
Expand Down Expand Up @@ -38,7 +38,6 @@ class CharField(FieldFactory, str):
def __new__( # type: ignore
cls,
*,
max_length: Optional[int] = 0,
min_length: Optional[int] = None,
regex: Union[str, Pattern] = None,
pattern: Union[str, Pattern] = None,
Expand All @@ -57,7 +56,7 @@ def __new__( # type: ignore
@classmethod
def validate(cls, kwargs: dict[str, Any]) -> None:
max_length = kwargs.get("max_length", 0)
if max_length <= 0:
if max_length is not None and max_length <= 0:
raise FieldDefinitionError(detail=f"'max_length' is required for {cls.__name__}")

min_length = kwargs.get("min_length")
Expand All @@ -68,38 +67,21 @@ def validate(cls, kwargs: dict[str, Any]) -> None:
assert pattern is None or isinstance(pattern, (str, Pattern))

@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.String(
length=kwargs.get("max_length"), collation=kwargs.get("collation")
def get_column_type(cls, max_length: Optional[int] = None, **kwargs: Any) -> Any:
return (
sqlalchemy.Text(collation=kwargs.get("collation"))
if max_length is None
else sqlalchemy.String(length=max_length, collation=kwargs.get("collation"))
)


class TextField(FieldFactory, str):
class TextField(CharField):
"""String representation of a text field which means no max_length required"""

field_type = str

def __new__(
cls,
*,
min_length: int = 0,
max_length: Optional[int] = None,
regex: Union[str, Pattern] = None,
pattern: Union[str, Pattern] = None,
**kwargs: Any,
) -> BaseFieldType:
if pattern is None:
pattern = regex
del regex
kwargs = {
**kwargs,
**{key: value for key, value in locals().items() if key not in CLASS_DEFAULTS},
}
return super().__new__(cls, **kwargs)

@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Text(collation=kwargs.get("collation"))
def validate(cls, kwargs: dict[str, Any]) -> None:
kwargs.setdefault("max_length", None)
super().validate(kwargs)


class IncrementOnSaveBaseField(Field):
Expand Down Expand Up @@ -541,7 +523,7 @@ class ChoiceField(FieldFactory):
def __new__( # type: ignore
cls,
choices: Optional[Sequence[Union[tuple[str, str], tuple[str, int]]]] = None,
**kwargs: dict[str, Any],
**kwargs: Any,
) -> BaseFieldType:
kwargs = {
**kwargs,
Expand All @@ -566,23 +548,61 @@ class PasswordField(CharField):
Representation of a Password
"""

def __new__( # type: ignore
cls,
derive_fn: Optional[Callable[[str], str]] = None,
**kwargs: Any,
) -> BaseFieldType:
kwargs.setdefault("keep_original", derive_fn is not None)
return super().__new__(cls, derive_fn=derive_fn, **kwargs)

@classmethod
def to_model(
cls, field_obj: BaseFieldType, field_name: str, value: Any, original_fn: Any = None
) -> dict[str, Any]:
if isinstance(value, (tuple, list)):
if value[0] != value[1]:
raise ValueError("Password doesn't match.")
else:
value = value[0]
retval: dict[str, Any] = {}
phase = CURRENT_PHASE.get()
derive_fn = cast(Optional[Callable[[str], str]], field_obj.derive_fn)
if phase in {"set", "init"} and derive_fn is not None:
retval[field_name] = derive_fn(value)
if getattr(field_obj, "keep_original", False):
retval[f"{field_name}_original"] = value
else:
retval[field_name] = value
# blank after saving or loading
if phase in {"post_insert", "post_update", "load"} and getattr(
field_obj, "keep_original", False
):
retval[f"{field_name}_original"] = None

return retval

@classmethod
def get_column_type(self, **kwargs: Any) -> sqlalchemy.String:
return sqlalchemy.String(length=kwargs.get("max_length"))
def validate(cls, kwargs: dict[str, Any]) -> None:
kwargs.setdefault("secret", True)
kwargs.setdefault("max_length", 255)
super().validate(kwargs)


class EmailField(CharField):
field_type = EmailStr

@classmethod
def get_column_type(self, **kwargs: Any) -> sqlalchemy.String:
return sqlalchemy.String(length=kwargs.get("max_length"))
def validate(cls, kwargs: dict[str, Any]) -> None:
kwargs.setdefault("max_length", 255)
super().validate(kwargs)


class URLField(CharField):
@classmethod
def get_column_type(self, **kwargs: Any) -> sqlalchemy.String:
return sqlalchemy.String(length=kwargs.get("max_length"))
def validate(cls, kwargs: dict[str, Any]) -> None:
kwargs.setdefault("max_length", 255)
super().validate(kwargs)


class IPAddressField(FieldFactory, str):
Expand Down
2 changes: 1 addition & 1 deletion edgy/core/db/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ class EdgyBaseModel(BaseModel, BaseModelType, metaclass=BaseModelMeta):
__db_model__: ClassVar[bool] = False
__reflected__: ClassVar[bool] = False
__show_pk__: ClassVar[bool] = False
__using_schema__: Union[str, None, Any] = Undefined
# private attribute
_loaded_or_deleted: bool = False
__using_schema__: Union[str, None, Any] = Undefined

def __init__(
self, *args: Any, __show_pk__: bool = False, __phase__: str = "init", **kwargs: Any
Expand Down
6 changes: 2 additions & 4 deletions edgy/core/db/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ async def _update(self, **kwargs: Any) -> Any:
new_kwargs = self.transform_input(
column_values, phase="post_update", instance=self
)
for k, v in new_kwargs.items():
setattr(self, k, v)
self.__dict__.update(new_kwargs)

# updates aren't required to change the db, they can also just affect the meta fields
await self.execute_post_save_hooks(
Expand Down Expand Up @@ -233,8 +232,7 @@ async def _insert(self, **kwargs: Any) -> "Model":
column_values[column.key] = autoincrement_value

new_kwargs = self.transform_input(column_values, phase="post_insert", instance=self)
for k, v in new_kwargs.items():
setattr(self, k, v)
self.__dict__.update(new_kwargs)

if self.meta.post_save_fields:
await self.execute_post_save_hooks(
Expand Down
110 changes: 110 additions & 0 deletions tests/fields/test_password_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import secrets
from hashlib import pbkdf2_hmac

import pytest
import sqlalchemy

import edgy
from edgy.core.db.fields import (
PasswordField,
)
from edgy.core.db.fields.base import BaseField
from edgy.testclient import DatabaseTestClient
from tests.settings import DATABASE_URL

pytestmark = pytest.mark.anyio
database = DatabaseTestClient(DATABASE_URL, drop_database=True, use_existing=False)
models = edgy.Registry(database=database)


class SampleHasher:
def derive(self, password: str, iterations: int = 10):
assert not password.startswith("pbkdf2")
# the default is not secure
return f"pbkdf2:{iterations}:{pbkdf2_hmac('sha256', password.encode(), salt=b'', iterations=iterations).hex()}"

def compare_pw(self, hash: str, password: str):
algo, iterations, _ = hash.split(":", 2)
print("pw2", password)
# this is not secure
derived = self.derive(password, int(iterations))
return hash == derived


hasher = SampleHasher()


class MyModel(edgy.Model):
pw = edgy.PasswordField(null=False, derive_fn=hasher.derive)
token = edgy.PasswordField(null=False, default=secrets.token_hex)

class Meta:
registry = models


@pytest.fixture(autouse=True, scope="function")
async def create_test_database():
async with models:
await models.create_all()
yield
if not database.drop:
await models.drop_all()


def test_can_create_password_field():
field = PasswordField(derive_fn=hasher.derive)

assert isinstance(field, BaseField)
assert field.min_length is None
assert field.max_length == 255
assert field.null is False
assert field.secret is True
assert field.is_required()
columns = field.get_columns("foo")
assert len(columns) == 1
assert columns[0].type.__class__ == sqlalchemy.String


def test_can_create_password_field2():
field = PasswordField(null=True, max_length=None, secret=False, derive_fn=hasher.derive)

assert isinstance(field, BaseField)
assert field.min_length is None
assert field.max_length is None
assert field.null is True
assert field.secret is False
assert not field.is_required()
columns = field.get_columns("foo")
assert len(columns) == 1
assert columns[0].type.__class__ == sqlalchemy.Text


async def test_pw_field_create_pw():
obj = await MyModel.query.create(pw="test")
assert obj.pw != "test"
assert hasher.compare_pw(obj.pw, "test")
obj.pw = "foobar"
assert obj.pw != "foobar"
assert obj.pw_original == "foobar"
with pytest.raises(ValueError):
obj.pw = ["foobar2", "test"]
assert obj.pw_original == "foobar"

await obj.save()
assert obj.pw_original is None
assert hasher.compare_pw(obj.pw, "foobar")


async def test_pw_field_create_token_and_validate():
obj = await MyModel.query.create(pw="test", token="234")
assert obj.token == "234"
obj.pw = ("foobar", "foobar")
assert obj.pw != "foobar"
assert obj.pw_original == "foobar"
await obj.save()
assert obj.pw_original is None


async def test_pw_field_create_fail():
with pytest.raises(ValueError):
await MyModel.query.create(pw=("test", "foobar"))

0 comments on commit 2f543ec

Please sign in to comment.