Skip to content

Commit

Permalink
Merge pull request #1945 from StanfordAHA/glb_batch_size
Browse files Browse the repository at this point in the history
fix max size of tile calc
  • Loading branch information
kalhankoul96 authored Jul 29, 2024
2 parents 3c0c217 + 54b6040 commit 2a4e412
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 17 deletions.
27 changes: 12 additions & 15 deletions aha/util/regress.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,22 +236,19 @@ def test_sparse_app(testname, seed_flow, data_tile_pairs, pipeline_num_l=None, o
use_pipeline = False
start = time.time()
if use_pipeline:
# Last batch won't have the same number of tiles as the rest, so we do two VCS calls
data_tile_pairs = [f"{test}_{tile}/GLB_DIR/{test}_combined_seed_{tile}" for tile in data_tile_pairs]
full_tile_pairs = []
full_pipeline_cmd = None
# if there's only one batch, we don't need to handle partially full batches
if len(data_tile_pairs) > 1:
full_tile_pairs = data_tile_pairs[:-1]
full_pipeline_num = pipeline_num_l[0]
full_pipeline_cmd = ["aha", "test"] + full_tile_pairs + ["--sparse", "--multiles", str(full_pipeline_num)]
last_tile_pair = data_tile_pairs[-1]
last_pipeline_num = pipeline_num_l[-1]
last_pipeline_cmd = ["aha", "test"] + [last_tile_pair] + ["--sparse", "--multiles", str(last_pipeline_num)]

cmd_list = [full_pipeline_cmd, last_pipeline_cmd]
if len(full_tile_pairs) == 0:
cmd_list = [last_pipeline_cmd]
# Dictionary grouping tile pairs by pipeline number
grouped_dict = defaultdict(list)
for tile_pair, pipeline_num in zip(data_tile_pairs, pipeline_num_l):
grouped_dict[pipeline_num].append(tile_pair)

# create cmd_list for each pipeline number
cmd_list = []
for pipeline_num, tile_pairs in grouped_dict.items():
# if list is longer than 64, split into batches of 64
tile_pair_batches = [tile_pairs[i:i + 64] for i in range(0, len(tile_pairs), 64)]
for tile_pair in tile_pair_batches:
cmd_list.append(["aha", "test"] + tile_pair + ["--sparse", "--multiles", str(pipeline_num)])

if testname not in test_dataset_runtime_dict:
test_dataset_runtime_dict[testname] = defaultdict(float)
Expand Down
2 changes: 1 addition & 1 deletion garnet
2 changes: 1 addition & 1 deletion sam

0 comments on commit 2a4e412

Please sign in to comment.