Skip to content

Commit

Permalink
feat: add ak.typetracer.length_one_if_typetracer (#2499)
Browse files Browse the repository at this point in the history
* feat: length_one_if_typetracer.

* feat: add `_smallest_zero_buffer_lengths` generator

This enables us to intelligently select the smallest possible buffer

* docs: update comments

* fix: missing forms

* fix: support NumpyArray with `inner_shape`

* Remove _smallest_zero_buffer_lengths.

* Prepare form + container via #2499 (comment)

* index_size_bytes is now dead code.

* Fix Windows.

* Deprecate 'empty_if_typetracer' in favor of 'length_zero_if_typetracer'.

---------

Co-authored-by: Angus Hollands <[email protected]>
  • Loading branch information
jpivarski and agoose77 authored Jun 9, 2023
1 parent d34b21e commit f0c8be8
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 41 deletions.
6 changes: 5 additions & 1 deletion src/awkward/forms/emptyform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import awkward as ak
from awkward._errors import deprecate
from awkward._typing import final
from awkward._nplikes.shape import ShapeItem
from awkward._typing import Iterator, final
from awkward._util import UNSET
from awkward.forms.form import Form, JSONMapping

Expand Down Expand Up @@ -103,3 +104,6 @@ def _select_columns(self, index, specifier, matches, output):

def _column_types(self) -> tuple[str, ...]:
return ("empty",)

def _length_one_buffer_lengths(self) -> Iterator[ShapeItem]:
yield 0
104 changes: 92 additions & 12 deletions src/awkward/forms/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from awkward._backends.numpy import NumpyBackend
from awkward._errors import deprecate
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.shape import unknown_length
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._parameters import parameters_union
from awkward._typing import Final, JSONMapping, JSONSerializable

Expand Down Expand Up @@ -453,19 +453,99 @@ def length_zero_array(
)

def length_one_array(self, *, backend=numpy_backend, highlevel=True, behavior=None):
# A length-1 array will need at least N bytes, where N is the largest dtype (e.g. 256 bit complex)
# Similarly, a length-1 array will need no more than 2*N bytes, as all contents need at most two
# index-types e.g. `ListOffsetArray.offsets` for their various index metadata. Therefore, we
# create a buffer of this length (2N) and instruct all contents to use it (via `buffer_key=""`).
# At the same time, with all index-like metadata set to 0, the list types will have zero lengths
# whilst unions, indexed, and option types will contain a single value.
# The naive implementation of a length-1 array requires that we have a sufficiently
# large buffer to be able to build _any_ subtree.
def max_prefer_unknown(this: ShapeItem, that: ShapeItem) -> ShapeItem:
if this is unknown_length:
return this
if that is unknown_length:
return that
return max(this, that)

container = {}

def prepare(form, multiplier):
form_key = f"node-{len(container)}"

if isinstance(form, (ak.forms.BitMaskedForm, ak.forms.ByteMaskedForm)):
if form.valid_when:
container[form_key] = b"\x00" * multiplier
else:
container[form_key] = b"\xff" * multiplier
return form.copy(form_key=form_key) # DO NOT RECURSE

elif isinstance(form, ak.forms.IndexedOptionForm):
container[form_key] = b"\xff\xff\xff\xff\xff\xff\xff\xff" # -1
return form.copy(form_key=form_key) # DO NOT RECURSE

elif isinstance(form, ak.forms.EmptyForm):
# no error if protected by non-recursing node type
raise TypeError(
"cannot generate a length_one_array from a Form with an "
"unknowntype that cannot be hidden (EmptyForm not within "
"BitMaskedForm, ByteMaskedForm, or IndexedOptionForm)"
)

elif isinstance(form, ak.forms.UnmaskedForm):
return form.copy(content=prepare(form.content, multiplier))

elif isinstance(form, (ak.forms.IndexedForm, ak.forms.ListForm)):
container[form_key] = b"\x00" * (8 * multiplier)
return form.copy(
content=prepare(form.content, multiplier), form_key=form_key
)

elif isinstance(form, ak.forms.ListOffsetForm):
# offsets length == array length + 1
container[form_key] = b"\x00" * (8 * (multiplier + 1))
return form.copy(
content=prepare(form.content, multiplier), form_key=form_key
)

elif isinstance(form, ak.forms.RegularForm):
size = form.size

# https://github.com/scikit-hep/awkward/pull/2499#discussion_r1220503454
if size is unknown_length:
size = 1

return form.copy(content=prepare(form.content, multiplier * size))

elif isinstance(form, ak.forms.NumpyForm):
dtype = ak.types.numpytype.primitive_to_dtype(form._primitive)
size = multiplier * dtype.itemsize
for x in form.inner_shape:
if x is not unknown_length:
size *= x

container[form_key] = b"\x00" * size
return form.copy(form_key=form_key)

elif isinstance(form, ak.forms.RecordForm):
return form.copy(
# recurse down all contents
contents=[prepare(x, multiplier) for x in form.contents]
)

elif isinstance(form, ak.forms.UnionForm):
# both tags and index will get this buffer, but index is 8 bytes
container[form_key] = b"\x00" * (8 * multiplier)
return form.copy(
# only recurse down contents[0] because all index == 0
contents=(
[prepare(form.contents[0], multiplier)] + form.contents[1:]
),
form_key=form_key,
)

else:
raise AssertionError(f"not a Form: {form!r}")

return ak.operations.ak_from_buffers._impl(
form=self,
form=prepare(self, 1),
length=1,
container={
"": b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
},
buffer_key="",
container=container,
buffer_key="{form_key}",
backend=backend,
byteorder=ak._util.native_byteorder,
highlevel=highlevel,
Expand Down
13 changes: 11 additions & 2 deletions src/awkward/forms/listoffsetform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
from __future__ import annotations

import awkward as ak
from awkward._parameters import type_parameters_equal
from awkward._typing import final
from awkward._typing import JSONMapping, final
from awkward._util import UNSET
from awkward.forms.form import Form

Expand All @@ -10,7 +12,14 @@
class ListOffsetForm(Form):
is_list = True

def __init__(self, offsets, content, *, parameters=None, form_key=None):
def __init__(
self,
offsets: str,
content: Form,
*,
parameters: JSONMapping | None = None,
form_key: str | None = None,
):
if not isinstance(offsets, str):
raise TypeError(
"{} 'offsets' must be of type str, not {}".format(
Expand Down
29 changes: 12 additions & 17 deletions src/awkward/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,29 @@
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyLike, NumpyMetadata
from awkward._nplikes.typetracer import TypeTracer
from awkward._typing import Self
from awkward._typing import Final, Self

np = NumpyMetadata.instance()
numpy = Numpy.instance()
np: Final = NumpyMetadata.instance()
numpy: Final = Numpy.instance()


_dtype_to_form = {
_dtype_to_form: Final = {
np.dtype(np.int8): "i8",
np.dtype(np.uint8): "u8",
np.dtype(np.int32): "i32",
np.dtype(np.uint32): "u32",
np.dtype(np.int64): "i64",
}

_form_to_dtype: Final = {v: k for k, v in _dtype_to_form.items()}

def _form_to_zero_length(form):
if form == "i8":
return Index8(numpy.zeros(0, dtype=np.int8))
elif form == "u8":
return IndexU8(numpy.zeros(0, dtype=np.uint8))
elif form == "i32":
return Index32(numpy.zeros(0, dtype=np.int32))
elif form == "u32":
return IndexU32(numpy.zeros(0, dtype=np.uint32))
elif form == "i64":
return Index64(numpy.zeros(0, dtype=np.int64))
else:
raise AssertionError(f"unrecognized Index form: {form!r}")

def _form_to_zero_length(form: str) -> Index:
try:
dtype = _form_to_dtype[form]
except KeyError:
raise AssertionError(f"unrecognized Index form: {form!r}") from None
return Index(numpy.zeros(0, dtype=dtype))


class Index:
Expand Down
23 changes: 21 additions & 2 deletions src/awkward/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from awkward._backends.typetracer import TypeTracerBackend
from awkward._behavior import behavior_of
from awkward._errors import deprecate
from awkward._layout import wrap_layout
from awkward._nplikes.placeholder import PlaceholderArray
from awkward._nplikes.shape import unknown_length
Expand All @@ -22,20 +23,38 @@
typetracer_with_report,
)
from awkward._typing import TypeVar
from awkward.forms import Form
from awkward.highlevel import Array, Record
from awkward.operations.ak_to_layout import to_layout

T = TypeVar("T", Array, Record)


def empty_if_typetracer(array: T) -> T:
def _length_0_1_if_typetracer(array: T, function):
typetracer_backend = TypeTracerBackend.instance()

layout = to_layout(array, allow_other=False)
behavior = behavior_of(array)

if layout.backend is typetracer_backend:
layout._touch_data(True)
layout = layout.form.length_zero_array(highlevel=False)
layout = function(layout.form, highlevel=False)

return wrap_layout(layout, behavior=behavior)


def empty_if_typetracer(array: T) -> T:
deprecate(
"'empty_if_typetracer' is being replaced by 'length_zero_if_typetracer' (change name)",
"2.4.0",
)

return length_zero_if_typetracer(array)


def length_zero_if_typetracer(array: T) -> T:
return _length_0_1_if_typetracer(array, Form.length_zero_array)


def length_one_if_typetracer(array: T) -> T:
return _length_0_1_if_typetracer(array, Form.length_one_array)
31 changes: 27 additions & 4 deletions tests/test_2085_empty_if_typetracer.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import numpy as np
import pytest # noqa: F401
import pytest

import awkward as ak
from awkward.typetracer import empty_if_typetracer, typetracer_with_report
from awkward.typetracer import (
length_one_if_typetracer,
length_zero_if_typetracer,
typetracer_with_report,
)


def test():
@pytest.mark.parametrize(
"function", [length_zero_if_typetracer, length_one_if_typetracer]
)
def test_typetracer(function):
def func(array):
assert ak.backend(array) == "typetracer"

radius = np.sqrt(array.x**2 + array.y**2)
radius = empty_if_typetracer(radius)
radius = function(radius)
assert ak.backend(radius) == "cpu"
if function is length_zero_if_typetracer:
assert len(radius) == 0
else:
assert len(radius) == 1

hist_contents, hist_edges = np.histogram(ak.flatten(radius, axis=None))

Expand All @@ -34,3 +45,15 @@ def func(array):

func(meta)
assert report.data_touched == ["node0", "node2", "node3"]


@pytest.mark.parametrize("regulararray", [False, True])
def test_multiplier(regulararray):
a = np.arange(2 * 3 * 5, dtype=np.int64).reshape(2, 3, 5)

b = ak.from_numpy(a, regulararray=regulararray)
assert str(b.type) == "2 * 3 * 5 * int64"

c = b.layout.form.length_one_array()
assert str(c.type) == "1 * 3 * 5 * int64"
assert c.tolist() == [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]
6 changes: 3 additions & 3 deletions tests/test_2501_positional_record_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def _min_pair(array, mask):
array = ak.typetracer.empty_if_typetracer(array)
array = ak.typetracer.length_zero_if_typetracer(array)

# Find location of minimum 0 slot
i_min = ak.argmin(array["0"], axis=-1, keepdims=True, mask_identity=True)
Expand All @@ -25,15 +25,15 @@ def _min_pair(array, mask):


def _argmin_pair(array, mask):
array = ak.typetracer.empty_if_typetracer(array)
array = ak.typetracer.length_zero_if_typetracer(array)

assert not mask
# Find location of minimum 0 slot
return ak.argmin(array["0"], axis=-1, keepdims=False, mask_identity=mask)


def _argmin_pair_bad(array, mask):
array = ak.typetracer.empty_if_typetracer(array)
array = ak.typetracer.length_zero_if_typetracer(array)

assert not mask
# Find location of minimum 0 slot
Expand Down

0 comments on commit f0c8be8

Please sign in to comment.