Skip to content

Commit

Permalink
Improve memlet label and string initialization (#1680)
Browse files Browse the repository at this point in the history
Follow up on the discussion in #1678.

Supports `src[expr] -> dst[expr]`, `src[expr] -> [expr]`, and `[expr] ->
dst[expr]` initializations for memlets. Also improves memlet label
printouts.

@philip-paul-mueller @phschaad the expression mentioned in the other PR
will now be printed as `[0, 0] -> B[0]` for clarity and can be reparsed.
  • Loading branch information
tbennun authored Oct 12, 2024
1 parent e6440a6 commit 64c54ab
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 36 deletions.
4 changes: 2 additions & 2 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node,
for i, s in zip(all_indices, array.shape)])
smallsubset = subsets.Range([(0, s - 1, 1) for s in shape])

memlet = Memlet(f'{array_name}[{subset}]->{smallsubset}')
memlet2 = Memlet(f'{viewname}[{smallsubset}]->{subset}')
memlet = Memlet(f'{array_name}[{subset}]->[{smallsubset}]')
memlet2 = Memlet(f'{viewname}[{smallsubset}]->[{subset}]')
wv = None
rv = None
if local_name.name in read_names:
Expand Down
6 changes: 3 additions & 3 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def _numpy_flip(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, axis
# acpy, _ = sdfg.add_temp_transient(desc.shape, desc.dtype, desc.storage)
# vnode = state.add_read(view)
# anode = state.add_read(acpy)
# state.add_edge(vnode, None, anode, None, Memlet(f'{view}[{sset}] -> {dset}'))
# state.add_edge(vnode, None, anode, None, Memlet(f'{view}[{sset}] -> [{dset}]'))

arr_copy, _ = sdfg.add_temp_transient_like(desc)
inpidx = ','.join([f'__i{i}' for i in range(ndim)])
Expand Down Expand Up @@ -3934,7 +3934,7 @@ def implement_ufunc_accumulate(visitor: ProgramVisitor, ast_node: ast.Call, sdfg
init_state = nested_sdfg.add_state(label="init")
r = init_state.add_read(inpconn)
w = init_state.add_write(outconn)
init_state.add_nedge(r, w, dace.Memlet("{a}[{i}] -> {oi}".format(a=inpconn, i='0', oi='0')))
init_state.add_nedge(r, w, dace.Memlet("{a}[{i}] -> [{oi}]".format(a=inpconn, i='0', oi='0')))

body_state = nested_sdfg.add_state(label="body")
r1 = body_state.add_read(inpconn)
Expand Down Expand Up @@ -4189,7 +4189,7 @@ def view(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, dtype, type
find_new_name=True)

# Register view with DaCe program visitor
# NOTE: We do not create here a Memlet of the form `A[subset] -> osubset`
# NOTE: We do not create here a Memlet of the form `A[subset] -> [osubset]`
# because the View can be of a different dtype. Adding `other_subset` in
# such cases will trigger validation error.
pv.views[newarr] = (arr, Memlet.from_array(arr, desc))
Expand Down
45 changes: 32 additions & 13 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def __init__(self,
of use API. Must follow one of the following forms:
1. ``ARRAY``,
2. ``ARRAY[SUBSET]``,
3. ``ARRAY[SUBSET] -> OTHER_SUBSET``.
3. ``ARRAY[SUBSET] -> [OTHER_SUBSET]``,
4. ``[OTHER_SUBSET] -> ARRAY[SUBSET]``,
5. ``SRC_ARRAY[SRC_SUBSET] -> DST_ARRAY[DST_SUBSET]``.
:param data: Data descriptor name attached to this memlet.
:param subset: The subset to take from the data attached to the edge,
represented either as a string or a Subset object.
Expand Down Expand Up @@ -330,6 +332,10 @@ def _parse_from_subexpr(self, expr: str):
raise SyntaxError('Invalid memlet syntax "%s"' % expr)
return expr, None

# [subset] syntax
if expr.startswith('['):
return None, SubsetProperty.from_string(expr[1:-1])

# array[subset] syntax
arrname, subset_str = expr[:-1].split('[')
if not dtypes.validate_name(arrname):
Expand All @@ -342,27 +348,40 @@ def _parse_memlet_from_str(self, expr: str):
or the _data,_subset fields.
:param expr: A string expression of the this memlet, given as an ease
of use API. Must follow one of the following forms:
1. ``ARRAY``,
2. ``ARRAY[SUBSET]``,
3. ``ARRAY[SUBSET] -> OTHER_SUBSET``.
Note that modes 2 and 3 are deprecated and will leave
the memlet uninitialized until inserted into an SDFG.
of use API. Must follow one of the following forms:
1. ``ARRAY``,
2. ``ARRAY[SUBSET]``,
3. ``ARRAY[SUBSET] -> [OTHER_SUBSET]``,
4. ``[OTHER_SUBSET] -> ARRAY[SUBSET]``,
5. ``SRC_ARRAY[SRC_SUBSET] -> DST_ARRAY[DST_SUBSET]``.
Note that options 1-2 will leave the memlet uninitialized
until added into an SDFG.
"""
expr = expr.strip()
if '->' not in expr: # Options 1 and 2
self.data, self.subset = self._parse_from_subexpr(expr)
return

# Option 3
# Options 3-5
src_expr, dst_expr = expr.split('->')
src_expr = src_expr.strip()
dst_expr = dst_expr.strip()
if '[' not in src_expr and not dtypes.validate_name(src_expr):
raise SyntaxError('Expression without data name not yet allowed')

self.data, self.subset = self._parse_from_subexpr(src_expr)
self.other_subset = SubsetProperty.from_string(dst_expr)
src_data, src_subset = self._parse_from_subexpr(src_expr)
dst_data, dst_subset = self._parse_from_subexpr(dst_expr)
if src_data is None and dst_data is None:
raise SyntaxError('At least one data name needs to be given')

if src_data is not None: # Prefer src[subset] -> [other_subset]
self.data = src_data
self.subset = src_subset
self.other_subset = dst_subset
self._is_data_src = True
else:
self.data = dst_data
self.subset = dst_subset
self.other_subset = src_subset
self._is_data_src = False

def try_initialize(self, sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState',
edge: 'dace.sdfg.graph.MultiConnectorEdge'):
Expand Down Expand Up @@ -660,7 +679,7 @@ def _label(self, shape):

if self.other_subset is not None:
if self._is_data_src is False:
result += ' <- [%s]' % str(self.other_subset)
result = f'[{self.other_subset}] -> {result}'
else:
result += ' -> [%s]' % str(self.other_subset)
return result
Expand Down
6 changes: 3 additions & 3 deletions dace/transformation/dataflow/bank_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def apply(self, graph: SDFGState, sdfg: SDFG) -> Union[Any, None]:
target_offset_str = ", ".join([f"({x}):({x}+{y})" for x, y in zip(target_offset, target_size)])
if collect_src:
copy_memlet = memlet.Memlet(f"{src.data}[{target_hbm_bank_str}, {target_size_str}]->"
f"{target_offset_str}")
f"[{target_offset_str}]")
else:
copy_memlet = memlet.Memlet(f"{src.data}[{target_offset_str}]->{target_hbm_bank_str}, "
f"{target_size_str}")
copy_memlet = memlet.Memlet(f"{src.data}[{target_offset_str}]->[{target_hbm_bank_str}, "
f"{target_size_str}]")
graph.add_edge(src, None, dst, None, copy_memlet)
4 changes: 2 additions & 2 deletions tests/codegen/dependency_edge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def test_mapped_dependency_edge(reverse):
state.add_edge(map_entry, "OUT_A", tmp_A, None, dace.Memlet("A[i]"))
state.add_edge(map_entry, "OUT_B", tmp_B, None, dace.Memlet("B[i]"))

state.add_edge(tmp_A, None, A2, None, dace.Memlet("tmp_A[0] -> ((i+1)%2)"))
state.add_edge(tmp_A, None, A2, None, dace.Memlet("tmp_A[0] -> [((i+1)%2)]"))
if not reverse:
state.add_edge(A2, None, tmp_B, None, dace.Memlet()) # Dependency Edge
state.add_edge(A2, None, map_exit, "IN_A", dace.Memlet("A[0:2]"))

state.add_edge(tmp_B, None, A3, None, dace.Memlet("tmp_B[0] -> ((i+1)%2)"))
state.add_edge(tmp_B, None, A3, None, dace.Memlet("tmp_B[0] -> [((i+1)%2)]"))
if reverse:
state.add_edge(A3, None, tmp_A, None, dace.Memlet()) # Dependency Edge
state.add_edge(A3, None, map_exit, "IN_A", dace.Memlet("A[0:2]"))
Expand Down
8 changes: 4 additions & 4 deletions tests/fpga/multibank_copy_fpga_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def copy_multibank_1_mem_type(mem_type):
s, a, _ = mkc(sdfg, None, "a", "x", StorageType.Default, StorageType.FPGA_Global, [3, 4, 4], [3, 4, 4], "a", None,
(mem_type, "0:3"))
s, _, _ = mkc(sdfg, s, "x", "y", None, StorageType.FPGA_Global, None, [2, 4, 4, 4],
"x[1, 1:4, 1:4]->1, 1:4, 1:4, 1", None, (mem_type, "3:5"))
"x[1, 1:4, 1:4]->[1, 1:4, 1:4, 1]", None, (mem_type, "3:5"))
s, _, _ = mkc(sdfg, s, "y", "z", None, StorageType.FPGA_Global, None, [1, 4, 4, 4],
"y[1, 0:4, 0:4, 0:4]->0, 0:4, 0:4, 0:4", None, (mem_type, "5:6"))
"y[1, 0:4, 0:4, 0:4]->[0, 0:4, 0:4, 0:4]", None, (mem_type, "5:6"))
s, _, _ = mkc(sdfg, s, "z", "w", None, StorageType.FPGA_Global, None, [1, 4, 4, 4], "z", None, (mem_type, "6:7"))
s, _, c = mkc(sdfg, s, "w", "c", None, StorageType.Default, None, [1, 4, 4, 4], "w")

Expand All @@ -97,9 +97,9 @@ def copy_multibank_2_mem_type(mem_type_1, mem_type_2):
sdfg = dace.SDFG("copy_multibank_2_mem_type_" + mem_type_1 + "_" + mem_type_2)
s, a, _ = mkc(sdfg, None, "a", "x", StorageType.Default, StorageType.FPGA_Global, [3, 5, 5], [3, 5, 5], "a", None,
(mem_type_1, "0:3"))
s, _, _ = mkc(sdfg, s, "x", "d1", None, StorageType.FPGA_Global, None, [3, 5, 5], "x[2, 0:5, 0:5]->1, 0:5, 0:5",
s, _, _ = mkc(sdfg, s, "x", "d1", None, StorageType.FPGA_Global, None, [3, 5, 5], "x[2, 0:5, 0:5]->[1, 0:5, 0:5]",
None, (mem_type_2, "1:4"))
s, _, _ = mkc(sdfg, s, "d1", "y", None, StorageType.FPGA_Global, None, [1, 7, 7], "d1[1, 0:5,0:5]->0, 2:7, 2:7",
s, _, _ = mkc(sdfg, s, "d1", "y", None, StorageType.FPGA_Global, None, [1, 7, 7], "d1[1, 0:5,0:5]->[0, 2:7, 2:7]",
None, (mem_type_1, "3:4"))
s, _, c = mkc(sdfg, s, "y", "c", None, StorageType.Default, None, [1, 7, 7], "y")

Expand Down
6 changes: 3 additions & 3 deletions tests/inlining_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_regression_reshape_unsqueeze():
A = nstate.add_access("view")
W = nstate.add_write("output")

mm1 = dace.Memlet("input[0:3, 0:3] -> 0:3, 0:3")
mm2 = dace.Memlet("view[0:3, 0:2] -> 3:9")
mm1 = dace.Memlet("input[0:3, 0:3] -> [0:3, 0:3]")
mm2 = dace.Memlet("view[0:3, 0:2] -> [3:9]")

nstate.add_edge(R, None, A, None, mm1)
nstate.add_edge(A, None, W, None, mm2)
Expand Down Expand Up @@ -405,7 +405,7 @@ def test_regression_inline_subset():
nsdfg.add_array("input", [96, 32], dace.float64)
nsdfg.add_array("output", [32, 32], dace.float64)
nstate.add_edge(nstate.add_read("input"), None, nstate.add_write("output"), None,
dace.Memlet("input[32:64, 0:32] -> 0:32, 0:32"))
dace.Memlet("input[32:64, 0:32] -> [0:32, 0:32]"))

@dace.program
def test(A: dace.float64[96, 32]):
Expand Down
2 changes: 1 addition & 1 deletion tests/passes/access_ranges_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def tester(A: dace.float64[N, N], B: dace.float64[20, 20]):
# Construct read/write memlets
memlet1 = dace.Memlet('A[0:N, 0:N]')
memlet1._is_data_src = False
memlet2 = dace.Memlet('A[1:21, 1:21] -> 0:20, 0:20')
memlet2 = dace.Memlet('A[1:21, 1:21] -> [0:20, 0:20]')
memlet2._is_data_src = False
memlet3 = dace.Memlet('A[0, 0]')
memlet4 = dace.Memlet('A[0, 0]')
Expand Down
6 changes: 3 additions & 3 deletions tests/sdfg/reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _create_scoped_sdfg():
inp = state.add_read('B')
t = state.add_tasklet('doit', {'r'}, {'w'}, 'w = r + 1')
out = state.add_write('A')
state.add_memlet_path(inp, me, ref, memlet=dace.Memlet('B[1, i] -> i'))
state.add_memlet_path(inp, me, ref, memlet=dace.Memlet('B[1, i] -> [i]'))
state.add_edge(ref, None, t, 'r', dace.Memlet('ref[i]'))
state.add_edge_pair(mx, t, out, internal_connector='w', internal_memlet=dace.Memlet('A[10, i]'))

Expand Down Expand Up @@ -250,7 +250,7 @@ def _create_loop_nonfree_symbols_sdfg():
sdfg.add_loop(istate, state, after, 'i', '0', 'i < 20', 'i + 1')

# Reference set inside loop
state.add_edge(state.add_read('A'), None, state.add_write('ref'), 'set', dace.Memlet('A[i] -> 0'))
state.add_edge(state.add_read('A'), None, state.add_write('ref'), 'set', dace.Memlet('A[i] -> [0]'))

# Use outisde loop
t = after.add_tasklet('setone', {}, {'out'}, 'out = 1')
Expand Down Expand Up @@ -519,7 +519,7 @@ def test_reference_loop_nonfree():
assert len(sources) == 1 # There is only one SDFG
sources = sources[0]
assert len(sources) == 1
assert sources['ref'] == {dace.Memlet('A[i] -> 0')}
assert sources['ref'] == {dace.Memlet('A[i] -> [0]')}

# Test loop-to-map - should fail to apply
from dace.transformation.interstate import LoopToMap
Expand Down
4 changes: 2 additions & 2 deletions tests/transformations/prune_connectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _make_read_write_sdfg(
istate.add_nedge(
inner_A,
inner_B,
dace.Memlet("inner_A[0:4, 0:4] -> 0:4, 0:4"),
dace.Memlet("inner_A[0:4, 0:4] -> [0:4, 0:4]"),
)
else:
# Because the `data` filed of the involved memlets differs the read to
Expand All @@ -216,7 +216,7 @@ def _make_read_write_sdfg(
istate.add_nedge(
inner_A,
inner_B,
dace.Memlet("inner_B[0:4, 0:4] -> 0:4, 0:4"),
dace.Memlet("inner_B[0:4, 0:4] -> [0:4, 0:4]"),
)

# Add the nested SDFG
Expand Down

0 comments on commit 64c54ab

Please sign in to comment.