Skip to content

Commit

Permalink
feat: added a commit method on TypeTracerReport to identify touch…
Browse files Browse the repository at this point in the history
…ed buffers in the Dask DAG-building pass (#3043)

* Report.data_touched can no longer be ordered (that's okay, right?)

* implemented TypeTracerReports as byte/bit-sets with a 'commit' method

* refactor: small changes

---------

Co-authored-by: Angus Hollands <[email protected]>
  • Loading branch information
jpivarski and agoose77 authored Mar 21, 2024
1 parent f2e777f commit bf18460
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 41 deletions.
139 changes: 114 additions & 25 deletions src/awkward/_nplikes/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from __future__ import annotations

from collections.abc import Sequence
from collections.abc import Collection, Sequence, Set
from numbers import Number
from typing import Callable
from typing import Callable, Iterator

import numpy

Expand Down Expand Up @@ -125,39 +125,121 @@ def __str__(self):
)


class ImmutableBitSet(Set):
def __init__(self, byteset: FillableByteSet):
self._labels: dict[str, int] = byteset._labels
if not byteset._is_filled.any():
self._is_filled = None
else:
self._is_filled = numpy.packbits(byteset._is_filled)

def __contains__(self, label: object) -> bool:
if self._is_filled is None or label not in self._labels:
return False
else:
assert isinstance(label, str)
return numpy.unpackbits(self._is_filled)[self._labels[label]]

def __iter__(self) -> Iterator[str]:
if self._is_filled is not None:
is_filled = numpy.unpackbits(self._is_filled)
for label, index in self._labels.items():
if is_filled[index]:
yield label

def __len__(self) -> int:
if self._is_filled is None:
return 0
else:
return int(numpy.unpackbits(self._is_filled).sum())


class FillableByteSet(Set):
# friend class ImmutableBitSet

def __init__(self, labels: Collection[str]):
self._labels = {label: i for i, label in enumerate(labels)}
self._is_filled = numpy.zeros(len(labels), dtype=numpy.bool_)

def add(self, label: str) -> None:
self._is_filled[self._labels[label]] = True

def to_bitset(self) -> ImmutableBitSet:
return ImmutableBitSet(self)

def clear(self) -> None:
self._is_filled.fill(False)

def __contains__(self, label: object) -> bool:
if label not in self._labels:
return False
else:
assert isinstance(label, str)
return self._is_filled[self._labels[label]]

def __iter__(self) -> Iterator[str]:
for label, index in self._labels.items():
if self._is_filled[index]:
yield label

def __len__(self) -> int:
return int(self._is_filled.sum())


class TypeTracerReport:
def __init__(self):
# maybe the order will be useful information
self._shape_touched_set = set()
self._shape_touched = []
self._data_touched_set = set()
self._data_touched = []
self._node_id_to_shape_touched: dict[str, ImmutableBitSet] = {}
self._node_id_to_data_touched: dict[str, ImmutableBitSet] = {}

def __repr__(self):
return f"<TypeTracerReport with {len(self._shape_touched)} shape_touched, {len(self._data_touched)} data_touched>"
return (
f"<TypeTracerReport with {len(self._shape_touched_set)} shape_touched, "
f"{len(self._data_touched_set)} data_touched>"
)

@property
def shape_touched(self):
return self._shape_touched
def set_labels(self, labels: Collection[str]):
self._shape_touched_set = FillableByteSet(labels)
self._data_touched_set = FillableByteSet(labels)

@property
def data_touched(self):
return self._data_touched

def touch_shape(self, label):
if label not in self._shape_touched_set:
self._shape_touched_set.add(label)
self._shape_touched.append(label)
def shape_touched(self) -> list[str]:
return list(self._shape_touched_set)

def touch_data(self, label):
if label not in self._data_touched_set:
# touching data implies that the shape will be touched as well
# implemented here so that the codebase doesn't need to be filled
# with calls to both methods everywhere
self._shape_touched_set.add(label)
self._shape_touched.append(label)
self._data_touched_set.add(label)
self._data_touched.append(label)
@property
def data_touched(self) -> list[str]:
return list(self._data_touched_set)

def touch_shape(self, label: str) -> None:
self._shape_touched_set.add(label)

def touch_data(self, label: str) -> None:
# Touching data implies that the shape will be touched as well
# implemented here so that the codebase doesn't need to be filled
# with calls to both methods everywhere
self._shape_touched_set.add(label)
self._data_touched_set.add(label)

def commit(self, node_id: str) -> None:
assert isinstance(self._shape_touched_set, FillableByteSet)
assert isinstance(self._data_touched_set, FillableByteSet)
self._node_id_to_shape_touched[node_id] = self._shape_touched_set.to_bitset()
self._node_id_to_data_touched[node_id] = self._data_touched_set.to_bitset()
self._shape_touched_set.clear()
self._data_touched_set.clear()

def shape_touched_in(self, node_ids: Collection[str]) -> list[str]:
out: set[str] = set()
for node_id in node_ids:
out.update(self._node_id_to_shape_touched[node_id])
return list(out)

def data_touched_in(self, node_ids: Collection[str]) -> list[str]:
out: set[str] = set()
for node_id in node_ids:
out.update(self._node_id_to_data_touched[node_id])
return list(out)


class TypeTracerArray(NDArrayOperatorsMixin, ArrayLike):
Expand Down Expand Up @@ -1685,4 +1767,11 @@ def typetracer_with_report(
layout = form.length_zero_array().to_typetracer(forget_length=True)
report = TypeTracerReport()
_attach_report(layout, form, report, getkey)

# Optimisation: identify buffer keys ahead of time, and register them with report
def buffer_key(form_key, attribute, form):
return getkey(form, attribute)

report.set_labels(form.expected_from_buffers(buffer_key))

return layout, report
16 changes: 8 additions & 8 deletions tests/test_2027_add_data_touch_reporting_to_TypeTracerArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,43 +52,43 @@ def test_prototypical_example():
)

# the listoffsets, but not the numpys, have been touched because of broadcasting
assert report.data_touched == [
assert set(report.data_touched) == {
"listoffset-1",
"listoffset-2",
"listoffset-3",
"listoffset-4",
]
}

pz = restructured.muons.pt * np.sinh(restructured.muons.eta) # noqa: F841

# order is preserved: numpy-eta is used before numpy-pt (may or may not be important)
assert report.data_touched == [
assert set(report.data_touched) == {
"listoffset-1",
"listoffset-2",
"listoffset-3",
"listoffset-4",
"numpy-eta",
"numpy-pt",
]
}

# slices are views, so they shouldn't trigger data access
sliced = restructured.muons[:1] # noqa: F841
assert report.data_touched == [
assert set(report.data_touched) == {
"listoffset-1",
"listoffset-2",
"listoffset-3",
"listoffset-4",
"numpy-eta",
"numpy-pt",
]
}

# changed behavior: printing should not touch data anymore
print(restructured.muons.mass)
assert report.data_touched == [
assert set(report.data_touched) == {
"listoffset-1",
"listoffset-2",
"listoffset-3",
"listoffset-4",
"numpy-eta",
"numpy-pt",
]
}
2 changes: 1 addition & 1 deletion tests/test_2085_empty_if_typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def func(array):
meta = ak.Array(meta)

func(meta)
assert report.data_touched == ["node0", "node2", "node3"]
assert set(report.data_touched) == {"node0", "node2", "node3"}


@pytest.mark.parametrize("regulararray", [False, True])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_2373_unzip_touching.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ def test():
ttarray = ak.Array(ttlayout)
pairs = ak.cartesian([ttarray.muon, ttarray.jet], axis=1, nested=True)
a, b = ak.unzip(pairs)
assert report.data_touched == ["muon_list!", "jet_list!"]
assert set(report.data_touched) == {"muon_list!", "jet_list!"}
12 changes: 6 additions & 6 deletions tests/test_2374_cartesian_touching.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,32 +132,32 @@ def test():

mval = delta_r2(a["0"], a["1"])

assert report.data_touched == [
assert set(report.data_touched) == {
"muon_list!",
"jet_list!",
"muon_eta!",
"jet_eta!",
"muon_phi!",
"jet_phi!",
]
}

mmin = ak.argmin(mval, axis=2)
assert report.data_touched == [
assert set(report.data_touched) == {
"muon_list!",
"jet_list!",
"muon_eta!",
"jet_eta!",
"muon_phi!",
"jet_phi!",
]
}

ak.firsts(a["1"][mmin], axis=2)

assert report.data_touched == [
assert set(report.data_touched) == {
"muon_list!",
"jet_list!",
"muon_eta!",
"jet_eta!",
"muon_phi!",
"jet_phi!",
]
}

0 comments on commit bf18460

Please sign in to comment.