Skip to content

Commit

Permalink
Update mapper to get neighbors
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gilbert committed Oct 2, 2024
1 parent 4250681 commit 2944cf1
Showing 1 changed file with 26 additions and 6 deletions.
32 changes: 26 additions & 6 deletions pytimeloop/fastfusion/mapper/mapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from functools import partial

from ruamel.yaml import YAML
Expand Down Expand Up @@ -25,6 +26,8 @@ def mapper(config, spec, workload, name_of_einsum_to_eval, tmp_path):
workload.tensors_written_by_einsum(id_of_einsum_to_eval)
)

adj_list = get_neighbors(workload)

# Shape is given as *inclusive* (min, max) by workload
einsum_shape = {
rank_id: workload.get_rank_shape(rank_id)[1]+1 for rank_id in ranks
Expand Down Expand Up @@ -86,12 +89,12 @@ def partial_model(level, state, temporal_loops, spatial_loops, retained_tensors)
tensors,
fusable_tensors,
id_of_einsum_to_eval,
neighbors,
lower_mapper,
partial_model,
step_back_model,
max_spatial,
max_capacity)
adj_list[id_of_einsum_to_eval],
lower_mapper=cur_mapper,
partial_model=partial_model,
step_back_model=step_back_model,
max_spatial=max_spatial,
max_capacity=max_capacity)

cur_mapper.run(einsum_shape)

Expand All @@ -111,3 +114,20 @@ def get_hardware_levels(arch):
fanout[bindings_id] = (node.spatial.meshX, node.spatial.meshY)
return bindings, fanout


def get_neighbors(workload):
adj_list = defaultdict(lambda: list())
for einsum_u_id in workload.einsum_id_to_name():
for einsum_v_id in workload.einsum_id_to_name():
u_written_tensor = workload.tensor_written_by_einsum(einsum_u_id)
v_read_tensors = workload.tensors_read_by_einsum(einsum_v_id)
if u_written_tensor is not None and u_written_tensor in v_read_tensors:
adj_list[einsum_u_id].append(einsum_v_id)
adj_list[einsum_v_id].append(einsum_u_id)
continue
u_read_tensors = workload.tensors_read_by_einsum(einsum_u_id)
v_written_tensor = workload.tensor_written_by_einsum(einsum_v_id)
if v_written_tensor is not None and v_written_tensor in u_read_tensors:
adj_list[einsum_u_id].append(einsum_v_id)
adj_list[einsum_v_id].append(einsum_u_id)
return adj_list

0 comments on commit 2944cf1

Please sign in to comment.