From f0c8be808039280a5c3437fd0c5580493beb60b3 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Thu, 8 Jun 2023 21:50:23 -0500 Subject: [PATCH] feat: add `ak.typetracer.length_one_if_typetracer` (#2499) * feat: length_one_if_typetracer. * feat: add `_smallest_zero_buffer_lengths` generator This enables us to intelligently select the smallest possible buffer * docs: update comments * fix: missing forms * fix: support NumpyArray with `inner_shape` * Remove _smallest_zero_buffer_lengths. * Prepare form + container via https://github.com/scikit-hep/awkward/pull/2499#issuecomment-1579572174 * index_size_bytes is now dead code. * Fix Windows. * Deprecate 'empty_if_typetracer' in favor of 'length_zero_if_typetracer'. --------- Co-authored-by: Angus Hollands --- src/awkward/forms/emptyform.py | 6 +- src/awkward/forms/form.py | 104 ++++++++++++++++--- src/awkward/forms/listoffsetform.py | 13 ++- src/awkward/index.py | 29 +++--- src/awkward/typetracer.py | 23 +++- tests/test_2085_empty_if_typetracer.py | 31 +++++- tests/test_2501_positional_record_reducer.py | 6 +- 7 files changed, 171 insertions(+), 41 deletions(-) diff --git a/src/awkward/forms/emptyform.py b/src/awkward/forms/emptyform.py index 430745a9b3..b0b7686eb7 100644 --- a/src/awkward/forms/emptyform.py +++ b/src/awkward/forms/emptyform.py @@ -3,7 +3,8 @@ import awkward as ak from awkward._errors import deprecate -from awkward._typing import final +from awkward._nplikes.shape import ShapeItem +from awkward._typing import Iterator, final from awkward._util import UNSET from awkward.forms.form import Form, JSONMapping @@ -103,3 +104,6 @@ def _select_columns(self, index, specifier, matches, output): def _column_types(self) -> tuple[str, ...]: return ("empty",) + + def _length_one_buffer_lengths(self) -> Iterator[ShapeItem]: + yield 0 diff --git a/src/awkward/forms/form.py b/src/awkward/forms/form.py index a980947157..d73296479d 100644 --- a/src/awkward/forms/form.py +++ b/src/awkward/forms/form.py @@ -10,7 +10,7 @@ from awkward._backends.numpy import NumpyBackend from awkward._errors import deprecate from awkward._nplikes.numpylike import NumpyMetadata -from awkward._nplikes.shape import unknown_length +from awkward._nplikes.shape import ShapeItem, unknown_length from awkward._parameters import parameters_union from awkward._typing import Final, JSONMapping, JSONSerializable @@ -453,19 +453,99 @@ def length_zero_array( ) def length_one_array(self, *, backend=numpy_backend, highlevel=True, behavior=None): - # A length-1 array will need at least N bytes, where N is the largest dtype (e.g. 256 bit complex) - # Similarly, a length-1 array will need no more than 2*N bytes, as all contents need at most two - # index-types e.g. `ListOffsetArray.offsets` for their various index metadata. Therefore, we - # create a buffer of this length (2N) and instruct all contents to use it (via `buffer_key=""`). - # At the same time, with all index-like metadata set to 0, the list types will have zero lengths - # whilst unions, indexed, and option types will contain a single value. + # The naive implementation of a length-1 array requires that we have a sufficiently + # large buffer to be able to build _any_ subtree. + def max_prefer_unknown(this: ShapeItem, that: ShapeItem) -> ShapeItem: + if this is unknown_length: + return this + if that is unknown_length: + return that + return max(this, that) + + container = {} + + def prepare(form, multiplier): + form_key = f"node-{len(container)}" + + if isinstance(form, (ak.forms.BitMaskedForm, ak.forms.ByteMaskedForm)): + if form.valid_when: + container[form_key] = b"\x00" * multiplier + else: + container[form_key] = b"\xff" * multiplier + return form.copy(form_key=form_key) # DO NOT RECURSE + + elif isinstance(form, ak.forms.IndexedOptionForm): + container[form_key] = b"\xff\xff\xff\xff\xff\xff\xff\xff" # -1 + return form.copy(form_key=form_key) # DO NOT RECURSE + + elif isinstance(form, ak.forms.EmptyForm): + # no error if protected by non-recursing node type + raise TypeError( + "cannot generate a length_one_array from a Form with an " + "unknowntype that cannot be hidden (EmptyForm not within " + "BitMaskedForm, ByteMaskedForm, or IndexedOptionForm)" + ) + + elif isinstance(form, ak.forms.UnmaskedForm): + return form.copy(content=prepare(form.content, multiplier)) + + elif isinstance(form, (ak.forms.IndexedForm, ak.forms.ListForm)): + container[form_key] = b"\x00" * (8 * multiplier) + return form.copy( + content=prepare(form.content, multiplier), form_key=form_key + ) + + elif isinstance(form, ak.forms.ListOffsetForm): + # offsets length == array length + 1 + container[form_key] = b"\x00" * (8 * (multiplier + 1)) + return form.copy( + content=prepare(form.content, multiplier), form_key=form_key + ) + + elif isinstance(form, ak.forms.RegularForm): + size = form.size + + # https://github.com/scikit-hep/awkward/pull/2499#discussion_r1220503454 + if size is unknown_length: + size = 1 + + return form.copy(content=prepare(form.content, multiplier * size)) + + elif isinstance(form, ak.forms.NumpyForm): + dtype = ak.types.numpytype.primitive_to_dtype(form._primitive) + size = multiplier * dtype.itemsize + for x in form.inner_shape: + if x is not unknown_length: + size *= x + + container[form_key] = b"\x00" * size + return form.copy(form_key=form_key) + + elif isinstance(form, ak.forms.RecordForm): + return form.copy( + # recurse down all contents + contents=[prepare(x, multiplier) for x in form.contents] + ) + + elif isinstance(form, ak.forms.UnionForm): + # both tags and index will get this buffer, but index is 8 bytes + container[form_key] = b"\x00" * (8 * multiplier) + return form.copy( + # only recurse down contents[0] because all index == 0 + contents=( + [prepare(form.contents[0], multiplier)] + form.contents[1:] + ), + form_key=form_key, + ) + + else: + raise AssertionError(f"not a Form: {form!r}") + return ak.operations.ak_from_buffers._impl( - form=self, + form=prepare(self, 1), length=1, - container={ - "": b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - }, - buffer_key="", + container=container, + buffer_key="{form_key}", backend=backend, byteorder=ak._util.native_byteorder, highlevel=highlevel, diff --git a/src/awkward/forms/listoffsetform.py b/src/awkward/forms/listoffsetform.py index bab844fbab..98a66be299 100644 --- a/src/awkward/forms/listoffsetform.py +++ b/src/awkward/forms/listoffsetform.py @@ -1,7 +1,9 @@ # BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE +from __future__ import annotations + import awkward as ak from awkward._parameters import type_parameters_equal -from awkward._typing import final +from awkward._typing import JSONMapping, final from awkward._util import UNSET from awkward.forms.form import Form @@ -10,7 +12,14 @@ class ListOffsetForm(Form): is_list = True - def __init__(self, offsets, content, *, parameters=None, form_key=None): + def __init__( + self, + offsets: str, + content: Form, + *, + parameters: JSONMapping | None = None, + form_key: str | None = None, + ): if not isinstance(offsets, str): raise TypeError( "{} 'offsets' must be of type str, not {}".format( diff --git a/src/awkward/index.py b/src/awkward/index.py index c9f1347ca0..3660aad9ae 100644 --- a/src/awkward/index.py +++ b/src/awkward/index.py @@ -10,13 +10,13 @@ from awkward._nplikes.numpy import Numpy from awkward._nplikes.numpylike import NumpyLike, NumpyMetadata from awkward._nplikes.typetracer import TypeTracer -from awkward._typing import Self +from awkward._typing import Final, Self -np = NumpyMetadata.instance() -numpy = Numpy.instance() +np: Final = NumpyMetadata.instance() +numpy: Final = Numpy.instance() -_dtype_to_form = { +_dtype_to_form: Final = { np.dtype(np.int8): "i8", np.dtype(np.uint8): "u8", np.dtype(np.int32): "i32", @@ -24,20 +24,15 @@ np.dtype(np.int64): "i64", } +_form_to_dtype: Final = {v: k for k, v in _dtype_to_form.items()} -def _form_to_zero_length(form): - if form == "i8": - return Index8(numpy.zeros(0, dtype=np.int8)) - elif form == "u8": - return IndexU8(numpy.zeros(0, dtype=np.uint8)) - elif form == "i32": - return Index32(numpy.zeros(0, dtype=np.int32)) - elif form == "u32": - return IndexU32(numpy.zeros(0, dtype=np.uint32)) - elif form == "i64": - return Index64(numpy.zeros(0, dtype=np.int64)) - else: - raise AssertionError(f"unrecognized Index form: {form!r}") + +def _form_to_zero_length(form: str) -> Index: + try: + dtype = _form_to_dtype[form] + except KeyError: + raise AssertionError(f"unrecognized Index form: {form!r}") from None + return Index(numpy.zeros(0, dtype=dtype)) class Index: diff --git a/src/awkward/typetracer.py b/src/awkward/typetracer.py index 9567926a42..0fba624ce9 100644 --- a/src/awkward/typetracer.py +++ b/src/awkward/typetracer.py @@ -12,6 +12,7 @@ from awkward._backends.typetracer import TypeTracerBackend from awkward._behavior import behavior_of +from awkward._errors import deprecate from awkward._layout import wrap_layout from awkward._nplikes.placeholder import PlaceholderArray from awkward._nplikes.shape import unknown_length @@ -22,13 +23,14 @@ typetracer_with_report, ) from awkward._typing import TypeVar +from awkward.forms import Form from awkward.highlevel import Array, Record from awkward.operations.ak_to_layout import to_layout T = TypeVar("T", Array, Record) -def empty_if_typetracer(array: T) -> T: +def _length_0_1_if_typetracer(array: T, function): typetracer_backend = TypeTracerBackend.instance() layout = to_layout(array, allow_other=False) @@ -36,6 +38,23 @@ def empty_if_typetracer(array: T) -> T: if layout.backend is typetracer_backend: layout._touch_data(True) - layout = layout.form.length_zero_array(highlevel=False) + layout = function(layout.form, highlevel=False) return wrap_layout(layout, behavior=behavior) + + +def empty_if_typetracer(array: T) -> T: + deprecate( + "'empty_if_typetracer' is being replaced by 'length_zero_if_typetracer' (change name)", + "2.4.0", + ) + + return length_zero_if_typetracer(array) + + +def length_zero_if_typetracer(array: T) -> T: + return _length_0_1_if_typetracer(array, Form.length_zero_array) + + +def length_one_if_typetracer(array: T) -> T: + return _length_0_1_if_typetracer(array, Form.length_one_array) diff --git a/tests/test_2085_empty_if_typetracer.py b/tests/test_2085_empty_if_typetracer.py index 690ca4a745..a17d205eb2 100644 --- a/tests/test_2085_empty_if_typetracer.py +++ b/tests/test_2085_empty_if_typetracer.py @@ -1,19 +1,30 @@ # BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE import numpy as np -import pytest # noqa: F401 +import pytest import awkward as ak -from awkward.typetracer import empty_if_typetracer, typetracer_with_report +from awkward.typetracer import ( + length_one_if_typetracer, + length_zero_if_typetracer, + typetracer_with_report, +) -def test(): +@pytest.mark.parametrize( + "function", [length_zero_if_typetracer, length_one_if_typetracer] +) +def test_typetracer(function): def func(array): assert ak.backend(array) == "typetracer" radius = np.sqrt(array.x**2 + array.y**2) - radius = empty_if_typetracer(radius) + radius = function(radius) assert ak.backend(radius) == "cpu" + if function is length_zero_if_typetracer: + assert len(radius) == 0 + else: + assert len(radius) == 1 hist_contents, hist_edges = np.histogram(ak.flatten(radius, axis=None)) @@ -34,3 +45,15 @@ def func(array): func(meta) assert report.data_touched == ["node0", "node2", "node3"] + + +@pytest.mark.parametrize("regulararray", [False, True]) +def test_multiplier(regulararray): + a = np.arange(2 * 3 * 5, dtype=np.int64).reshape(2, 3, 5) + + b = ak.from_numpy(a, regulararray=regulararray) + assert str(b.type) == "2 * 3 * 5 * int64" + + c = b.layout.form.length_one_array() + assert str(c.type) == "1 * 3 * 5 * int64" + assert c.tolist() == [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]] diff --git a/tests/test_2501_positional_record_reducer.py b/tests/test_2501_positional_record_reducer.py index 314091498a..33e87c8cb4 100644 --- a/tests/test_2501_positional_record_reducer.py +++ b/tests/test_2501_positional_record_reducer.py @@ -6,7 +6,7 @@ def _min_pair(array, mask): - array = ak.typetracer.empty_if_typetracer(array) + array = ak.typetracer.length_zero_if_typetracer(array) # Find location of minimum 0 slot i_min = ak.argmin(array["0"], axis=-1, keepdims=True, mask_identity=True) @@ -25,7 +25,7 @@ def _min_pair(array, mask): def _argmin_pair(array, mask): - array = ak.typetracer.empty_if_typetracer(array) + array = ak.typetracer.length_zero_if_typetracer(array) assert not mask # Find location of minimum 0 slot @@ -33,7 +33,7 @@ def _argmin_pair(array, mask): def _argmin_pair_bad(array, mask): - array = ak.typetracer.empty_if_typetracer(array) + array = ak.typetracer.length_zero_if_typetracer(array) assert not mask # Find location of minimum 0 slot