Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Lazy dak.num(..., axis=0) (or negative axis that is equivalent to axis=0) #333

Merged
merged 1 commit into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions src/dask_awkward/lib/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
from typing import TYPE_CHECKING, Any

import awkward as ak
from dask.base import is_dask_collection
import numpy as np
from awkward._nplikes.typetracer import TypeTracerArray
from dask.base import is_dask_collection, tokenize
from dask.highlevelgraph import HighLevelGraph

from dask_awkward.layers import AwkwardMaterializedLayer
from dask_awkward.lib.core import (
Array,
PartitionCompatibility,
map_partitions,
new_known_scalar,
new_scalar_object,
partition_compatibility,
)
from dask_awkward.utils import (
Expand Down Expand Up @@ -562,6 +568,10 @@ def nan_to_num(
raise DaskAwkwardNotImplemented("TODO")


def _numaxis0(*integers):
return np.sum(np.array(integers))


@borrow_docstring(ak.num)
def num(
array: Any,
Expand All @@ -571,7 +581,28 @@ def num(
) -> Any:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
if axis and axis != 0:
if axis == 0 or axis == -1 * array.ndim:
if array.known_divisions:
return new_known_scalar(array.defined_divisions[-1], label="num")

per_axis = map_partitions(
ak.num,
array,
axis=0,
meta=ak.Array(ak.Array([1, 1]).layout.to_typetracer(forget_length=True)),
)
name = f"numaxis0-{tokenize(array, axis)}"
keys = per_axis.__dask_keys__()
matlayer = AwkwardMaterializedLayer(
{(name, 0): (_numaxis0, *keys)}, previous_layer_names=[per_axis.name]
)
hlg = HighLevelGraph.from_collections(name, matlayer, dependencies=(per_axis,))
return new_scalar_object(
hlg,
name,
meta=TypeTracerArray._new(dtype=np.int64, shape=()),
)
else:
return map_partitions(
ak.num,
array,
Expand All @@ -580,9 +611,6 @@ def num(
behavior=behavior,
output_divisions=1,
)
if axis == 0:
return len(array)
raise DaskAwkwardNotImplemented("TODO")


@borrow_docstring(ak.ones_like)
Expand Down
9 changes: 4 additions & 5 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,21 @@ def test_flatten(caa: ak.Array, daa: dak.Array, axis: int | None) -> None:
assert_eq(cr, dr)


@pytest.mark.parametrize("axis", [0, 1, -1])
def test_num(caa: ak.Array, daa: dak.Array, axis: int | None) -> None:
@pytest.mark.parametrize("axis", [0, 1, -1, -2])
def test_num(caa: ak.Array, daa: dak.Array, axis: int) -> None:
da = daa["points"]
ca = caa["points"]

if axis == 0:
assert_eq(dak.num(da.x, axis=axis), ak.num(ca.x, axis=axis))
da.eager_compute_divisions()

assert_eq(dak.num(da.x, axis=axis), ak.num(ca.x, axis=axis))

if axis == 1:
c1 = dak.num(da.x, axis=axis) > 2
c2 = ak.num(ca.x, axis=axis) > 2
assert_eq(da[c1], ca[c2])

assert_eq(dak.num(da.x, axis=axis), ak.num(ca.x, axis=axis))


def test_zip_dict_input(caa: ak.Array, daa: dak.Array) -> None:
da1 = daa["points"]["x"]
Expand Down
Loading