From e2a6466f883b8900dd9f4bb43ea052967a7e5f13 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 11 Dec 2024 15:34:57 +0100 Subject: [PATCH] Added more extensible meta access replacement function --- dace/sdfg/replace.py | 13 ++-------- dace/sdfg/state.py | 23 +++++++++++++++++ .../interstate/gpu_transform_sdfg.py | 25 ++----------------- 3 files changed, 27 insertions(+), 34 deletions(-) diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index b49f13cee6..cab313fc9b 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -11,7 +11,7 @@ from dace import dtypes, properties, symbolic from dace.codegen import cppunparse from dace.frontend.python.astutils import ASTFindReplace -from dace.sdfg.state import ConditionalBlock, LoopRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion if TYPE_CHECKING: from dace.sdfg.state import StateSubgraphView @@ -203,13 +203,4 @@ def replace_datadesc_names(sdfg: 'dace.SDFG', repl: Dict[str, str]): edge.data.data = repl[edge.data.data] # Replace in loop or branch conditions: - if isinstance(cf, LoopRegion): - replace_in_codeblock(cf.loop_condition, repl) - if cf.update_statement: - replace_in_codeblock(cf.update_statement, repl) - if cf.init_statement: - replace_in_codeblock(cf.init_statement, repl) - elif isinstance(cf, ConditionalBlock): - for c, _ in cf.branches: - if c is not None: - replace_in_codeblock(c, repl) + cf.replace_meta_accesses(repl) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 5e5d07b288..fbc157f74d 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -13,6 +13,7 @@ import dace from dace.frontend.python import astutils +from dace.sdfg.replace import replace_in_codeblock import dace.serialize from dace import data as dt from dace import dtypes @@ -2612,6 +2613,16 @@ def get_meta_read_memlets(self) -> List[mm.Memlet]: """ return [] + def replace_meta_accesses(self, replacements: dict) -> None: + """ + Replace accesses to specific data containers in reads or writes performed by the control flow region itself in + meta accesses, such as in condition checks for conditional blocks or in loop conditions for loops, etc. + + :param replacements: A dictionary mapping the current data container names to the names of data containers with + which accesses to them should be replaced. + """ + pass + @property def root_sdfg(self) -> 'SDFG': from dace.sdfg.sdfg import SDFG # Avoid import loop @@ -3304,6 +3315,13 @@ def get_meta_read_memlets(self) -> List[mm.Memlet]: read_memlets.extend(memlets_in_ast(self.update_statement.code[0], self.sdfg.arrays)) return read_memlets + def replace_meta_accesses(self, replacements): + replace_in_codeblock(self.loop_condition, replacements) + if self.init_statement: + replace_in_codeblock(self.init_statement, replacements) + if self.update_statement: + replace_in_codeblock(self.update_statement, replacements) + def _used_symbols_internal(self, all_symbols: bool, defined_syms: Optional[Set] = None, @@ -3418,6 +3436,11 @@ def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, parent: Optio def sub_regions(self): return [b for _, b in self.branches] + def replace_meta_accesses(self, replacements): + for c, _ in self.branches: + if c is not None: + replace_in_codeblock(c, replacements) + def __str__(self): return self._label diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 901b05cb64..49a2e16227 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -598,35 +598,14 @@ def _create_copy_out(arrays_used: Set[str]) -> Dict[str, str]: e.data.replace(devicename, hostname, False) for block in list(sdfg.all_control_flow_blocks()): - arrays_used = set() - if isinstance(block, ConditionalBlock): - for c, _ in block.branches: - if c is not None: - arrays_used.update(set(c.get_free_symbols()) & cloned_data) - elif isinstance(block, LoopRegion): - arrays_used.update(set(block.loop_condition.get_free_symbols()) & cloned_data) - if block.init_statement: - arrays_used.update(set(block.init_statement.get_free_symbols()) & cloned_data) - if block.update_statement: - arrays_used.update(set(block.update_statement.get_free_symbols()) & cloned_data) - else: - continue + arrays_used = set(block.used_symbols(all_symbols=True, with_contents=False)) & cloned_data # Create a state and copy out used arrays if len(arrays_used) > 0: co_state = block.parent_graph.add_state_before(block, block.label + '_icopyout') mapping = _create_copy_out(arrays_used) for devicename, hostname in mapping.items(): - if isinstance(block, ConditionalBlock): - for c, _ in block.branches: - if c is not None: - replace_in_codeblock(c, {devicename: hostname}) - elif isinstance(block, LoopRegion): - replace_in_codeblock(block.loop_condition, {devicename: hostname}) - if block.init_statement: - replace_in_codeblock(block.init_statement, {devicename: hostname}) - if block.update_statement: - replace_in_codeblock(block.update_statement, {devicename: hostname}) + block.replace_meta_accesses({devicename: hostname}) # Step 9: Simplify if not self.simplify: