Skip to content

Commit

Permalink
Merge branch 'main' of github.com:Accelergy-Project/timeloop-python
Browse files Browse the repository at this point in the history
  • Loading branch information
tanner-andrulis committed Sep 19, 2024
2 parents 8c87b4a + ea1da56 commit 0e09de5
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 16 deletions.
1 change: 1 addition & 0 deletions bindings/looptree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace pytimeloop::looptree_bindings
.FUSED_WORKLOAD_METHOD(tensors_written_by_einsum, TensorsWrittenByEinsum)
.FUSED_WORKLOAD_METHOD(reader_einsums, ReaderEinsums)
.FUSED_WORKLOAD_METHOD(writer_einsum, WriterEinsum)
.FUSED_WORKLOAD_METHOD(get_rank_shape, GetRankShape)
.def_static("parse_cfg", &problem::ParseFusedWorkload);

py::class_<problem::FusedWorkloadDependencyAnalyzer>(m, "LooptreeWorkloadDependencyAnalyzer")
Expand Down
27 changes: 13 additions & 14 deletions pytimeloop/looptree/accesses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import islpy as isl

from pytimeloop.isl.singular import get_sum_of_pw_qpolynomial
from pytimeloop.looptree.mapping_utilities import *


def get_total_accesses(accesses: Mapping):
Expand All @@ -27,9 +28,13 @@ def reads_and_writes_from_fill_by_parent(fills: Mapping, mapping, workload):

parent_buffers = get_parent_buffers(mapping, workload)

einsums_with_complete_mappings = get_einsums_with_complete_mappings(mapping)

for (buffer_id, dspace_id, einsum_id), (tags, fill) in fills.items():
dspace_name = dspace_id_to_name[dspace_id]
einsum_name = einsum_id_to_name[einsum_id]
if einsum_name not in einsums_with_complete_mappings:
continue
parent_buffer = parent_buffers[(buffer_id, dspace_name, einsum_name)]
if parent_buffer is not None:
if dspace_id in workload.tensors_written_by_einsum(einsum_id):
Expand All @@ -42,16 +47,22 @@ def reads_and_writes_from_fill_by_parent(fills: Mapping, mapping, workload):
return reads, writes


def reads_and_writes_from_fill_by_peer(fills: Mapping, workload):
def reads_and_writes_from_fill_by_peer(fills: Mapping, mapping, workload):
mapping = mapping['nodes']
dspace_id_to_name = workload.data_space_id_to_name()
einsum_id_to_name = workload.einsum_id_to_name()

reads = {}
writes = {}

einsums_with_complete_mappings = get_einsums_with_complete_mappings(mapping)

for (buffer_id, dspace_id, einsum_id), (tags, fill) in fills.items():
einsum_name = einsum_id_to_name[einsum_id]
dspace_name = dspace_id_to_name[dspace_id]
if einsum_name not in einsums_with_complete_mappings:
continue

reads[(buffer_id, dspace_name, einsum_name)] = fill
writes[(buffer_id, dspace_name, einsum_name)] = fill

Expand Down Expand Up @@ -87,16 +98,4 @@ def get_parent_buffers(mapping, workload):
if dspace in dspace_to_top_buffer:
parent_buffers[key] = dspace_to_top_buffer[dspace]

return parent_buffers


def get_paths(mapping):
cur_path = []
for node in mapping:
cur_path.append(node)
if node['type'] in ['pipeline', 'sequential']:
for child in node['branches']:
for subpath in get_paths(child):
yield cur_path + subpath
elif node['type'] == 'compute':
yield cur_path.copy()
return parent_buffers
12 changes: 10 additions & 2 deletions pytimeloop/looptree/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pytimeloop.isl.singular import get_sum_of_pw_qpolynomial
from pytimeloop.timeloopfe.v4.ert import Ert
from pytimeloop.looptree.accesses import *
from pytimeloop.looptree.mapping_utilities import *


def gather_actions(looptree_results, mapping, workload, bindings):
Expand All @@ -16,6 +17,7 @@ def gather_actions(looptree_results, mapping, workload, bindings):

peer_reads, peer_writes = reads_and_writes_from_fill_by_peer(
looptree_results.fills_by_peer,
mapping,
workload
)
peer_reads = get_total_accesses(peer_reads)
Expand All @@ -33,8 +35,14 @@ def gather_actions(looptree_results, mapping, workload, bindings):
else:
writes[k] = v

ops = sum(get_sum_of_pw_qpolynomial(v)
for (tags, v) in looptree_results.ops.values()).to_python()
einsums_with_complete_mapping = get_einsums_with_complete_mappings(mapping['nodes'])
einsum_id_to_name = workload.einsum_id_to_name()

ops = sum(
get_sum_of_pw_qpolynomial(v)
for einsum_id, (tags, v) in looptree_results.ops.items()
if einsum_id_to_name[einsum_id] in einsums_with_complete_mapping
).to_python()

actions = {}
for (buf, tensor, einsum), counts in reads.items():
Expand Down
29 changes: 29 additions & 0 deletions pytimeloop/looptree/mapping_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
def get_paths(mapping):
cur_path = []
for node in mapping:
cur_path.append(node)
if node['type'] in ['pipeline', 'sequential']:
for child in node['branches']:
for subpath in get_paths(child):
yield cur_path + subpath
elif node['type'] == 'compute':
yield cur_path.copy()


def get_leaves(mapping):
for node in mapping:
if node['type'] in ['pipeline', 'sequential']:
for child in node['branches']:
yield from get_leaves(child)
elif node['type'] == 'compute':
yield node


def get_einsums_with_complete_mappings(mapping):
einsums_with_complete_mappings = set()
for compute_node in get_leaves(mapping):
if 'incomplete' not in compute_node:
einsums_with_complete_mappings.add(compute_node['einsum'])
if 'incomplete' in compute_node and not compute_node['incomplete']:
einsums_with_complete_mappings.add(compute_node['einsum'])
return einsums_with_complete_mappings
7 changes: 7 additions & 0 deletions tests/looptree/test_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def test_dspace_name_to_id(self):
id_to_name = self._workload.data_space_id_to_name()
self.assert_maps_are_inverted_equivalent(name_to_id, id_to_name)

def test_rank_shape(self):
name_to_id = self._workload.dimension_name_to_id()
rank_shape = self._workload.get_rank_shape(name_to_id['P1'])
self.assertEqual((0, 8), rank_shape)
rank_shape = self._workload.get_rank_shape(name_to_id['M2'])
self.assertEqual((0, 7), rank_shape)

def assert_maps_are_inverted_equivalent(self, dict1, dict2):
for key1, value1 in dict1.items():
self.assertEqual(key1, dict2[value1])
Expand Down

0 comments on commit 0e09de5

Please sign in to comment.