From 3364a4fda9cc11d46c2ebe8e32fee0b161da27cf Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:19:06 -0500 Subject: [PATCH] working pass 1 --- pytato/analysis/__init__.py | 50 +++++++++---------- pytato/distributed/partition.py | 86 ++++++++++++++++++--------------- pytato/transform/__init__.py | 5 +- 3 files changed, 74 insertions(+), 67 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 38ed276fe..1a4359e4b 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -310,48 +310,48 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter +from orderedsets import FrozenOrderedSet + class DirectPredecessorsGetter(Mapper): """ Mapper to get the `direct predecessors `__ of a node. - .. note:: - We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[Array]: - return frozenset({dim for dim in shape if isinstance(dim, Array)}) + def _get_preds_from_shape(self, shape: ShapeType) -> abc_Set[Array]: + return FrozenOrderedSet([dim for dim in shape if isinstance(dim, Array)]) - def map_index_lambda(self, expr: IndexLambda) -> frozenset[Array]: - return (frozenset(expr.bindings.values()) + def map_index_lambda(self, expr: IndexLambda) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) - def map_stack(self, expr: Stack) -> frozenset[Array]: - return (frozenset(expr.arrays) + def map_stack(self, expr: Stack) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_concatenate(self, expr: Concatenate) -> frozenset[Array]: - return (frozenset(expr.arrays) + def map_concatenate(self, expr: Concatenate) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_einsum(self, expr: Einsum) -> frozenset[Array]: - return (frozenset(expr.args) + def map_einsum(self, expr: Einsum) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.args) | self._get_preds_from_shape(expr.shape)) - def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]: - from pytato.loopy import LoopyCall, LoopyCallResult + def map_loopy_call_result(self, expr: NamedArray) -> abc_Set[Array]: + from pytato.loopy import LoopyCallResult, LoopyCall assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return (frozenset(ary + return (FrozenOrderedSet(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) | self._get_preds_from_shape(expr.shape)) - def _map_index_base(self, expr: IndexBase) -> frozenset[Array]: - return (frozenset([expr.array]) - | frozenset(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> abc_Set[Array]: + return (FrozenOrderedSet([expr.array]) + | FrozenOrderedSet(idx for idx in expr.indices if isinstance(idx, Array)) | self._get_preds_from_shape(expr.shape)) @@ -360,29 +360,29 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[Array]: map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> frozenset[Array]: - return frozenset([expr.array]) + ) -> abc_Set[Array]: + return FrozenOrderedSet([expr.array]) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> frozenset[Array]: + def _map_input_base(self, expr: InputArgumentBase) -> abc_Set[Array]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[Array]: + def map_distributed_recv(self, expr: DistributedRecv) -> abc_Set[Array]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[Array]: - return frozenset([expr.passthrough_data]) + ) -> abc_Set[Array]: + return FrozenOrderedSet([expr.passthrough_data]) - def map_named_call_result(self, expr: NamedCallResult) -> frozenset[Array]: + def map_named_call_result(self, expr: NamedCallResult) -> abc_Set[Array]: raise NotImplementedError( "DirectPredecessorsGetter does not yet support expressions containing " "functions.") diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 5865ec491..7c9b510e7 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -476,9 +476,10 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: frozenset[CommunicationOpIdentifier] - ) -> frozenset[CommunicationOpIdentifier]: - return reduce(frozenset.union, args, frozenset()) + self, *args: Tuple[CommunicationOpIdentifier] + ) -> Tuple[CommunicationOpIdentifier]: + from pytools import unique + return reduce(lambda x, y: tuple(unique(x+y)), args, tuple()) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder @@ -496,8 +497,8 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]: - return frozenset() + def _map_input_base(self, expr: Array) -> Tuple[CommunicationOpIdentifier]: + return tuple() map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -505,21 +506,21 @@ def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]: def map_distributed_recv( self, expr: DistributedRecv - ) -> frozenset[CommunicationOpIdentifier]: + ) -> Tuple[CommunicationOpIdentifier]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = frozenset() + self.local_comm_ids_to_needed_comm_ids[recv_id] = tuple() self.local_recv_id_to_recv_node[recv_id] = expr - return frozenset({recv_id}) + return (recv_id,) def map_named_call_result( - self, expr: NamedCallResult) -> frozenset[CommunicationOpIdentifier]: + self, expr: NamedCallResult) -> Tuple[CommunicationOpIdentifier]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -557,10 +558,10 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[set[TaskType]] = [set() for _ in range(nlevels)] + task_batches: Sequence[List[TaskType]] = [list() for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): - task_batches[dep_level].add(task_id) + task_batches[dep_level].append(task_id) return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -623,7 +624,7 @@ class _MaterializedArrayCollector(CachedWalkMapper): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: _OrderedSet[Array] = _OrderedSet() + self.materialized_arrays: List[Array] = [] def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -633,15 +634,15 @@ def post_visit(self, expr: Any) -> None: from pytato.tags import ImplStored if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): - self.materialized_arrays.add(expr) + self.materialized_arrays.append(expr) if isinstance(expr, LoopyCallResult): - self.materialized_arrays.add(expr) + self.materialized_arrays.append(expr) from pytato.loopy import LoopyCall assert isinstance(expr._container, LoopyCall) for _, subexpr in sorted(expr._container.bindings.items()): if isinstance(subexpr, Array): - self.materialized_arrays.add(subexpr) + self.materialized_arrays.append(subexpr) else: assert isinstance(subexpr, SCALAR_CLASSES) @@ -651,13 +652,13 @@ def post_visit(self, expr: Any) -> None: # {{{ _set_dict_union_mpi def _set_dict_union_mpi( - dict_a: Mapping[_KeyT, frozenset[_ValueT]], - dict_b: Mapping[_KeyT, frozenset[_ValueT]], - mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, frozenset[_ValueT]]: + dict_a: Mapping[_KeyT, Sequence[_ValueT]], + dict_b: Mapping[_KeyT, Sequence[_ValueT]], + mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, Sequence[_ValueT]]: assert mpi_data_type is None result = dict(dict_a) for key, values in dict_b.items(): - result[key] = result.get(key, frozenset()) | values + result[key] = result.get(key, tuple()) + values return result # }}} @@ -782,6 +783,8 @@ def find_distributed_partition( - Gather sent arrays into assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ + from pytools import unique + import mpi4py.MPI as MPI from pytato.transform import SubsetDependencyMapper @@ -833,12 +836,13 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids - part_comm_ids: list[_PartCommIDs] = [] + + part_comm_ids: List[_PartCommIDs] = [] if comm_batches: - recv_ids: frozenset[CommunicationOpIdentifier] = frozenset() + recv_ids: Tuple[CommunicationOpIdentifier] = tuple() for batch in comm_batches: - send_ids = frozenset( - comm_id for comm_id in batch + send_ids = tuple( + comm_id for comm_id in unique(batch) if comm_id.src_rank == local_rank) if recv_ids or send_ids: part_comm_ids.append( @@ -846,19 +850,19 @@ def find_distributed_partition( recv_ids=recv_ids, send_ids=send_ids)) # These go into the next part - recv_ids = frozenset( - comm_id for comm_id in batch + recv_ids = tuple( + comm_id for comm_id in unique(batch) if comm_id.dest_rank == local_rank) if recv_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, - send_ids=frozenset())) + send_ids=tuple())) else: part_comm_ids.append( _PartCommIDs( - recv_ids=frozenset(), - send_ids=frozenset())) + recv_ids=tuple(), + send_ids=tuple())) nparts = len(part_comm_ids) @@ -876,7 +880,7 @@ def find_distributed_partition( comm_id_to_part_id = { comm_id: ipart for ipart, comm_ids in enumerate(part_comm_ids) - for comm_id in comm_ids.send_ids | comm_ids.recv_ids} + for comm_id in unique(comm_ids.send_ids + comm_ids.recv_ids)} # }}} @@ -888,10 +892,10 @@ def find_distributed_partition( # The sets of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = _OrderedSet( + sent_arrays = tuple( send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) - received_arrays = _OrderedSet(lsrdg.local_recv_id_to_recv_node.values()) + received_arrays = tuple(lsrdg.local_recv_id_to_recv_node.values()) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -899,14 +903,16 @@ def find_distributed_partition( # We could allow sent *arrays* to be included here because they are distinct # from send *nodes*, but we choose to exclude them in order to simplify the # processing below. - materialized_arrays = ( - materialized_arrays_collector.materialized_arrays - - received_arrays - - sent_arrays) + materialized_arrays_set = set(materialized_arrays_collector.materialized_arrays) \ + - set(received_arrays) \ + - set(sent_arrays) + + from pytools import unique + materialized_arrays = tuple(a for a in materialized_arrays_collector.materialized_arrays if a in materialized_arrays_set) # "mso" for "materialized/sent/output" - output_arrays = _OrderedSet(outputs._data.values()) - mso_arrays = materialized_arrays | sent_arrays | output_arrays + output_arrays = tuple(outputs._data.values()) + mso_arrays = materialized_arrays + sent_arrays + output_arrays # FIXME: This gathers up materialized_arrays recursively, leading to # result sizes potentially quadratic in the number of materialized arrays. @@ -970,7 +976,7 @@ def find_distributed_partition( assert all(0 <= part_id < nparts for part_id in stored_ary_to_part_id.values()) - stored_arrays = _OrderedSet(stored_ary_to_part_id) + stored_arrays = tuple(unique(stored_ary_to_part_id)) # {{{ find which stored arrays should become part outputs # (because they are used in not just their local part, but also others) @@ -986,13 +992,13 @@ def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]: materialized_preds |= get_materialized_predecessors(pred) return materialized_preds - stored_arrays_promoted_to_part_outputs = { + stored_arrays_promoted_to_part_outputs = tuple(unique( stored_pred for stored_ary in stored_arrays for stored_pred in get_materialized_predecessors(stored_ary) if (stored_ary_to_part_id[stored_ary] != stored_ary_to_part_id[stored_pred]) - } + )) # }}} diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index b78c24301..56b2a53d6 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -926,9 +926,10 @@ def __init__(self, universe: frozenset[Array]): def combine(self, *args: frozenset[Array]) -> frozenset[Array]: from functools import reduce - return reduce(lambda acc, arg: acc | (arg & self.universe), + from pytools import unique + return reduce(lambda acc, arg: unique(tuple(acc) + tuple(set(arg) & self.universe)), args, - frozenset()) + tuple()) # }}}