From 343ae08b5806c360ad493cc74d0111858253c216 Mon Sep 17 00:00:00 2001 From: atharva-2001 Date: Mon, 25 Sep 2023 14:37:10 +0530 Subject: [PATCH 1/6] Add **kwargs for various subclasses of AbstractSyrupyExtension --- src/syrupy/extensions/amber/__init__.py | 4 ++-- src/syrupy/extensions/base.py | 23 +++++++++++++++++------ src/syrupy/extensions/json/__init__.py | 1 + src/syrupy/extensions/single_file.py | 23 +++++++++++++---------- 4 files changed, 33 insertions(+), 18 deletions(-) diff --git a/src/syrupy/extensions/amber/__init__.py b/src/syrupy/extensions/amber/__init__.py index 74dbc33b..8d38d018 100644 --- a/src/syrupy/extensions/amber/__init__.py +++ b/src/syrupy/extensions/amber/__init__.py @@ -47,7 +47,7 @@ def delete_snapshots( else: Path(snapshot_location).unlink() - def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection": + def _read_snapshot_collection(self, snapshot_location: str, **kwargs: Any) -> "SnapshotCollection": return self.serializer_class.read_file(snapshot_location) @classmethod @@ -72,7 +72,7 @@ def _read_snapshot_data_from_location( @classmethod def _write_snapshot_collection( - cls, *, snapshot_collection: "SnapshotCollection" + cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Any ) -> None: cls.serializer_class.write_file(snapshot_collection, merge=True) diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 945cf20b..97b33845 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -15,6 +15,7 @@ Optional, Set, Tuple, + Any, ) from syrupy.constants import ( @@ -67,6 +68,7 @@ def serialize( exclude: Optional["PropertyFilter"] = None, include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, + **kwargs: Any, ) -> "SerializedData": """ Serializes a python object / data structure into a string @@ -108,7 +110,7 @@ def is_snapshot_location(self, *, location: str) -> bool: return location.endswith(self._file_extension) def discover_snapshots( - self, *, test_location: "PyTestLocation" + self, *, test_location: "PyTestLocation", **kwargs: Any ) -> "SnapshotCollections": """ Returns all snapshot collections in test site @@ -216,7 +218,7 @@ def delete_snapshots( @abstractmethod def _read_snapshot_collection( - self, *, snapshot_location: str + self, *, snapshot_location: str, **kwargs: Any ) -> "SnapshotCollection": """ Read the snapshot location and construct a snapshot collection object @@ -235,7 +237,7 @@ def _read_snapshot_data_from_location( @classmethod @abstractmethod def _write_snapshot_collection( - cls, *, snapshot_collection: "SnapshotCollection" + cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Any ) -> None: """ Adds the snapshot data to the snapshots in collection location @@ -243,7 +245,9 @@ def _write_snapshot_collection( raise NotImplementedError @classmethod - def dirname(cls, *, test_location: "PyTestLocation") -> str: + def dirname( + cls, *, test_location: "PyTestLocation", **kwargs: Any + ) -> str: test_dir = Path(test_location.filepath).parent return str(test_dir.joinpath(SNAPSHOT_DIRNAME)) @@ -259,7 +263,10 @@ class SnapshotReporter(ABC): _context_line_count = 1 def diff_snapshots( - self, serialized_data: "SerializedData", snapshot_data: "SerializedData" + self, + serialized_data: "SerializedData", + snapshot_data: "SerializedData", + **kwargs: Any, ) -> "SerializedData": env = {DISABLE_COLOR_ENV_VAR: "true"} attrs = {"_context_line_count": 0} @@ -267,7 +274,10 @@ def diff_snapshots( return "\n".join(self.diff_lines(serialized_data, snapshot_data)) def diff_lines( - self, serialized_data: "SerializedData", snapshot_data: "SerializedData" + self, + serialized_data: "SerializedData", + snapshot_data: "SerializedData", + **kwargs: Any, ) -> Iterator[str]: for line in self.__diff_lines(str(snapshot_data), str(serialized_data)): yield reset(line) @@ -407,6 +417,7 @@ def matches( *, serialized_data: "SerializableData", snapshot_data: "SerializableData", + **kwargs: Any, ) -> bool: """ Compares serialized data and snapshot data and returns diff --git a/src/syrupy/extensions/json/__init__.py b/src/syrupy/extensions/json/__init__.py index 5b52a8d5..ccc9488e 100644 --- a/src/syrupy/extensions/json/__init__.py +++ b/src/syrupy/extensions/json/__init__.py @@ -145,6 +145,7 @@ def serialize( exclude: Optional["PropertyFilter"] = None, include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, + **kwargs: Any, ) -> "SerializedData": data = self._filter( data=data, diff --git a/src/syrupy/extensions/single_file.py b/src/syrupy/extensions/single_file.py index 0b216115..f405ffbf 100644 --- a/src/syrupy/extensions/single_file.py +++ b/src/syrupy/extensions/single_file.py @@ -1,13 +1,7 @@ from enum import Enum from gettext import gettext from pathlib import Path -from typing import ( - TYPE_CHECKING, - Optional, - Set, - Type, - Union, -) +from typing import TYPE_CHECKING, Optional, Set, Type, Union, Dict, Any from unicodedata import category from syrupy.constants import TEXT_ENCODING @@ -49,6 +43,7 @@ def serialize( exclude: Optional["PropertyFilter"] = None, include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, + **kwargs: Any, ) -> "SerializedData": return self.get_supported_dataclass()(data) @@ -74,12 +69,17 @@ def _get_file_basename( return cls.get_snapshot_name(test_location=test_location, index=index) @classmethod - def dirname(cls, *, test_location: "PyTestLocation") -> str: + def dirname( + cls, *, test_location: "PyTestLocation", **kwargs: Any + ) -> str: original_dirname = AbstractSyrupyExtension.dirname(test_location=test_location) return str(Path(original_dirname).joinpath(test_location.basename)) def _read_snapshot_collection( - self, *, snapshot_location: str + self, + *, + snapshot_location: str, + **kwargs: Any, ) -> "SnapshotCollection": file_ext_len = len(self._file_extension) + 1 if self._file_extension else 0 filename_wo_ext = snapshot_location[:-file_ext_len] @@ -116,7 +116,10 @@ def get_write_encoding(cls) -> Optional[str]: @classmethod def _write_snapshot_collection( - cls, *, snapshot_collection: "SnapshotCollection" + cls, + *, + snapshot_collection: "SnapshotCollection", + **kwargs: Any, ) -> None: filepath, data = ( snapshot_collection.location, From 581a1155544ffc18f9f5ad7afcbf3bb3cce7956b Mon Sep 17 00:00:00 2001 From: atharva-2001 Date: Mon, 25 Sep 2023 14:44:55 +0530 Subject: [PATCH 2/6] Format using black --- src/syrupy/extensions/amber/__init__.py | 4 +++- src/syrupy/extensions/base.py | 4 +--- src/syrupy/extensions/single_file.py | 6 ++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/syrupy/extensions/amber/__init__.py b/src/syrupy/extensions/amber/__init__.py index 8d38d018..98e722bd 100644 --- a/src/syrupy/extensions/amber/__init__.py +++ b/src/syrupy/extensions/amber/__init__.py @@ -47,7 +47,9 @@ def delete_snapshots( else: Path(snapshot_location).unlink() - def _read_snapshot_collection(self, snapshot_location: str, **kwargs: Any) -> "SnapshotCollection": + def _read_snapshot_collection( + self, snapshot_location: str, **kwargs: Any + ) -> "SnapshotCollection": return self.serializer_class.read_file(snapshot_location) @classmethod diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 97b33845..2d7ebe8c 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -245,9 +245,7 @@ def _write_snapshot_collection( raise NotImplementedError @classmethod - def dirname( - cls, *, test_location: "PyTestLocation", **kwargs: Any - ) -> str: + def dirname(cls, *, test_location: "PyTestLocation", **kwargs: Any) -> str: test_dir = Path(test_location.filepath).parent return str(test_dir.joinpath(SNAPSHOT_DIRNAME)) diff --git a/src/syrupy/extensions/single_file.py b/src/syrupy/extensions/single_file.py index f405ffbf..c87b4328 100644 --- a/src/syrupy/extensions/single_file.py +++ b/src/syrupy/extensions/single_file.py @@ -1,7 +1,7 @@ from enum import Enum from gettext import gettext from pathlib import Path -from typing import TYPE_CHECKING, Optional, Set, Type, Union, Dict, Any +from typing import TYPE_CHECKING, Optional, Set, Type, Union, Any from unicodedata import category from syrupy.constants import TEXT_ENCODING @@ -69,9 +69,7 @@ def _get_file_basename( return cls.get_snapshot_name(test_location=test_location, index=index) @classmethod - def dirname( - cls, *, test_location: "PyTestLocation", **kwargs: Any - ) -> str: + def dirname(cls, *, test_location: "PyTestLocation", **kwargs: Any) -> str: original_dirname = AbstractSyrupyExtension.dirname(test_location=test_location) return str(Path(original_dirname).joinpath(test_location.basename)) From ac538357ed954c6f2999e45ad444453bda84f6e4 Mon Sep 17 00:00:00 2001 From: atharva-2001 Date: Mon, 25 Sep 2023 14:55:00 +0530 Subject: [PATCH 3/6] Run isort --- src/syrupy/extensions/base.py | 2 +- src/syrupy/extensions/single_file.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 2d7ebe8c..9181f88e 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import ( TYPE_CHECKING, + Any, Callable, Dict, Iterator, @@ -15,7 +16,6 @@ Optional, Set, Tuple, - Any, ) from syrupy.constants import ( diff --git a/src/syrupy/extensions/single_file.py b/src/syrupy/extensions/single_file.py index c87b4328..6f421360 100644 --- a/src/syrupy/extensions/single_file.py +++ b/src/syrupy/extensions/single_file.py @@ -1,7 +1,14 @@ from enum import Enum from gettext import gettext from pathlib import Path -from typing import TYPE_CHECKING, Optional, Set, Type, Union, Any +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Set, + Type, + Union, +) from unicodedata import category from syrupy.constants import TEXT_ENCODING From 829921a0b540ebe42a2e90c3062663ec6dc7e0f0 Mon Sep 17 00:00:00 2001 From: atharva-2001 Date: Mon, 25 Sep 2023 15:09:33 +0530 Subject: [PATCH 4/6] kwargs in __call__ --- src/syrupy/assertion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 6bdb4fcf..d65c099e 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -264,6 +264,7 @@ def __call__( extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, matcher: Optional["PropertyMatcher"] = None, name: Optional["SnapshotIndex"] = None, + **kwargs: Any, ) -> "SnapshotAssertion": """ Modifies assertion instance options From c0238da67bc63ba801ad468c370b93fa99c9c9f4 Mon Sep 17 00:00:00 2001 From: atharva-2001 Date: Tue, 26 Sep 2023 19:11:55 +0530 Subject: [PATCH 5/6] extra args argument --- src/syrupy/assertion.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index d65c099e..135b9571 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -63,6 +63,7 @@ class SnapshotAssertion: include: Optional["PropertyFilter"] = None exclude: Optional["PropertyFilter"] = None matcher: Optional["PropertyMatcher"] = None + extra_args: Dict = field(default_factory=dict) _exclude: Optional["PropertyFilter"] = field( init=False, @@ -105,6 +106,7 @@ def __post_init__(self) -> None: self._include = self.include self._exclude = self.exclude self._matcher = self.matcher + self._extra_args = self.extra_args def __init_extension( self, extension_class: Type["AbstractSyrupyExtension"] @@ -178,6 +180,7 @@ def with_defaults( include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, + extra_args: Optional[Dict] = None ) -> "SnapshotAssertion": """ Create new snapshot assertion fixture with provided values. This preserves @@ -191,6 +194,7 @@ def with_defaults( test_location=self.test_location, extension_class=extension_class or self.extension_class, session=self.session, + extra_args=extra_args or self.extra_args ) def use_extension( @@ -264,7 +268,7 @@ def __call__( extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, matcher: Optional["PropertyMatcher"] = None, name: Optional["SnapshotIndex"] = None, - **kwargs: Any, + extra_args: Optional[Dict] = None, ) -> "SnapshotAssertion": """ Modifies assertion instance options @@ -281,6 +285,8 @@ def __call__( self.__with_prop("_custom_index", name) if diff is not None: self.__with_prop("_snapshot_diff", diff) + if extra_args: + self._extra_args = extra_args return self def __repr__(self) -> str: @@ -301,6 +307,11 @@ def _assert(self, data: "SerializableData") -> bool: matches = False assertion_success = False assertion_exception = None + matcher_options = None + for key,value in self._extra_args.items(): + if key == "matcher_options": + matcher_options = value + print("matcher_options", matcher_options) try: snapshot_data, tainted = self._recall_data(index=self.index) serialized_data = self._serialize(data) @@ -317,7 +328,7 @@ def _assert(self, data: "SerializableData") -> bool: not tainted and snapshot_data is not None and self.extension.matches( - serialized_data=serialized_data, snapshot_data=snapshot_data + serialized_data=serialized_data, snapshot_data=snapshot_data, **matcher_options ) ) assertion_success = matches From 07a51cc01e58d8e601c31e220a50c527435ec8d1 Mon Sep 17 00:00:00 2001 From: atharva-2001 Date: Wed, 27 Sep 2023 16:26:57 +0530 Subject: [PATCH 6/6] extra args with props --- src/syrupy/assertion.py | 34 ++++++++++++++++++++-------------- src/syrupy/extensions/base.py | 11 +++++++++-- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 135b9571..5b48f698 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -180,7 +180,7 @@ def with_defaults( include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, - extra_args: Optional[Dict] = None + extra_args: Optional[Dict] = None, ) -> "SnapshotAssertion": """ Create new snapshot assertion fixture with provided values. This preserves @@ -194,7 +194,7 @@ def with_defaults( test_location=self.test_location, extension_class=extension_class or self.extension_class, session=self.session, - extra_args=extra_args or self.extra_args + extra_args=extra_args or self.extra_args, ) def use_extension( @@ -209,9 +209,13 @@ def use_extension( def assert_match(self, data: "SerializableData") -> None: assert self == data - def _serialize(self, data: "SerializableData") -> "SerializedData": + def _serialize(self, data: "SerializableData", **kwargs: Any) -> "SerializedData": return self.extension.serialize( - data, exclude=self._exclude, include=self._include, matcher=self.__matcher + data, + exclude=self._exclude, + include=self._include, + matcher=self.__matcher, + **kwargs, ) def get_assert_diff(self) -> List[str]: @@ -286,7 +290,7 @@ def __call__( if diff is not None: self.__with_prop("_snapshot_diff", diff) if extra_args: - self._extra_args = extra_args + self.__with_prop("_extra_args", extra_args) return self def __repr__(self) -> str: @@ -307,28 +311,29 @@ def _assert(self, data: "SerializableData") -> bool: matches = False assertion_success = False assertion_exception = None - matcher_options = None - for key,value in self._extra_args.items(): - if key == "matcher_options": - matcher_options = value - print("matcher_options", matcher_options) + extra_args = getattr(self, "_extra_args", {}) try: snapshot_data, tainted = self._recall_data(index=self.index) - serialized_data = self._serialize(data) + serialized_data = self._serialize(data, **extra_args) snapshot_diff = getattr(self, "_snapshot_diff", None) if snapshot_diff is not None: - snapshot_data_diff, _ = self._recall_data(index=snapshot_diff) + snapshot_data_diff, _ = self._recall_data( + index=snapshot_diff, **extra_args + ) if snapshot_data_diff is None: raise SnapshotDoesNotExist() serialized_data = self.extension.diff_snapshots( serialized_data=serialized_data, snapshot_data=snapshot_data_diff, + **extra_args, ) matches = ( not tainted and snapshot_data is not None and self.extension.matches( - serialized_data=serialized_data, snapshot_data=snapshot_data, **matcher_options + serialized_data=serialized_data, + snapshot_data=snapshot_data, + **extra_args, ) ) assertion_success = matches @@ -373,7 +378,7 @@ def _post_assert(self) -> None: self._post_assert_actions.pop()() def _recall_data( - self, index: "SnapshotIndex" + self, index: "SnapshotIndex", **kwargs: Any ) -> Tuple[Optional["SerializableData"], bool]: try: return ( @@ -381,6 +386,7 @@ def _recall_data( test_location=self.test_location, index=index, session_id=str(id(self.session)), + **kwargs, ), False, ) diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 9181f88e..3cb0c5e7 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -119,7 +119,7 @@ def discover_snapshots( for filepath in walk_snapshot_dir(self.dirname(test_location=test_location)): if self.is_snapshot_location(location=filepath): snapshot_collection = self._read_snapshot_collection( - snapshot_location=filepath + snapshot_location=filepath, **kwargs ) if not snapshot_collection.has_snapshots: snapshot_collection = SnapshotEmptyCollection(location=filepath) @@ -136,6 +136,7 @@ def read_snapshot( test_location: "PyTestLocation", index: "SnapshotIndex", session_id: str, + **kwargs: Any, ) -> "SerializedData": """ This method is _final_, do not override. You can override @@ -147,6 +148,7 @@ def read_snapshot( snapshot_location=snapshot_location, snapshot_name=snapshot_name, session_id=session_id, + **kwargs, ) if snapshot_data is None: raise SnapshotDoesNotExist() @@ -227,7 +229,12 @@ def _read_snapshot_collection( @abstractmethod def _read_snapshot_data_from_location( - self, *, snapshot_location: str, snapshot_name: str, session_id: str + self, + *, + snapshot_location: str, + snapshot_name: str, + session_id: str, + **kwargs: Any, ) -> Optional["SerializedData"]: """ Get only the snapshot data from location for assertion