Skip to content

Commit

Permalink
Added more extensible meta access replacement function
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Dec 11, 2024
1 parent c9d6b51 commit e2a6466
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 34 deletions.
13 changes: 2 additions & 11 deletions dace/sdfg/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dace import dtypes, properties, symbolic
from dace.codegen import cppunparse
from dace.frontend.python.astutils import ASTFindReplace
from dace.sdfg.state import ConditionalBlock, LoopRegion
from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion

if TYPE_CHECKING:
from dace.sdfg.state import StateSubgraphView
Expand Down Expand Up @@ -203,13 +203,4 @@ def replace_datadesc_names(sdfg: 'dace.SDFG', repl: Dict[str, str]):
edge.data.data = repl[edge.data.data]

# Replace in loop or branch conditions:
if isinstance(cf, LoopRegion):
replace_in_codeblock(cf.loop_condition, repl)
if cf.update_statement:
replace_in_codeblock(cf.update_statement, repl)
if cf.init_statement:
replace_in_codeblock(cf.init_statement, repl)
elif isinstance(cf, ConditionalBlock):
for c, _ in cf.branches:
if c is not None:
replace_in_codeblock(c, repl)
cf.replace_meta_accesses(repl)
23 changes: 23 additions & 0 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import dace
from dace.frontend.python import astutils
from dace.sdfg.replace import replace_in_codeblock
import dace.serialize
from dace import data as dt
from dace import dtypes
Expand Down Expand Up @@ -2612,6 +2613,16 @@ def get_meta_read_memlets(self) -> List[mm.Memlet]:
"""
return []

def replace_meta_accesses(self, replacements: dict) -> None:
"""
Replace accesses to specific data containers in reads or writes performed by the control flow region itself in
meta accesses, such as in condition checks for conditional blocks or in loop conditions for loops, etc.
:param replacements: A dictionary mapping the current data container names to the names of data containers with
which accesses to them should be replaced.
"""
pass

@property
def root_sdfg(self) -> 'SDFG':
from dace.sdfg.sdfg import SDFG # Avoid import loop
Expand Down Expand Up @@ -3304,6 +3315,13 @@ def get_meta_read_memlets(self) -> List[mm.Memlet]:
read_memlets.extend(memlets_in_ast(self.update_statement.code[0], self.sdfg.arrays))
return read_memlets

def replace_meta_accesses(self, replacements):
replace_in_codeblock(self.loop_condition, replacements)
if self.init_statement:
replace_in_codeblock(self.init_statement, replacements)
if self.update_statement:
replace_in_codeblock(self.update_statement, replacements)

def _used_symbols_internal(self,
all_symbols: bool,
defined_syms: Optional[Set] = None,
Expand Down Expand Up @@ -3418,6 +3436,11 @@ def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, parent: Optio
def sub_regions(self):
return [b for _, b in self.branches]

def replace_meta_accesses(self, replacements):
for c, _ in self.branches:
if c is not None:
replace_in_codeblock(c, replacements)

def __str__(self):
return self._label

Expand Down
25 changes: 2 additions & 23 deletions dace/transformation/interstate/gpu_transform_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,35 +598,14 @@ def _create_copy_out(arrays_used: Set[str]) -> Dict[str, str]:
e.data.replace(devicename, hostname, False)

for block in list(sdfg.all_control_flow_blocks()):
arrays_used = set()
if isinstance(block, ConditionalBlock):
for c, _ in block.branches:
if c is not None:
arrays_used.update(set(c.get_free_symbols()) & cloned_data)
elif isinstance(block, LoopRegion):
arrays_used.update(set(block.loop_condition.get_free_symbols()) & cloned_data)
if block.init_statement:
arrays_used.update(set(block.init_statement.get_free_symbols()) & cloned_data)
if block.update_statement:
arrays_used.update(set(block.update_statement.get_free_symbols()) & cloned_data)
else:
continue
arrays_used = set(block.used_symbols(all_symbols=True, with_contents=False)) & cloned_data

# Create a state and copy out used arrays
if len(arrays_used) > 0:
co_state = block.parent_graph.add_state_before(block, block.label + '_icopyout')
mapping = _create_copy_out(arrays_used)
for devicename, hostname in mapping.items():
if isinstance(block, ConditionalBlock):
for c, _ in block.branches:
if c is not None:
replace_in_codeblock(c, {devicename: hostname})
elif isinstance(block, LoopRegion):
replace_in_codeblock(block.loop_condition, {devicename: hostname})
if block.init_statement:
replace_in_codeblock(block.init_statement, {devicename: hostname})
if block.update_statement:
replace_in_codeblock(block.update_statement, {devicename: hostname})
block.replace_meta_accesses({devicename: hostname})

# Step 9: Simplify
if not self.simplify:
Expand Down

0 comments on commit e2a6466

Please sign in to comment.