From e7337041c950fc45a86ca39235e929a5733ee620 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 18 Jul 2024 13:46:40 -0500 Subject: [PATCH 1/8] Apply CEESD changes --- meshmode/array_context.py | 1540 ++++++++++++++++++ meshmode/discretization/connection/direct.py | 57 +- meshmode/discretization/poly_element.py | 6 + meshmode/distributed.py | 112 +- meshmode/mesh/__init__.py | 20 +- meshmode/mesh/io.py | 17 +- meshmode/mesh/processing.py | 11 +- meshmode/pytato_utils.py | 115 ++ meshmode/transform_metadata.py | 10 + requirements.txt | 10 +- test/3x3.msh | 38 + test/3x3_bound.msh | 54 + test/3x3_minus.msh | 38 + test/3x3_twisted.msh | 38 + test/3x3_twisted_bound.msh | 54 + test/test_mesh.py | 62 +- 16 files changed, 2123 insertions(+), 59 deletions(-) create mode 100644 meshmode/pytato_utils.py create mode 100644 test/3x3.msh create mode 100644 test/3x3_bound.msh create mode 100644 test/3x3_minus.msh create mode 100644 test/3x3_twisted.msh create mode 100644 test/3x3_twisted_bound.msh diff --git a/meshmode/array_context.py b/meshmode/array_context.py index fbd8e25da..15e739136 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -25,6 +25,17 @@ THE SOFTWARE. """ +import logging +import numpy as np +from typing import ( + Union, + FrozenSet, + Tuple, + Any, + Optional, + Callable, + TYPE_CHECKING +) from warnings import warn from arraycontext import ( @@ -36,6 +47,40 @@ _PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory, ) +from loopy.translation_unit import for_each_kernel + +from loopy.tools import memoize_on_disk +from pytools import ProcessLogger, memoize_on_first_arg +from pytools.tag import UniqueTag, tag_dataclass + +from meshmode import Error +from meshmode.transform_metadata import (DiscretizationElementAxisTag, + DiscretizationDOFAxisTag, + DiscretizationFaceAxisTag, + DiscretizationDimAxisTag, + DiscretizationTopologicalDimAxisTag, + DiscretizationAmbientDimAxisTag, + DiscretizationFlattenedDOFAxisTag, + DiscretizationEntityAxisTag) +from dataclasses import dataclass + +if TYPE_CHECKING: + import pyopencl as cl + +from immutabledict import immutabledict +logger = logging.getLogger(__name__) + + +class ArrayContextLoopyTransformError(Error): + pass + + +class AxisTagInferenceError(ArrayContextLoopyTransformError): + pass + + +class EinsumInferenceError(ArrayContextLoopyTransformError): + pass def thaw(actx, ary): @@ -342,4 +387,1499 @@ def __getattr__(name): # }}} +@for_each_kernel +def _single_grid_work_group_transform(kernel, cl_device): + import loopy as lp + from meshmode.transform_metadata import (ConcurrentElementInameTag, + ConcurrentDOFInameTag) + + splayed_inames = set() + ngroups = cl_device.max_compute_units * 4 # '4' to overfill the device + l_one_size = 4 + l_zero_size = 16 + + for insn in kernel.instructions: + if insn.within_inames in splayed_inames: + continue + + if isinstance(insn, lp.CallInstruction): + # must be a callable kernel, don't touch. + pass + elif isinstance(insn, lp.Assignment): + bigger_loop = None + smaller_loop = None + + if len(insn.within_inames) == 0: + continue + + if len(insn.within_inames) == 1: + iname, = insn.within_inames + + kernel = lp.split_iname(kernel, iname, + ngroups * l_zero_size * l_one_size) + kernel = lp.split_iname(kernel, f"{iname}_inner", + l_zero_size, inner_tag="l.0") + kernel = lp.split_iname(kernel, f"{iname}_inner_outer", + l_one_size, inner_tag="l.1", + outer_tag="g.0") + + splayed_inames.add(insn.within_inames) + continue + + for iname in insn.within_inames: + if kernel.iname_tags_of_type(iname, + ConcurrentElementInameTag): + assert bigger_loop is None + bigger_loop = iname + elif kernel.iname_tags_of_type(iname, + ConcurrentDOFInameTag): + assert smaller_loop is None + smaller_loop = iname + else: + pass + + if bigger_loop or smaller_loop: + assert (bigger_loop is not None + and smaller_loop is not None) + else: + sorted_inames = sorted(tuple(insn.within_inames), + key=kernel.get_constant_iname_length) + smaller_loop = sorted_inames[0] + bigger_loop = sorted_inames[-1] + + kernel = lp.split_iname(kernel, f"{bigger_loop}", + l_one_size * ngroups) + kernel = lp.split_iname(kernel, f"{bigger_loop}_inner", + l_one_size, inner_tag="l.1", outer_tag="g.0") + kernel = lp.split_iname(kernel, smaller_loop, + l_zero_size, inner_tag="l.0") + splayed_inames.add(insn.within_inames) + elif isinstance(insn, lp.BarrierInstruction): + pass + else: + raise NotImplementedError(type(insn)) + + return kernel + + +def _alias_global_temporaries(t_unit): + """ + Returns a copy of *t_unit* with temporaries of that have disjoint live + intervals using the same :attr:`loopy.TemporaryVariable.base_storage`. + """ + from loopy.kernel.data import AddressSpace + from collections import defaultdict + + kernel = t_unit.default_entrypoint + toposorted_iels = _get_element_loop_topo_sorted_order(kernel) + iel_order = {iel: i + for i, iel in enumerate(toposorted_iels)} + + temp_vars = frozenset(tv.name + for tv in kernel.temporary_variables.values() + if tv.address_space == AddressSpace.GLOBAL) + temp_var_to_iels = {tv: set() for tv in temp_vars} + all_iels = { + iel + for iel in kernel.all_inames() + if kernel.inames[iel].tags_of_type((DiscretizationElementAxisTag, + DiscretizationFlattenedDOFAxisTag))} + + if not all_iels: + # no element loops => return the t_unit as is. + return t_unit + + for insn in kernel.instructions: + iel, = insn.within_inames & all_iels + + for tv in insn.dependency_names() & temp_vars: + temp_var_to_iels[tv].add(iel) + + temp_to_iel_start = {tv: min(iels, + key=lambda x: iel_order[x], + default=toposorted_iels[-1] + ) + for tv, iels in temp_var_to_iels.items()} + temp_to_iel_end = {tv: max(iels, + key=lambda x: iel_order[x], + default=toposorted_iels[0] + ) + for tv, iels in temp_var_to_iels.items()} + + iel_to_temps_to_allocate = {iel: set() for iel in all_iels} + iel_to_temps_to_free = {iel: set() for iel in all_iels} + for tv in temp_vars: + allocate_iel, free_iel = temp_to_iel_start[tv], temp_to_iel_end[tv] + if iel_order[allocate_iel] >= iel_order[free_iel]: + continue + iel_to_temps_to_allocate[allocate_iel].add(tv) + iel_to_temps_to_free[free_iel].add(tv) + + vng = kernel.get_var_name_generator() + # a mapping from shape to the available base storages from temp variables + # that were dead. + shape_to_available_base_storage = defaultdict(set) + + new_tvs = {} + + for iel in toposorted_iels: + to_be_allocated_temps = iel_to_temps_to_allocate[iel] + + for tv_name in sorted(to_be_allocated_temps): + assert len(to_be_allocated_temps) <= 1 + tv = kernel.temporary_variables[tv_name] + assert tv.name not in new_tvs + assert tv.base_storage is None + if shape_to_available_base_storage[tv.nbytes]: + base_storage = sorted(shape_to_available_base_storage[tv.nbytes])[0] + shape_to_available_base_storage[tv.nbytes].remove(base_storage) + else: + base_storage = vng("_msh_actx_tmp_base") + + new_tvs[tv.name] = tv.copy(base_storage=base_storage) + + just_dead_temps = iel_to_temps_to_free[iel] + for tv_name in sorted(just_dead_temps): + tv = new_tvs[tv_name] + assert tv.base_storage is not None + assert tv.base_storage not in shape_to_available_base_storage[tv.nbytes] + shape_to_available_base_storage[tv.nbytes].add(tv.base_storage) + + for name, tv in kernel.temporary_variables.items(): + if tv.address_space != AddressSpace.GLOBAL: + new_tvs[name] = tv + else: + # FIXME: Need tighter assertion condition (this doesn't work when + # zero-size arrays are present) + # assert name in new_tvs + pass + + kernel = kernel.copy(temporary_variables=new_tvs) + + old_tmp_mem_requirement = sum( + tv.nbytes + for tv in kernel.temporary_variables.values()) + + new_tmp_mem_requirement = sum( + {tv.base_storage: tv.nbytes + for tv in kernel.temporary_variables.values()}.values()) + + logger.info( + f"[_alias_global_temporaries]: Reduced memory requirement of '{kernel.name}' from " + f"{old_tmp_mem_requirement*1e-6:.1f}MB to" + f" {new_tmp_mem_requirement*1e-6:.1f}MB.") + + return t_unit.with_kernel(kernel) + + +def _can_be_eagerly_computed(ary) -> bool: + from pytato.transform import InputGatherer + from pytato.array import Placeholder + return all(not isinstance(inp, Placeholder) + for inp in InputGatherer()(ary)) + + +def deduplicate_data_wrappers(dag): + import pytato as pt + data_wrapper_cache = {} + data_wrappers_encountered = 0 + + def cached_data_wrapper_if_present(ary): + nonlocal data_wrappers_encountered + + if isinstance(ary, pt.DataWrapper): + + data_wrappers_encountered += 1 + cache_key = (ary.data.base_data.int_ptr, ary.data.offset, + ary.shape, ary.data.strides) + try: + result = data_wrapper_cache[cache_key] + except KeyError: + result = ary + data_wrapper_cache[cache_key] = result + + return result + else: + return ary + + dag = pt.transform.map_and_copy(dag, cached_data_wrapper_if_present) + + if data_wrappers_encountered: + logger.info("data wrapper de-duplication: " + "%d encountered, %d kept, %d eliminated", + data_wrappers_encountered, + len(data_wrapper_cache), + data_wrappers_encountered - len(data_wrapper_cache)) + + return dag + + +class SingleGridWorkBalancingPytatoArrayContext(PytatoPyOpenCLArrayContextBase): + """ + A :class:`PytatoPyOpenCLArrayContext` that parallelizes work in an OpenCL + kernel so that the work + """ + def transform_loopy_program(self, t_unit): + import loopy as lp + + t_unit = _single_grid_work_group_transform(t_unit, self.queue.device) + t_unit = lp.set_options(t_unit, "insert_gbarriers") + + return t_unit + + def _get_fake_numpy_namespace(self): + from meshmode.pytato_utils import ( + EagerReduceComputingPytatoFakeNumpyNamespace) + return EagerReduceComputingPytatoFakeNumpyNamespace(self) + + def transform_dag(self, dag): + import pytato as pt + + # {{{ face_mass: materialize einsum args + + def materialize_face_mass_vec(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, "ifj,fej,fej->ei")): + mat, jac, vec = expr.args + return pt.einsum("ifj,fej,fej->ei", + mat, + jac, + vec.tagged(pt.tags.ImplStored())) + else: + return expr + + dag = pt.transform.map_and_copy(dag, materialize_face_mass_vec) + + # }}} + + # {{{ materialize all einsums + + def materialize_einsums(ary: pt.Array) -> pt.Array: + if isinstance(ary, pt.Einsum): + return ary.tagged(pt.tags.ImplStored()) + + return ary + + dag = pt.transform.map_and_copy(dag, materialize_einsums) + + # }}} + + dag = pt.transform.materialize_with_mpms(dag) + dag = deduplicate_data_wrappers(dag) + + # {{{ /!\ Remove tags from Loopy call results. + # See + + def untag_loopy_call_results(expr): + from pytato.loopy import LoopyCallResult + if isinstance(expr, LoopyCallResult): + return expr.copy(tags=frozenset(), + axes=(pt.Axis(frozenset()),)*expr.ndim) + else: + return expr + + dag = pt.transform.map_and_copy(dag, untag_loopy_call_results) + + # }}} + + return dag + + +def get_temps_not_to_contract(knl): + from functools import reduce + wmap = knl.writer_map() + rmap = knl.reader_map() + + temps_not_to_contract = set() + for tv in knl.temporary_variables: + if len(wmap.get(tv, set())) == 1: + writer_id, = wmap[tv] + writer_loop_nest = knl.id_to_insn[writer_id].within_inames + insns_in_writer_loop_nest = reduce(frozenset.union, + (knl.iname_to_insns()[iname] + for iname in writer_loop_nest), + frozenset()) + if ( + (not (rmap.get(tv, frozenset()) + <= insns_in_writer_loop_nest)) + or len(knl.id_to_insn[writer_id].reduction_inames()) != 0 + or any((len(knl.id_to_insn[reader_id].reduction_inames()) != 0) + for reader_id in rmap.get(tv, frozenset()))): + temps_not_to_contract.add(tv) + else: + temps_not_to_contract.add(tv) + return temps_not_to_contract + + # Better way to query it... + # import loopy as lp + # from kanren.constraints import neq as kanren_neq + # + # tempo = lp.relations.get_tempo(knl) + # producero = lp.relations.get_producero(knl) + # consumero = lp.relations.get_consumero(knl) + # withino = lp.relations.get_withino(knl) + # reduce_insno = lp.relations.get_reduce_insno(knl) + # + # # temp_k: temporary variable that cannot be contracted + # temp_k = kanren.var() + # producer_insn_k = kanren.var() + # producer_loops_k = kanren.var() + # consumer_insn_k = kanren.var() + # consumer_loops_k = kanren.var() + + # temps_not_to_contract = kanren.run(0, + # temp_k, + # tempo(temp_k), + # producero(producer_insn_k, + # temp_k), + # consumero(consumer_insn_k, + # temp_k), + # withino(producer_insn_k, + # producer_loops_k), + # withino(consumer_insn_k, + # consumer_loops_k), + # kanren.lany( + # kanren_neq( + # producer_loops_k, + # consumer_loops_k), + # reduce_insno(consumer_insn_k)), + # results_filter=frozenset) + # return temps_not_to_contract + + +def _is_iel_loop_part_of_global_dof_loops(iel: str, knl) -> bool: + insn, = knl.iname_to_insns()[iel] + return any(iname + for iname in knl.id_to_insn[insn].within_inames + if knl.iname_tags_of_type(iname, DiscretizationDOFAxisTag)) + + +def _discr_entity_sort_key(discr_tag: DiscretizationEntityAxisTag + ) -> Tuple[Any, ...]: + + return type(discr_tag).__name__ + + +# {{{ define FEMEinsumTag + +@dataclass(frozen=True) +class EinsumIndex: + discr_entity: DiscretizationEntityAxisTag + length: int + + @classmethod + def from_iname(cls, iname, kernel): + discr_entity, = kernel.filter_iname_tags_by_type( + iname, DiscretizationEntityAxisTag) + length = kernel.get_constant_iname_length(iname) + return cls(discr_entity, length) + + +@dataclass(frozen=True) +class FreeEinsumIndex(EinsumIndex): + pass + + +@dataclass(frozen=True) +class SummationEinsumIndex(EinsumIndex): + pass + + +@dataclass(frozen=True) +class FEMEinsumTag(UniqueTag): + indices: Tuple[Tuple[EinsumIndex, ...], ...] + + +class NotAnFEMEinsumError(ValueError): + """ + pass + """ + +# }}} + + +@memoize_on_first_arg +def _get_redn_iname_to_insns(kernel): + redn_iname_to_insns = {iname: set() + for iname in kernel.all_inames()} + + for insn in kernel.instructions: + for redn_iname in insn.reduction_inames(): + redn_iname_to_insns[redn_iname].add(insn.id) + + return immutabledict({k: frozenset(v) + for k, v in redn_iname_to_insns.items()}) + + +def _do_inames_belong_to_different_einsum_types(iname1, iname2, kernel): + if kernel.iname_to_insns()[iname1]: + assert (len(kernel.iname_to_insns()[iname1]) + == len(kernel.iname_to_insns()[iname2]) + == 1) + insn1, = kernel.iname_to_insns()[iname1] + insn2, = kernel.iname_to_insns()[iname2] + else: + redn_iname_to_insns = _get_redn_iname_to_insns(kernel) + assert (len(redn_iname_to_insns[iname1]) + == len(redn_iname_to_insns[iname2]) + == 1) + insn1, = redn_iname_to_insns[iname1] + insn2, = redn_iname_to_insns[iname2] + + assert (len(redn_iname_to_insns[iname1]) + == len(redn_iname_to_insns[iname2]) + == 1) + + var1_name, = kernel.id_to_insn[insn1].assignee_var_names() + var2_name, = kernel.id_to_insn[insn2].assignee_var_names() + var1 = kernel.get_var_descriptor(var1_name) + var2 = kernel.get_var_descriptor(var2_name) + + ensm1, = var1.tags_of_type(FEMEinsumTag) + ensm2, = var2.tags_of_type(FEMEinsumTag) + + return ensm1 != ensm2 + + +def _fuse_loops_over_a_discr_entity(knl, + mesh_entity, + fused_loop_prefix, + should_fuse_redn_loops, + orig_knl): + import loopy as lp + import kanren + from functools import reduce, partial + taggedo = lp.relations.get_taggedo_of_type(orig_knl, mesh_entity) + + redn_loops = reduce(frozenset.union, + (insn.reduction_inames() + for insn in orig_knl.instructions), + frozenset()) + + non_redn_loops = reduce(frozenset.union, + (insn.within_inames + for insn in orig_knl.instructions), + frozenset()) + + # tag_k: tag of type 'mesh_entity' + tag_k = kanren.var() + tags = kanren.run(0, + tag_k, + taggedo(kanren.var(), tag_k), + results_filter=frozenset) + for itag, tag in enumerate( + sorted(tags, key=lambda x: _discr_entity_sort_key(x))): + # iname_k: iname tagged with 'tag' + iname_k = kanren.var() + inames = kanren.run(0, + iname_k, + taggedo(iname_k, tag), + results_filter=frozenset) + inames = frozenset(inames) + if should_fuse_redn_loops: + inames = inames & redn_loops + else: + inames = inames & non_redn_loops + + length_to_inames = {} + for iname in inames: + length = knl.get_constant_iname_length(iname) + length_to_inames.setdefault(length, set()).add(iname) + + for i, (_, inames_to_fuse) in enumerate( + sorted(length_to_inames.items())): + + knl = lp.rename_inames_in_batch( + knl, + lp.get_kennedy_unweighted_fusion_candidates( + knl, inames_to_fuse, + prefix=f"{fused_loop_prefix}_{itag}_{i}_", + force_infusible=partial( + _do_inames_belong_to_different_einsum_types, + kernel=orig_knl), + )) + knl = lp.tag_inames(knl, {f"{fused_loop_prefix}_{itag}_*": tag}) + + return knl + + +@memoize_on_disk +def fuse_same_discretization_entity_loops(knl): + # maintain an 'orig_knl' to keep the original iname and tags before + # transforming it. + orig_knl = knl + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationFaceAxisTag, + "iface", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationElementAxisTag, + "iel", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDOFAxisTag, + "idof", + False, + orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag, + "idim", + False, + orig_knl) + + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationFaceAxisTag, + "iface", + True, + orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDOFAxisTag, + "idof", + True, + orig_knl) + knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationDimAxisTag, + "idim", + True, + orig_knl) + + return knl + + +@memoize_on_disk +def contract_arrays(knl, callables_table): + import loopy as lp + from loopy.transform.precompute import precompute_for_single_kernel + + temps_not_to_contract = get_temps_not_to_contract(knl) + all_temps = frozenset(knl.temporary_variables) + + logger.info("Array Contraction: Contracting " + f"{len(all_temps-frozenset(temps_not_to_contract))} temps") + + wmap = knl.writer_map() + + for temp in sorted(all_temps - frozenset(temps_not_to_contract)): + writer_id, = wmap[temp] + rmap = knl.reader_map() + ensm_tag, = knl.id_to_insn[writer_id].tags_of_type(EinsumTag) + + knl = lp.assignment_to_subst(knl, temp, + remove_newly_unused_inames=False) + if temp not in rmap: + # no one was reading 'temp' i.e. dead code got eliminated :) + assert f"{temp}_subst" not in knl.substitutions + continue + try: + knl = precompute_for_single_kernel( + knl, callables_table, f"{temp}_subst", + sweep_inames=(), + temporary_address_space=lp.AddressSpace.PRIVATE, + compute_insn_id=f"_mm_contract_{temp}", + _enable_mirgecom_workaround=True, + ) + except TypeError as e: + if "_enable_mirgecom_workaround" in str(e): + knl = precompute_for_single_kernel( + knl, callables_table, f"{temp}_subst", + sweep_inames=(), + temporary_address_space=lp.AddressSpace.PRIVATE, + compute_insn_id=f"_mm_contract_{temp}", + ) + else: + raise + + knl = lp.map_instructions(knl, + f"id:_mm_contract_{temp}", + lambda x: x.tagged(ensm_tag)) + + return lp.remove_unused_inames(knl) + + +def _get_group_size_for_dof_array_loop(nunit_dofs): + """ + Returns the OpenCL workgroup size for a loop iterating over the global DOFs + of a discretization with *nunit_dofs* per cell. + """ + if nunit_dofs == {6}: + return 16, 6 + elif nunit_dofs == {10}: + return 16, 10 + elif nunit_dofs == {20}: + return 16, 10 + elif nunit_dofs == {1}: + return 32, 1 + elif nunit_dofs == {2}: + return 32, 2 + elif nunit_dofs == {4}: + return 16, 4 + elif nunit_dofs == {3}: + return 32, 3 + elif nunit_dofs == {35}: + return 9, 7 + elif nunit_dofs == {15}: + return 8, 8 + elif nunit_dofs == {7}: + return 9, 7 + else: + # /!\ not ideal performance-wise but better than raising. + return 8, 4 + + +def _get_iel_to_idofs(kernel): + iel_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type((DiscretizationElementAxisTag, + DiscretizationFlattenedDOFAxisTag))) + } + idof_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type(DiscretizationDOFAxisTag)) + } + iface_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type(DiscretizationFaceAxisTag)) + } + idim_inames = {iname + for iname in kernel.all_inames() + if (kernel + .inames[iname] + .tags_of_type(DiscretizationDimAxisTag)) + } + + iel_to_idofs = {iel: set() for iel in iel_inames} + + for insn in kernel.instructions: + if (len(insn.within_inames) == 1 + and (insn.within_inames) <= iel_inames): + iel, = insn.within_inames + if all(kernel.id_to_insn[el_insn].within_inames == insn.within_inames + for el_insn in kernel.iname_to_insns()[iel]): + # the iel here doesn't interfere with any idof i.e. we + # support parallelizing such loops. + pass + else: + raise NotImplementedError(f"The loop {insn.within_inames}" + " does not appear as a singly nested" + " loop.") + elif ((len(insn.within_inames) == 2) + and (len(insn.within_inames & iel_inames) == 1) + and (len(insn.within_inames & idof_inames) == 1)): + iel, = insn.within_inames & iel_inames + idof, = insn.within_inames & idof_inames + iel_to_idofs[iel].add(idof) + if all((iel in kernel.id_to_insn[dof_insn].within_inames) + for dof_insn in kernel.iname_to_insns()[idof]): + pass + else: + for dof_insn in kernel.iname_to_insns()[idof]: + if iel not in kernel.id_to_insn[dof_insn].within_inames: + print(f"_get_iel_to_idofs: {str(kernel.id_to_insn[dof_insn])=}") + raise NotImplementedError("The loop " + f"'{insn.within_inames}' has the idof-loop" + " that's not nested within the iel-loop.") + elif ((len(insn.within_inames) > 2) + and (len(insn.within_inames & iel_inames) == 1) + and (len(insn.within_inames & idof_inames) == 1) + and (len(insn.within_inames & (idim_inames | iface_inames)) + == (len(insn.within_inames) - 2))): + iel, = insn.within_inames & iel_inames + idof, = insn.within_inames & idof_inames + iel_to_idofs[iel].add(idof) + if all((all({iel, idof} <= kernel.id_to_insn[non_iel_insn].within_inames + for non_iel_insn in kernel.iname_to_insns()[non_iel_iname])) + for non_iel_iname in insn.within_inames - {iel}): + iel_to_idofs[iel].add(idof) + else: + raise NotImplementedError("Could not fit into " + " loop nest pattern.") + else: + print(f"_get_iel_to_idofs: {str(insn)=}") + raise NotImplementedError(f"Cannot fit loop nest '{insn.within_inames}'" + " into known set of loop-nest patterns.") + + return immutabledict({iel: frozenset(idofs) + for iel, idofs in iel_to_idofs.items()}) + + +def _get_iel_loop_from_insn(insn, knl): + iel, = {iname + for iname in insn.within_inames + if knl.inames[iname].tags_of_type((DiscretizationElementAxisTag, + DiscretizationFlattenedDOFAxisTag))} + return iel + + +def _get_element_loop_topo_sorted_order(knl): + from loopy import MultiAssignmentBase + dag = {iel: set() + for iel in knl.all_inames() + if knl.inames[iel].tags_of_type(DiscretizationElementAxisTag)} + + for insn in knl.instructions: + if isinstance(insn, MultiAssignmentBase): + succ_iel = _get_iel_loop_from_insn(insn, knl) + for dep_id in insn.depends_on: + pred_iel = _get_iel_loop_from_insn(knl.id_to_insn[dep_id], knl) + if pred_iel != succ_iel: + dag[pred_iel].add(succ_iel) + + from pytools.graph import compute_topological_order + return compute_topological_order(dag, key=lambda x: x) + + +@tag_dataclass +class EinsumTag(UniqueTag): + orig_loop_nest: FrozenSet[str] + + +def _prepare_kernel_for_parallelization(kernel): + discr_tag_to_prefix = {DiscretizationElementAxisTag: "iel", + DiscretizationDOFAxisTag: "idof", + DiscretizationDimAxisTag: "idim", + DiscretizationAmbientDimAxisTag: "idim", + DiscretizationTopologicalDimAxisTag: "idim", + DiscretizationFlattenedDOFAxisTag: "imsh_nodes", + DiscretizationFaceAxisTag: "iface"} + import loopy as lp + from loopy.match import ObjTagged + + # A mapping from inames that the instruction accesss to + # the instructions ids within that iname. + ensm_buckets = {} + vng = kernel.get_var_name_generator() + + for insn in kernel.instructions: + inames = insn.within_inames | insn.reduction_inames() + ensm_buckets.setdefault(tuple(sorted(inames)), set()).add(insn.id) + + # FIXME: Dependency violation is a big concern here + # Waiting on the loopy feature: https://github.com/inducer/loopy/issues/550 + + for ieinsm, (loop_nest, insns) in enumerate(sorted(ensm_buckets.items())): + new_insns = [insn.tagged(EinsumTag(frozenset(loop_nest))) + if insn.id in insns + else insn + for insn in kernel.instructions] + kernel = kernel.copy(instructions=new_insns) + + new_inames = [] + for iname in loop_nest: + discr_tag, = kernel.iname_tags_of_type(iname, + DiscretizationEntityAxisTag) + new_iname = vng(f"{discr_tag_to_prefix[type(discr_tag)]}_ensm{ieinsm}") + new_inames.append(new_iname) + + kernel = lp.duplicate_inames( + kernel, + loop_nest, + within=ObjTagged(EinsumTag(frozenset(loop_nest))), + new_inames=new_inames, + tags={iname: kernel.inames[iname].tags + for iname in loop_nest}) + + return kernel + + +def _get_elementwise_einsum(t_unit, einsum_tag): + import loopy as lp + import feinsum as fnsm + from loopy.match import ObjTagged + from pymbolic.primitives import Variable, Subscript + + kernel = t_unit.default_entrypoint + + assert isinstance(einsum_tag, EinsumTag) + insn_match = ObjTagged(einsum_tag) + + global_vars = ({tv.name + for tv in kernel.temporary_variables.values() + if tv.address_space == lp.AddressSpace.GLOBAL} + | set(kernel.arg_dict.keys())) + insns = [insn + for insn in kernel.instructions + if insn_match(kernel, insn)] + idx_tuples = set() + + for insn in insns: + assert len(insn.assignees) == 1 + if isinstance(insn.assignee, Variable): + if insn.assignee.name in global_vars: + raise NotImplementedError(insn) + else: + assert (kernel.temporary_variables[insn.assignee.name].address_space + == lp.AddressSpace.PRIVATE) + elif isinstance(insn.assignee, Subscript): + assert insn.assignee_name in global_vars + idx_tuples.add(tuple(idx.name + for idx in insn.assignee.index_tuple)) + else: + raise NotImplementedError(insn) + + if len(idx_tuples) != 1: + raise NotImplementedError("Multiple einsums in the same loop nest =>" + " not allowed.") + idx_tuple, = idx_tuples + subscript = "{lhs}, {lhs}->{lhs}".format( + lhs="".join(chr(97+i) + for i in range(len(idx_tuple)))) + arg_shape = tuple(np.inf + if kernel.iname_tags_of_type(idx, DiscretizationElementAxisTag) + else kernel.get_constant_iname_length(idx) + for idx in idx_tuple) + return fnsm.einsum(subscript, + fnsm.array(arg_shape, "float64"), + fnsm.array(arg_shape, "float64")) + + +def _combine_einsum_domains(knl): + import islpy as isl + from functools import reduce + + new_domains = [] + einsum_tags = reduce( + frozenset.union, + (insn.tags_of_type(EinsumTag) + for insn in knl.instructions), + frozenset()) + + for tag in sorted(einsum_tags, + key=lambda x: sorted(x.orig_loop_nest)): + insns = [insn + for insn in knl.instructions + if tag in insn.tags] + inames = reduce(frozenset.union, + ((insn.within_inames | insn.reduction_inames()) + for insn in insns), + frozenset()) + domain = knl.get_inames_domain(frozenset(inames)) + new_domains.append(domain.project_out_except(sorted(inames), + [isl.dim_type.set])) + + return knl.copy(domains=new_domains) + + +class FusionContractorArrayContext( + SingleGridWorkBalancingPytatoArrayContext): + + def __init__( + self, queue: "cl.CommandQueue", allocator=None, *, + use_memory_pool: Optional[bool] = None, + compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None, + use_axis_tag_inference_fallback: bool = False, + use_einsum_inference_fallback: bool = False, + + # do not use: only for testing + _force_svm_arg_limit: Optional[int] = None, + ) -> None: + super().__init__( + queue, allocator, + use_memory_pool=use_memory_pool, + compile_trace_callback=compile_trace_callback, + _force_svm_arg_limit=_force_svm_arg_limit) + self.use_axis_tag_inference_fallback = use_axis_tag_inference_fallback + self.use_einsum_inference_fallback = use_einsum_inference_fallback + + def transform_dag(self, dag): + import pytato as pt + + # {{{ Remove FEMEinsumTags that might have been propagated + + # TODO: Is this too hacky? + + def remove_fem_einsum_tags(expr): + if isinstance(expr, pt.Array): + try: + fem_ensm_tag = next(iter(expr.tags_of_type(FEMEinsumTag))) + except StopIteration: + return expr + else: + # See https://github.com/inducer/arraycontext/pull/229 + # assert isinstance(expr, pt.InputArgumentBase) + return expr.without_tags(fem_ensm_tag) + else: + return expr + + dag = pt.transform.map_and_copy(dag, remove_fem_einsum_tags) + + # }}} + + # {{{ CSE + + with ProcessLogger(logger, "transform_dag.mpms_materialization"): + dag = pt.transform.materialize_with_mpms(dag) + + def mark_materialized_nodes_as_cse( + ary: Union[pt.Array, + pt.AbstractResultWithNamedArrays]) -> pt.Array: + if isinstance(ary, pt.AbstractResultWithNamedArrays): + return ary + + if ary.tags_of_type(pt.tags.ImplStored): + return ary.tagged(pt.tags.PrefixNamed("cse")) + else: + return ary + + with ProcessLogger(logger, "transform_dag.naming_cse"): + dag = pt.transform.map_and_copy(dag, mark_materialized_nodes_as_cse) + + # }}} + + # {{{ indirect addressing are non-negative + + indirection_maps = set() + + class _IndirectionMapRecorder(pt.transform.CachedWalkMapper): + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, expr) -> int: # type: ignore[override] + return id(expr) + + def post_visit(self, expr): + if isinstance(expr, pt.IndexBase): + for idx in expr.indices: + if isinstance(idx, pt.Array): + indirection_maps.add(idx) + + _IndirectionMapRecorder()(dag) + + def tag_indices_as_non_negative(ary): + if ary in indirection_maps: + return ary.tagged(pt.tags.AssumeNonNegative()) + else: + return ary + + with ProcessLogger(logger, "transform_dag.tag_indices_as_non_negative"): + dag = pt.transform.map_and_copy(dag, tag_indices_as_non_negative) + + # }}} + + with ProcessLogger(logger, "transform_dag.deduplicate_data_wrappers"): + dag = pt.transform.deduplicate_data_wrappers(dag) + + # {{{ get rid of copies for different views of a cl-array + + def eliminate_reshapes_of_data_wrappers(ary): + if (isinstance(ary, pt.Reshape) + and isinstance(ary.array, pt.DataWrapper)): + return pt.make_data_wrapper(ary.array.data.reshape(ary.shape), + tags=ary.tags, + axes=ary.axes) + else: + return ary + + dag = pt.transform.map_and_copy(dag, + eliminate_reshapes_of_data_wrappers) + + # }}} + + # {{{ face_mass: materialize einsum args + + def materialize_face_mass_input_and_output(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, + "ifj,fej,fej->ei")): + mat, jac, vec = expr.args + return (pt.einsum("ifj,fej,fej->ei", + mat, + jac, + vec.tagged(pt.tags.ImplStored())) + .tagged((pt.tags.ImplStored(), + pt.tags.PrefixNamed("face_mass")))) + else: + return expr + + with ProcessLogger(logger, + "transform_dag.materialize_face_mass_ins_and_outs"): + dag = pt.transform.map_and_copy(dag, + materialize_face_mass_input_and_output) + + # }}} + + # {{{ materialize inverse mass inputs + + def materialize_inverse_mass_inputs(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, + "ei,ij,ej->ei")): + arg1, arg2, arg3 = expr.args + if not arg3.tags_of_type(pt.tags.PrefixNamed): + arg3 = arg3.tagged(pt.tags.PrefixNamed("mass_inv_inp")) + if not arg3.tags_of_type(pt.tags.ImplStored): + arg3 = arg3.tagged(pt.tags.ImplStored()) + + return expr.copy(args=(arg1, arg2, arg3)) + else: + return expr + + dag = pt.transform.map_and_copy(dag, materialize_inverse_mass_inputs) + + # }}} + + # {{{ materialize all einsums + + def materialize_all_einsums_or_reduces(expr): + from pytato.raising import (index_lambda_to_high_level_op, + ReduceOp) + + if isinstance(expr, pt.Einsum): + return expr.tagged(pt.tags.ImplStored()) + elif (isinstance(expr, pt.IndexLambda) + and isinstance(index_lambda_to_high_level_op(expr), ReduceOp)): + return expr.tagged(pt.tags.ImplStored()) + else: + return expr + + with ProcessLogger(logger, + "transform_dag.materialize_all_einsums_or_reduces"): + dag = pt.transform.map_and_copy(dag, materialize_all_einsums_or_reduces) + + # }}} + + # {{{ infer axis types + + from meshmode.pytato_utils import unify_discretization_entity_tags + + with ProcessLogger(logger, "transform_dag.infer_axes_tags"): + dag = unify_discretization_entity_tags(dag) + + # }}} + + # {{{ /!\ Remove tags from Loopy call results. + # See + + def untag_loopy_call_results(expr): + from pytato.loopy import LoopyCallResult + if isinstance(expr, LoopyCallResult): + return expr.copy(tags=frozenset(), + axes=(pt.Axis(frozenset()),)*expr.ndim) + else: + return expr + + dag = pt.transform.map_and_copy(dag, untag_loopy_call_results) + + # }}} + + # {{{ remove broadcasts from einsums: help feinsum + + ensm_arg_rewrite_cache = {} + + def _get_rid_of_broadcasts_from_einsum(expr): + # Helpful for matching against the available expressions + # in feinsum. + + from pytato.utils import (are_shape_components_equal, + are_shapes_equal) + if isinstance(expr, pt.Einsum): + from pytato.array import EinsumElementwiseAxis + idx_to_len = expr._access_descr_to_axis_len() + new_access_descriptors = [] + new_args = [] + inp_gatherer = pt.transform.InputGatherer() + access_descr_to_axes = dict(expr.redn_axis_to_redn_descr) + for iax, axis in enumerate(expr.axes): + access_descr_to_axes[EinsumElementwiseAxis(iax)] = axis + + for access_descrs, arg in zip(expr.access_descriptors, + expr.args): + new_shape = [] + new_access_descrs = [] + new_axes = [] + for iaxis, (access_descr, axis_len) in enumerate( + zip(access_descrs, + arg.shape)): + if not are_shape_components_equal(axis_len, + idx_to_len[access_descr]): + assert are_shape_components_equal(axis_len, 1) + if any(isinstance(inp, pt.Placeholder) + for inp in inp_gatherer(arg)): + # do not get rid of broadcasts from parameteric + # data. + new_shape.append(axis_len) + new_access_descrs.append(access_descr) + new_axes.append(arg.axes[iaxis]) + else: + new_axes.append(arg.axes[iaxis]) + new_shape.append(axis_len) + new_access_descrs.append(access_descr) + + if not are_shapes_equal(new_shape, arg.shape): + assert len(new_axes) == len(new_shape) + arg_to_freeze = (arg.reshape(new_shape) + .copy(axes=tuple( + access_descr_to_axes[acc_descr] + for acc_descr in new_access_descrs))) + + try: + new_arg = ensm_arg_rewrite_cache[arg_to_freeze] + except KeyError: + new_arg = self.thaw(self.freeze(arg_to_freeze)) + ensm_arg_rewrite_cache[arg_to_freeze] = new_arg + + arg = new_arg + + assert arg.ndim == len(new_access_descrs) + new_args.append(arg) + new_access_descriptors.append(tuple(new_access_descrs)) + + return expr.copy( + access_descriptors=tuple(new_access_descriptors), + args=tuple(new_args)) + else: + return expr + + dag = pt.transform.map_and_copy(dag, _get_rid_of_broadcasts_from_einsum) + + # }}} + + # {{{ remove any PartID tags + + # FIXME: Remove after https://github.com/inducer/pytato/pull/393 goes in + try: + from pytato.distributed import PartIDTag + + def remove_part_id_tags(expr): + if isinstance(expr, pt.Array) and expr.tags_of_type(PartIDTag): + tag, = expr.tags_of_type(PartIDTag) + return expr.without_tags(tag) + else: + return expr + except ImportError: + remove_part_id_tags = None + + if remove_part_id_tags is not None: + dag = pt.transform.map_and_copy(dag, remove_part_id_tags) + + # }}} + + # {{{ attach FEMEinsumTag tags + + dag_outputs = frozenset(dag._data.values()) + + def add_fem_einsum_tags(expr): + if isinstance(expr, pt.Einsum): + from pytato.array import (EinsumElementwiseAxis, + EinsumReductionAxis) + assert expr.tags_of_type(pt.tags.ImplStored) + ensm_indices = [] + for arg, access_descrs in zip(expr.args, + expr.access_descriptors): + arg_indices = [] + for iaxis, access_descr in enumerate(access_descrs): + try: + discr_tag = next( + iter(arg + .axes[iaxis] + .tags_of_type(DiscretizationEntityAxisTag))) + except StopIteration: + raise NotAnFEMEinsumError(expr) + else: + if isinstance(access_descr, EinsumElementwiseAxis): + arg_indices.append(FreeEinsumIndex(discr_tag, + arg.shape[iaxis])) + elif isinstance(access_descr, EinsumReductionAxis): + arg_indices.append(SummationEinsumIndex( + discr_tag, + arg.shape[iaxis])) + else: + raise NotImplementedError(access_descr) + ensm_indices.append(tuple(arg_indices)) + + return expr.tagged(FEMEinsumTag(tuple(ensm_indices))) + elif (isinstance(expr, pt.Array) + and (expr.tags_of_type(pt.tags.ImplStored) + or expr in dag_outputs)): + if (isinstance(expr, pt.IndexLambda) + and expr.var_to_reduction_descr + and expr.shape == ()): + raise NotImplementedError("all-reduce expressions not" + " supported") + else: + discr_tags = [] + for axis in expr.axes: + try: + discr_tag = next( + iter(axis.tags_of_type(DiscretizationEntityAxisTag))) + except StopIteration: + raise NotAnFEMEinsumError(expr) + else: + discr_tags.append(discr_tag) + + fem_ensm_tag = FEMEinsumTag( + (tuple(FreeEinsumIndex(discr_tag, dim) + for dim, discr_tag in zip(expr.shape, + discr_tags)),) * 2 + ) + + return expr.tagged(fem_ensm_tag) + + else: + return expr + + try: + dag = pt.transform.map_and_copy(dag, add_fem_einsum_tags) + except NotAnFEMEinsumError: + pass + + # }}} + + # {{{ untag outputs tagged from being tagged ImplStored + + def _untag_impl_stored(expr): + if isinstance(expr, pt.InputArgumentBase): + return expr + else: + return expr.without_tags(pt.tags.ImplStored(), + verify_existence=False) + + dag = pt.make_dict_of_named_arrays({ + name: _untag_impl_stored(named_ary.expr) + for name, named_ary in dag.items()}) + + # }}} + + return dag + + def transform_loopy_program(self, t_unit): + import loopy as lp + from functools import reduce + from arraycontext.impl.pytato.compile import FromArrayContextCompile + + original_t_unit = t_unit + + # from loopy.transform.instruction import simplify_indices + # t_unit = simplify_indices(t_unit) + + knl = t_unit.default_entrypoint + + logger.info(f"Transforming kernel '{knl.name}' with {len(knl.instructions)} statements.") + + # {{{ fallback: if the inames are not inferred which mesh entity they + # iterate over. + + for iname in knl.all_inames(): + if not knl.iname_tags_of_type(iname, DiscretizationEntityAxisTag): + if not self.use_axis_tag_inference_fallback: + raise AxisTagInferenceError("Unable to infer axis tags.") + else: + warn(f"[{knl.name}]: Falling back to a slower transformation" + " strategy as some loops are uninferred which mesh entity" + " they belong to.", + stacklevel=2) + return super().transform_loopy_program(original_t_unit) + + for insn in knl.instructions: + for assignee in insn.assignee_var_names(): + var = knl.get_var_descriptor(assignee) + if not var.tags_of_type(FEMEinsumTag): + if not self.use_einsum_inference_fallback: + raise EinsumInferenceError( + "Unable to infer instructions as einsums.") + else: + warn(f"[{knl.name}]: Falling back to a slower transformation" + " strategy as some instructions couldn't be inferred as" + " einsums", + stacklevel=2) + return super().transform_loopy_program(original_t_unit) + + # }}} + + # {{{ hardcode offset to 0 (sorry humanity) + + knl = knl.copy(args=[arg.copy(offset=0) + for arg in knl.args]) + + # }}} + + # {{{ loop fusion + + with ProcessLogger(logger, "Loop Fusion"): + knl = fuse_same_discretization_entity_loops(knl) + + # }}} + + # {{{ align kernels for fused einsums + + knl = _prepare_kernel_for_parallelization(knl) + knl = _combine_einsum_domains(knl) + + # }}} + + # {{{ array contraction + + with ProcessLogger(logger, "Array Contraction"): + knl = contract_arrays(knl, t_unit.callables_table) + + # }}} + + # {{{ Stats Collection (Disabled) + + if 0: + with ProcessLogger(logger, "Counting Kernel Ops"): + from loopy.kernel.array import ArrayBase + from pytools import product + knl = knl.copy( + silenced_warnings=(knl.silenced_warnings + + ["insn_count_subgroups_upper_bound", + "summing_if_branches_ops"])) + + t_unit = t_unit.with_kernel(knl) + + op_map = lp.get_op_map(t_unit, subgroup_size=32) + + c64_ops = {op_type: (op_map.filter_by(dtype=[np.complex64], + name=op_type, + kernel_name=knl.name) + .eval_and_sum({})) + for op_type in ["add", "mul", "div"]} + c128_ops = {op_type: (op_map.filter_by(dtype=[np.complex128], + name=op_type, + kernel_name=knl.name) + .eval_and_sum({})) + for op_type in ["add", "mul", "div"]} + f32_ops = ((op_map.filter_by(dtype=[np.float32], + kernel_name=knl.name) + .eval_and_sum({})) + + (2 * c64_ops["add"] + + 6 * c64_ops["mul"] + + (6 + 3 + 2) * c64_ops["div"])) + f64_ops = ((op_map.filter_by(dtype=[np.float64], + kernel_name="_pt_kernel") + .eval_and_sum({})) + + (2 * c128_ops["add"] + + 6 * c128_ops["mul"] + + (6 + 3 + 2) * c128_ops["div"])) + + # {{{ footprint gathering + + nfootprint_bytes = 0 + + for ary in knl.args: + if (isinstance(ary, ArrayBase) + and ary.address_space == lp.AddressSpace.GLOBAL): + nfootprint_bytes += (product(ary.shape) + * ary.dtype.itemsize) + + for ary in knl.temporary_variables.values(): + if ary.address_space == lp.AddressSpace.GLOBAL: + # global temps would be written once and read once + nfootprint_bytes += (2 * product(ary.shape) + * ary.dtype.itemsize) + + # }}} + + if f32_ops: + logger.info(f"Single-prec. GFlOps: {f32_ops * 1e-9}") + if f64_ops: + logger.info(f"Double-prec. GFlOps: {f64_ops * 1e-9}") + logger.info(f"Footprint GBs: {nfootprint_bytes * 1e-9}") + + # }}} + + # {{{ check whether we can parallelize the kernel + + try: + iel_to_idofs = _get_iel_to_idofs(knl) + except NotImplementedError as err: + if knl.tags_of_type(FromArrayContextCompile): + raise err + else: + warn(f"[{knl.name}]: FusionContractorArrayContext." + "transform_loopy_program not broad enough (yet)." + " Falling back to a possibly slower" + " transformation strategy.") + return super().transform_loopy_program(original_t_unit) + + # }}} + + # {{{ insert barriers between consecutive iel-loops + + toposorted_iels = _get_element_loop_topo_sorted_order(knl) + + for iel_pred, iel_succ in zip(toposorted_iels[:-1], + toposorted_iels[1:]): + knl = lp.add_barrier(knl, + insn_before=f"iname:{iel_pred}", + insn_after=f"iname:{iel_succ}") + + # }}} + + t_unit = _alias_global_temporaries(t_unit) + + # {{{ Parallelization strategy: Use feinsum + + t_unit = t_unit.with_kernel(knl) + del knl + + if False and t_unit.default_entrypoint.tags_of_type(FromArrayContextCompile): + # FIXME: Enable this branch, WIP for now and hence disabled it. + from loopy.match import ObjTagged + import feinsum as fnsm + from meshmode.feinsum_transformations import FEINSUM_TO_TRANSFORMS + + assert all(insn.tags_of_type(EinsumTag) + for insn in t_unit.default_entrypoint.instructions + if isinstance(insn, lp.MultiAssignmentBase) + ) + + einsum_tags = reduce( + frozenset.union, + (insn.tags_of_type(EinsumTag) + for insn in t_unit.default_entrypoint.instructions), + frozenset()) + for ensm_tag in sorted(einsum_tags, + key=lambda x: sorted(x.orig_loop_nest)): + if reduce(frozenset.union, + (insn.reduction_inames() + for insn in (t_unit.default_entrypoint.instructions) + if ensm_tag in insn.tags), + frozenset()): + fused_einsum = fnsm.match_einsum(t_unit, ObjTagged(ensm_tag)) + else: + # elementwise loop + fused_einsum = _get_elementwise_einsum(t_unit, ensm_tag) + + try: + fnsm_transform = FEINSUM_TO_TRANSFORMS[ + fnsm.normalize_einsum(fused_einsum)] + except KeyError: + fnsm.query(fused_einsum, + self.queue.context, + err_if_no_results=True) + 1/0 + + t_unit = fnsm_transform(t_unit, + insn_match=ObjTagged(ensm_tag)) + else: + knl = t_unit.default_entrypoint + for iel, idofs in sorted(iel_to_idofs.items()): + if idofs: + nunit_dofs = {knl.get_constant_iname_length(idof) + for idof in idofs} + idof, = idofs + + l_one_size, l_zero_size = _get_group_size_for_dof_array_loop( + nunit_dofs) + + knl = lp.split_iname(knl, iel, l_one_size, + inner_tag="l.1", outer_tag="g.0") + knl = lp.split_iname(knl, idof, l_zero_size, + inner_tag="l.0", outer_tag="unr") + else: + knl = lp.split_iname(knl, iel, 32, + outer_tag="g.0", inner_tag="l.0") + + t_unit = t_unit.with_kernel(knl) + + # }}} + + return t_unit + + # vim: foldmethod=marker diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index 06d00dd8a..e3db6ad63 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -47,6 +47,7 @@ ConcurrentElementInameTag, DiscretizationDOFAxisTag, DiscretizationElementAxisTag, + DiscretizationDOFPickListAxisTag, ) @@ -478,6 +479,10 @@ def _per_target_group_pick_info( cgrp = self.groups[i_tgrp] tgrp = self.to_discr.groups[i_tgrp] + if tgrp.nelements == 1: + from warnings import warn + warn("_per_target_group_pick_info: tgrp has 1 element") + batch_dof_pick_lists = [ self._resample_point_pick_indices(i_tgrp, i_batch) for i_batch in range(len(cgrp.batches))] @@ -541,17 +546,22 @@ def _per_target_group_pick_info( _FromGroupPickData( from_group_index=source_group_index, dof_pick_lists=actx.freeze( - actx.tag(NameHint("dof_pick_lists"), - actx.from_numpy(dof_pick_lists))), + actx.tag_axis(0, DiscretizationDOFPickListAxisTag(), + actx.tag(NameHint("dof_pick_lists"), + actx.from_numpy(dof_pick_lists)))), dof_pick_list_indices=actx.freeze( - actx.tag(NameHint("dof_pick_list_indices"), - actx.from_numpy(dof_pick_list_indices))), + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("dof_pick_list_indices"), + actx.from_numpy(dof_pick_list_indices)))), from_el_present=actx.freeze( - actx.tag(NameHint("from_el_present"), - actx.from_numpy(from_el_present.astype(np.int8)))), + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("from_el_present"), + actx.from_numpy( + from_el_present.astype(np.int8))))), from_element_indices=actx.freeze( - actx.tag(NameHint("from_el_indices"), - actx.from_numpy(from_el_indices))), + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("from_el_indices"), + actx.from_numpy(from_el_indices)))), is_surjective=from_el_present.all() )) @@ -723,7 +733,7 @@ def group_pick_knl(is_surjective: bool): group_pick_info = None if group_pick_info is not None: - group_array_contributions = [] + # group_array_contributions = [] if actx.permits_advanced_indexing and not _force_use_loopy: for fgpd in group_pick_info: @@ -731,8 +741,10 @@ def group_pick_knl(is_surjective: bool): if ary[fgpd.from_group_index].size: grp_ary_contrib = ary[fgpd.from_group_index][ + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, _reshape_and_preserve_tags( - actx, from_element_indices, (-1, 1)), + actx, from_element_indices, (-1, 1))), actx.thaw(fgpd.dof_pick_lists)[ actx.thaw(fgpd.dof_pick_list_indices)] ] @@ -740,8 +752,10 @@ def group_pick_knl(is_surjective: bool): if not fgpd.is_surjective: from_el_present = actx.thaw(fgpd.from_el_present) grp_ary_contrib = actx.np.where( - _reshape_and_preserve_tags( - actx, from_el_present, (-1, 1)), + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, + _reshape_and_preserve_tags( + actx, from_el_present, (-1, 1))), grp_ary_contrib, 0) @@ -791,8 +805,10 @@ def group_pick_knl(is_surjective: bool): mat = self._resample_matrix(actx, i_tgrp, i_batch) if actx.permits_advanced_indexing and not _force_use_loopy: batch_result = actx.np.where( + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, _reshape_and_preserve_tags( - actx, from_el_present, (-1, 1)), + actx, from_el_present, (-1, 1))), actx.einsum("ij,ej->ei", mat, grp_ary[from_element_indices]), 0) @@ -813,11 +829,15 @@ def group_pick_knl(is_surjective: bool): if actx.permits_advanced_indexing and not _force_use_loopy: batch_result = actx.np.where( + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, _reshape_and_preserve_tags( - actx, from_el_present, (-1, 1)), + actx, from_el_present, (-1, 1))), from_vec[ + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, _reshape_and_preserve_tags( - actx, from_element_indices, (-1, 1)), + actx, from_element_indices, (-1, 1))), pick_list], 0) else: @@ -844,10 +864,13 @@ def group_pick_knl(is_surjective: bool): else: # If no batched data at all, return zeros for this # particular group array - group_array = actx.zeros( + group_array = tag_axes(actx, { + 0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}, + actx.zeros( shape=(self.to_discr.groups[i_tgrp].nelements, self.to_discr.groups[i_tgrp].nunit_dofs), - dtype=ary.entry_dtype) + dtype=ary.entry_dtype)) group_arrays.append(group_array) diff --git a/meshmode/discretization/poly_element.py b/meshmode/discretization/poly_element.py index 9d511e74b..f118414d0 100644 --- a/meshmode/discretization/poly_element.py +++ b/meshmode/discretization/poly_element.py @@ -569,8 +569,14 @@ def __init__(self, mesh_el_group: _MeshTensorProductElementGroup, "`unit_nodes` dim = {unit_nodes.shape[0]}.") self._basis = basis + self._bases_1d = basis.bases[0] self._nodes = unit_nodes + def bases_1d(self): + """Return 1D component bases used to construct the tensor product basis. + """ + return self._bases_1d + def basis_obj(self): return self._basis diff --git a/meshmode/distributed.py b/meshmode/distributed.py index df9e1de28..24d33508e 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -2,6 +2,7 @@ .. autoclass:: InterRankBoundaryInfo .. autoclass:: MPIBoundaryCommSetupHelper +.. autofunction:: mpi_distribute .. autofunction:: get_partition_by_pymetis .. autofunction:: membership_list_to_map .. autofunction:: get_connected_parts @@ -36,11 +37,22 @@ """ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Hashable, List, Mapping, Sequence, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Hashable, + List, + Optional, + Mapping, + Sequence, + Set, + Union, + cast +) from warnings import warn import numpy as np - +from contextlib import contextmanager from arraycontext import ArrayContext from meshmode.discretization import ElementGroupFactory @@ -66,6 +78,70 @@ # {{{ mesh distributor +@contextmanager +def _duplicate_mpi_comm(mpi_comm): + dup_comm = mpi_comm.Dup() + try: + yield dup_comm + finally: + dup_comm.Free() + + +def mpi_distribute( + mpi_comm: "mpi4py.MPI.Intracomm", + source_data: Optional[Mapping[int, Any]] = None, + source_rank: int = 0) -> Optional[Any]: + """ + Distribute data to a set of processes. + + :arg mpi_comm: An ``MPI.Intracomm`` + :arg source_data: A :class:`dict` mapping destination ranks to data to be sent. + Only present on the source rank. + :arg source_rank: The rank from which the data is being sent. + + :returns: The data local to the current process if there is any, otherwise + *None*. + """ + with _duplicate_mpi_comm(mpi_comm) as mpi_comm: + num_proc = mpi_comm.Get_size() + rank = mpi_comm.Get_rank() + + local_data = None + + if rank == source_rank: + if source_data is None: + raise TypeError("source rank has no data.") + + sending_to = [False] * num_proc + for dest_rank in source_data.keys(): + sending_to[dest_rank] = True + + mpi_comm.scatter(sending_to, root=source_rank) + + reqs = [] + for dest_rank, data in source_data.items(): + if dest_rank == rank: + local_data = data + logger.info("rank %d: received data", rank) + else: + reqs.append(mpi_comm.isend(data, dest=dest_rank)) + + logger.info("rank %d: sent all data", rank) + + from mpi4py import MPI + MPI.Request.waitall(reqs) + + else: + receiving = mpi_comm.scatter([], root=source_rank) + + if receiving: + local_data = mpi_comm.recv(source=source_rank) + logger.info("rank %d: received data", rank) + + return local_data + + +# TODO: Deprecate? class MPIMeshDistributor: """ .. automethod:: is_mananger_rank @@ -97,9 +173,7 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts): Sends each part to a different rank. Returns one part that was not sent to any other rank. """ - mpi_comm = self.mpi_comm - rank = mpi_comm.Get_rank() - assert num_parts <= mpi_comm.Get_size() + assert num_parts <= self.mpi_comm.Get_size() assert self.is_mananger_rank() @@ -108,38 +182,16 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts): from meshmode.mesh.processing import partition_mesh parts = partition_mesh(mesh, part_num_to_elements) - local_part = None - - reqs = [] - for r, part in parts.items(): - if r == self.manager_rank: - local_part = part - else: - reqs.append(mpi_comm.isend(part, dest=r, tag=TAG_DISTRIBUTE_MESHES)) - - logger.info("rank %d: sent all mesh parts", rank) - for req in reqs: - req.wait() - - return local_part + return mpi_distribute( + self.mpi_comm, source_data=parts, source_rank=self.manager_rank) def receive_mesh_part(self): """ Returns the mesh sent by the manager rank. """ - mpi_comm = self.mpi_comm - rank = mpi_comm.Get_rank() - assert not self.is_mananger_rank(), "Manager rank cannot receive mesh" - from mpi4py import MPI - status = MPI.Status() - result = self.mpi_comm.recv( - source=self.manager_rank, tag=TAG_DISTRIBUTE_MESHES, - status=status) - logger.info("rank %d: received local mesh (size = %d)", rank, status.count) - - return result + return mpi_distribute(self.mpi_comm, source_rank=self.manager_rank) # }}} diff --git a/meshmode/mesh/__init__.py b/meshmode/mesh/__init__.py index 7218d943e..7914623b9 100644 --- a/meshmode/mesh/__init__.py +++ b/meshmode/mesh/__init__.py @@ -903,7 +903,8 @@ def check_mesh_consistency( "parameter force_positive_orientation=True to make_mesh().") else: warn("Unimplemented: Cannot check element orientation for a mesh with " - "mesh.dim != mesh.ambient_dim", stacklevel=2) + f"mesh.dim != mesh.ambient_dim ({mesh.dim=},{mesh.ambient_dim=})", + stacklevel=2) def is_mesh_consistent( @@ -944,6 +945,7 @@ def make_mesh( node_vertex_consistency_tolerance: Optional[float] = None, skip_element_orientation_test: bool = False, force_positive_orientation: bool = False, + face_vertex_indices_to_tags=None, ) -> "Mesh": """Construct a new mesh from a given list of *groups*. @@ -1032,6 +1034,15 @@ def make_mesh( nodal_adjacency = ( NodalAdjacency(neighbors_starts=nb_starts, neighbors=nbs)) + face_vert_ind_to_tags_local = None + if face_vertex_indices_to_tags is not None: + face_vert_ind_to_tags_local = face_vertex_indices_to_tags.copy() + + if (facial_adjacency_groups is False or facial_adjacency_groups is None): + if face_vertex_indices_to_tags is not None: + facial_adjacency_groups = _compute_facial_adjacency_from_vertices( + groups, np.int32, np.int8, face_vertex_indices_to_tags) + if ( facial_adjacency_groups is not False and facial_adjacency_groups is not None): @@ -1058,8 +1069,13 @@ def make_mesh( if force_positive_orientation: if mesh.dim == mesh.ambient_dim: import meshmode.mesh.processing as mproc + mesh_making_kwargs = { + "face_vertex_indices_to_tags": face_vert_ind_to_tags_local + } mesh = mproc.perform_flips( - mesh, mproc.find_volume_mesh_element_orientations(mesh) < 0) + mesh=mesh, + flip_flags=mproc.find_volume_mesh_element_orientations(mesh) < 0, + skip_tests=False, mesh_making_kwargs=mesh_making_kwargs) else: raise ValueError("cannot enforce positive element orientation " "on non-volume meshes") diff --git a/meshmode/mesh/io.py b/meshmode/mesh/io.py index f3d3a3264..70e2a8e62 100644 --- a/meshmode/mesh/io.py +++ b/meshmode/mesh/io.py @@ -257,6 +257,7 @@ def get_mesh(self, return_tag_to_elements_map=False): # compute facial adjacency for Mesh if there is tag information facial_adjacency_groups = None + face_vert_ind_to_tags_local = face_vertex_indices_to_tags.copy() if is_conforming and self.tags: from meshmode.mesh import _compute_facial_adjacency_from_vertices facial_adjacency_groups = _compute_facial_adjacency_from_vertices( @@ -266,6 +267,7 @@ def get_mesh(self, return_tag_to_elements_map=False): vertices, groups, is_conforming=is_conforming, facial_adjacency_groups=facial_adjacency_groups, + face_vertex_indices_to_tags=face_vert_ind_to_tags_local, **self.mesh_construction_kwargs) return (mesh, tag_to_elements) if return_tag_to_elements_map else mesh @@ -294,10 +296,21 @@ def read_gmsh( belong to that volume. """ from gmsh_interop.reader import read_gmsh + import time + print("Reading gmsh mesh from disk file...") recv = GmshMeshReceiver(mesh_construction_kwargs=mesh_construction_kwargs) - read_gmsh(recv, filename, force_dimension=force_ambient_dim) - return recv.get_mesh(return_tag_to_elements_map=return_tag_to_elements_map) + read_start = time.time() + read_gmsh(recv, filename, force_dimension=force_ambient_dim) + read_finish = time.time() + print("Done. Populating meshmode data structures...") + retval = recv.get_mesh( + return_tag_to_elements_map=return_tag_to_elements_map) + get_mesh_finish = time.time() + print("Done.") + print(f"Read GMSH: {read_finish - read_start}\n" + f"MeshData: {get_mesh_finish - read_finish}") + return retval def generate_gmsh(source, dimensions=None, order=None, other_options=None, diff --git a/meshmode/mesh/processing.py b/meshmode/mesh/processing.py index 9e03f2531..13a701e27 100644 --- a/meshmode/mesh/processing.py +++ b/meshmode/mesh/processing.py @@ -820,7 +820,8 @@ def flip_element_group( def perform_flips( mesh: Mesh, flip_flags: np.ndarray, - skip_tests: bool = False) -> Mesh: + skip_tests: bool = False, + mesh_making_kwargs = None) -> Mesh: """ :arg flip_flags: A :class:`numpy.ndarray` with :attr:`meshmode.mesh.Mesh.nelements` entries @@ -830,6 +831,9 @@ def perform_flips( if mesh.vertices is None: raise ValueError("Mesh must have vertices to perform flips") + if mesh_making_kwargs is None: + mesh_making_kwargs = {} + flip_flags = flip_flags.astype(bool) new_groups = [] @@ -844,9 +848,8 @@ def perform_flips( new_groups.append(new_grp) return make_mesh( - mesh.vertices, new_groups, skip_tests=skip_tests, - is_conforming=mesh.is_conforming, - ) + mesh.vertices, groups=new_groups, skip_tests=skip_tests, + is_conforming=mesh.is_conforming, **mesh_making_kwargs) # }}} diff --git a/meshmode/pytato_utils.py b/meshmode/pytato_utils.py new file mode 100644 index 000000000..2071bcf57 --- /dev/null +++ b/meshmode/pytato_utils.py @@ -0,0 +1,115 @@ +import pyopencl.array as cl_array +import pytato as pt +import logging + +from functools import partial, reduce +from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace +from arraycontext import rec_map_reduce_array_container +from meshmode.transform_metadata import DiscretizationEntityAxisTag +from pytato.transform import ArrayOrNames +from pytato.transform.metadata import ( + AxesTagsEquationCollector as BaseAxesTagsEquationCollector) +from arraycontext import ArrayContainer +from arraycontext.container.traversal import rec_map_array_container +from typing import Union +logger = logging.getLogger(__name__) + + +MAX_UNIFY_RETRIES = 50 # used by unify_discretization_entity_tags + + +def _can_be_eagerly_computed(ary) -> bool: + from pytato.transform import InputGatherer + from pytato.array import Placeholder + return all(not isinstance(inp, Placeholder) + for inp in InputGatherer()(ary)) + + +class EagerReduceComputingPytatoFakeNumpyNamespace(PytatoFakeNumpyNamespace): + """ + A Numpy-namespace that computes the reductions eagerly whenever possible. + """ + def sum(self, a, axis=None, dtype=None): + if (rec_map_reduce_array_container(all, + _can_be_eagerly_computed, a) + and axis is None): + + def _pt_sum(ary): + return cl_array.sum(self._array_context.freeze(ary), + dtype=dtype, + queue=self._array_context.queue) + + return self._array_context.thaw( + rec_map_reduce_array_container(sum, _pt_sum, a)) + else: + return super().sum(a, axis=axis, dtype=dtype) + + def min(self, a, axis=None): + if (rec_map_reduce_array_container(all, + _can_be_eagerly_computed, a) + and axis is None): + queue = self._array_context.queue + frozen_result = rec_map_reduce_array_container( + partial(reduce, partial(cl_array.minimum, queue=queue)), + lambda ary: cl_array.min(self._array_context.freeze(ary), + queue=queue), + a) + return self._array_context.thaw(frozen_result) + else: + return super().min(a, axis=axis) + + def max(self, a, axis=None): + if (rec_map_reduce_array_container(all, + _can_be_eagerly_computed, a) + and axis is None): + queue = self._array_context.queue + frozen_result = rec_map_reduce_array_container( + partial(reduce, partial(cl_array.maximum, queue=queue)), + lambda ary: cl_array.max(self._array_context.freeze(ary), + queue=queue), + a) + return self._array_context.thaw(frozen_result) + else: + return super().max(a, axis=axis) + + +# {{{ solve for discretization metadata for arrays' axes + +class AxesTagsEquationCollector(BaseAxesTagsEquationCollector): + def map_reshape(self, expr: pt.Reshape) -> None: + super().map_reshape(expr) + + if (expr.size > 0 + and (1 not in (expr.array.shape)) # leads to ambiguous newaxis + and (set(expr.shape) <= (set(expr.array.shape) | {1}))): + i_in_axis = 0 + for i_out_axis, dim in enumerate(expr.shape): + if dim != 1: + assert dim == expr.array.shape[i_in_axis] + self.record_equation( + self.get_var_for_axis(expr.array, + i_in_axis), + self.get_var_for_axis(expr, + i_out_axis) + ) + i_in_axis += 1 + else: + # print(f"Skipping: {expr.array.shape} -> {expr.shape}") + # Wacky reshape => bail. + pass + + +def unify_discretization_entity_tags(expr: Union[ArrayContainer, ArrayOrNames] + ) -> ArrayOrNames: + if not isinstance(expr, (pt.Array, pt.DictOfNamedArrays)): + return rec_map_array_container(unify_discretization_entity_tags, + expr) + + return pt.unify_axes_tags(expr, + tag_t=DiscretizationEntityAxisTag, + equations_collector_t=AxesTagsEquationCollector) + +# }}} + + +# vim: fdm=marker diff --git a/meshmode/transform_metadata.py b/meshmode/transform_metadata.py index a352aecc0..148ea577e 100644 --- a/meshmode/transform_metadata.py +++ b/meshmode/transform_metadata.py @@ -8,6 +8,7 @@ .. autoclass:: DiscretizationDOFAxisTag .. autoclass:: DiscretizationAmbientDimAxisTag .. autoclass:: DiscretizationTopologicalDimAxisTag +.. autoclass:: DiscretizationDOFPickListAxisTag """ __copyright__ = """ @@ -121,3 +122,12 @@ class DiscretizationTopologicalDimAxisTag(DiscretizationDimAxisTag): Array dimensions tagged with this tag type describe an axis indexing over the discretization's physical coordinate dimensions. """ + + +@tag_dataclass +class DiscretizationDOFPickListAxisTag(DiscretizationEntityAxisTag): + """ + Array dimensions tagged with this tag type describe an axis indexing over + DOF pick lists. See :mod:`meshmode.discretization.connection.direct` for + details. + """ diff --git a/requirements.txt b/requirements.txt index 7ad3c51b2..cef503369 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,13 +6,15 @@ git+https://github.com/inducer/pyvisfile.git#egg=pyvisfile git+https://github.com/inducer/modepy.git#egg=modepy git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy -git+https://github.com/inducer/pytato.git#egg=pytato +# git+https://github.com/inducer/pytato.git#egg=pytato +git+https://github.com/kaushikcfd/pytato.git#egg=pytato # required by pytential, which is in turn needed for some tests git+https://github.com/inducer/pymbolic.git#egg=pymbolic # also depends on pymbolic, so should come after it -git+https://github.com/inducer/loopy.git#egg=loopy +# git+https://github.com/inducer/loopy.git#egg=loopy +git+https://github.com/kaushikcfd/loopy.git#egg=loopy # depends on loopy, so should come after it. git+https://github.com/inducer/arraycontext.git#egg=arraycontext @@ -27,3 +29,7 @@ git+https://github.com/inducer/pymetis.git#egg=pymetis # for examples/tp-lagrange-stl.py numpy-stl + +# for FusionContractorActx transforms +git+https://github.com/kaushikcfd/feinsum.git#egg=feinsum +git+https://github.com/pythological/kanren.git#egg=miniKanren diff --git a/test/3x3.msh b/test/3x3.msh new file mode 100644 index 000000000..faa3b135a --- /dev/null +++ b/test/3x3.msh @@ -0,0 +1,38 @@ +$MeshFormat +2.2 0 8 +$EndMeshFormat +$PhysicalNames +1 +2 10 "fluid" +$EndPhysicalNames +$Nodes +16 +1 0 0 0 +2 1 0 0 +3 1 0.5 0 +4 0 0.5 0 +5 0.3333333333333333 0 0 +6 0.6666666666666666 0 0 +7 1 0.1666666666666667 0 +8 1 0.3333333333333333 0 +9 0.6666666666666667 0.5 0 +10 0.3333333333333334 0.5 0 +11 0 0.3333333333333334 0 +12 0 0.1666666666666667 0 +13 0.3333333333333334 0.1666666666666667 0 +14 0.3333333333333334 0.3333333333333333 0 +15 0.6666666666666667 0.1666666666666667 0 +16 0.6666666666666667 0.3333333333333333 0 +$EndNodes +$Elements +9 +1 3 2 10 1 1 5 13 12 +2 3 2 10 1 12 13 14 11 +3 3 2 10 1 11 14 10 4 +4 3 2 10 1 5 6 15 13 +5 3 2 10 1 13 15 16 14 +6 3 2 10 1 14 16 9 10 +7 3 2 10 1 6 2 7 15 +8 3 2 10 1 15 7 8 16 +9 3 2 10 1 16 8 3 9 +$EndElements diff --git a/test/3x3_bound.msh b/test/3x3_bound.msh new file mode 100644 index 000000000..de74e581c --- /dev/null +++ b/test/3x3_bound.msh @@ -0,0 +1,54 @@ +$MeshFormat +2.2 0 8 +$EndMeshFormat +$PhysicalNames +5 +1 2 "Left" +1 3 "Right" +1 4 "Bottom" +1 5 "Top" +2 1 "fluid" +$EndPhysicalNames +$Nodes +16 +1 0 0 0 +2 1 0 0 +3 1 0.5 0 +4 0 0.5 0 +5 0.3333333333333333 0 0 +6 0.6666666666666666 0 0 +7 1 0.1666666666666667 0 +8 1 0.3333333333333333 0 +9 0.6666666666666667 0.5 0 +10 0.3333333333333334 0.5 0 +11 0 0.3333333333333334 0 +12 0 0.1666666666666667 0 +13 0.3333333333333334 0.1666666666666667 0 +14 0.3333333333333334 0.3333333333333333 0 +15 0.6666666666666667 0.1666666666666667 0 +16 0.6666666666666667 0.3333333333333333 0 +$EndNodes +$Elements +21 +1 1 2 2 1 1 12 +2 1 2 2 1 12 11 +3 1 2 2 1 11 4 +4 1 2 3 1 7 2 +5 1 2 3 1 8 7 +6 1 2 3 1 3 8 +7 1 2 4 1 5 1 +8 1 2 4 1 6 5 +9 1 2 4 1 2 6 +10 1 2 5 1 3 9 +11 1 2 5 1 9 10 +12 1 2 5 1 10 4 +13 3 2 1 1 12 13 5 1 +14 3 2 1 1 11 14 13 12 +15 3 2 1 1 4 10 14 11 +16 3 2 1 1 13 15 6 5 +17 3 2 1 1 14 16 15 13 +18 3 2 1 1 10 9 16 14 +19 3 2 1 1 15 7 2 6 +20 3 2 1 1 16 8 7 15 +21 3 2 1 1 9 3 8 16 +$EndElements diff --git a/test/3x3_minus.msh b/test/3x3_minus.msh new file mode 100644 index 000000000..b7192593b --- /dev/null +++ b/test/3x3_minus.msh @@ -0,0 +1,38 @@ +$MeshFormat +2.2 0 8 +$EndMeshFormat +$PhysicalNames +1 +2 10 "fluid" +$EndPhysicalNames +$Nodes +16 +1 0 0 0 +2 1 0 0 +3 1 0.5 0 +4 0 0.5 0 +5 0.3333333333333333 0 0 +6 0.6666666666666666 0 0 +7 1 0.1666666666666667 0 +8 1 0.3333333333333333 0 +9 0.6666666666666667 0.5 0 +10 0.3333333333333334 0.5 0 +11 0 0.3333333333333334 0 +12 0 0.1666666666666667 0 +13 0.3333333333333334 0.1666666666666667 0 +14 0.3333333333333334 0.3333333333333333 0 +15 0.6666666666666667 0.1666666666666667 0 +16 0.6666666666666667 0.3333333333333333 0 +$EndNodes +$Elements +9 +1 3 2 10 1 12 13 5 1 +2 3 2 10 1 14 13 12 11 +3 3 2 10 1 10 14 11 4 +4 3 2 10 1 13 15 6 5 +5 3 2 10 1 14 16 15 13 +6 3 2 10 1 16 14 10 9 +7 3 2 10 1 6 15 7 2 +8 3 2 10 1 15 16 8 7 +9 3 2 10 1 8 16 9 3 +$EndElements diff --git a/test/3x3_twisted.msh b/test/3x3_twisted.msh new file mode 100644 index 000000000..e0607bef5 --- /dev/null +++ b/test/3x3_twisted.msh @@ -0,0 +1,38 @@ +$MeshFormat +2.2 0 8 +$EndMeshFormat +$PhysicalNames +1 +2 10 "fluid" +$EndPhysicalNames +$Nodes +16 +1 0 0 0 +2 1 0 0 +3 1 0.5 0 +4 0 0.5 0 +5 0.3333333333333333 0 0 +6 0.6666666666666666 0 0 +7 1 0.1666666666666667 0 +8 1 0.3333333333333333 0 +9 0.6666666666666667 0.5 0 +10 0.3333333333333334 0.5 0 +11 0 0.3333333333333334 0 +12 0 0.1666666666666667 0 +13 0.3333333333333334 0.1666666666666667 0 +14 0.3333333333333334 0.3333333333333333 0 +15 0.6666666666666667 0.1666666666666667 0 +16 0.6666666666666667 0.3333333333333333 0 +$EndNodes +$Elements +9 +1 3 2 10 1 1 5 13 12 +2 3 2 10 1 11 12 13 14 +3 3 2 10 1 4 11 14 10 +4 3 2 10 1 5 6 15 13 +5 3 2 10 1 13 15 16 14 +6 3 2 10 1 9 10 14 16 +7 3 2 10 1 2 7 15 6 +8 3 2 10 1 7 8 16 15 +9 3 2 10 1 3 9 16 8 +$EndElements diff --git a/test/3x3_twisted_bound.msh b/test/3x3_twisted_bound.msh new file mode 100644 index 000000000..1f659fe4f --- /dev/null +++ b/test/3x3_twisted_bound.msh @@ -0,0 +1,54 @@ +$MeshFormat +2.2 0 8 +$EndMeshFormat +$PhysicalNames +5 +1 2 "Left" +1 3 "Right" +1 4 "Bottom" +1 5 "Top" +2 1 "fluid" +$EndPhysicalNames +$Nodes +16 +1 0 0 0 +2 1 0 0 +3 1 0.5 0 +4 0 0.5 0 +5 0.3333333333333333 0 0 +6 0.6666666666666666 0 0 +7 1 0.1666666666666667 0 +8 1 0.3333333333333333 0 +9 0.6666666666666667 0.5 0 +10 0.3333333333333334 0.5 0 +11 0 0.3333333333333334 0 +12 0 0.1666666666666667 0 +13 0.3333333333333334 0.1666666666666667 0 +14 0.3333333333333334 0.3333333333333333 0 +15 0.6666666666666667 0.1666666666666667 0 +16 0.6666666666666667 0.3333333333333333 0 +$EndNodes +$Elements +21 +1 1 2 2 1 1 12 +2 1 2 2 1 12 1 +3 1 2 2 1 11 4 +4 1 2 3 1 2 7 +5 1 2 3 1 7 8 +6 1 2 3 1 8 3 +7 1 2 4 1 1 5 +8 1 2 4 1 5 6 +9 1 2 4 1 6 2 +10 1 2 5 1 3 9 +11 1 2 5 1 9 10 +12 1 2 5 1 10 4 +13 3 2 1 1 1 5 13 12 +14 3 2 1 1 11 12 13 14 +15 3 2 1 1 4 11 14 10 +16 3 2 1 1 5 6 15 13 +17 3 2 1 1 13 15 16 14 +18 3 2 1 1 9 10 14 16 +19 3 2 1 1 2 7 15 6 +20 3 2 1 1 7 8 16 15 +21 3 2 1 1 3 9 16 8 +$EndElements diff --git a/test/test_mesh.py b/test/test_mesh.py index c471e0951..45d76a199 100644 --- a/test/test_mesh.py +++ b/test/test_mesh.py @@ -578,12 +578,15 @@ def test_merge_and_map(actx_factory, group_cls, visualize=False): # {{{ element orientation -@pytest.mark.parametrize("case", ["blob", "gh-394"]) +@pytest.mark.parametrize("case", ["blob", "gh-394", "3x3", "3x3_twisted", + "3x3_minus", "3x3_bound", + "3x3_twisted_bound"]) def test_element_orientation_via_flipping(case): from meshmode.mesh.io import FileSource, generate_gmsh mesh_order = 3 + meshfile = f"{thisdir}/{case}.msh" if case == "blob": mesh = generate_gmsh( FileSource(str(thisdir / "blob-2d.step")), 2, order=mesh_order, @@ -593,13 +596,68 @@ def test_element_orientation_via_flipping(case): ) elif case == "gh-394": mesh = mio.read_gmsh( - str(thisdir / "gh-394.msh"), + meshfile, + force_ambient_dim=2, + mesh_construction_kwargs={"skip_tests": True}) + elif case == "3x3": # regular ole rectangular 3x3 tensor product els (TPE) + mesh = mio.read_gmsh( + meshfile, + force_ambient_dim=2, + mesh_construction_kwargs={"skip_tests": True}) + elif case == "3x3_twisted": # TPEs, rotated connectivities, all positive + mesh = mio.read_gmsh( + meshfile, + force_ambient_dim=2, + mesh_construction_kwargs={"skip_tests": True}) + elif case == "3x3_minus": # TPEs with negative orientation (clockwise conn) + mesh = mio.read_gmsh( + meshfile, + force_ambient_dim=2, + mesh_construction_kwargs={"skip_tests": True}) + elif case == "3x3_bound": # TPEs (clockwise conn, w/boundaries) + mesh = mio.read_gmsh( + meshfile, + force_ambient_dim=2, + mesh_construction_kwargs={"skip_tests": True}) + elif case == "3x3_twisted_bound": # TPEs (clockwise conn, w/boundaries) + mesh = mio.read_gmsh( + meshfile, force_ambient_dim=2, mesh_construction_kwargs={"skip_tests": True}) else: raise ValueError(f"unknown case: {case}") + boundary_tags = set() + for igrp in range(len(mesh.groups)): + bdry_fagrps = [ + fagrp for fagrp in mesh.facial_adjacency_groups[igrp] + if isinstance(fagrp, BoundaryAdjacencyGroup)] + for bdry_fagrp in bdry_fagrps: + print(f"Boundary tag: {bdry_fagrp.boundary_tag}") + boundary_tags.add(bdry_fagrp.boundary_tag) + mesh_orient = mproc.find_volume_mesh_element_orientations(mesh) + if not (mesh_orient > 0).all(): + logger.info(f"Mesh({meshfile}) is negative, trying to reorient.") + print(f"Mesh({meshfile}) is negative, trying to reorient.") + mesh = mio.read_gmsh( + meshfile, + force_ambient_dim=2, + mesh_construction_kwargs={ + "skip_tests": True, + "force_positive_orientation": True}) + + mesh_orient = mproc.find_volume_mesh_element_orientations(mesh) + boundary_tags_reoriented = set() + for igrp in range(len(mesh.groups)): + bdry_fagrps = [ + fagrp for fagrp in mesh.facial_adjacency_groups[igrp] + if isinstance(fagrp, BoundaryAdjacencyGroup)] + for bdry_fagrp in bdry_fagrps: + boundary_tags_reoriented.add(bdry_fagrp.boundary_tag) + + # Make sure rotation doesn't lose boundaries + assert boundary_tags == boundary_tags_reoriented assert (mesh_orient > 0).all() From f400c8a6585c8b2fb80df05ef87a814c413df82c Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Tue, 30 Jul 2024 08:03:36 -0500 Subject: [PATCH 2/8] Add annular cylinder init func. --- meshmode/mesh/generation.py | 91 +++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/meshmode/mesh/generation.py b/meshmode/mesh/generation.py index 323fa66e8..6099d8a8e 100644 --- a/meshmode/mesh/generation.py +++ b/meshmode/mesh/generation.py @@ -1550,6 +1550,97 @@ def m(x: np.ndarray) -> np.ndarray: # {{{ generate_annular_cylinder_slice_mesh +def generate_annular_cylinder_mesh( + n: int, center: np.ndarray, inner_radius: float, outer_radius: float, + nelements_per_axis: Optional[int] = None, + periodic: bool = False, group_cls=None, dim: int = 3) -> Mesh: + r""" + Generate a slice of a 3D annular cylinder for + :math:`\theta \in [-\frac{\pi}{4}, \frac{\pi}{4}]`. Optionally periodic in + $\theta$. + """ + if nelements_per_axis is None: + nelements_per_axis = (n,)*dim + boundary_tag_to_face = { + "-r": ["-x"], + "+r": ["+x"], + "-theta": ["-y"], + "+theta": ["+y"], + } + if dim == 3: + boundary_tag_to_face["-z"] = ["-z"] + boundary_tag_to_face["+z"] = ["+z"] + if periodic: + boundary_tag_to_face["periodic_-theta"] = ["-y"] + boundary_tag_to_face["periodic_+theta"] = ["+y"] + if dim == 3: + boundary_tag_to_face["periodic_-z"] = ["-z"] + boundary_tag_to_face["periodic_+z"] = ["+z"] + + unit_mesh = generate_regular_rect_mesh( + a=(0,)*dim, + b=(1,)*dim, + nelements_per_axis=nelements_per_axis, + boundary_tag_to_face = boundary_tag_to_face, + group_cls=group_cls) + + def transform3(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + r = inner_radius*(1 - x[0]) + outer_radius*x[0] + # theta = -np.pi/4*(1 - x[1]) + np.pi/4*x[1] + theta = 2*np.pi*x[1] + z = -0.5*(1 - x[2]) + 0.5*x[2] + return ( + center[0] + r*np.cos(theta), + center[1] + r*np.sin(theta), + center[2] + z) + + def transform2(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + r = inner_radius*(1 - x[0]) + outer_radius*x[0] + # theta = -np.pi/4*(1 - x[1]) + np.pi/4*x[1] + theta = 2*np.pi*x[1] + return ( + center[0] + r*np.cos(theta), + center[1] + r*np.sin(theta)) + + from meshmode.mesh.processing import map_mesh + if dim == 3: + mesh = map_mesh(unit_mesh, lambda x: np.stack(transform3(x))) + else: + mesh = map_mesh(unit_mesh, lambda x: np.stack(transform2(x))) + + if periodic: + from meshmode import AffineMap + from meshmode.mesh.processing import ( + BoundaryPairMapping, glue_mesh_boundaries) + bdry_pair_mappings_and_tols = [] + for idim in range(dim): + # if periodic[idim]: + if idim == 1: + offset = np.zeros(dim, dtype=np.float64) + # offset[idim] = axis_coords[idim][-1] - axis_coords[idim][0] + bdry_pair_mappings_and_tols.append(( + BoundaryPairMapping( + "periodic_-theta", # + axes[idim], + "periodic_+theta", # + axes[idim], + AffineMap(offset=offset)), + 1e-12)) + if idim == 2: + offset = np.zeros(dim, dtype=np.float64) + offset[idim] = 1.0 + bdry_pair_mappings_and_tols.append(( + BoundaryPairMapping( + "periodic_-z", # + axes[idim], + "periodic_+z", # + axes[idim], + AffineMap(offset=offset)), + 1e-12)) + + periodic_mesh = glue_mesh_boundaries(mesh, bdry_pair_mappings_and_tols) + + return periodic_mesh + else: + return mesh + + def generate_annular_cylinder_slice_mesh( n: int, center: np.ndarray, inner_radius: float, outer_radius: float, periodic: bool = False) -> Mesh: From 86d8611cac4bcb4e8e82250dc77c03260198b506 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 29 Aug 2024 10:19:36 -0500 Subject: [PATCH 3/8] Defer switch to actx.np.zeros --- meshmode/discretization/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/meshmode/discretization/__init__.py b/meshmode/discretization/__init__.py index 1a0ab6523..7980285e5 100644 --- a/meshmode/discretization/__init__.py +++ b/meshmode/discretization/__init__.py @@ -476,7 +476,8 @@ def empty(self, actx: ArrayContext, f"in 2025. Use '{type(self).__name__}.zeros' instead.", DeprecationWarning, stacklevel=2) - return self._new_array(actx, actx.np.zeros, dtype=dtype) + # return self._new_array(actx, actx.np.zeros, dtype=dtype) + return self._new_array(actx, actx.zeros, dtype=dtype) def zeros(self, actx: ArrayContext, dtype: Optional[np.dtype] = None) -> _DOFArray: @@ -490,7 +491,8 @@ def zeros(self, actx: ArrayContext, raise TypeError( f"'actx' must be an ArrayContext, not '{type(actx).__name__}'") - return self._new_array(actx, actx.np.zeros, dtype=dtype) + # return self._new_array(actx, actx.np.zeros, dtype=dtype) + return self._new_array(actx, actx.zeros, dtype=dtype) def empty_like(self, array: _DOFArray) -> _DOFArray: warn(f"'{type(self).__name__}.empty_like' is deprecated and will be removed " @@ -498,7 +500,8 @@ def empty_like(self, array: _DOFArray) -> _DOFArray: DeprecationWarning, stacklevel=2) actx = array.array_context - return self._new_array(actx, actx.np.zeros, dtype=array.entry_dtype) + # return self._new_array(actx, actx.np.zeros, dtype=array.entry_dtype) + return self._new_array(actx, actx.zeros, dtype=array.entry_dtype) def zeros_like(self, array: _DOFArray) -> _DOFArray: return self.zeros(array.array_context, dtype=array.entry_dtype) From 2275a94fee4a78f379c050b4e62191b9ccb1fc39 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 4 Sep 2024 13:55:19 -0700 Subject: [PATCH 4/8] Revert "Add some diagnostic checks to ferret out loop-nest errors." This reverts commit 353fd2528d3f6ab01fd4949ecdae64ccb118d19b. --- meshmode/array_context.py | 4 ---- meshmode/discretization/connection/direct.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/meshmode/array_context.py b/meshmode/array_context.py index d82e1e9cc..86c6668a4 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -1068,9 +1068,6 @@ def _get_iel_to_idofs(kernel): for dof_insn in kernel.iname_to_insns()[idof]): pass else: - for dof_insn in kernel.iname_to_insns()[idof]: - if iel not in kernel.id_to_insn[dof_insn].within_inames: - print(f"_get_iel_to_idofs: {str(kernel.id_to_insn[dof_insn])=}") raise NotImplementedError("The loop " f"'{insn.within_inames}' has the idof-loop" " that's not nested within the iel-loop.") @@ -1090,7 +1087,6 @@ def _get_iel_to_idofs(kernel): raise NotImplementedError("Could not fit into " " loop nest pattern.") else: - print(f"_get_iel_to_idofs: {str(insn)=}") raise NotImplementedError(f"Cannot fit loop nest '{insn.within_inames}'" " into known set of loop-nest patterns.") diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index f6bcf54a2..8a08bfc59 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -472,10 +472,6 @@ def _per_target_group_pick_info( cgrp = self.groups[i_tgrp] tgrp = self.to_discr.groups[i_tgrp] - if tgrp.nelements == 1: - from warnings import warn - warn("_per_target_group_pick_info: tgrp has 1 element") - batch_dof_pick_lists = [ self._resample_point_pick_indices(i_tgrp, i_batch) for i_batch in range(len(cgrp.batches))] From 039bdbd38901bf74ad85f8a5747d45b5a564ba2b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 5 Sep 2024 14:35:51 -0500 Subject: [PATCH 5/8] Fusion actx: cache transform_loopy_program --- meshmode/array_context.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 86c6668a4..a8fd76c5d 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -1251,6 +1251,9 @@ def _combine_einsum_domains(knl): return knl.copy(domains=new_domains) +from pytools.persistent_dict import WriteOncePersistentDict +from pytato.analysis import PytatoKeyBuilder + class FusionContractorArrayContext( SingleGridWorkBalancingPytatoArrayContext): @@ -1272,6 +1275,10 @@ def __init__( self.use_axis_tag_inference_fallback = use_axis_tag_inference_fallback self.use_einsum_inference_fallback = use_einsum_inference_fallback + self.transform_loopy_cache = WriteOncePersistentDict("meshmode-fusion_actx_transform_loopy_cache-v1", + key_builder=PytatoKeyBuilder(), + safe_sync=False) + def transform_dag(self, dag): import pytato as pt @@ -1639,11 +1646,20 @@ def transform_loopy_program(self, t_unit): from arraycontext.impl.pytato.compile import FromArrayContextCompile original_t_unit = t_unit + knl = t_unit.default_entrypoint + + try: + r = self.transform_loopy_cache[t_unit] + except KeyError: + logger.debug(f"FusionContractorArrayContext.transform_loopy_program '{knl.name}': cache miss") + pass + else: + logger.info(f"FusionContractorArrayContext.transform_loopy_program '{knl.name}': cache hit") + return r # from loopy.transform.instruction import simplify_indices # t_unit = simplify_indices(t_unit) - knl = t_unit.default_entrypoint logger.info(f"Transforming kernel '{knl.name}' with {len(knl.instructions)} statements.") @@ -1866,6 +1882,8 @@ def transform_loopy_program(self, t_unit): # }}} + self.transform_loopy_cache[original_t_unit] = t_unit + return t_unit From 6886f21ce56f2de7d8a516934fa2c03abfe38bb6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 6 Sep 2024 18:18:36 -0500 Subject: [PATCH 6/8] hotfix: don't fail when another rank added transformation --- meshmode/array_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/meshmode/array_context.py b/meshmode/array_context.py index a8fd76c5d..94c181f04 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -1882,7 +1882,7 @@ def transform_loopy_program(self, t_unit): # }}} - self.transform_loopy_cache[original_t_unit] = t_unit + self.transform_loopy_cache.store_if_not_present(original_t_unit, t_unit) return t_unit From 749dfaef66b3048999c662eaa1bacecad1230fd2 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Mon, 21 Oct 2024 13:11:04 -0500 Subject: [PATCH 7/8] slight pedantry --- meshmode/mesh/generation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/meshmode/mesh/generation.py b/meshmode/mesh/generation.py index 6099d8a8e..6650acb2d 100644 --- a/meshmode/mesh/generation.py +++ b/meshmode/mesh/generation.py @@ -1634,11 +1634,9 @@ def transform2(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: AffineMap(offset=offset)), 1e-12)) - periodic_mesh = glue_mesh_boundaries(mesh, bdry_pair_mappings_and_tols) + mesh = glue_mesh_boundaries(mesh, bdry_pair_mappings_and_tols) - return periodic_mesh - else: - return mesh + return mesh def generate_annular_cylinder_slice_mesh( From cc8e781de354ddfa929af3d6c9a1dd9f5123f639 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Mon, 4 Nov 2024 19:49:35 -0600 Subject: [PATCH 8/8] Remove extraneous Optional --- meshmode/mesh/generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/meshmode/mesh/generation.py b/meshmode/mesh/generation.py index 796c04d89..ac8893530 100644 --- a/meshmode/mesh/generation.py +++ b/meshmode/mesh/generation.py @@ -1553,7 +1553,7 @@ def m(x: np.ndarray) -> np.ndarray: def generate_annular_cylinder_mesh( n: int, center: np.ndarray, inner_radius: float, outer_radius: float, - nelements_per_axis: Optional[int] = None, + nelements_per_axis: int | None = None, periodic: bool = False, group_cls=None, dim: int = 3) -> Mesh: r""" Generate a slice of a 3D annular cylinder for