From 00ce851a5ab257b11366e4933ad2c65640a149f0 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Sun, 6 Nov 2022 13:12:10 +0100 Subject: [PATCH] FEAT: implement `ChainedDataTransformer` (#470) --- .cspell.json | 1 + src/tensorwaves/data/_attrs.py | 9 +++++++ src/tensorwaves/data/transform.py | 30 ++++++++++++++++++++++ tests/data/test_transform.py | 41 +++++++++++++++++++++++++++++-- 4 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 src/tensorwaves/data/_attrs.py diff --git a/.cspell.json b/.cspell.json index 16e6b953..b0460087 100644 --- a/.cspell.json +++ b/.cspell.json @@ -169,6 +169,7 @@ "qrules", "rightarrow", "rtfd", + "rtol", "scipy", "sdist", "seealso", diff --git a/src/tensorwaves/data/_attrs.py b/src/tensorwaves/data/_attrs.py new file mode 100644 index 00000000..e6631cea --- /dev/null +++ b/src/tensorwaves/data/_attrs.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from typing import Iterable + +from tensorwaves.interface import DataTransformer + + +def to_tuple(items: Iterable[DataTransformer]) -> tuple[DataTransformer, ...]: + return tuple(items) diff --git a/src/tensorwaves/data/transform.py b/src/tensorwaves/data/transform.py index 30abfa2c..a75c5faa 100644 --- a/src/tensorwaves/data/transform.py +++ b/src/tensorwaves/data/transform.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING, Mapping +from attrs import field, frozen + from tensorwaves.function import PositionalArgumentFunction from tensorwaves.function.sympy import ( _get_free_symbols, # pyright: ignore[reportPrivateUsage] @@ -12,10 +14,38 @@ ) from tensorwaves.interface import DataSample, DataTransformer, Function +from ._attrs import to_tuple + if TYPE_CHECKING: # pragma: no cover import sympy as sp +@frozen +class ChainedDataTransformer(DataTransformer): + """Combine multiple `.DataTransformer` classes into one. + + Args: + transformer: Ordered list of transformers that you want to chain. + extend: Set to `True` in order to keep keys of each output `.DataSample` and + collect them into the final, chained `.DataSample`. + """ + + transformers: tuple[DataTransformer, ...] = field(converter=to_tuple) + extend: bool = True + + def __call__(self, data: DataSample) -> DataSample: + new_data = dict(data) + weights = new_data.get("weights") + for transformer in self.transformers: + if self.extend: + new_data.update(transformer(new_data)) + else: + new_data = transformer(new_data) + if weights is not None: + new_data["weights"] = weights + return new_data + + class IdentityTransformer(DataTransformer): """`.DataTransformer` that leaves a `.DataSample` intact.""" diff --git a/tests/data/test_transform.py b/tests/data/test_transform.py index b3713578..d0b56369 100644 --- a/tests/data/test_transform.py +++ b/tests/data/test_transform.py @@ -1,10 +1,47 @@ -# pylint: disable=invalid-name +from __future__ import annotations + import numpy as np import pytest import sympy as sp from numpy import sqrt -from tensorwaves.data import IdentityTransformer, SympyDataTransformer +from tensorwaves.data.transform import ( + ChainedDataTransformer, + IdentityTransformer, + SympyDataTransformer, +) + + +class TestChainedDataTransformer: + @pytest.mark.parametrize("extend", [False, True]) + def test_identity_chain(self, extend: bool): + x, y, v, w = sp.symbols("x y v w") + transform1 = _create_transformer({v: 2 * x - 5, w: -0.2 * y + 3}) + transform2 = _create_transformer({x: 0.5 * (v + 5), y: 5 * (3 - w)}) + chained_transform = ChainedDataTransformer([transform1, transform2], extend) + rng = np.random.default_rng(seed=0) + data = {"x": rng.uniform(size=100), "y": rng.uniform(size=100)} + transformed_data = chained_transform(data) + for key in data: # pylint: disable=consider-using-dict-items + np.testing.assert_allclose(data[key], transformed_data[key], rtol=1e-13) + if extend: + assert set(transformed_data) == {"x", "y", "v", "w"} + else: + assert set(transformed_data) == {"x", "y"} + + def test_single_chain(self): + transform = IdentityTransformer() + chained_transform = ChainedDataTransformer([transform]) + data = { + "x": np.ones(5), + "y": np.ones(5), + } + assert data == chained_transform(data) + assert data is not chained_transform(data) # DataSample returned as new dict + + +def _create_transformer(expressions: dict[sp.Symbol, sp.Expr]) -> SympyDataTransformer: + return SympyDataTransformer.from_sympy(expressions, backend="jax") class TestIdentityTransformer: