Skip to content

Commit

Permalink
Pint array: strip and reattach appropriate units
Browse files Browse the repository at this point in the history
Closes #163
  • Loading branch information
dcherian committed Jan 24, 2023
1 parent cd6eeb5 commit 8ee7488
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 12 deletions.
73 changes: 62 additions & 11 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import warnings
from functools import partial
from typing import Callable

import numpy as np
import numpy_groupies as npg
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
dtypes=None,
final_dtype=None,
reduction_type="reduce",
units_func: Callable | None = None,
):
"""
Blueprint for computing grouped aggregations.
Expand Down Expand Up @@ -156,6 +158,8 @@ def __init__(
per reduction in ``chunk`` as a tuple.
final_dtype : DType, optional
DType for output. By default, uses dtype of array being reduced.
units_func : pint.Unit
units for the output
"""
self.name = name
# preprocess before blockwise
Expand Down Expand Up @@ -187,6 +191,8 @@ def __init__(
# The following are set by _initialize_aggregation
self.finalize_kwargs = {}
self.min_count = None
self.units_func = units_func
self.units = None

def _normalize_dtype_fill_value(self, value, name):
value = _atleast_1d(value)
Expand Down Expand Up @@ -235,17 +241,44 @@ def __repr__(self):
final_dtype=np.intp,
)


def identity(x):
return x


def square(x):
return x**2


def raise_units_error(x):
raise ValueError(
"Units cannot supported for prod in general. "
"We can only attach units when there are "
"equal number of members in each group. "
"Please strip units and then reattach units "
"to the output manually."
)


# note that the fill values are the result of np.func([np.nan, np.nan])
# final_fill_value is used for groups that don't exist. This is usually np.nan
sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0)
nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0)
prod = Aggregation("prod", chunk="prod", combine="prod", fill_value=1, final_fill_value=1)
sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0, units_func=identity)
nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0, units_func=identity)
prod = Aggregation(
"prod",
chunk="prod",
combine="prod",
fill_value=1,
final_fill_value=1,
units_func=raise_units_error,
)
nanprod = Aggregation(
"nanprod",
chunk="nanprod",
combine="prod",
fill_value=1,
final_fill_value=dtypes.NA,
units_func=raise_units_error,
)


Expand All @@ -262,6 +295,7 @@ def _mean_finalize(sum_, count):
fill_value=(0, 0),
dtypes=(None, np.intp),
final_dtype=np.floating,
units_func=identity,
)
nanmean = Aggregation(
"nanmean",
Expand All @@ -271,6 +305,7 @@ def _mean_finalize(sum_, count):
fill_value=(0, 0),
dtypes=(None, np.intp),
final_dtype=np.floating,
units_func=identity,
)


Expand All @@ -296,6 +331,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
final_dtype=np.floating,
units_func=square,
)
nanvar = Aggregation(
"nanvar",
Expand All @@ -306,6 +342,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
final_dtype=np.floating,
units_func=square,
)
std = Aggregation(
"std",
Expand All @@ -316,6 +353,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
final_dtype=np.floating,
units_func=identity,
)
nanstd = Aggregation(
"nanstd",
Expand All @@ -329,10 +367,14 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
)


min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan)
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, units_func=identity)
nanmin = Aggregation(
"nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan, units_func=identity
)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, units_func=identity)
nanmax = Aggregation(
"nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan, units_func=identity
)


def argreduce_preprocess(array, axis):
Expand Down Expand Up @@ -420,10 +462,14 @@ def _pick_second(*x):
final_dtype=np.intp,
)

first = Aggregation("first", chunk=None, combine=None, fill_value=0)
last = Aggregation("last", chunk=None, combine=None, fill_value=0)
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan)
nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan)
first = Aggregation("first", chunk=None, combine=None, fill_value=0, units_func=identity)
last = Aggregation("last", chunk=None, combine=None, fill_value=0, units_func=identity)
nanfirst = Aggregation(
"nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan, units_func=identity
)
nanlast = Aggregation(
"nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan, units_func=identity
)

all_ = Aggregation(
"all",
Expand Down Expand Up @@ -483,6 +529,7 @@ def _initialize_aggregation(
dtype,
array_dtype,
fill_value,
array_units,
min_count: int | None,
finalize_kwargs,
) -> Aggregation:
Expand Down Expand Up @@ -547,4 +594,8 @@ def _initialize_aggregation(
agg.dtype["intermediate"] += (np.intp,)
agg.dtype["numpy"] += (np.intp,)

if array_units is not None and agg.units_func is not None:
import pint

agg.units = agg.units_func(pint.Quantity([1], units=array_units))
return agg
10 changes: 9 additions & 1 deletion flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
generic_aggregate,
)
from .cache import memoize
from .pint_compat import _reattach_units, _strip_units
from .xrutils import is_duck_array, is_duck_dask_array, isnull

if TYPE_CHECKING:
Expand Down Expand Up @@ -1702,6 +1703,8 @@ def groupby_reduce(
by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
any_by_dask = any(by_is_dask)

array, *bys, units = _strip_units(array, *bys)

if method in ["split-reduce", "cohorts"] and any_by_dask:
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

Expand Down Expand Up @@ -1803,7 +1806,9 @@ def groupby_reduce(
fill_value = np.nan

kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs)
agg = _initialize_aggregation(
func, dtype, array.dtype, fill_value, units[0], min_count, finalize_kwargs
)

if not has_dask:
results = _reduce_blockwise(
Expand Down Expand Up @@ -1862,4 +1867,7 @@ def groupby_reduce(

if _is_minmax_reduction(func) and is_bool_array:
result = result.astype(bool)

units[0] = agg.units
result, *groups = _reattach_units(result, *groups, units=units)
return (result, *groups)
16 changes: 16 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
except ImportError:
xr_types = () # type: ignore

try:
import pint

pint_types = pint.Quantity
except ImportError:
pint_types = () # type: ignore


def _importorskip(modname, minversion=None):
try:
Expand All @@ -46,6 +53,7 @@ def LooseVersion(vstring):


has_dask, requires_dask = _importorskip("dask")
has_pint, requires_pint = _importorskip("pint")
has_xarray, requires_xarray = _importorskip("xarray")


Expand Down Expand Up @@ -95,6 +103,14 @@ def assert_equal(a, b, tolerance=None):
xr.testing.assert_identical(a, b)
return

if has_pint and isinstance(a, pint_types) or isinstance(b, pint_types):
assert isinstance(a, pint_types)
assert isinstance(b, pint_types)
assert a.units == b.units

a = a.magnitude
b = b.magnitude

if tolerance is None and (
np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64)
):
Expand Down
37 changes: 37 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
has_dask,
raise_if_dask_computes,
requires_dask,
requires_pint,
)

labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
Expand Down Expand Up @@ -1321,3 +1322,39 @@ def test_negative_index_factorize_race_condition():
for f in func
]
[dask.compute(out, scheduler="threads") for _ in range(5)]


@requires_pint
@pytest.mark.parametrize("func", ["all", "count", "sum", "var"])
@pytest.mark.parametrize("chunk", [True, False])
def test_pint(chunk, func):
import pint

if chunk:
d = dask.array.array([1, 2, 3])
else:
d = np.array([1, 2, 3])
q = pint.Quantity(d, units="m")

actual, _ = groupby_reduce(q, [0, 0, 1], func=func)
expected, _ = groupby_reduce(q.magnitude, [0, 0, 1], func=func)

units = None if func in ["count", "all"] else getattr(np, func)(q).units
if units is not None:
expected = pint.Quantity(expected, units=units)
assert_equal(expected, actual)


@requires_pint
@pytest.mark.parametrize("chunk", [True, False])
def test_pint_prod_error(chunk):
import pint

if chunk:
d = dask.array.array([1, 2, 3])
else:
d = np.array([1, 2, 3])
q = pint.Quantity(d, units="m")

with pytest.raises(ValueError):
groupby_reduce(q, [0, 0, 1], func="prod")

0 comments on commit 8ee7488

Please sign in to comment.