diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 41612edd..a1ab9d6a 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -14,6 +14,7 @@ from dask.local import get_sync from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer +from dask_awkward.lib.utils import typetracer_nochecks from dask_awkward.utils import first if TYPE_CHECKING: @@ -45,7 +46,8 @@ def all_optimizations(dsk: Mapping, keys: Sequence[Key], **_: Any) -> Mapping: dsk = HighLevelGraph.from_collections(str(id(dsk)), dsk, dependencies=()) # Perform dask-awkward specific optimizations. - dsk = optimize(dsk, keys=keys) + with typetracer_nochecks(): + dsk = optimize(dsk, keys=keys) # Perform Blockwise optimizations for HLG input dsk = optimize_blockwise(dsk, keys=keys) # fuse nearby layers diff --git a/src/dask_awkward/lib/utils.py b/src/dask_awkward/lib/utils.py index a25b3dd5..7b067386 100644 --- a/src/dask_awkward/lib/utils.py +++ b/src/dask_awkward/lib/utils.py @@ -3,6 +3,7 @@ __all__ = ("trace_form_structure", "buffer_keys_required_to_compute_shapes") from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping +from contextlib import contextmanager from typing import TYPE_CHECKING, TypedDict, TypeVar import awkward as ak @@ -164,3 +165,16 @@ def impl(form: Form, key: str) -> None: form = ak.forms.from_dict(form.to_dict()) impl(form, key) return form + + +@contextmanager +def typetracer_nochecks(): + from awkward._nplikes.typetracer import TypeTracerArray + + oldval = getattr(TypeTracerArray, "runtime_typechecks", None) + TypeTracerArray.runtime_typechecks = False + yield + if oldval is not None: + TypeTracerArray.runtime_typechecks = oldval + else: + del TypeTracerArray.runtime_typechecks diff --git a/tests/test_utils.py b/tests/test_utils.py index 3d31a964..3d1b38c7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import pytest +from dask_awkward.lib.utils import typetracer_nochecks from dask_awkward.utils import ( LazyInputsDict, field_access_to_front, @@ -78,3 +79,12 @@ def test_field_access_to_front(pairs): res = field_access_to_front(pairs[0]) assert res[0] == pairs[1] assert res[1] == pairs[2] + + +def test_nocheck_context(): + from awkward._nplikes.typetracer import TypeTracerArray + + assert getattr(TypeTracerArray, "runtime_typechecks", True) + with typetracer_nochecks(): + assert not TypeTracerArray.runtime_typechecks + assert getattr(TypeTracerArray, "runtime_typechecks", True)