Skip to content

Commit

Permalink
Merge branch 'main' into production-pilot-up2date
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Aug 29, 2024
2 parents fefadcf + a92a0d1 commit 2a79348
Show file tree
Hide file tree
Showing 26 changed files with 611 additions and 238 deletions.
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ Pytato: Get Descriptions of Array Computations via Lazy Evaluation
.. image:: https://gitlab.tiker.net/inducer/pytato/badges/main/pipeline.svg
:alt: Gitlab Build Status
:target: https://gitlab.tiker.net/inducer/pytato/commits/main
.. image:: https://github.com/inducer/pytato/workflows/CI/badge.svg?branch=main&event=push
.. image:: https://github.com/inducer/pytato/workflows/CI/badge.svg?branch=main
:alt: Github Build Status
:target: https://github.com/inducer/pytato/actions?query=branch%3Amain+workflow%3ACI+event%3Apush
:target: https://github.com/inducer/pytato/actions?query=branch%3Amain+workflow%3ACI
.. image:: https://badge.fury.io/py/pytato.png
:alt: Python Package Index Release Page
:target: https://pypi.org/project/pytato/
Expand Down
2 changes: 1 addition & 1 deletion doc/upload-docs.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#! /bin/sh

rsync --verbose --archive --delete _build/html/* doc-upload:doc/pytato
rsync --verbose --archive --delete _build/html/ doc-upload:doc/pytato
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ multiline-quotes = "double"
[[tool.mypy.overrides]]
module = [
"islpy",
"loopy.*",
"pymbolic.*",
"pyopencl.*",
"jax.*",
Expand Down
146 changes: 116 additions & 30 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
.. autofunction:: get_num_nodes
.. autofunction:: get_node_type_counts
.. autofunction:: get_node_multiplicities
.. autofunction:: get_num_call_sites
.. autoclass:: DirectPredecessorsGetter
Expand Down Expand Up @@ -322,26 +326,26 @@ class DirectPredecessorsGetter(Mapper):
We only consider the predecessors of a nodes in a data-flow sense.
"""
def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[Array]:
def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[ArrayOrNames]:
return frozenset({dim for dim in shape if isinstance(dim, Array)})

def map_index_lambda(self, expr: IndexLambda) -> frozenset[Array]:
def map_index_lambda(self, expr: IndexLambda) -> frozenset[ArrayOrNames]:
return (frozenset(expr.bindings.values())
| self._get_preds_from_shape(expr.shape))

def map_stack(self, expr: Stack) -> frozenset[Array]:
def map_stack(self, expr: Stack) -> frozenset[ArrayOrNames]:
return (frozenset(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_concatenate(self, expr: Concatenate) -> frozenset[Array]:
def map_concatenate(self, expr: Concatenate) -> frozenset[ArrayOrNames]:
return (frozenset(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_einsum(self, expr: Einsum) -> frozenset[Array]:
def map_einsum(self, expr: Einsum) -> frozenset[ArrayOrNames]:
return (frozenset(expr.args)
| self._get_preds_from_shape(expr.shape))

def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]:
def map_loopy_call_result(self, expr: NamedArray) -> frozenset[ArrayOrNames]:
from pytato.loopy import LoopyCall, LoopyCallResult
assert isinstance(expr, LoopyCallResult)
assert isinstance(expr._container, LoopyCall)
Expand All @@ -350,7 +354,7 @@ def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]:
if isinstance(ary, Array))
| self._get_preds_from_shape(expr.shape))

def _map_index_base(self, expr: IndexBase) -> frozenset[Array]:
def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]:
return (frozenset([expr.array])
| frozenset(idx for idx in expr.indices
if isinstance(idx, Array))
Expand All @@ -361,32 +365,34 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[Array]:
map_non_contiguous_advanced_index = _map_index_base

def _map_index_remapping_base(self, expr: IndexRemappingBase
) -> frozenset[Array]:
) -> frozenset[ArrayOrNames]:
return frozenset([expr.array])

map_roll = _map_index_remapping_base
map_axis_permutation = _map_index_remapping_base
map_reshape = _map_index_remapping_base

def _map_input_base(self, expr: InputArgumentBase) -> frozenset[Array]:
def _map_input_base(self, expr: InputArgumentBase) -> frozenset[ArrayOrNames]:
return self._get_preds_from_shape(expr.shape)

map_placeholder = _map_input_base
map_data_wrapper = _map_input_base
map_size_param = _map_input_base

def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[Array]:
def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[ArrayOrNames]:
return self._get_preds_from_shape(expr.shape)

def map_distributed_send_ref_holder(self,
expr: DistributedSendRefHolder
) -> frozenset[Array]:
) -> frozenset[ArrayOrNames]:
return frozenset([expr.passthrough_data])

def map_named_call_result(self, expr: NamedCallResult) -> frozenset[Array]:
raise NotImplementedError(
"DirectPredecessorsGetter does not yet support expressions containing "
"functions.")
def map_call(self, expr: Call) -> frozenset[ArrayOrNames]:
return frozenset(expr.bindings.values())

def map_named_call_result(
self, expr: NamedCallResult) -> frozenset[ArrayOrNames]:
return frozenset([expr._container])


# }}}
Expand All @@ -397,34 +403,115 @@ def map_named_call_result(self, expr: NamedCallResult) -> frozenset[Array]:
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class NodeCountMapper(CachedWalkMapper):
"""
Counts the number of nodes in a DAG.
Counts the number of nodes of a given type in a DAG.
.. attribute:: count
.. autoattribute:: expr_type_counts
.. autoattribute:: count_duplicates
The number of nodes.
Dictionary mapping node types to number of nodes of that type.
"""

def __init__(self, count_duplicates: bool = False) -> None:
from collections import defaultdict
super().__init__()
self.expr_type_counts: dict[type[Any], int] = defaultdict(int)
self.count_duplicates = count_duplicates

def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
# Returns unique nodes only if count_duplicates is False
return id(expr) if self.count_duplicates else expr

def post_visit(self, expr: Any) -> None:
if not isinstance(expr, DictOfNamedArrays):
self.expr_type_counts[type(expr)] += 1


def get_node_type_counts(
outputs: Array | DictOfNamedArrays,
count_duplicates: bool = False
) -> dict[type[Any], int]:
"""
Returns a dictionary mapping node types to node count for that type
in DAG *outputs*.
Instances of `DictOfNamedArrays` are excluded from counting.
"""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

ncm = NodeCountMapper(count_duplicates)
ncm(outputs)

return ncm.expr_type_counts


def get_num_nodes(
outputs: Array | DictOfNamedArrays,
count_duplicates: bool | None = None
) -> int:
"""
Returns the number of nodes in DAG *outputs*.
Instances of `DictOfNamedArrays` are excluded from counting.
"""
if count_duplicates is None:
from warnings import warn
warn(
"The default value of 'count_duplicates' will change "
"from True to False in 2025. "
"For now, pass the desired value explicitly.",
DeprecationWarning, stacklevel=2)
count_duplicates = True

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

ncm = NodeCountMapper(count_duplicates)
ncm(outputs)

return sum(ncm.expr_type_counts.values())

# }}}


# {{{ NodeMultiplicityMapper


class NodeMultiplicityMapper(CachedWalkMapper):
"""
Computes the multiplicity of each unique node in a DAG.
The multiplicity of a node `x` is the number of nodes with distinct `id()`\\ s
that equal `x`.
.. autoattribute:: expr_multiplicity_counts
"""
def __init__(self) -> None:
from collections import defaultdict
super().__init__()
self.count = 0
self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int)

def get_cache_key(self, expr: ArrayOrNames) -> int:
# Returns each node, including nodes that are duplicates
return id(expr)

def post_visit(self, expr: Any) -> None:
self.count += 1

if not isinstance(expr, DictOfNamedArrays):
self.expr_multiplicity_counts[expr] += 1

def get_num_nodes(outputs: Array | DictOfNamedArrays) -> int:
"""Returns the number of nodes in DAG *outputs*."""

def get_node_multiplicities(
outputs: Array | DictOfNamedArrays) -> dict[Array, int]:
"""
Returns the multiplicity per `expr`.
"""
from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

ncm = NodeCountMapper()
ncm(outputs)
nmm = NodeMultiplicityMapper()
nmm(outputs)

return ncm.count
return nmm.expr_multiplicity_counts

# }}}

Expand All @@ -449,17 +536,16 @@ def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

@memoize_method
def map_function_definition(self, /, expr: FunctionDefinition,
*args: Any, **kwargs: Any) -> None:
def map_function_definition(self, expr: FunctionDefinition) -> None:
if not self.visit(expr):
return

new_mapper = self.clone_for_callee(expr)
for subexpr in expr.returns.values():
new_mapper(subexpr, *args, **kwargs)
new_mapper(subexpr)
self.count += new_mapper.count

self.post_visit(expr, *args, **kwargs)
self.post_visit(expr)

def post_visit(self, expr: Any) -> None:
if isinstance(expr, Call):
Expand Down
28 changes: 7 additions & 21 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,8 +1162,6 @@ class Einsum(_SuppliedAxesAndTagsMixin, Array):
redn_axis_to_redn_descr: Mapping[EinsumReductionAxis,
ReductionDescriptor] = \
attrs.field(validator=attrs.validators.instance_of(immutabledict))
index_to_access_descr: Mapping[str, EinsumAxisDescriptor] = \
attrs.field(validator=attrs.validators.instance_of(immutabledict))
_mapper_method: ClassVar[str] = "map_einsum"

@memoize_method
Expand Down Expand Up @@ -1213,30 +1211,20 @@ def dtype(self) -> np.dtype[Any]:
return np.result_type(*[arg.dtype for arg in self.args])

def with_tagged_reduction(self,
redn_axis: EinsumReductionAxis | str,
redn_axis: EinsumReductionAxis,
tag: Tag) -> Einsum:
"""
Returns a copy of *self* with the :class:`ReductionDescriptor`
associated with *redn_axis* tagged with *tag*.
"""
from pytato.diagnostic import InvalidEinsumIndex, NotAReductionAxis
# {{{ sanity checks

# {{{ sanity checks
if isinstance(redn_axis, str):
try:
redn_axis_ = self.index_to_access_descr[redn_axis]
except KeyError as err:
raise InvalidEinsumIndex(
f"'{redn_axis}': not a valid axis index.") from err
if isinstance(redn_axis_, EinsumReductionAxis):
redn_axis = redn_axis_
else:
raise NotAReductionAxis(f"'{redn_axis}' is not"
" a reduction axis.")
elif isinstance(redn_axis, EinsumReductionAxis):
pass
else:
raise TypeError("Argument 'redn_axis' expected to be"
raise TypeError("Argument `redn_axis' as a string is no longer"
" accepted as a valid index type."
" Use the actual EinsumReductionAxis object instead.")
elif not isinstance(redn_axis, EinsumReductionAxis):
raise TypeError(f"Argument `redn_axis' expected to be"
f" EinsumReductionAxis, got {type(redn_axis)}")

if redn_axis in self.redn_axis_to_redn_descr:
Expand All @@ -1259,7 +1247,6 @@ def with_tagged_reduction(self,
redn_axis_to_redn_descr=immutabledict
(new_redn_axis_to_redn_descr),
tags=self.tags,
index_to_access_descr=self.index_to_access_descr,
non_equality_tags=self.non_equality_tags,
)

Expand Down Expand Up @@ -1466,7 +1453,6 @@ def einsum(subscripts: str, *operands: Array,
EinsumElementwiseAxis)})
),
redn_axis_to_redn_descr=immutabledict(redn_axis_to_redn_descr),
index_to_access_descr=index_to_descr,
non_equality_tags=_get_created_at_tag(),
)

Expand Down
4 changes: 2 additions & 2 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall:
# {{{ eliminate callable name collision

for name, clbl in translation_unit.callables_table.items():
if isinstance(clbl, lp.kernel.function_interface.CallableKernel):
if isinstance(clbl, lp.CallableKernel):
if name in self.kernels_seen and (
translation_unit[name] != self.kernels_seen[name]):
# callee name collision => must rename
Expand All @@ -186,7 +186,7 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall:
translation_unit, name, new_name)
name = new_name

self.kernels_seen[name] = translation_unit[name]
self.kernels_seen[name] = clbl.subkernel

# }}}

Expand Down
2 changes: 1 addition & 1 deletion pytato/distributed/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _mpi_send(mpi_communicator: Any, send_node: DistributedSend,

def execute_distributed_partition(
partition: DistributedGraphPartition, prg_per_partition:
dict[Hashable, BoundProgram],
Mapping[Hashable, BoundProgram],
queue: Any, mpi_communicator: Any,
*,
allocator: Any | None = None,
Expand Down
10 changes: 6 additions & 4 deletions pytato/distributed/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,14 @@ def _run_partition_diagnostics(

from pytato.analysis import get_num_nodes
num_nodes_per_part = [get_num_nodes(make_dict_of_named_arrays(
{x: gp.name_to_output[x] for x in part.output_names}))
{x: gp.name_to_output[x] for x in part.output_names}),
count_duplicates=False)
for part in gp.parts.values()]

logger.info(f"find_distributed_partition: Split {get_num_nodes(outputs)} nodes "
f"into {len(gp.parts)} parts, with {num_nodes_per_part} nodes in each "
"partition.")
logger.info("find_distributed_partition: "
f"Split {get_num_nodes(outputs, count_duplicates=False)} nodes "
f"into {len(gp.parts)} parts, with {num_nodes_per_part} nodes in each "
"partition.")

# }}}

Expand Down
2 changes: 2 additions & 0 deletions pytato/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
) -> bool:
return (expr1.__class__ is expr2.__class__
and expr1.parameters == expr2.parameters
and expr1.return_type == expr2.return_type
and (set(expr1.returns.keys()) == set(expr2.returns.keys()))
and all(self.rec(expr1.returns[k], expr2.returns[k])
for k in expr1.returns)
Expand All @@ -311,6 +312,7 @@ def map_call(self, expr1: Call, expr2: Any) -> bool:
and all(self.rec(bnd,
expr2.bindings[name])
for name, bnd in expr1.bindings.items())
and expr1.tags == expr2.tags
)

def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool:
Expand Down
Loading

0 comments on commit 2a79348

Please sign in to comment.