Skip to content

Commit

Permalink
Add in fixes for gen_gantt script
Browse files Browse the repository at this point in the history
  • Loading branch information
weiya711 committed Oct 11, 2023
1 parent fe3ad32 commit b5d9547
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
3 changes: 3 additions & 0 deletions sam/sim/test/gen_gantt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def gen_gantt(extra_info, testname):
sam_name = ''

for k in extra_info.keys():
print(k, sam_name)
if "done_cycles" in k:
sam_name = k.split('/')[0]
finish_c = extra_info[k]
Expand Down Expand Up @@ -41,6 +42,8 @@ def gen_gantt(extra_info, testname):
if "backpressure" in extra_info.keys() and extra_info["backpressure"]:
back_depth = extra_info["depth"]

print(finish_list, block_list, start_list, duration_list)

# Writing cycle info to csv file
with open(testname + '_' + extra_info["dataset"] + '_back_' + back_depth + '.csv', 'w', newline='') as file:
writer = csv.writer(file)
Expand Down
14 changes: 12 additions & 2 deletions scripts/gen_sam_apps/test_generating_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,11 @@ def generate_header(f, out_name):
f.write("from sam.sim.src.token import *\n")
f.write("from sam.sim.test.test import *\n")
f.write("from sam.sim.test.gold import *\n")
f.write("from sam.sim.test.gen_gantt import gen_gantt\n")
f.write("\n")
f.write("import os\n")
f.write("import csv\n")
f.write("\n")
f.write("cwd = os.getcwd()\n")
if out_name in suitesparse_list:
f.write("formatted_dir = os.getenv('SUITESPARSE_FORMATTED_PATH', default=os.path.join(cwd, 'mode-formats'))\n")
Expand Down Expand Up @@ -402,8 +405,12 @@ def finish_outputs(f, elements, nodes_completed):


def generate_benchmarking_code(f, tensor_format_parse, test_name):
f.write("\n" + tab(1) + "def bench():\n")
f.write("\n")
f.write(tab(1) + "# Print out cycle count for pytest output\n")
f.write(tab(1) + "print(time_cnt)\n")
f.write(tab(1) + "def bench():\n")
f.write(tab(2) + "time.sleep(0.01)\n\n")
f.write("\n")
f.write(tab(1) + "extra_info = dict()\n")
f.write(tab(1) + "extra_info[\"dataset\"] = " + get_dataset_name(test_name) + "\n")
f.write(tab(1) + "extra_info[\"cycles\"] = time_cnt\n")
Expand All @@ -422,7 +429,10 @@ def generate_benchmarking_code(f, tensor_format_parse, test_name):
if d[u]["type"] in statistic_available:
f.write(tab(1) + "sample_dict = " + d[u]["object"] + ".return_statistics()\n")
f.write(tab(1) + "for k in sample_dict.keys():\n")
f.write(tab(2) + "extra_info[\"" + d[u]["object"] + "\" + \"_\" + k] = sample_dict[k]\n\n")
f.write(tab(2) + "extra_info[\"" + d[u]["object"] + "\" + \"/\" + k] = sample_dict[k]\n\n")

f.write(tab(1) + "gen_gantt(extra_info, \"" + test_name + "\")\n")
f.write("\n")


def generate_check_against_gold_code(f, tensor_format_parse, test_name):
Expand Down

0 comments on commit b5d9547

Please sign in to comment.