Skip to content

Commit

Permalink
Explore more unfused choices
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gilbert committed Nov 13, 2024
1 parent d9a0655 commit 9493b7a
Showing 1 changed file with 41 additions and 31 deletions.
72 changes: 41 additions & 31 deletions pytimeloop/fastfusion/mapper/per_einsum_mapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from itertools import product, permutations
from itertools import combinations, product, permutations
from functools import reduce
from operator import or_, mul

Expand Down Expand Up @@ -249,32 +249,48 @@ def get_top_loop_jobs(
or_, (tensor_to_relevant_ranks[t] for t in intermediate_tensors), set()
)
logfunc(f"Allowed top-level loop ranks: {top_level_ranks}")
for partial_mapping in make_top_loops(mapping,
top_level_ranks,
logfunc):
for partial_mapping in place_fusion_level(
partial_mapping,
intermediate_tensors,
tensor_to_relevant_ranks,
explore_glb_uneven,
logfunc=logfunc
):
args.append(dict(
config=config,
pe_array_constraint=pe_array_constraint,
mac_array_constraint=mac_array_constraint,
spec=spec,
explore_glb_uneven=explore_glb_uneven,
explore_pe_uneven=explore_pe_uneven,
einsum_id=einsum_id,
energy_dict=energy_dict,
partial_mapping=partial_mapping,
log_queue=log_queue,
verbose_stream=verbose_stream,
))
for partial_mapping in explore_fused_unfused(mapping,
intermediate_tensors):
for partial_mapping in make_top_loops(mapping,
top_level_ranks,
logfunc):
for partial_mapping in place_glb_level(
partial_mapping,
intermediate_tensors,
tensor_to_relevant_ranks,
explore_glb_uneven,
logfunc=logfunc
):
args.append(dict(
config=config,
pe_array_constraint=pe_array_constraint,
mac_array_constraint=mac_array_constraint,
spec=spec,
explore_glb_uneven=explore_glb_uneven,
explore_pe_uneven=explore_pe_uneven,
einsum_id=einsum_id,
energy_dict=energy_dict,
partial_mapping=partial_mapping,
log_queue=log_queue,
verbose_stream=verbose_stream,
))
return args


def explore_fused_unfused(mapping: LinearMapping,
intermediate_tensors):
original = mapping
for r in range(len(intermediate_tensors)+1):
for unfused_tensors in combinations(intermediate_tensors, r):
mapping = original.copy()
if len(unfused_tensors) > 0:
mapping.add_storage(0, set(unfused_tensors))
mapping.add_sequential()
yield mapping
else:
yield mapping


def make_top_loops(mapping: LinearMapping, ranks, logfunc):
original = mapping
for r in range(len(ranks) + 1):
Expand All @@ -286,7 +302,7 @@ def make_top_loops(mapping: LinearMapping, ranks, logfunc):
yield mapping


def place_fusion_level(
def place_glb_level(
mapping: LinearMapping,
intermediate_tensors,
tensor_to_relevant_ranks,
Expand All @@ -300,10 +316,8 @@ def place_fusion_level(
relevant_ranks = tensor_to_relevant_ranks[tensor_id]
tensor_choices = []
last_is_relevant = True
untiled = True
for i, node in enumerate(mapping):
if node["type"] == "temporal":
untiled = False
rank_id = node["rank"]
is_relevant = rank_id in relevant_ranks
if last_is_relevant and not is_relevant:
Expand All @@ -316,10 +330,6 @@ def place_fusion_level(
if len(tensor_choices) == 0:
tensor_choices.append((len(mapping), 1))

# If untiled, another choice: unfused
if untiled:
tensor_choices.append((len(mapping), 0))

all_tensor_choices.append(tensor_choices)

original = mapping.copy()
Expand Down

0 comments on commit 9493b7a

Please sign in to comment.