Skip to content

Commit

Permalink
New mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gilbert committed Oct 24, 2024
1 parent 13a4e64 commit 55a613e
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 121 deletions.
46 changes: 35 additions & 11 deletions pytimeloop/fastfusion/fastmodel/fastmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ def compile_mapping(mapping,
tensor_name_to_id = workload.data_space_name_to_id()

einsum_name = mapping[-1]['einsum']
einsum_id = einsum_name_to_id[einsum_name]
if isinstance(einsum_name, int):
einsum_id = einsum_name
else:
einsum_id = einsum_name_to_id[einsum_name]

tensors = (
workload.tensors_read_by_einsum(einsum_id)
Expand Down Expand Up @@ -63,12 +66,18 @@ def compile_mapping(mapping,
for node in mapping:
if node['type'] == 'temporal':
rank_name = node['rank']
rank_id = rank_name_to_id[rank_name]
if isinstance(rank_name, int):
rank_id = rank_name
else:
rank_id = rank_name_to_id[rank_name]
group_id = rank_groups.rank_to_group_id[rank_id]

tile_shape = sympy.symbols(f'tileshape{len(tile_shapes)}')
tile_shapes.append(tile_shape)
factor = einsum_shape[group_id] // tile_shape
if 'tile_shape' not in node:
tile_shape = sympy.symbols(f'tileshape{len(tile_shapes)}')
tile_shapes.append(tile_shape)
else:
tile_shape = node['tile_shape']
factor = einsum_shape[group_id] / tile_shape
einsum_shape[group_id] = tile_shape

latency *= factor
Expand All @@ -78,7 +87,7 @@ def compile_mapping(mapping,
if group_id in relevant_ranks:
actual_tensor_access_multiplier[tensor_id] = \
potential_tensor_access_multiplier[tensor_id]
tensor_size[tensor_id] //= factor
tensor_size[tensor_id] /= factor
else:
potential_tensor_access_multiplier[tensor_id] *= factor
elif node['type'] == 'sequential':
Expand All @@ -87,13 +96,24 @@ def compile_mapping(mapping,
potential_tensor_access_multiplier[tensor_id]
elif node['type'] == 'spatial':
rank_name = node['rank']
rank_id = rank_name_to_id[rank_name]
if isinstance(rank_name, int):
rank_id = rank_name
else:
rank_id = rank_name_to_id[rank_name]
group_id = rank_groups.rank_to_group_id[rank_id]

tile_shape = sympy.symbols(f'tileshape{len(tile_shapes)}')
tile_shapes.append(tile_shape)
factor = einsum_shape[group_id] // tile_shape
if 'tile_shape' not in node:
tile_shape = sympy.symbols(f'tileshape{len(tile_shapes)}')
tile_shapes.append(tile_shape)
else:
tile_shape = node['tile_shape']
factor = einsum_shape[group_id] / tile_shape
einsum_shape[group_id] = tile_shape

for tensor_id in tensors:
relevant_ranks = tensor_to_relevant_ranks[tensor_id]
if group_id in relevant_ranks:
tensor_size[tensor_id] /= factor

if 'spatial' not in node:
spatial = 0
Expand All @@ -107,7 +127,10 @@ def compile_mapping(mapping,
target = node['target']
tensor_names = node['dspace']
for tensor_name in tensor_names:
tensor_id = tensor_name_to_id[tensor_name]
if isinstance(tensor_name, int):
tensor_id = tensor_name
else:
tensor_id = tensor_name_to_id[tensor_name]
if tensor_id not in tensors:
continue

Expand Down Expand Up @@ -166,5 +189,6 @@ def lambdify(d):
output.temporal_steps = lambdify(output.temporal_steps)
output.fanout = lambdify(output.fanout)
output.occupancy = lambdify(output.occupancy)
output.fills_by_parent = lambdify(output.fills_by_parent)

return tile_shapes, output
1 change: 0 additions & 1 deletion pytimeloop/fastfusion/mapper/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def get_neighbors(workload):


def get_intermediate_tensors(workload: LooptreeWorkload):
tensor_id_to_name = workload.data_space_id_to_name()
result = set()
for einsum in workload.einsum_id_to_name():
written_tensors = workload.tensors_written_by_einsum(einsum)
Expand Down
Loading

0 comments on commit 55a613e

Please sign in to comment.