diff --git a/scripts/gen_sam_apps/test_generating_code.py b/scripts/gen_sam_apps/test_generating_code.py index 5aaef13e..17bfa91b 100755 --- a/scripts/gen_sam_apps/test_generating_code.py +++ b/scripts/gen_sam_apps/test_generating_code.py @@ -198,6 +198,19 @@ def get_common_test_name(test_name): return test_name +def get_out_crd_str(d, u_, index_value): + # By default, the input primitive connected to a crddrop will be a level scanner + out_crd_str = "out_crd" + # However, if the input primitive is another crddrop, we need to make sure it's reading from + # the correct input crddrop output. + if d[u_]["type"] == "crddrop": + if index_value == d[u_]["inner"]: + out_crd_str += "_inner" + elif index_value == d[u_]["outer"]: + out_crd_str += "_outer" + return out_crd_str + + def generate_datasets_code(f, tensor_formats, scope_lvl, tensor_info, tensor_format_parse, test_name): # Assuming the format is csr and csc: for ten in tensor_format_parse.return_all_tensors(): @@ -539,7 +552,7 @@ def get_all_files(directory_path): continue out_name.append(filename[0:-3]) # checking if it is a file - print(out_name[-1]) + print("Test Name:", out_name[-1]) if os.path.isfile(f): file_paths.append(f) return file_paths, out_name @@ -810,9 +823,13 @@ def get_all_files(directory_path): for u_ in data.get_parents()[v]: index_value = data.get_edge_data()[v][data.get_parents()[v].index(u_)][-1] if index_value == d[v]["inner"]: - f.write(tab(2) + d[v]["object"] + ".set_inner_crd" + "(" + d[u_]["object"] + ".out_crd())\n") + out_crd_str = get_out_crd_str(d, u_, index_value) + f.write(tab(2) + d[v]["object"] + ".set_inner_crd" + "(" + d[u_]["object"] + "." + + out_crd_str + "())\n") if index_value == d[v]["outer"]: - f.write(tab(2) + d[v]["object"] + ".set_outer_crd" + "(" + d[u_]["object"] + ".out_crd())\n") + out_crd_str = get_out_crd_str(d, u_, index_value) + f.write(tab(2) + d[v]["object"] + ".set_outer_crd" + "(" + d[u_]["object"] + "." + + out_crd_str + "())\n") nodes_updating_list.append(tab(2) + d[v]["object"] + ".update()\n") # f.write(tab(2) + d[v]["object"] + ".update()\n\n") data.add_done(v) @@ -933,7 +950,6 @@ def get_all_files(directory_path): if "val" not in data.get_edge_data()[v][i] and "spaccumulator" \ in d[u_]["object"]: local_index = data.get_edge_data()[v][i][-1] - print(d[u_], " ", local_index, " ", apath) if d[u_]["in0"] == local_index: local_cord = "_inner" else: