From 979fb8f94f73aaf5179a8157cdec50983610d014 Mon Sep 17 00:00:00 2001 From: Michael Tiemann <72577720+MichaelTiemannOSC@users.noreply.github.com> Date: Sun, 14 Jan 2024 12:30:30 +1300 Subject: [PATCH] Final tweaks for clean mypy pass Adjusted various declarations to use public Pandas APIs where possible, ignored things that cannot be fixed without fixing Pandas-Stubs, etc. Might need some more attention for different versions of Pandas/Python, but this work for Pandas 2.1.4 and Python 3.11. Signed-off-by: Michael Tiemann <72577720+MichaelTiemannOSC@users.noreply.github.com> --- .pre-commit-config.yaml | 1 - docs/conf.py | 4 +- pint_pandas/__init__.py | 2 +- pint_pandas/pint_array.py | 49 +++++++++++-------- .../testsuite/test_pandas_extensiontests.py | 10 ++-- 5 files changed, 36 insertions(+), 30 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 830f82e3..478dd97e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,5 +39,4 @@ repos: "pandas-stubs", "pint", "matplotlib-stubs", - "itr", ] diff --git a/docs/conf.py b/docs/conf.py index b90fe148..23e64859 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -5,7 +5,7 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html import datetime -from importlib.metadata import version +from importlib.metadata import version as metadata_version # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the @@ -22,7 +22,7 @@ # built documents. try: # pragma: no cover - version = version(project) + version = metadata_version(project) except Exception: # pragma: no cover # we seem to have a local copy not installed without setuptools # so the reported version will be unknown diff --git a/pint_pandas/__init__.py b/pint_pandas/__init__.py index 145a42dc..17dbc236 100644 --- a/pint_pandas/__init__.py +++ b/pint_pandas/__init__.py @@ -6,7 +6,7 @@ from importlib.metadata import version except ImportError: # Backport for Python < 3.8 - from importlib_metadata import version + from importlib_metadata import version # type: ignore try: # pragma: no cover __version__ = version("pint_pandas") diff --git a/pint_pandas/pint_array.py b/pint_pandas/pint_array.py index ede32f01..92906d10 100644 --- a/pint_pandas/pint_array.py +++ b/pint_pandas/pint_array.py @@ -2,7 +2,7 @@ import re import warnings from importlib.metadata import version -from typing import Optional +from typing import Any, Callable, Dict, Optional, cast import numpy as np import pandas as pd @@ -11,15 +11,15 @@ from pandas.api.extensions import ( ExtensionArray, ExtensionDtype, + ExtensionScalarOpsMixin, register_dataframe_accessor, register_extension_dtype, register_series_accessor, ) +from pandas.api.indexers import check_array_indexer from pandas.api.types import is_integer, is_list_like, is_object_dtype, is_string_dtype from pandas.compat import set_function_name -from pandas.core import nanops -from pandas.core.arrays.base import ExtensionOpsMixin -from pandas.core.indexers import check_array_indexer +from pandas.core import nanops # type: ignore from pint import Quantity as _Quantity from pint import Unit as _Unit from pint import compat, errors @@ -47,7 +47,7 @@ class PintType(ExtensionDtype): units: Optional[_Unit] = None # Filled in by `construct_from_..._string` _metadata = ("units",) _match = re.compile(r"(P|p)int\[(?P.+)\]") - _cache = {} + _cache = {} # type: ignore ureg = pint.get_application_registry() @property @@ -78,11 +78,13 @@ def __new__(cls, units=None): units = cls.ureg.Quantity(1, units).units try: - return cls._cache["{:P}".format(units)] + # TODO: fix when Pint implements Callable typing + # TODO: wrap string into PintFormatStr class + return cls._cache["{:P}".format(units)] # type: ignore except KeyError: u = object.__new__(cls) u.units = units - cls._cache["{:P}".format(units)] = u + cls._cache["{:P}".format(units)] = u # type: ignore return u @classmethod @@ -193,9 +195,9 @@ def __repr__(self): _NumpyEADtype = ( - pd.core.dtypes.dtypes.PandasDtype + pd.core.dtypes.dtypes.PandasDtype # type: ignore if pandas_version_info < (2, 1) - else pd.core.dtypes.dtypes.NumpyEADtype + else pd.core.dtypes.dtypes.NumpyEADtype # type: ignore ) dtypemap = { @@ -215,7 +217,7 @@ def __repr__(self): dtypeunmap = {v: k for k, v in dtypemap.items()} -class PintArray(ExtensionArray, ExtensionOpsMixin): +class PintArray(ExtensionArray, ExtensionScalarOpsMixin): """Implements a class to describe an array of physical quantities: the product of an array of numerical values and a unit of measurement. @@ -234,7 +236,7 @@ class PintArray(ExtensionArray, ExtensionOpsMixin): """ - _data = np.array([]) + _data: ExtensionArray = cast(ExtensionArray, np.array([])) context_name = None context_units = None @@ -383,7 +385,7 @@ def isna(self): ------- missing : np.array """ - return self._data.isna() + return cast(np.ndarray, self._data.isna()) def astype(self, dtype, copy=True): """Cast to a NumPy array with 'dtype'. @@ -620,11 +622,11 @@ def unique(self): data = self._data return self._from_sequence(unique(data), dtype=self.dtype) - def __contains__(self, item) -> bool: + def __contains__(self, item) -> bool | np.bool_: if not isinstance(item, _Quantity): return False elif pd.isna(item.magnitude): - return self.isna().any() + return cast(np.ndarray, self.isna()).any() else: return super().__contains__(item) @@ -908,11 +910,12 @@ def _reduce(self, name, *, skipna: bool = True, keepdims: bool = False, **kwds): if isinstance(self._data, ExtensionArray): try: - result = self._data._reduce( + # TODO: https://github.com/pandas-dev/pandas-stubs/issues/850 + result = self._data._reduce( # type: ignore name, skipna=skipna, keepdims=keepdims, **kwds ) except NotImplementedError: - result = functions[name](self.numpy_data, **kwds) + result = cast(_Quantity, functions[name](self.numpy_data, **kwds)) if name in {"all", "any", "kurt", "skew"}: return result @@ -927,7 +930,9 @@ def _reduce(self, name, *, skipna: bool = True, keepdims: bool = False, **kwds): def _accumulate(self, name: str, *, skipna: bool = True, **kwds): if name == "cumprod": raise TypeError("cumprod not supported for pint arrays") - functions = { + functions: Dict[ + str, Callable[[np._typing._SupportsArray[np.dtype[Any]]], Any] + ] = { "cummin": np.minimum.accumulate, "cummax": np.maximum.accumulate, "cumsum": np.cumsum, @@ -935,7 +940,8 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwds): if isinstance(self._data, ExtensionArray): try: - result = self._data._accumulate(name, **kwds) + # TODO: https://github.com/pandas-dev/pandas-stubs/issues/850 + result = self._data._accumulate(name, **kwds) # type: ignore except NotImplementedError: result = functions[name](self.numpy_data, **kwds) @@ -1181,9 +1187,10 @@ def is_pint_type(obj): try: # for pint < 0.21 we need to explicitly register - compat.upcast_types.append(PintArray) + # TODO: fix when Pint is properly typed for mypy + compat.upcast_types.append(PintArray) # type: ignore except AttributeError: # for pint = 0.21 we need to add the full names of PintArray and DataFrame, # which is to be added in pint > 0.21 - compat.upcast_type_map.setdefault("pint_pandas.pint_array.PintArray", PintArray) - compat.upcast_type_map.setdefault("pandas.core.frame.DataFrame", DataFrame) + compat.upcast_type_map.setdefault("pint_pandas.pint_array.PintArray", PintArray) # type: ignore + compat.upcast_type_map.setdefault("pandas.core.frame.DataFrame", DataFrame) # type: ignore diff --git a/pint_pandas/testsuite/test_pandas_extensiontests.py b/pint_pandas/testsuite/test_pandas_extensiontests.py index 1427baa8..25da6d04 100644 --- a/pint_pandas/testsuite/test_pandas_extensiontests.py +++ b/pint_pandas/testsuite/test_pandas_extensiontests.py @@ -18,7 +18,6 @@ use_numpy, # noqa: F401 ) - from pint.errors import DimensionalityError from pint_pandas import PintArray, PintType @@ -381,9 +380,10 @@ def _get_expected_exception( return TypeError if isinstance(obj, pd.Series): try: - if obj.pint.m.dtype.kind == "c": + # PintSeriesAccessor is dynamically constructed; need stubs to make it mypy-compatible + if obj.pint.m.dtype.kind == "c": # type: ignore pytest.skip( - f"{obj.pint.m.dtype.name} {obj.dtype} does not support {op_name}" + f"{obj.pint.m.dtype.name} {obj.dtype} does not support {op_name}" # type: ignore ) return TypeError except AttributeError: @@ -392,9 +392,9 @@ def _get_expected_exception( return exc if isinstance(other, pd.Series): try: - if other.pint.m.dtype.kind == "c": + if other.pint.m.dtype.kind == "c": # type: ignore pytest.skip( - f"{other.pint.m.dtype.name} {other.dtype} does not support {op_name}" + f"{other.pint.m.dtype.name} {other.dtype} does not support {op_name}" # type: ignore ) return TypeError except AttributeError: