Skip to content

Commit

Permalink
refactor: scaffolding to support custom context in extensions
Browse files Browse the repository at this point in the history
NOTE: Since syrupy v4 migrated from instance methods to classmethods, this new context is not actual usable. This lays the groundwork for a switch back to instance methods though (if we continue along this path).
  • Loading branch information
atharva-2001 authored and Noah Negin-Ulster committed Sep 27, 2023
1 parent b240712 commit 2189962
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 21 deletions.
31 changes: 24 additions & 7 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class SnapshotAssertion:
exclude: Optional["PropertyFilter"] = None
matcher: Optional["PropertyMatcher"] = None

# context is reserved exclusively for custom extensions
context: Optional[Dict[str, Any]] = None

_exclude: Optional["PropertyFilter"] = field(
init=False,
default=None,
Expand Down Expand Up @@ -109,7 +112,8 @@ def __post_init__(self) -> None:
def __init_extension(
self, extension_class: Type["AbstractSyrupyExtension"]
) -> "AbstractSyrupyExtension":
return extension_class()
kwargs = {"context": self.context} if self.context else {}
return extension_class(**kwargs)

@property
def extension(self) -> "AbstractSyrupyExtension":
Expand Down Expand Up @@ -178,6 +182,7 @@ def with_defaults(
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
context: Optional[Dict[str, Any]] = None,
) -> "SnapshotAssertion":
"""
Create new snapshot assertion fixture with provided values. This preserves
Expand All @@ -191,6 +196,7 @@ def with_defaults(
test_location=self.test_location,
extension_class=extension_class or self.extension_class,
session=self.session,
context=context or self.context,
)

def use_extension(
Expand All @@ -207,7 +213,10 @@ def assert_match(self, data: "SerializableData") -> None:

def _serialize(self, data: "SerializableData") -> "SerializedData":
return self.extension.serialize(
data, exclude=self._exclude, include=self._include, matcher=self.__matcher
data,
exclude=self._exclude,
include=self._include,
matcher=self.__matcher,
)

def get_assert_diff(self) -> List[str]:
Expand Down Expand Up @@ -264,6 +273,7 @@ def __call__(
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional["SnapshotIndex"] = None,
context: Optional[Dict[str, Any]] = None,
) -> "SnapshotAssertion":
"""
Modifies assertion instance options
Expand All @@ -272,14 +282,18 @@ def __call__(
self.__with_prop("_exclude", exclude)
if include:
self.__with_prop("_include", include)
if extension_class:
self.__with_prop("_extension", self.__init_extension(extension_class))
if matcher:
self.__with_prop("_matcher", matcher)
if name:
self.__with_prop("_custom_index", name)
if diff is not None:
self.__with_prop("_snapshot_diff", diff)
if context and context != self.context:
self.__with_prop("context", context)
# We need to force the extension to be re-initialized if the context changes
extension_class = extension_class or self.extension_class
if extension_class:
self.__with_prop("_extension", self.__init_extension(extension_class))
return self

def __repr__(self) -> str:
Expand All @@ -290,10 +304,12 @@ def __eq__(self, other: "SerializableData") -> bool:

def _assert(self, data: "SerializableData") -> bool:
snapshot_location = self.extension.get_location(
test_location=self.test_location, index=self.index
test_location=self.test_location,
index=self.index,
)
snapshot_name = self.extension.get_snapshot_name(
test_location=self.test_location, index=self.index
test_location=self.test_location,
index=self.index,
)
snapshot_data: Optional["SerializedData"] = None
serialized_data: Optional["SerializedData"] = None
Expand All @@ -316,7 +332,8 @@ def _assert(self, data: "SerializableData") -> bool:
not tainted
and snapshot_data is not None
and self.extension.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
serialized_data=serialized_data,
snapshot_data=snapshot_data,
)
)
assertion_success = matches
Expand Down
19 changes: 15 additions & 4 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ class SnapshotCollectionStorage(ABC):

@classmethod
def get_snapshot_name(
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" = 0
cls,
*,
test_location: "PyTestLocation",
index: "SnapshotIndex" = 0,
) -> str:
"""Get the snapshot name for the assertion index in a test location"""
index_suffix = ""
Expand Down Expand Up @@ -225,7 +228,11 @@ def _read_snapshot_collection(

@abstractmethod
def _read_snapshot_data_from_location(
self, *, snapshot_location: str, snapshot_name: str, session_id: str
self,
*,
snapshot_location: str,
snapshot_name: str,
session_id: str,
) -> Optional["SerializedData"]:
"""
Get only the snapshot data from location for assertion
Expand Down Expand Up @@ -259,15 +266,19 @@ class SnapshotReporter(ABC):
_context_line_count = 1

def diff_snapshots(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
) -> "SerializedData":
env = {DISABLE_COLOR_ENV_VAR: "true"}
attrs = {"_context_line_count": 0}
with env_context(**env), obj_attrs(self, attrs):
return "\n".join(self.diff_lines(serialized_data, snapshot_data))

def diff_lines(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
) -> Iterator[str]:
for line in self.__diff_lines(str(snapshot_data), str(serialized_data)):
yield reset(line)
Expand Down
1 change: 1 addition & 0 deletions src/syrupy/extensions/json/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs: Any,
) -> "SerializedData":
data = self._filter(
data=data,
Expand Down
19 changes: 15 additions & 4 deletions src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def serialize(

@classmethod
def get_snapshot_name(
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" = 0
cls,
*,
test_location: "PyTestLocation",
index: "SnapshotIndex" = 0,
) -> str:
return cls.__clean_filename(
AbstractSyrupyExtension.get_snapshot_name(
Expand All @@ -79,7 +82,9 @@ def dirname(cls, *, test_location: "PyTestLocation") -> str:
return str(Path(original_dirname).joinpath(test_location.basename))

def _read_snapshot_collection(
self, *, snapshot_location: str
self,
*,
snapshot_location: str,
) -> "SnapshotCollection":
file_ext_len = len(self._file_extension) + 1 if self._file_extension else 0
filename_wo_ext = snapshot_location[:-file_ext_len]
Expand All @@ -90,7 +95,11 @@ def _read_snapshot_collection(
return snapshot_collection

def _read_snapshot_data_from_location(
self, *, snapshot_location: str, snapshot_name: str, session_id: str
self,
*,
snapshot_location: str,
snapshot_name: str,
session_id: str,
) -> Optional["SerializableData"]:
try:
with open(
Expand All @@ -116,7 +125,9 @@ def get_write_encoding(cls) -> Optional[str]:

@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
cls,
*,
snapshot_collection: "SnapshotCollection",
) -> None:
filepath, data = (
snapshot_collection.location,
Expand Down
31 changes: 27 additions & 4 deletions src/syrupy/session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from collections import defaultdict
from dataclasses import (
dataclass,
Expand Down Expand Up @@ -54,7 +55,7 @@ class SnapshotSession:
)

_queued_snapshot_writes: Dict[
Tuple[Type["AbstractSyrupyExtension"], str],
Tuple[Type["AbstractSyrupyExtension"], Optional[bytes], str],
List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]],
] = field(default_factory=dict)

Expand All @@ -68,19 +69,41 @@ def queue_snapshot_write(
snapshot_location = extension.get_location(
test_location=test_location, index=index
)
key = (extension.__class__, snapshot_location)

extension_context = getattr(extension, "context", None)

try:
extension_kwargs_bytes = (
pickle.dumps(extension_context) if extension_context else None
)
except pickle.PicklingError:
print("Extension context must be serializable.")
raise

key = (extension.__class__, extension_kwargs_bytes, snapshot_location)
queue = self._queued_snapshot_writes.get(key, [])
queue.append((data, test_location, index))
self._queued_snapshot_writes[key] = queue

def flush_snapshot_write_queue(self) -> None:
for (
extension_class,
extension_kwargs_bytes,
snapshot_location,
), queued_write in self._queued_snapshot_writes.items():
if queued_write:
extension_class.write_snapshot(
snapshot_location=snapshot_location, snapshots=queued_write
# It's possible to instantiate an extension with context. We need to
# ensure we never lose context between instantiations (since we may
# instantiate multiple times in a test session).
extension_kwargs = (
{"context": pickle.loads(extension_kwargs_bytes)}
if extension_kwargs_bytes
else {}
)
extension = extension_class(**extension_kwargs)
extension.write_snapshot(
snapshot_location=snapshot_location,
snapshots=queued_write,
)
self._queued_snapshot_writes = {}

Expand Down
6 changes: 4 additions & 2 deletions tests/examples/test_custom_snapshot_name.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Example: Custom Snapshot Name
"""
from typing import Any

import pytest

from syrupy.extensions.amber import AmberSnapshotExtension
Expand All @@ -11,10 +13,10 @@
class CanadianNameExtension(AmberSnapshotExtension):
@classmethod
def get_snapshot_name(
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex"
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex", **kwargs: Any
) -> str:
original_name = AmberSnapshotExtension.get_snapshot_name(
test_location=test_location, index=index
test_location=test_location, index=index, **kwargs
)
return f"{original_name}🇨🇦"

Expand Down

1 comment on commit 2189962

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 2189962 Previous: b240712 Ratio
benchmarks/test_1000x.py::test_1000x_reads 0.79030033544977 iter/sec (stddev: 0.050948223822124356) 0.7172343042658249 iter/sec (stddev: 0.04888383581177893) 0.91
benchmarks/test_1000x.py::test_1000x_writes 0.7807394633637691 iter/sec (stddev: 0.06168038169209299) 0.7125160400292763 iter/sec (stddev: 0.07169036131462862) 0.91
benchmarks/test_standard.py::test_standard 0.7758155425583666 iter/sec (stddev: 0.056108112251705394) 0.7023221367731313 iter/sec (stddev: 0.08922127742323605) 0.91

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.