diff --git a/README.rst b/README.rst index 05873b6fb..c9023fd03 100644 --- a/README.rst +++ b/README.rst @@ -33,7 +33,3 @@ Pytato is written to pose no particular restrictions on the version of numpy used for execution. To use mypy-based type checking on Pytato itself or packages using Pytato, numpy 1.20 or newer is required, due to the typing-based changes to numpy in that release. - -Furthermore, pytato now uses type promotion rules based on those in -`numpy `__ that should result in the same -data types as the currently installed version of numpy. diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 988c11a16..9c156786d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -31,6 +31,7 @@ from pymbolic.mapper.optimize import optimize_mapper from pytools import memoize_method +from loopy.tools import LoopyKeyBuilder from pytato.array import ( Array, @@ -326,37 +327,37 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], []]): We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[ArrayOrNames]: - return frozenset({dim for dim in shape if isinstance(dim, Array)}) + def _get_preds_from_shape(self, shape: ShapeType) -> dict[Array, None]: + return dict.fromkeys(dim for dim in shape if isinstance(dim, Array)) - def map_index_lambda(self, expr: IndexLambda) -> frozenset[ArrayOrNames]: - return (frozenset(expr.bindings.values()) + def map_index_lambda(self, expr: IndexLambda) -> dict[Array, None]: + return (dict.fromkeys(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) - def map_stack(self, expr: Stack) -> frozenset[ArrayOrNames]: - return (frozenset(expr.arrays) + def map_stack(self, expr: Stack) -> dict[Array, None]: + return (dict.fromkeys(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_concatenate(self, expr: Concatenate) -> frozenset[ArrayOrNames]: - return (frozenset(expr.arrays) + def map_concatenate(self, expr: Concatenate) -> dict[Array, None]: + return (dict.fromkeys(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_einsum(self, expr: Einsum) -> frozenset[ArrayOrNames]: - return (frozenset(expr.args) + def map_einsum(self, expr: Einsum) -> dict[Array, None]: + return (dict.fromkeys(expr.args) | self._get_preds_from_shape(expr.shape)) - def map_loopy_call_result(self, expr: NamedArray) -> frozenset[ArrayOrNames]: + def map_loopy_call_result(self, expr: NamedArray) -> dict[Array, None]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return (frozenset(ary + return (dict.fromkeys(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) | self._get_preds_from_shape(expr.shape)) - def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]: - return (frozenset([expr.array]) - | frozenset(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> dict[Array, None]: + return (dict.fromkeys([expr.array]) + | dict.fromkeys(idx for idx in expr.indices if isinstance(idx, Array)) | self._get_preds_from_shape(expr.shape)) @@ -365,34 +366,34 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]: map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> frozenset[ArrayOrNames]: - return frozenset([expr.array]) + ) -> dict[ArrayOrNames, None]: + return dict.fromkeys([expr.array]) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> frozenset[ArrayOrNames]: + def _map_input_base(self, expr: InputArgumentBase) -> dict[Array, None]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[ArrayOrNames]: + def map_distributed_recv(self, expr: DistributedRecv) -> dict[Array, None]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[ArrayOrNames]: - return frozenset([expr.passthrough_data]) + ) -> dict[ArrayOrNames, None]: + return dict.fromkeys([expr.passthrough_data]) - def map_call(self, expr: Call) -> frozenset[ArrayOrNames]: - return frozenset(expr.bindings.values()) + def map_call(self, expr: Call) -> dict[ArrayOrNames, None]: + return dict.fromkeys(expr.bindings.values()) def map_named_call_result( - self, expr: NamedCallResult) -> frozenset[ArrayOrNames]: - return frozenset([expr._container]) + self, expr: NamedCallResult) -> dict[ArrayOrNames, None]: + return dict.fromkeys([expr._container]) # }}} @@ -565,4 +566,47 @@ def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int: # }}} + +# {{{ PytatoKeyBuilder + +class PytatoKeyBuilder(LoopyKeyBuilder): # type: ignore[misc] + """A custom :class:`pytools.persistent_dict.KeyBuilder` subclass + for objects within :mod:`pytato`. + """ + + def update_for_ndarray(self, key_hash: Any, key: Any) -> None: + self.rec(key_hash, key.data.tobytes()) + + def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None: + self.rec(key_hash, key.get()) + + def update_for_Array(self, key_hash: Any, key: Any) -> None: + # CL Array + self.rec(key_hash, key.get()) + + # update_for_BitwiseAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_BitwiseNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_BitwiseXor = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_LeftShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_LogicalAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_LogicalNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Lookup = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Power = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Remainder = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_RightShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Subscript = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_Variable = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + +# }}} + # vim: fdm=marker diff --git a/pytato/array.py b/pytato/array.py index 19f8ac12d..805b3671e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1,4 +1,5 @@ from __future__ import annotations +from traceback import FrameSummary, StackSummary __copyright__ = """ @@ -296,19 +297,25 @@ def normalize_shape_component( # {{{ array interface ConvertibleToIndexExpr = Union[int, slice, "Array", None, EllipsisType] -IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None] +IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None, EllipsisType] +DtypeOrScalar = Union[_dtype_any, ScalarT] +ArrayOrScalar = Union["Array", ScalarT] PyScalarType = type[bool] | type[int] | type[float] | type[complex] DtypeOrPyScalarType = _dtype_any | PyScalarType -def _np_result_dtype( - *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +# https://github.com/numpy/numpy/issues/19302 +def _np_result_type( + # actual dtype: + #*arrays_and_dtypes: Union[np.typing.ArrayLike, np.typing.DTypeLike], + # our dtype: + *arrays_and_dtypes: DtypeOrScalar, ) -> np.dtype[Any]: return np.result_type(*arrays_and_dtypes) -def _truediv_result_type(*dtypes: DtypeOrPyScalarType) -> np.dtype[Any]: - dtype = _np_result_dtype(*dtypes) +def _truediv_result_type(arg1: DtypeOrScalar, arg2: DtypeOrScalar) -> np.dtype[Any]: + dtype = _np_result_type(arg1, arg2) # See: test_true_divide in numpy/core/tests/test_ufunc.py # pylint: disable=no-member if dtype.kind in "iu": @@ -571,11 +578,12 @@ def __matmul__(self, other: Array, reverse: bool = False) -> Array: def _binary_op( self, - op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], + op: Callable[[ScalarExpression, ScalarExpression], + ScalarExpression], other: ArrayOrScalar, get_result_type: Callable[ [ArrayOrScalar, ArrayOrScalar], - np.dtype[Any]] = _np_result_dtype, + np.dtype[Any]] = _np_result_type, reverse: bool = False, cast_to_result_dtype: bool = True, is_pow: bool = False, @@ -632,21 +640,33 @@ def _unary_op(self, op: Any) -> Array: non_equality_tags=_get_created_at_tag(), var_to_reduction_descr=immutabledict()) - __mul__ = partialmethod(_binary_op, operator.mul) - __rmul__ = partialmethod(_binary_op, operator.mul, reverse=True) + # NOTE: Initializing the expression to "prim.Product(expr1, expr2)" is + # essential as opposed to performing "expr1 * expr2". This is to account + # for pymbolic's implementation of the "*" operator which might not + # instantiate the node corresponding to the operation when one of + # the operands is the neutral element of the operation. + # + # For the same reason 'prim.(Sum|FloorDiv|Quotient)' is preferred over the + # python operators on the operands. - __add__ = partialmethod(_binary_op, operator.add) - __radd__ = partialmethod(_binary_op, operator.add, reverse=True) + __mul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r))) + __rmul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r)), + reverse=True) - __sub__ = partialmethod(_binary_op, operator.sub) - __rsub__ = partialmethod(_binary_op, operator.sub, reverse=True) + __add__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r))) + __radd__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r)), + reverse=True) - __floordiv__ = partialmethod(_binary_op, operator.floordiv) - __rfloordiv__ = partialmethod(_binary_op, operator.floordiv, reverse=True) + __sub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r))) + __rsub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r)), + reverse=True) - __truediv__ = partialmethod(_binary_op, operator.truediv, + __floordiv__ = partialmethod(_binary_op, prim.FloorDiv) + __rfloordiv__ = partialmethod(_binary_op, prim.FloorDiv, reverse=True) + + __truediv__ = partialmethod(_binary_op, prim.Quotient, get_result_type=_truediv_result_type) - __rtruediv__ = partialmethod(_binary_op, operator.truediv, + __rtruediv__ = partialmethod(_binary_op, prim.Quotient, get_result_type=_truediv_result_type, reverse=True) __mod__ = partialmethod(_binary_op, operator.mod) @@ -1397,7 +1417,7 @@ class Stack(_SuppliedAxesAndTagsMixin, Array): @property def dtype(self) -> np.dtype[Any]: - return _np_result_dtype(*(arr.dtype for arr in self.arrays)) + return _np_result_type(*(arr.dtype for arr in self.arrays)) @property def shape(self) -> ShapeType: @@ -1430,7 +1450,7 @@ class Concatenate(_SuppliedAxesAndTagsMixin, Array): @property def dtype(self) -> np.dtype[Any]: - return _np_result_dtype(*(arr.dtype for arr in self.arrays)) + return _np_result_type(*(arr.dtype for arr in self.arrays)) @property def shape(self) -> ShapeType: @@ -1547,6 +1567,7 @@ class Reshape(_SuppliedAxesAndTagsMixin, IndexRemappingBase): if __debug__: def __attrs_post_init__(self) -> None: + # assert self.non_equality_tags super().__attrs_post_init__() @property @@ -2058,9 +2079,9 @@ def reshape(array: Array, newshape: int | Sequence[int], *and* the output array are linearized according to this order and 'matched up'. - Groups are found by multiplying axis lengths on the input and output side, - a matching input/output group is found once adding an input or axis to the - group makes the two products match. + Groups are found by multiplying axis lengths on the input and output + side, a matching input/output group is found once adding an input or + axis to the group makes the two products match. The semantics are identical to :func:`numpy.reshape`. diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index c9822549d..86fe5df51 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -368,8 +368,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: frozenset[CommunicationOpIdentifier] - send_ids: frozenset[CommunicationOpIdentifier] + recv_ids: immutabledict[CommunicationOpIdentifier, None] + send_ids: immutabledict[CommunicationOpIdentifier, None] # {{{ _make_distributed_partition @@ -455,12 +455,14 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( + # Production + # CombineMapper[dict[CommunicationOpIdentifier, None]]): CombineMapper[frozenset[CommunicationOpIdentifier]]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ dict[CommunicationOpIdentifier, - frozenset[CommunicationOpIdentifier]] = {} + dict[CommunicationOpIdentifier, None]] = {} self.local_recv_id_to_recv_node: \ dict[CommunicationOpIdentifier, DistributedRecv] = {} @@ -470,13 +472,13 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: frozenset[CommunicationOpIdentifier] - ) -> frozenset[CommunicationOpIdentifier]: - return reduce(frozenset.union, args, frozenset()) + self, *args: dict[CommunicationOpIdentifier, None] + ) -> dict[CommunicationOpIdentifier, None]: + return reduce(lambda x, y: x | y, args, {}) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[CommunicationOpIdentifier]: + ) -> dict[CommunicationOpIdentifier, None]: send_id = _send_to_comm_id(self.local_rank, expr.send) if send_id in self.local_send_id_to_send_node: @@ -490,8 +492,8 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]: - return frozenset() + def _map_input_base(self, expr: Array) -> dict[CommunicationOpIdentifier, None]: + return {} map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -499,21 +501,21 @@ def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]: def map_distributed_recv( self, expr: DistributedRecv - ) -> frozenset[CommunicationOpIdentifier]: + ) -> dict[CommunicationOpIdentifier, None]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = frozenset() + self.local_comm_ids_to_needed_comm_ids[recv_id] = {} self.local_recv_id_to_recv_node[recv_id] = expr - return frozenset({recv_id}) + return {recv_id: None} def map_named_call_result( - self, expr: NamedCallResult) -> frozenset[CommunicationOpIdentifier]: + self, expr: NamedCallResult) -> dict[CommunicationOpIdentifier, None]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -526,6 +528,9 @@ def map_named_call_result( # {{{ _schedule_task_batches (and related) def _schedule_task_batches( + # Production + # task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ + # -> Sequence[dict[TaskType, None]]: task_ids_to_needed_task_ids: Mapping[TaskType, Set[TaskType]]) \ -> Sequence[Set[TaskType]]: """For each :type:`TaskType`, determine the @@ -541,6 +546,9 @@ def _schedule_task_batches( # {{{ _schedule_task_batches_counted def _schedule_task_batches_counted( + # Production + # task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ + # -> tuple[Sequence[dict[TaskType, None]], int]: task_ids_to_needed_task_ids: Mapping[TaskType, Set[TaskType]]) \ -> tuple[Sequence[Set[TaskType]], int]: """ @@ -551,10 +559,11 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[set[TaskType]] = [set() for _ in range(nlevels)] + task_batches: Sequence[dict[TaskType, None]] = [{} for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): - task_batches[dep_level].add(task_id) + if task_id not in task_batches[dep_level]: + task_batches[dep_level][task_id] = None return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -617,7 +626,7 @@ class _MaterializedArrayCollector(CachedWalkMapper[[]]): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: _OrderedSet[Array] = _OrderedSet() + self.materialized_arrays: dict[Array, None] = {} def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -627,15 +636,15 @@ def post_visit(self, expr: Any) -> None: from pytato.tags import ImplStored if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): - self.materialized_arrays.add(expr) + self.materialized_arrays[expr] = None if isinstance(expr, LoopyCallResult): - self.materialized_arrays.add(expr) + self.materialized_arrays[expr] = None from pytato.loopy import LoopyCall assert isinstance(expr._container, LoopyCall) for _, subexpr in sorted(expr._container.bindings.items()): if isinstance(subexpr, Array): - self.materialized_arrays.add(subexpr) + self.materialized_arrays[subexpr] = None else: assert isinstance(subexpr, SCALAR_CLASSES) @@ -645,13 +654,14 @@ def post_visit(self, expr: Any) -> None: # {{{ _set_dict_union_mpi def _set_dict_union_mpi( - dict_a: Mapping[_KeyT, frozenset[_ValueT]], - dict_b: Mapping[_KeyT, frozenset[_ValueT]], - mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, frozenset[_ValueT]]: + dict_a: Mapping[_KeyT, dict[_ValueT, None]], + dict_b: Mapping[_KeyT, dict[_ValueT, None]], + mpi_data_type: mpi4py.MPI.Datatype | None) \ + -> Mapping[_KeyT, dict[_ValueT, None]]: assert mpi_data_type is None result = dict(dict_a) for key, values in dict_b.items(): - result[key] = result.get(key, frozenset()) | values + result[key] = result.get(key, {}) | values return result # }}} @@ -777,6 +787,7 @@ def find_distributed_partition( assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ import mpi4py.MPI as MPI + from immutabledict import immutabledict from pytato.transform import SubsetDependencyMapper @@ -818,7 +829,8 @@ def find_distributed_partition( comm_batches_or_exc = mpi_communicator.bcast(None) if isinstance(comm_batches_or_exc, Exception): raise comm_batches_or_exc - + # Production + # comm_batches = comm_batches_or_exc comm_batches = cast( Sequence[Set[CommunicationOpIdentifier]], comm_batches_or_exc) @@ -829,30 +841,31 @@ def find_distributed_partition( part_comm_ids: list[_PartCommIDs] = [] if comm_batches: - recv_ids: frozenset[CommunicationOpIdentifier] = frozenset() + recv_ids: immutabledict[CommunicationOpIdentifier, None] = immutabledict() for batch in comm_batches: - send_ids = frozenset( - comm_id for comm_id in batch - if comm_id.src_rank == local_rank) + send_ids: immutabledict[CommunicationOpIdentifier, None] \ + = immutabledict.fromkeys( + comm_id for comm_id in batch + if comm_id.src_rank == local_rank) if recv_ids or send_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, send_ids=send_ids)) # These go into the next part - recv_ids = frozenset( + recv_ids = immutabledict.fromkeys( comm_id for comm_id in batch if comm_id.dest_rank == local_rank) if recv_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, - send_ids=frozenset())) + send_ids=immutabledict())) else: part_comm_ids.append( _PartCommIDs( - recv_ids=frozenset(), - send_ids=frozenset())) + recv_ids=immutabledict(), + send_ids=immutabledict())) nparts = len(part_comm_ids) @@ -879,13 +892,13 @@ def find_distributed_partition( materialized_arrays_collector = _MaterializedArrayCollector() materialized_arrays_collector(outputs) - # The sets of arrays below must have a deterministic order in order to ensure + # The collections of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = _OrderedSet( + sent_arrays = dict.fromkeys( send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) - received_arrays = _OrderedSet(lsrdg.local_recv_id_to_recv_node.values()) + received_arrays = dict.fromkeys(lsrdg.local_recv_id_to_recv_node.values()) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -893,13 +906,12 @@ def find_distributed_partition( # We could allow sent *arrays* to be included here because they are distinct # from send *nodes*, but we choose to exclude them in order to simplify the # processing below. - materialized_arrays = ( - materialized_arrays_collector.materialized_arrays - - received_arrays - - sent_arrays) + materialized_arrays = {a: None + for a in materialized_arrays_collector.materialized_arrays + if a not in received_arrays | sent_arrays} # "mso" for "materialized/sent/output" - output_arrays = _OrderedSet(outputs._data.values()) + output_arrays = dict.fromkeys(outputs._data.values()) mso_arrays = materialized_arrays | sent_arrays | output_arrays # FIXME: This gathers up materialized_arrays recursively, leading to @@ -964,30 +976,31 @@ def find_distributed_partition( assert all(0 <= part_id < nparts for part_id in stored_ary_to_part_id.values()) - stored_arrays = _OrderedSet(stored_ary_to_part_id) + stored_arrays = dict.fromkeys(stored_ary_to_part_id) # {{{ find which stored arrays should become part outputs # (because they are used in not just their local part, but also others) direct_preds_getter = DirectPredecessorsGetter() - def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]: - materialized_preds: _OrderedSet[Array] = _OrderedSet() + def get_materialized_predecessors(ary: Array) -> dict[Array, None]: + materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): assert isinstance(pred, Array) if pred in materialized_arrays: - materialized_preds.add(pred) + materialized_preds[pred] = None else: - materialized_preds |= get_materialized_predecessors(pred) + for p in get_materialized_predecessors(pred): + materialized_preds[p] = None return materialized_preds stored_arrays_promoted_to_part_outputs = { - stored_pred + stored_pred: None for stored_ary in stored_arrays for stored_pred in get_materialized_predecessors(stored_ary) if (stored_ary_to_part_id[stored_ary] != stored_ary_to_part_id[stored_pred]) - } + } # }}} diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 032005049..58564b2fb 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -50,6 +50,7 @@ TYPE_CHECKING, Any, Never, + Union, cast, ) @@ -84,6 +85,9 @@ # {{{ scalar expressions INT_CLASSES = (int, np.integer) +IntegralScalarExpression = Union[IntegerT, prim.Expression] +Scalar = Union[np.number[Any], int, np.bool_, bool, float, complex] +ScalarExpression = Union[Scalar, prim.Expression] PYTHON_SCALAR_CLASSES = (int, float, complex, bool) SCALAR_CLASSES = prim.VALID_CONSTANT_CLASSES diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 25edfc7da..4c281955c 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1831,6 +1831,33 @@ def rec_get_user_nodes(expr: ArrayOrNames, # }}} +# {{{ BranchMorpher + +class BranchMorpher(CopyMapper): + """ + A mapper that replaces equal segments of graphs with identical objects. + """ + def __init__(self) -> None: + super().__init__() + self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {} + + def cache_key(self, expr: CachedMapperT) -> Any: + return (id(expr), expr) + + # type-ignore reason: incompatible with Mapper.rec + def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] + rec_expr = super().rec(expr) + try: + # type-ignored because 'result_cache' maps to ArrayOrNames + return self.result_cache[rec_expr] # type: ignore[return-value] + except KeyError: + self.result_cache[rec_expr] = rec_expr + # type-ignored because of super-class' relaxed types + return rec_expr # type: ignore[no-any-return] + +# }}} + + # {{{ deduplicate_data_wrappers def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable: @@ -1920,8 +1947,9 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames: len(data_wrapper_cache), data_wrappers_encountered - len(data_wrapper_cache)) - return array_or_names + return BranchMorpher()(array_or_names) # }}} + # vim: foldmethod=marker diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 179366bc8..4a4144805 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -133,6 +133,51 @@ def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]: assert expr.size == 1 return () + if expr.order not in ["C", "F"]: + raise NotImplementedError("Order expected to be 'C' or 'F'", + f" found {expr.order}") + + if expr.order == "C": + newstrides: list[IntegralT] = [1] # reshaped array strides + for new_axis_len in reversed(expr.shape[1:]): + assert isinstance(new_axis_len, INT_CLASSES) + newstrides.insert(0, newstrides[0]*new_axis_len) + + flattened_idx = sum(prim.Variable(f"_{i}")*stride + for i, stride in enumerate(newstrides)) + + oldstrides: list[IntegralT] = [1] # input array strides + for axis_len in reversed(expr.array.shape[1:]): + assert isinstance(axis_len, INT_CLASSES) + oldstrides.insert(0, oldstrides[0]*axis_len) + + assert isinstance(expr.array.shape[-1], INT_CLASSES) + oldsizetills = [expr.array.shape[-1]] # input array size + # till for axes idx + for old_axis_len in reversed(expr.array.shape[:-1]): + assert isinstance(old_axis_len, INT_CLASSES) + oldsizetills.insert(0, oldsizetills[0]*old_axis_len) + + else: + newstrides: list[IntegralT] = [1] # reshaped array strides + for new_axis_len in expr.shape[:-1]: + assert isinstance(new_axis_len, INT_CLASSES) + newstrides.append(newstrides[-1]*new_axis_len) + + flattened_idx = sum(prim.Variable(f"_{i}")*stride + for i, stride in enumerate(newstrides)) + + oldstrides: list[IntegralT] = [1] # input array strides + for axis_len in expr.array.shape[:-1]: + assert isinstance(axis_len, INT_CLASSES) + oldstrides.append(oldstrides[-1]*axis_len) + + assert isinstance(expr.array.shape[0], INT_CLASSES) + oldsizetills = [expr.array.shape[0]] # input array size till for axes idx + for old_axis_len in expr.array.shape[1:]: + assert isinstance(old_axis_len, INT_CLASSES) + oldsizetills.append(oldsizetills[-1]*old_axis_len) + if new_shape == (): return _generate_index_expressions(old_shape, new_shape, order, index_vars) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 139b7bf5b..b8fee7ec7 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -613,8 +613,12 @@ def rec(self, expr: ArrayOrNames) -> Any: assert expr_copy.ndim == expr.ndim for iaxis in range(expr.ndim): + axis_tags = self.axis_to_tags.get((expr, iaxis), []) + if len(axis_tags) == 0: + print(f"failed to infer axis {iaxis} of array of type {type(expr)}.") + print(f"{expr.non_equality_tags=}") expr_copy = expr_copy.with_tagged_axis( - iaxis, self.axis_to_tags.get((expr, iaxis), [])) + iaxis, axis_tags) # {{{ tag reduction descrs diff --git a/pytato/utils.py b/pytato/utils.py index 72984b015..e243a20ee 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -29,9 +29,22 @@ cast, ) -import islpy as isl -import numpy as np +from typing import (Tuple, List, Union, Callable, Any, Sequence, Dict, + Optional, Iterable, TypeVar, FrozenSet) +from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent, + DtypeOrScalar, ArrayOrScalar, BasicIndex, + AdvancedIndexInContiguousAxes, + AdvancedIndexInNoncontiguousAxes, + ConvertibleToIndexExpr, IndexExpr, NormalizedSlice, + _dtype_any, Einsum) +from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression, + SCALAR_CLASSES, INT_CLASSES, BoolT) +from pytools import UniqueNameGenerator +from pytato.transform import Mapper +from pytools.tag import Tag from immutabledict import immutabledict +import numpy as np +import islpy as isl import pymbolic.primitives as prim from pymbolic import ScalarT @@ -168,6 +181,19 @@ def with_indices_for_broadcasted_shape(val: prim.Variable, shape: ShapeType, return val[get_indexing_expression(shape, result_shape)] +def extract_dtypes_or_scalars( + exprs: Sequence[ArrayOrScalar]) -> List[DtypeOrScalar]: + dtypes: List[DtypeOrScalar] = [] + for expr in exprs: + if isinstance(expr, Array): + dtypes.append(expr.dtype) + else: + assert isinstance(expr, SCALAR_CLASSES) + dtypes.append(expr) + + return dtypes + + def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, bnd_name: str, bindings: dict[str, Array], @@ -206,16 +232,19 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, ) -> ArrayOrScalar: from pytato.array import _get_default_axes + if isinstance(a1, SCALAR_CLASSES): + a1 = np.dtype(type(a1)).type(a1) + + if isinstance(a2, SCALAR_CLASSES): + a2 = np.dtype(type(a2)).type(a2) + if np.isscalar(a1) and np.isscalar(a2): from pytato.scalar_expr import evaluate return evaluate(op(a1, a2)) # type: ignore result_shape = get_shape_after_broadcasting([a1, a2]) - - # Note: get_result_type calls np.result_type by default, which means - # that we are passing a pytato array to numpy. Luckily, np.result_type - # only looks at the dtype of input arrays as of numpy v2.1. - result_dtype = get_result_type(a1, a2) + dtypes = extract_dtypes_or_scalars([a1, a2]) + result_dtype = get_result_type(*dtypes) bindings: dict[str, Array] = {} diff --git a/test/test_codegen.py b/test/test_codegen.py index e65f5b7e9..c1499e741 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -317,6 +317,7 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse): out_ref = np_op(x_in, y_orig.astype(dtype)) assert out.dtype == out_ref.dtype, (out.dtype, out_ref.dtype) + # In some cases ops are done in float32 in loopy but float64 in numpy. assert np.allclose(out, out_ref), (out, out_ref) @@ -1612,7 +1613,7 @@ def test_zero_size_cl_array_dedup(ctx_factory): x4 = pt.make_data_wrapper(x_cl2) out = pt.make_dict_of_named_arrays({"out1": 2*x1, - "out2": 2*x2, + "out2": 3*x2, "out3": x3 + x4 }) diff --git a/test/test_distributed.py b/test/test_distributed.py index ac7ca1389..1554a024b 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -899,13 +899,11 @@ def test_number_symbolic_tags_bare_classes(ctx_factory): outputs = pt.make_dict_of_named_arrays({"out": res}) partition = pt.find_distributed_partition(comm, outputs) - (_distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242) + (distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242) assert next_tag == 4244 - # FIXME: For the next assertion, find_distributed_partition needs to be - # deterministic too (https://github.com/inducer/pytato/pull/465). - # assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # noqa: E501 + assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 271c8fb01..5523d6a8b 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -514,8 +514,8 @@ def _assert_stripped_repr(ary: pt.Array, expected_repr: str): dtype='int64', expr=Product((Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), - TypeCast(dtype('int64'), Subscript(Variable('_in1'), - (Variable('_0'), Variable('_1')))))), + Subscript(Variable('_in1'), + (Variable('_0'), Variable('_1'))))), bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='y'), '_in1': IndexLambda( shape=(10, 4), @@ -1364,6 +1364,94 @@ def test_dot_visualizers(): # }}} +# {{{ Test PytatoKeyBuilder + +def run_test_with_new_python_invocation(f, *args, extra_env_vars=None) -> None: + import os + if extra_env_vars is None: + extra_env_vars = {} + + from base64 import b64encode + from pickle import dumps + from subprocess import check_call + + env_vars = { + "INVOCATION_INFO": b64encode(dumps((f, args))).decode(), + } + env_vars.update(extra_env_vars) + + my_env = os.environ.copy() + my_env.update(env_vars) + + check_call([sys.executable, __file__], env=my_env) + + +def run_test_with_new_python_invocation_inner() -> None: + from base64 import b64decode + from pickle import loads + import os + + f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode())) + + f(*args) + + +def test_persistent_hashing_and_persistent_dict() -> None: + from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError + from pytato.analysis import PytatoKeyBuilder + import shutil + import tempfile + + try: + tmpdir = tempfile.mkdtemp() + + pkb = PytatoKeyBuilder() + + pd = WriteOncePersistentDict("test_persistent_dict", + key_builder=pkb, + container_dir=tmpdir) + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=5, use_numpy=True) + + dag = make_random_dag(rdagc) + + # Make sure the PytatoKeyBuilder can handle 'dag' + pd[dag] = 42 + + # Make sure that the key stays the same within the same Python invocation + with pytest.raises(ReadOnlyEntryError): + pd[dag] = 42 + + # Make sure that the key stays the same across Python invocations + run_test_with_new_python_invocation( + _test_persistent_hashing_and_persistent_dict_stage2, tmpdir) + finally: + shutil.rmtree(tmpdir) + + +def _test_persistent_hashing_and_persistent_dict_stage2(tmpdir) -> None: + from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError + + from pytato.analysis import PytatoKeyBuilder + pkb = PytatoKeyBuilder() + + pd = WriteOncePersistentDict("test_persistent_dict", + key_builder=pkb, + container_dir=tmpdir) + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=5, use_numpy=True) + + dag = make_random_dag(rdagc) + + with pytest.raises(ReadOnlyEntryError): + pd[dag] = 42 + +# }}} + def test_numpy_type_promotion_with_pytato_arrays(): class NotReallyAnArray: @property @@ -1380,7 +1468,10 @@ def dtype(self): if __name__ == "__main__": - if len(sys.argv) > 1: + import os + if "INVOCATION_INFO" in os.environ: + run_test_with_new_python_invocation_inner() + elif len(sys.argv) > 1: exec(sys.argv[1]) else: from pytest import main