Skip to content

Commit

Permalink
feat: !17 ✨ Check scalar values for equality before marking fields as…
Browse files Browse the repository at this point in the history
… changed
  • Loading branch information
ddanier committed Nov 7, 2023
1 parent d395b7b commit 7bcb6d6
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
35 changes: 32 additions & 3 deletions pydantic_changedetect/changedetect.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import decimal
import warnings
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -205,6 +206,12 @@ def model_set_changed(self, *fields: str, original: Any = NO_VALUE) -> None:
self.model_original[name] = original
self.model_self_changed_fields.add(name)

def _model_is_change_comparable_type(self, value: Any) -> bool:
return (
value is None
or isinstance(value, (str, int, float, bool, decimal.Decimal))
)

@no_type_check
def __setattr__(self, name, value) -> None: # noqa: ANN001
self_compat = PydanticCompat(self)
Expand All @@ -217,11 +224,33 @@ def __setattr__(self, name, value) -> None: # noqa: ANN001
super().__setattr__(name, value)
return

# Store changed data
# Get original value
original_update = {}
if name in self_compat.model_fields and name not in self.model_original:
self.model_original[name] = self.__dict__[name]
original_update[name] = self.__dict__[name]

# Store changed value using pydantic
super().__setattr__(name, value)
self.model_self_changed_fields.add(name)

# Check if value has actually been changed
has_changed = True
if name in self_compat.model_fields:
# Fetch original from original_update so we don't have to check everything again
original_value = original_update.get(name, None)
# Don't use value parameter directly, as pydantic validation might have changed it
# (when validate_assignment == True)
current_value = self.__dict__[name]
if (
self._model_is_change_comparable_type(original_value)
and self._model_is_change_comparable_type(current_value)
and original_value == current_value
):
has_changed = False

# Store changed state
if has_changed:
self.model_original.update(original_update)
self.model_self_changed_fields.add(name)

def __getstate__(self) -> Dict[str, Any]:
state = super().__getstate__()
Expand Down
39 changes: 39 additions & 0 deletions tests/test_changedetect.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import decimal
import pickle
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -42,6 +43,15 @@ def __getstate__(self) -> Dict[str, Any]:
return super(ChangeDetectionMixin, self).__getstate__()


class SomethingWithDifferentValueTypes(ChangeDetectionMixin, pydantic.BaseModel):
s: Union[str, None] = None
i: Union[int, None] = None
f: Union[float, None] = None
b: Union[bool, None] = None
d: Union[decimal.Decimal, None] = None
m: Union[Something, None] = None


def test_initial_state():
obj = Something(id=1)

Expand Down Expand Up @@ -397,6 +407,35 @@ class SomethingPrivate(Something):
assert something.model_has_changed is False


@pytest.mark.parametrize(
("attr", "original", "changed", "expected"),
[
("s", "old", "new", True),
("s", "old", "old", False),
("i", 1, 2, True),
("i", 1, 1, False),
("f", 1.0, 2.0, True),
("f", 1.0, 1.0, False),
("b", True, False, True),
("b", True, True, False),
("d", decimal.Decimal(1), decimal.Decimal(2), True),
("d", decimal.Decimal(1), decimal.Decimal(1), False),
("m", Something(id=1), Something(id=2), True),
("m", Something(id=1), Something(id=1), True), # models will always be counted as changed
],
)
def test_value_types_checked_for_equality(
attr: str,
original: Any,
changed: Any,
expected: bool,
):
obj = SomethingWithDifferentValueTypes(**{attr: original})
setattr(obj, attr, changed)

assert obj.model_has_changed is expected


@pytest.mark.skipif(PYDANTIC_V1, reason="pydantic v1 does not support model_construct()")
def test_model_construct_works():
something = Something.model_construct(id=1)
Expand Down

0 comments on commit 7bcb6d6

Please sign in to comment.