Skip to content

Commit

Permalink
more typing
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis committed Oct 17, 2023
1 parent 171302d commit f4696ff
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 41 deletions.
2 changes: 1 addition & 1 deletion src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def mock(self) -> MaterializedLayer:

# when using Array.partitions we need to mock that we
# just want the first partition.
if len(task) == 2 and task[1] > 0:
if len(task) == 2 and isinstance(task[1], int) and task[1] > 0:
task = (task[0], 0)
return MaterializedLayer({(name, 0): task})
return self
Expand Down
55 changes: 24 additions & 31 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from awkward.types.type import Type
from dask.array.core import Array as DaskArray
from dask.bag.core import Bag as DaskBag
from dask.typing import Graph, Key, NestedKeys, PostComputeCallable
from numpy.typing import DTypeLike


Expand Down Expand Up @@ -122,13 +123,13 @@ def __init__(
self._meta: Any = self._check_meta(meta)
self._known_value: Any | None = known_value

def __dask_graph__(self) -> HighLevelGraph:
def __dask_graph__(self) -> Graph:
return self._dask

def __dask_keys__(self) -> list[Hashable]:
def __dask_keys__(self) -> NestedKeys:
return [self.key]

def __dask_layers__(self) -> tuple[str, ...]:
def __dask_layers__(self) -> Sequence[str]:
return (self.name,)

def __dask_tokenize__(self) -> Hashable:
Expand All @@ -140,18 +141,13 @@ def __dask_tokenize__(self) -> Hashable:

__dask_scheduler__ = staticmethod(threaded_get)

def __dask_postcompute__(self) -> tuple[Callable, tuple]:
def __dask_postcompute__(self) -> tuple[PostComputeCallable, tuple]:
return first, ()

def __dask_postpersist__(self) -> tuple[Callable, tuple]:
def __dask_postpersist__(self):
return self._rebuild, ()

def _rebuild(
self,
dsk: HighLevelGraph,
*,
rename: Mapping[str, str] | None = None,
) -> Any:
def _rebuild(self, dsk, *, rename=None):
name = self._name
if rename:
raise ValueError("rename= unsupported in dask-awkward")
Expand All @@ -169,7 +165,7 @@ def name(self) -> str:
return self._name

@property
def key(self) -> Hashable:
def key(self) -> Key:
return (self._name, 0)

def _check_meta(self, m: Any) -> Any | None:
Expand Down Expand Up @@ -227,10 +223,12 @@ def __getitem__(self, where: Any) -> Any:
token = tokenize(self, operator.getitem, where)
label = "getitem"
name = f"{label}-{token}"
d = self.to_delayed(optimize_graph=True)
task = {name: (operator.getitem, d.key, where)}
hlg = HighLevelGraph.from_collections(name, task, dependencies=(d,))
return Delayed(name, hlg)
task = AwkwardMaterializedLayer(
{(name, 0): (operator.getitem, self.key, where)},
previous_layer_names=[self.name],
)
hlg = HighLevelGraph.from_collections(name, task, dependencies=[self])
return new_scalar_object(hlg, name, meta=None)

def __getattr__(self, attr: str) -> Any:
d = self.to_delayed(optimize_graph=True)
Expand Down Expand Up @@ -343,7 +341,7 @@ def new_known_scalar(
dtype = np.dtype(type(s))
else:
dtype = np.dtype(dtype)
llg = {(name, 0): s}
llg = AwkwardMaterializedLayer({(name, 0): s}, previous_layer_names=[])
hlg = HighLevelGraph.from_collections(name, llg, dependencies=())
return Scalar(
hlg, name, meta=TypeTracerArray._new(dtype=dtype, shape=()), known_value=s
Expand Down Expand Up @@ -489,17 +487,17 @@ def __init__(
dsk: HighLevelGraph,
name: str,
meta: ak.Array,
divisions: tuple[int | None, ...],
divisions: tuple[int, ...] | tuple[None, ...],
) -> None:
self._dask: HighLevelGraph = dsk
self._name: str = name
self._divisions: tuple[int | None, ...] = divisions
self._divisions: tuple[int, ...] | tuple[None, ...] = divisions
self._meta: ak.Array = meta

def __dask_graph__(self) -> HighLevelGraph:
return self.dask

def __dask_keys__(self) -> list[Hashable]:
def __dask_keys__(self) -> NestedKeys:
return [(self.name, i) for i in range(self.npartitions)]

def __dask_layers__(self) -> tuple[str]:
Expand All @@ -511,7 +509,7 @@ def __dask_tokenize__(self) -> Hashable:
def __dask_postcompute__(self) -> tuple[Callable, tuple]:
return _finalize_array, ()

def __dask_postpersist__(self) -> tuple[Callable, tuple]:
def __dask_postpersist__(self):
return self._rebuild, ()

__dask_optimize__ = globalmethod(
Expand Down Expand Up @@ -540,12 +538,7 @@ def __setitem__(self, where: Any, what: Any) -> None:
self._dask = appended._dask
self._name = appended._name

def _rebuild(
self,
dsk: HighLevelGraph,
*,
rename: Mapping[str, str] | None = None,
) -> Array:
def _rebuild(self, dsk, *, rename=None):
name = self.name
if rename:
raise ValueError("rename= unsupported in dask-awkward")
Expand Down Expand Up @@ -666,7 +659,7 @@ def dask(self) -> HighLevelGraph:
return self._dask

@property
def keys(self) -> list[Hashable]:
def keys(self) -> NestedKeys:
"""Task graph keys."""
return self.__dask_keys__()

Expand All @@ -682,7 +675,7 @@ def ndim(self) -> int:
return self._meta.ndim

@property
def divisions(self) -> tuple[int | None, ...]:
def divisions(self) -> tuple[int, ...] | tuple[None, ...]:
"""Location of the collections partition boundaries."""
return self._divisions

Expand Down Expand Up @@ -1379,7 +1372,7 @@ def new_array_object(
meta: ak.Array | None = None,
behavior: dict | None = None,
npartitions: int | None = None,
divisions: tuple[int | None, ...] | None = None,
divisions: tuple[int, ...] | tuple[None, ...] | None = None,
) -> Array:
"""Instantiate a new Array collection object.
Expand Down Expand Up @@ -1411,7 +1404,7 @@ def new_array_object(
"""
if divisions is None:
if npartitions is not None:
divs: tuple[int | None, ...] = (None,) * (npartitions + 1)
divs: tuple[int, ...] | tuple[None, ...] = (None,) * (npartitions + 1)
else:
raise ValueError("One of either divisions or npartitions must be defined.")
else:
Expand Down
23 changes: 17 additions & 6 deletions src/dask_awkward/lib/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def from_delayed(
source: list[Delayed] | Delayed,
meta: ak.Array | None = None,
behavior: dict | None = None,
divisions: tuple[int | None, ...] | None = None,
divisions: tuple[int, ...] | tuple[None, ...] | None = None,
prefix: str = "from-delayed",
) -> Array:
"""Create an Array collection from a set of :class:`~dask.delayed.Delayed` objects.
Expand Down Expand Up @@ -185,17 +185,21 @@ def from_delayed(
name = f"{prefix}-{tokenize(parts)}"
dsk = AwkwardMaterializedLayer(
{(name, i): part.key for i, part in enumerate(parts)},
previous_layer_names=[parts[0].name],
previous_layer_names=[parts[0].key],
)
if divisions is None:
divs: tuple[int | None, ...] = (None,) * (len(parts) + 1)
divs: tuple[int, ...] | tuple[None, ...] = (None,) * (len(parts) + 1)
else:
divs = tuple(divisions)
divs = divisions
if len(divs) != len(parts) + 1:
raise ValueError("divisions must be a tuple of length len(source) + 1")
hlg = HighLevelGraph.from_collections(name, dsk, dependencies=parts)
return new_array_object(
hlg, name=name, meta=meta, behavior=behavior, divisions=divs
hlg,
name=name,
meta=meta,
behavior=behavior,
divisions=divs,
)


Expand Down Expand Up @@ -320,7 +324,14 @@ def to_dask_array(
for i, k in enumerate(flatten(array.__dask_keys__()))
}

graph = HighLevelGraph.from_collections(name, llg, dependencies=[array])
graph = HighLevelGraph.from_collections(
name,
AwkwardMaterializedLayer(
llg,
previous_layer_names=[array.name],
),
dependencies=[array],
)
return new_da_object(graph, name, meta=None, chunks=chunks, dtype=dtype)


Expand Down
7 changes: 6 additions & 1 deletion src/dask_awkward/lib/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from fsspec.core import get_fs_token_paths, url_to_fs
from fsspec.utils import infer_compression, read_block

from dask_awkward.layers.layers import AwkwardMaterializedLayer
from dask_awkward.lib.core import map_partitions, new_scalar_object, typetracer_array
from dask_awkward.lib.io.columnar import ColumnProjectionMixin
from dask_awkward.lib.io.io import (
Expand Down Expand Up @@ -746,7 +747,11 @@ def to_json(
map_res.dask.layers[map_res.name].annotations = {"ak_output": True}
name = f"to-json-{tokenize(array, path)}"
dsk = {(name, 0): (lambda *_: None, map_res.__dask_keys__())}
graph = HighLevelGraph.from_collections(name, dsk, dependencies=(map_res,))
graph = HighLevelGraph.from_collections(
name,
AwkwardMaterializedLayer(dsk, previous_layer_names=[map_res.name]),
dependencies=(map_res,),
)
res = new_scalar_object(graph, name=name, meta=None)
if compute:
res.compute()
Expand Down
7 changes: 6 additions & 1 deletion src/dask_awkward/lib/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from fsspec import AbstractFileSystem
from fsspec.core import get_fs_token_paths, url_to_fs

from dask_awkward.layers.layers import AwkwardMaterializedLayer
from dask_awkward.lib.core import Array, Scalar, map_partitions, new_scalar_object
from dask_awkward.lib.io.columnar import ColumnProjectionMixin
from dask_awkward.lib.io.io import from_map
Expand Down Expand Up @@ -609,7 +610,11 @@ def to_parquet(
else:
final_name = name + "-finalize"
dsk[(final_name, 0)] = (lambda *_: None, map_res.__dask_keys__())
graph = HighLevelGraph.from_collections(final_name, dsk, dependencies=[map_res])
graph = HighLevelGraph.from_collections(
final_name,
AwkwardMaterializedLayer(dsk, previous_layer_names=[map_res.name]),
dependencies=[map_res],
)
out = new_scalar_object(graph, final_name, meta=None)
if compute:
out.compute()
Expand Down
2 changes: 1 addition & 1 deletion src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def all_optimizations(
keys = tuple(flatten(keys))

if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
dsk = HighLevelGraph.from_collections(str(id(dsk)), dsk, dependencies=())

else:
# Perform dask-awkward specific optimizations.
Expand Down

0 comments on commit f4696ff

Please sign in to comment.