Skip to content

Commit

Permalink
feat: support full UHI for rebinning
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Jan 25, 2024
1 parent 3401e19 commit 15a8e42
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 15 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ packages =
boost_histogram._internal
boost_histogram.axis
install_requires =
uhi
numpy>=1.26.0b1;python_version>='3.12'
numpy;python_version<'3.12'
typing-extensions;python_version<'3.8'
Expand Down
1 change: 0 additions & 1 deletion src/boost_histogram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .tag import ( # pylint: disable=redefined-builtin
loc,
overflow,
rebin,
sum,
underflow,
)
Expand Down
39 changes: 34 additions & 5 deletions src/boost_histogram/_internal/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from boost_histogram import _core

from .axestuple import AxesTuple
from .axis import Axis
from .axis import Axis, Variable
from .enum import Kind
from .storage import Double, Storage
from .typing import Accumulator, ArrayLike, CppHistogram, SupportsIndex
Expand Down Expand Up @@ -672,7 +672,6 @@ def _compute_uhi_index(self, index: InnerIndexing, axis: int) -> SimpleIndexing:
if index is sum or hasattr(index, "factor"): # type: ignore[comparison-overlap]
return slice(None, None, index)

# General locators
# Note that MyPy doesn't like these very much - the fix
# will be to properly set input types
if callable(index):
Expand Down Expand Up @@ -854,13 +853,15 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
if ind != slice(None):
merge = 1
if ind.step is not None:
if hasattr(ind.step, "factor"):
if ind.step.factor is not None:
merge = ind.step.factor
elif callable(ind.step):
if ind.step is sum:
integrations.add(i)
elif ind.step.groups is not None:
groups = ind.step.groups
else:
raise RuntimeError("Full UHI not supported yet")
raise NotImplementedError

if ind.start is not None or ind.stop is not None:
slices.append(
Expand All @@ -876,7 +877,10 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:

assert isinstance(start, int)
assert isinstance(stop, int)
slices.append(_core.algorithm.slice_and_rebin(i, start, stop, merge))
if not (ind.step is not None and ind.step.factor is None):
slices.append(
_core.algorithm.slice_and_rebin(i, start, stop, merge)
)

# Will be updated below
if slices or pick_set or pick_each or integrations:
Expand All @@ -885,6 +889,31 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
logger.debug("Reduce actions are all empty, just making a copy")
reduced = copy.copy(self._hist)

# bin re-grouping
if (
hasattr(ind, "step")
and ind.step is not None
and ind.step.groups is not None
):
axes = [reduced.axis(i) for i in range(reduced.rank())]
reduced_view = reduced.view(flow=True)
new_axes_indices = [axes[i].edges[0]]
j: int = 0
for group in groups:
new_axes_indices += [axes[i].edges[j + 1 : j + group + 1][-1]]
j = group

variable_axis = Variable(new_axes_indices)
variable_axis.metadata = axes[i].metadata
axes[i] = variable_axis
reduced_view = np.take(reduced_view, range(len(reduced_view)), axis=i)

logger.debug("Axes: %s", axes)

new_reduced = reduced.__class__(axes)
new_reduced.view(flow=True)[...] = reduced_view
reduced = new_reduced

if pick_each:
tuple_slice = tuple(
pick_each.get(i, slice(None)) for i in range(reduced.rank())
Expand Down
57 changes: 48 additions & 9 deletions src/boost_histogram/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

import copy
from builtins import sum
from typing import TypeVar
from typing import Mapping, Sequence, TypeVar

from uhi.typing.plottable import PlottableAxis

from ._internal.typing import AxisLike
from .axis import Regular, Variable

__all__ = ("Slicer", "Locator", "at", "loc", "overflow", "underflow", "rebin", "sum")
__all__ = ("Slicer", "Locator", "at", "loc", "overflow", "underflow", "Rebinner", "sum")


class Slicer:
Expand Down Expand Up @@ -107,13 +110,49 @@ def __call__(self, axis: AxisLike) -> int: # noqa: ARG002
return self.value


class rebin:
__slots__ = ("factor",)

def __init__(self, value: int) -> None:
class Rebinner:
__slots__ = (
"factor",
"groups",
"category_map",
)

def __init__(
self,
*,
value: int | None = None,
groups: Sequence[int] | None = None,
) -> None:
if (
sum(i is None for i in [value, groups]) == 2
or sum(i is not None for i in [value, groups]) > 1
):
raise ValueError("exactly one, a value or groups should be provided")
self.factor = value
self.groups = groups

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.factor})"

# TODO: Add __call__ to support UHI
repr_str = f"{self.__class__.__name__}"
args: dict[str, int | Sequence[int] | None] = {
"value": self.factor,
"groups": self.groups,
}
for k, v in args.items():
if v is not None:
return_str = f"{repr_str}({k}={v})"
break
return return_str

def __call__(
self, axis: PlottableAxis
) -> int | Sequence[int] | Mapping[int | str, Sequence[int | str]]:
if isinstance(axis, Regular):
if self.factor is None:
raise ValueError("must provide a value")
return self.factor
elif isinstance(axis, Variable): # noqa: RET505
if self.groups is None:
raise ValueError("must provide bin groups")
return self.groups
else:
raise NotImplementedError(axis)

0 comments on commit 15a8e42

Please sign in to comment.