Skip to content

Commit

Permalink
Turn of ak's typetracer _new checks during optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Mar 15, 2024
1 parent a0652da commit 5eba9fe
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dask.local import get_sync

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer
from dask_awkward.lib.utils import typetracer_nochecks
from dask_awkward.utils import first

if TYPE_CHECKING:
Expand Down Expand Up @@ -45,7 +46,8 @@ def all_optimizations(dsk: Mapping, keys: Sequence[Key], **_: Any) -> Mapping:
dsk = HighLevelGraph.from_collections(str(id(dsk)), dsk, dependencies=())

# Perform dask-awkward specific optimizations.
dsk = optimize(dsk, keys=keys)
with typetracer_nochecks():
dsk = optimize(dsk, keys=keys)
# Perform Blockwise optimizations for HLG input
dsk = optimize_blockwise(dsk, keys=keys)
# fuse nearby layers
Expand Down
14 changes: 14 additions & 0 deletions src/dask_awkward/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__all__ = ("trace_form_structure", "buffer_keys_required_to_compute_shapes")

from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping
from contextlib import contextmanager
from typing import TYPE_CHECKING, TypedDict, TypeVar

import awkward as ak
Expand Down Expand Up @@ -164,3 +165,16 @@ def impl(form: Form, key: str) -> None:
form = ak.forms.from_dict(form.to_dict())
impl(form, key)
return form


@contextmanager
def typetracer_nochecks():
from awkward._nplikes.typetracer import TypeTracerArray

oldval = getattr(TypeTracerArray, "runtime_typechecks", None)
TypeTracerArray.runtime_typechecks = False
yield
if oldval is not None:
TypeTracerArray.runtime_typechecks = oldval
else:
del TypeTracerArray.runtime_typechecks
10 changes: 10 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from dask_awkward.lib.utils import typetracer_nochecks
from dask_awkward.utils import (
LazyInputsDict,
field_access_to_front,
Expand Down Expand Up @@ -78,3 +79,12 @@ def test_field_access_to_front(pairs):
res = field_access_to_front(pairs[0])
assert res[0] == pairs[1]
assert res[1] == pairs[2]


def test_nocheck_context():
from awkward._nplikes.typetracer import TypeTracerArray

assert getattr(TypeTracerArray, "runtime_typechecks", True)
with typetracer_nochecks():
assert not TypeTracerArray.runtime_typechecks
assert getattr(TypeTracerArray, "runtime_typechecks", True)

0 comments on commit 5eba9fe

Please sign in to comment.