From 6cddc1095e440a22b6937af91df46165abd9a585 Mon Sep 17 00:00:00 2001 From: andrewgsavage Date: Sun, 9 Jun 2024 21:59:04 +0100 Subject: [PATCH] Implement Array ufunc (#160) --- CHANGES | 2 +- pint_pandas/pint_array.py | 59 +++++++++++++++++++ .../testsuite/test_pandas_extensiontests.py | 11 ++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/CHANGES b/CHANGES index 7bfcbf9..561992f 100644 --- a/CHANGES +++ b/CHANGES @@ -6,7 +6,7 @@ pint-pandas Changelog - Fix dequantify duplicate column failure #202 - Fix astype issue #196 - +- Support for `__array_ufunc__` and unary ops. #160 0.5 (2023-09-07) ---------------- diff --git a/pint_pandas/pint_array.py b/pint_pandas/pint_array.py index 2bf34e8..5ff8b5c 100644 --- a/pint_pandas/pint_array.py +++ b/pint_pandas/pint_array.py @@ -1,4 +1,5 @@ import copy +import numbers import re import warnings from importlib.metadata import version @@ -218,6 +219,15 @@ def __repr__(self): dtypeunmap = {v: k for k, v in dtypemap.items()} +def convert_np_inputs(inputs): + if isinstance(inputs, tuple): + return tuple(x.quantity if isinstance(x, PintArray) else x for x in inputs) + if isinstance(inputs, dict): + return { + item: (x.quantity if isinstance(x, PintArray) else x) for item, x in inputs + } + + class PintArray(ExtensionArray, ExtensionScalarOpsMixin): """Implements a class to describe an array of physical quantities: the product of an array of numerical values and a unit of measurement. @@ -240,6 +250,7 @@ class PintArray(ExtensionArray, ExtensionScalarOpsMixin): _data: ExtensionArray = cast(ExtensionArray, np.array([])) context_name = None context_units = None + _HANDLED_TYPES = (np.ndarray, numbers.Number, _Quantity) def __init__(self, values, dtype=None, copy=False): if dtype is None: @@ -281,6 +292,54 @@ def __setstate__(self, dct): self.__dict__.update(dct) self._Q = self.dtype.ureg.Quantity + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + out = kwargs.get("out", ()) + for x in inputs + out: + # Only support operations with instances of _HANDLED_TYPES. + # Use ArrayLike instead of type(self) for isinstance to + # allow subclasses that don't override __array_ufunc__ to + # handle ArrayLike objects. + if not isinstance(x, self._HANDLED_TYPES + (PintArray,)): + return NotImplemented + + # Defer to pint's implementation of the ufunc. + inputs = convert_np_inputs(inputs) + if out: + kwargs["out"] = convert_np_inputs(out) + print(inputs) + result = getattr(ufunc, method)(*inputs, **kwargs) + return self._convert_np_result(result) + + def _convert_np_result(self, result): + if isinstance(result, _Quantity) and is_list_like(result.m): + return PintArray.from_1darray_quantity(result) + elif isinstance(result, _Quantity): + return result + elif type(result) is tuple: + # multiple return values + return tuple(type(self)(x) for x in result) + elif isinstance(result, np.ndarray) and all( + isinstance(item, _Quantity) for item in result + ): + return PintArray._from_sequence(result) + elif result is None: + # no return value + return result + elif pd.api.types.is_bool_dtype(result): + return result + else: + # one return value + return type(self)(result) + + def __pos__(self): + return 1 * self + + def __neg__(self): + return -1 * self + + def __abs__(self): + return self._Q(np.abs(self._data), self._dtype.units) + @property def dtype(self): # type: () -> ExtensionDtype diff --git a/pint_pandas/testsuite/test_pandas_extensiontests.py b/pint_pandas/testsuite/test_pandas_extensiontests.py index 25da6d0..44a4caf 100644 --- a/pint_pandas/testsuite/test_pandas_extensiontests.py +++ b/pint_pandas/testsuite/test_pandas_extensiontests.py @@ -657,6 +657,17 @@ def test_setitem_2d_values(self, data): assert (df.loc[1, :] == original[0]).all() +class TestUnaryOps(base.BaseUnaryOpsTests): + @pytest.mark.xfail(run=True, reason="invert not implemented") + def test_invert(self, data): + base.BaseUnaryOpsTests.test_invert(self, data) + + @pytest.mark.xfail(run=True, reason="np.positive requires pint 0.21") + @pytest.mark.parametrize("ufunc", [np.positive, np.negative, np.abs]) + def test_unary_ufunc_dunder_equivalence(self, data, ufunc): + base.BaseUnaryOpsTests.test_unary_ufunc_dunder_equivalence(self, data, ufunc) + + class TestAccumulate(base.BaseAccumulateTests): @pytest.mark.parametrize("skipna", [True, False]) def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):