Skip to content

Commit

Permalink
fix dak.num(..., axis=0) (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis authored Jul 27, 2023
1 parent 0c3493d commit 84e0546
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
38 changes: 33 additions & 5 deletions src/dask_awkward/lib/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
from typing import TYPE_CHECKING, Any

import awkward as ak
from dask.base import is_dask_collection
import numpy as np
from awkward._nplikes.typetracer import TypeTracerArray
from dask.base import is_dask_collection, tokenize
from dask.highlevelgraph import HighLevelGraph

from dask_awkward.layers import AwkwardMaterializedLayer
from dask_awkward.lib.core import (
Array,
PartitionCompatibility,
map_partitions,
new_known_scalar,
new_scalar_object,
partition_compatibility,
)
from dask_awkward.utils import (
Expand Down Expand Up @@ -562,6 +568,10 @@ def nan_to_num(
raise DaskAwkwardNotImplemented("TODO")


def _numaxis0(*integers):
return np.sum(np.array(integers))


@borrow_docstring(ak.num)
def num(
array: Any,
Expand All @@ -571,7 +581,28 @@ def num(
) -> Any:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
if axis and axis != 0:
if axis == 0 or axis == -1 * array.ndim:
if array.known_divisions:
return new_known_scalar(array.defined_divisions[-1], label="num")

per_axis = map_partitions(
ak.num,
array,
axis=0,
meta=ak.Array(ak.Array([1, 1]).layout.to_typetracer(forget_length=True)),
)
name = f"numaxis0-{tokenize(array, axis)}"
keys = per_axis.__dask_keys__()
matlayer = AwkwardMaterializedLayer(
{(name, 0): (_numaxis0, *keys)}, previous_layer_names=[per_axis.name]
)
hlg = HighLevelGraph.from_collections(name, matlayer, dependencies=(per_axis,))
return new_scalar_object(
hlg,
name,
meta=TypeTracerArray._new(dtype=np.int64, shape=()),
)
else:
return map_partitions(
ak.num,
array,
Expand All @@ -580,9 +611,6 @@ def num(
behavior=behavior,
output_divisions=1,
)
if axis == 0:
return len(array)
raise DaskAwkwardNotImplemented("TODO")


@borrow_docstring(ak.ones_like)
Expand Down
9 changes: 4 additions & 5 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,21 @@ def test_flatten(caa: ak.Array, daa: dak.Array, axis: int | None) -> None:
assert_eq(cr, dr)


@pytest.mark.parametrize("axis", [0, 1, -1])
def test_num(caa: ak.Array, daa: dak.Array, axis: int | None) -> None:
@pytest.mark.parametrize("axis", [0, 1, -1, -2])
def test_num(caa: ak.Array, daa: dak.Array, axis: int) -> None:
da = daa["points"]
ca = caa["points"]

if axis == 0:
assert_eq(dak.num(da.x, axis=axis), ak.num(ca.x, axis=axis))
da.eager_compute_divisions()

assert_eq(dak.num(da.x, axis=axis), ak.num(ca.x, axis=axis))

if axis == 1:
c1 = dak.num(da.x, axis=axis) > 2
c2 = ak.num(ca.x, axis=axis) > 2
assert_eq(da[c1], ca[c2])

assert_eq(dak.num(da.x, axis=axis), ak.num(ca.x, axis=axis))


def test_zip_dict_input(caa: ak.Array, daa: dak.Array) -> None:
da1 = daa["points"]["x"]
Expand Down

0 comments on commit 84e0546

Please sign in to comment.