Skip to content

Commit

Permalink
Merge pull request #1344 from spcl/used-symbol-fixes
Browse files Browse the repository at this point in the history
Used symbol fixes
  • Loading branch information
alexnick83 authored Oct 4, 2023
2 parents 06144b8 + 17fa4c1 commit e608c2c
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 27 deletions.
4 changes: 2 additions & 2 deletions dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,8 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
raise TypeError('Passing an object (type %s) to an array in argument "%s"' %
(type(arg).__name__, a))
elif dtypes.is_array(arg) and not isinstance(atype, dt.Array):
# GPU scalars are pointers, so this is fine
if atype.storage != dtypes.StorageType.GPU_Global:
# GPU scalars and return values are pointers, so this is fine
if atype.storage != dtypes.StorageType.GPU_Global and not a.startswith('__return'):
raise TypeError('Passing an array to a scalar (type %s) in argument "%s"' % (atype.dtype.ctype, a))
elif (not isinstance(atype, (dt.Array, dt.Structure)) and
not isinstance(atype.dtype, dtypes.callback) and
Expand Down
12 changes: 7 additions & 5 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,9 +1511,10 @@ def make_restrict(expr: str) -> str:
arguments += [
f'{atype} {restrict} {aname}' for (atype, aname, _), restrict in zip(memlet_references, restrict_args)
]
fsyms = node.sdfg.used_symbols(all_symbols=False, keep_defined_in_mapping=True)
arguments += [
f'{node.sdfg.symbols[aname].as_arg(aname)}' for aname in sorted(node.symbol_mapping.keys())
if aname not in sdfg.constants
if aname in fsyms and aname not in sdfg.constants
]
arguments = ', '.join(arguments)
return f'void {sdfg_label}({arguments}) {{'
Expand All @@ -1522,9 +1523,10 @@ def generate_nsdfg_call(self, sdfg, state, node, memlet_references, sdfg_label,
prepend = []
if state_struct:
prepend = ['__state']
fsyms = node.sdfg.used_symbols(all_symbols=False, keep_defined_in_mapping=True)
args = ', '.join(prepend + [argval for _, _, argval in memlet_references] + [
cpp.sym2cpp(symval)
for symname, symval in sorted(node.symbol_mapping.items()) if symname not in sdfg.constants
cpp.sym2cpp(symval) for symname, symval in sorted(node.symbol_mapping.items())
if symname in fsyms and symname not in sdfg.constants
])
return f'{sdfg_label}({args});'

Expand Down Expand Up @@ -1808,11 +1810,11 @@ def _generate_MapEntry(

# Find if bounds are used within the scope
scope = state_dfg.scope_subgraph(node, False, False)
fsyms = scope.free_symbols
fsyms = self._frame.free_symbols(scope)
# Include external edges
for n in scope.nodes():
for e in state_dfg.all_edges(n):
fsyms |= e.data.free_symbols
fsyms |= e.data.used_symbols(False, e)
fsyms = set(map(str, fsyms))

ntid_is_used = '__omp_num_threads' in fsyms
Expand Down
3 changes: 2 additions & 1 deletion dace/codegen/targets/intel_fpga.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,9 +729,10 @@ def generate_module(self, sdfg, state, kernel_name, module_name, subgraph, param
def generate_nsdfg_header(self, sdfg, state, state_id, node, memlet_references, sdfg_label):
# Intel FPGA needs to deal with streams
arguments = [f'{atype} {aname}' for atype, aname, _ in memlet_references]
fsyms = node.sdfg.used_symbols(all_symbols=False, keep_defined_in_mapping=True)
arguments += [
f'{node.sdfg.symbols[aname].as_arg(aname)}' for aname in sorted(node.symbol_mapping.keys())
if aname not in sdfg.constants
if aname in fsyms and aname not in sdfg.constants
]
arguments = ', '.join(arguments)
function_header = f'void {sdfg_label}({arguments}) {{'
Expand Down
3 changes: 2 additions & 1 deletion dace/codegen/targets/xilinx.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,10 @@ def generate_flatten_loop_post(kernel_stream, sdfg, state_id, node):
def generate_nsdfg_header(self, sdfg, state, state_id, node, memlet_references, sdfg_label):
# TODO: Use a single method for GPU kernels, FPGA modules, and NSDFGs
arguments = [f'{atype} {aname}' for atype, aname, _ in memlet_references]
fsyms = node.sdfg.used_symbols(all_symbols=False, keep_defined_in_mapping=True)
arguments += [
f'{node.sdfg.symbols[aname].as_arg(aname)}' for aname in sorted(node.symbol_mapping.keys())
if aname not in sdfg.constants
if aname in fsyms and aname not in sdfg.constants
]
arguments = ', '.join(arguments)
return f'void {sdfg_label}({arguments}) {{\n#pragma HLS INLINE'
Expand Down
37 changes: 31 additions & 6 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,22 +512,47 @@ def validate(self, sdfg, state):
if self.data is not None and self.data not in sdfg.arrays:
raise KeyError('Array "%s" not found in SDFG' % self.data)

def used_symbols(self, all_symbols: bool) -> Set[str]:
def used_symbols(self, all_symbols: bool, edge=None) -> Set[str]:
"""
Returns a set of symbols used in this edge's properties.
:param all_symbols: If False, only returns the set of symbols that will be used
in the generated code and are needed as arguments.
:param edge: If given, provides richer context-based tests for the case
of ``all_symbols=False``.
"""
# Symbolic properties are in volume, and the two subsets
result = set()
view_edge = False
if all_symbols:
result |= set(map(str, self.volume.free_symbols))
if self.src_subset:
result |= self.src_subset.free_symbols

if self.dst_subset:
result |= self.dst_subset.free_symbols
elif edge is not None: # Not all symbols are requested, and an edge is given
view_edge = False
from dace.sdfg import nodes
if isinstance(edge.dst, nodes.CodeNode) or isinstance(edge.src, nodes.CodeNode):
view_edge = True
elif edge.dst_conn == 'views' and isinstance(edge.dst, nodes.AccessNode):
view_edge = True
elif edge.src_conn == 'views' and isinstance(edge.src, nodes.AccessNode):
view_edge = True

if not view_edge:
if self.src_subset:
result |= self.src_subset.free_symbols

if self.dst_subset:
result |= self.dst_subset.free_symbols
else:
# View edges do not require the end of the range nor strides
if self.src_subset:
for rb, _, _ in self.src_subset.ndrange():
if symbolic.issymbolic(rb):
result |= set(map(str, rb.free_symbols))

if self.dst_subset:
for rb, _, _ in self.dst_subset.ndrange():
if symbolic.issymbolic(rb):
result |= set(map(str, rb.free_symbols))

return result

Expand Down
8 changes: 7 additions & 1 deletion dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,14 +1323,16 @@ def arrays_recursive(self):
if isinstance(node, nd.NestedSDFG):
yield from node.sdfg.arrays_recursive()

def used_symbols(self, all_symbols: bool) -> Set[str]:
def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]:
"""
Returns a set of symbol names that are used by the SDFG, but not
defined within it. This property is used to determine the symbolic
parameters of the SDFG.
:param all_symbols: If False, only returns the set of symbols that will be used
in the generated code and are needed as arguments.
:param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping
will be removed from the set of defined symbols.
"""
defined_syms = set()
free_syms = set()
Expand Down Expand Up @@ -1372,6 +1374,10 @@ def used_symbols(self, all_symbols: bool) -> Set[str]:
# Remove symbols that were used before they were assigned
defined_syms -= used_before_assignment

# Remove from defined symbols those that are in the symbol mapping
if self.parent_nsdfg_node is not None and keep_defined_in_mapping:
defined_syms -= set(self.parent_nsdfg_node.symbol_mapping.keys())

# Add the set of SDFG symbol parameters
# If all_symbols is False, those symbols would only be added in the case of non-Python tasklets
if all_symbols:
Expand Down
7 changes: 4 additions & 3 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]:
if not all_symbols and not self.is_leaf_memlet(e):
continue

freesyms |= e.data.used_symbols(all_symbols)
freesyms |= e.data.used_symbols(all_symbols, e)

# Do not consider SDFG constants as symbols
new_symbols.update(set(sdfg.constants.keys()))
Expand Down Expand Up @@ -688,14 +688,15 @@ def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Dat
defined_syms = defined_syms or self.defined_symbols()
scalar_args.update({
k: dt.Scalar(defined_syms[k]) if k in defined_syms else sdfg.arrays[k]
for k in self.free_symbols if not k.startswith('__dace') and k not in sdfg.constants
for k in self.used_symbols(all_symbols=False) if not k.startswith('__dace') and k not in sdfg.constants
})

# Add scalar arguments from free symbols of data descriptors
for arg in data_args.values():
scalar_args.update({
str(k): dt.Scalar(k.dtype)
for k in arg.free_symbols if not str(k).startswith('__dace') and str(k) not in sdfg.constants
for k in arg.used_symbols(all_symbols=False)
if not str(k).startswith('__dace') and str(k) not in sdfg.constants
})

# Fill up ordered dictionary
Expand Down
95 changes: 95 additions & 0 deletions tests/codegen/codegen_used_symbols_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
""" Tests used-symbols in code generation."""
import dace
import numpy
import pytest


n0i, n0j, n0k = (dace.symbol(s, dtype=dace.int32) for s in ('n0i', 'n0j', 'n0k'))
n1i, n1j, n1k = (dace.symbol(s, dtype=dace.int64) for s in ('n1i', 'n1j', 'n1k'))


@dace.program
def rprj3(r: dace.float64[n0i, n0j, n0k], s: dace.float64[n1i, n1j, n1k]):

for i, j, k in dace.map[1:s.shape[0] - 1, 1:s.shape[1] - 1, 1:s.shape[2] - 1]:

s[i, j, k] = (
0.5000 * r[2 * i, 2 * j, 2 * k] +
0.2500 * (r[2 * i - 1, 2 * j, 2 * k] + r[2 * i + 1, 2 * j, 2 * k] + r[2 * i, 2 * j - 1, 2 * k] +
r[2 * i, 2 * j + 1, 2 * k] + r[2 * i, 2 * j, 2 * k - 1] + r[2 * i, 2 * j, 2 * k + 1]) +
0.1250 * (r[2 * i - 1, 2 * j - 1, 2 * k] + r[2 * i - 1, 2 * j + 1, 2 * k] +
r[2 * i + 1, 2 * j - 1, 2 * k] + r[2 * i + 1, 2 * j + 1, 2 * k] +
r[2 * i - 1, 2 * j, 2 * k - 1] + r[2 * i - 1, 2 * j, 2 * k + 1] +
r[2 * i + 1, 2 * j, 2 * k - 1] + r[2 * i + 1, 2 * j, 2 * k + 1] +
r[2 * i, 2 * j - 1, 2 * k - 1] + r[2 * i, 2 * j - 1, 2 * k + 1] +
r[2 * i, 2 * j + 1, 2 * k - 1] + r[2 * i, 2 * j + 1, 2 * k + 1]) +
0.0625 * (r[2 * i - 1, 2 * j - 1, 2 * k - 1] + r[2 * i - 1, 2 * j - 1, 2 * k + 1] +
r[2 * i - 1, 2 * j + 1, 2 * k - 1] + r[2 * i - 1, 2 * j + 1, 2 * k + 1] +
r[2 * i + 1, 2 * j - 1, 2 * k - 1] + r[2 * i + 1, 2 * j - 1, 2 * k + 1] +
r[2 * i + 1, 2 * j + 1, 2 * k - 1] + r[2 * i + 1, 2 * j + 1, 2 * k + 1]))


def test_codegen_used_symbols_cpu():

rng = numpy.random.default_rng(42)
r = rng.random((10, 10, 10))
s_ref = numpy.zeros((4, 4, 4))
s_val = numpy.zeros((4, 4, 4))

rprj3.f(r, s_ref)
rprj3(r, s_val)

assert numpy.allclose(s_ref, s_val)


def test_codegen_used_symbols_cpu_2():

@dace.program
def rprj3_nested(r: dace.float64[n0i, n0j, n0k], s: dace.float64[n1i, n1j, n1k]):
rprj3(r, s)

rng = numpy.random.default_rng(42)
r = rng.random((10, 10, 10))
s_ref = numpy.zeros((4, 4, 4))
s_val = numpy.zeros((4, 4, 4))

rprj3.f(r, s_ref)
rprj3_nested(r, s_val)

assert numpy.allclose(s_ref, s_val)


@pytest.mark.gpu
def test_codegen_used_symbols_gpu():

sdfg = rprj3.to_sdfg()
for _, desc in sdfg.arrays.items():
if not desc.transient and isinstance(desc, dace.data.Array):
desc.storage = dace.StorageType.GPU_Global
sdfg.apply_gpu_transformations()
func = sdfg.compile()

try:
import cupy

rng = numpy.random.default_rng(42)
r = rng.random((10, 10, 10))
r_dev = cupy.asarray(r)
s_ref = numpy.zeros((4, 4, 4))
s_val = cupy.zeros((4, 4, 4))

rprj3.f(r, s_ref)
func(r=r_dev, s=s_val, n0i=10, n0j=10, n0k=10, n1i=4, n1j=4, n1k=4)

assert numpy.allclose(s_ref, s_val)

except (ImportError, ModuleNotFoundError):
pass


if __name__ == "__main__":

test_codegen_used_symbols_cpu()
test_codegen_used_symbols_cpu_2()
test_codegen_used_symbols_gpu()
17 changes: 9 additions & 8 deletions tests/symbol_mapping_replace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ def outer(A, inp1: float, inp2: float):

def test_symbol_mapping_replace():

with dace.config.set_temporary('optimizer', 'automatic_simplification', value=True):
A = np.ones((10, 10, 10))
ref = A.copy()
b = 2.0
c = 2.0
outer(A, inp1=b, inp2=c)
outer.f(ref, inp1=b, inp2=c)
assert (np.allclose(A, ref))
# TODO/NOTE: Setting temporary config values does not work in the CI
# with dace.config.set_temporary('optimizer', 'automatic_simplification', value=True):
A = np.ones((10, 10, 10))
ref = A.copy()
b = 2.0
c = 2.0
outer(A, inp1=b, inp2=c)
outer.f(ref, inp1=b, inp2=c)
assert (np.allclose(A, ref))


if __name__ == '__main__':
Expand Down

0 comments on commit e608c2c

Please sign in to comment.