Skip to content

Commit

Permalink
Finish implementation of MERGE intrinsic
Browse files Browse the repository at this point in the history
  • Loading branch information
mcopik committed Oct 19, 2023
1 parent e45d62d commit 873c6dc
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 12 deletions.
20 changes: 14 additions & 6 deletions dace/frontend/fortran/intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ def _parse_binary_op(self, node: ast_internal_classes.Call_Expr_Node, arg: ast_i
return (dominant_array, None, cond)

if len(first_array.indices) != len(second_array.indices):
raise TypeError("Can't parse Fortran ANY with different array ranks!")
raise TypeError("Can't parse Fortran binary op with different array ranks!")

for left_idx, right_idx in zip(first_array.indices, second_array.indices):
if left_idx.type != right_idx.type:
raise TypeError("Can't parse Fortran ANY with different array ranks!")
raise TypeError("Can't parse Fortran binary op with different array ranks!")

# Now, we need to convert the array to a proper subscript node
cond = copy.deepcopy(arg)
Expand Down Expand Up @@ -879,14 +879,22 @@ def _summarize_args(self, node: ast_internal_classes.FNode, new_func_body: List[
# The first main argument is an array -> this dictates loop boundaries
# Other arrays, regardless if they appear as the second array or mask, need to have the same loop boundary.
par_Decl_Range_Finder(self.first_array, self.loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True)
par_Decl_Range_Finder(self.second_array, [], [], [], self.count, new_func_body, self.scope_vars, True)

loop_ranges = []
par_Decl_Range_Finder(self.second_array, loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True)
self._adjust_array_ranges(node, self.second_array, self.loop_ranges, loop_ranges)

par_Decl_Range_Finder(self.destination_array, [], [], [], self.count, new_func_body, self.scope_vars, True)

if self.mask_first_array is not None:
par_Decl_Range_Finder(self.mask_first_array, [], [], [], self.count, new_func_body, self.scope_vars, True)
if self.mask_second_array is not None:
par_Decl_Range_Finder(self.mask_second_array, [], [], [], self.count, new_func_body, self.scope_vars, True)
loop_ranges = []
par_Decl_Range_Finder(self.mask_first_array, loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True)
self._adjust_array_ranges(node, self.mask_first_array, self.loop_ranges, loop_ranges)

if self.mask_second_array is not None:
loop_ranges = []
par_Decl_Range_Finder(self.mask_second_array, loop_ranges, [], [], self.count, new_func_body, self.scope_vars, True)
self._adjust_array_ranges(node, self.mask_second_array, self.loop_ranges, loop_ranges)

def _initialize_result(self, node: ast_internal_classes.FNode) -> Optional[ast_internal_classes.BinOp_Node]:
"""
Expand Down
103 changes: 97 additions & 6 deletions tests/fortran/intrinsic_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,106 @@ def test_fortran_frontend_merge_comparison_arrays():
else:
assert res[i] == 13

# mask comparison on array participating
# mask comparison on two arrays participating
# mask comparison - second array with shift
# mask comparison - both arrays wiht a shift
# second array - shift!
# merge 2d

def test_fortran_frontend_merge_comparison_arrays_offset():
"""
Tests that the generated array map correctly handles offsets.
"""
test_string = """
PROGRAM merge_test
implicit none
double precision, dimension(7) :: input1
double precision, dimension(7) :: input2
double precision, dimension(14) :: mask1
double precision, dimension(14) :: mask2
double precision, dimension(7) :: res
CALL merge_test_function(input1, input2, mask1, mask2, res)
end
SUBROUTINE merge_test_function(input1, input2, mask1, mask2, res)
double precision, dimension(7) :: input1
double precision, dimension(7) :: input2
double precision, dimension(14) :: mask1
double precision, dimension(14) :: mask2
double precision, dimension(7) :: res
res = MERGE(input1, input2, mask1(3:9) .lt. mask2(5:11))
END SUBROUTINE merge_test_function
"""

# Now test to verify it executes correctly with no offset normalization
sdfg = fortran_parser.create_sdfg_from_string(test_string, "merge_test", True)
sdfg.simplify(verbose=True)
sdfg.compile()
size = 7

# Minimum is in the beginning
first = np.full([size], 13, order="F", dtype=np.float64)
second = np.full([size], 42, order="F", dtype=np.float64)
mask1 = np.full([size*2], 30, order="F", dtype=np.float64)
mask2 = np.full([size*2], 0, order="F", dtype=np.float64)
res = np.full([size], 40, order="F", dtype=np.float64)

mask1[2:9] = 3
mask2[4:11] = 4
sdfg(input1=first, input2=second, mask1=mask1, mask2=mask2, res=res)
for val in res:
assert val == 13


def test_fortran_frontend_merge_array_shift():
"""
Tests that the generated array map correctly handles offsets.
"""
test_string = """
PROGRAM merge_test
implicit none
double precision, dimension(7) :: input1
double precision, dimension(21) :: input2
double precision, dimension(14) :: mask1
double precision, dimension(14) :: mask2
double precision, dimension(7) :: res
CALL merge_test_function(input1, input2, mask1, mask2, res)
end
SUBROUTINE merge_test_function(input1, input2, mask1, mask2, res)
double precision, dimension(7) :: input1
double precision, dimension(21) :: input2
double precision, dimension(14) :: mask1
double precision, dimension(14) :: mask2
double precision, dimension(7) :: res
res = MERGE(input1, input2(13:19), mask1(3:9) .gt. mask2(5:11))
END SUBROUTINE merge_test_function
"""

# Now test to verify it executes correctly with no offset normalization
sdfg = fortran_parser.create_sdfg_from_string(test_string, "merge_test", True)
sdfg.simplify(verbose=True)
sdfg.compile()
size = 7

# Minimum is in the beginning
first = np.full([size], 13, order="F", dtype=np.float64)
second = np.full([size*3], 42, order="F", dtype=np.float64)
mask1 = np.full([size*2], 30, order="F", dtype=np.float64)
mask2 = np.full([size*2], 0, order="F", dtype=np.float64)
res = np.full([size], 40, order="F", dtype=np.float64)

second[12:19] = 100
mask1[2:9] = 3
mask2[4:11] = 4
sdfg(input1=first, input2=second, mask1=mask1, mask2=mask2, res=res)
for val in res:
assert val == 100


if __name__ == "__main__":

test_fortran_frontend_merge_1d()
test_fortran_frontend_merge_comparison_scalar()
test_fortran_frontend_merge_comparison_arrays()
test_fortran_frontend_merge_comparison_arrays_offset()
test_fortran_frontend_merge_array_shift()

0 comments on commit 873c6dc

Please sign in to comment.