Skip to content

Commit

Permalink
Better partial evaluation support
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gilbert committed Sep 19, 2024
1 parent b5c679f commit ea1da56
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
27 changes: 13 additions & 14 deletions pytimeloop/looptree/accesses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import islpy as isl

from pytimeloop.isl.singular import get_sum_of_pw_qpolynomial
from pytimeloop.looptree.mapping_utilities import *


def get_total_accesses(accesses: Mapping):
Expand All @@ -27,9 +28,13 @@ def reads_and_writes_from_fill_by_parent(fills: Mapping, mapping, workload):

parent_buffers = get_parent_buffers(mapping, workload)

einsums_with_complete_mappings = get_einsums_with_complete_mappings(mapping)

for (buffer_id, dspace_id, einsum_id), (tags, fill) in fills.items():
dspace_name = dspace_id_to_name[dspace_id]
einsum_name = einsum_id_to_name[einsum_id]
if einsum_name not in einsums_with_complete_mappings:
continue
parent_buffer = parent_buffers[(buffer_id, dspace_name, einsum_name)]
if parent_buffer is not None:
if dspace_id in workload.tensors_written_by_einsum(einsum_id):
Expand All @@ -42,16 +47,22 @@ def reads_and_writes_from_fill_by_parent(fills: Mapping, mapping, workload):
return reads, writes


def reads_and_writes_from_fill_by_peer(fills: Mapping, workload):
def reads_and_writes_from_fill_by_peer(fills: Mapping, mapping, workload):
mapping = mapping['nodes']
dspace_id_to_name = workload.data_space_id_to_name()
einsum_id_to_name = workload.einsum_id_to_name()

reads = {}
writes = {}

einsums_with_complete_mappings = get_einsums_with_complete_mappings(mapping)

for (buffer_id, dspace_id, einsum_id), (tags, fill) in fills.items():
einsum_name = einsum_id_to_name[einsum_id]
dspace_name = dspace_id_to_name[dspace_id]
if einsum_name not in einsums_with_complete_mappings:
continue

reads[(buffer_id, dspace_name, einsum_name)] = fill
writes[(buffer_id, dspace_name, einsum_name)] = fill

Expand Down Expand Up @@ -87,16 +98,4 @@ def get_parent_buffers(mapping, workload):
if dspace in dspace_to_top_buffer:
parent_buffers[key] = dspace_to_top_buffer[dspace]

return parent_buffers


def get_paths(mapping):
cur_path = []
for node in mapping:
cur_path.append(node)
if node['type'] in ['pipeline', 'sequential']:
for child in node['branches']:
for subpath in get_paths(child):
yield cur_path + subpath
elif node['type'] == 'compute':
yield cur_path.copy()
return parent_buffers
12 changes: 10 additions & 2 deletions pytimeloop/looptree/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pytimeloop.isl.singular import get_sum_of_pw_qpolynomial
from pytimeloop.timeloopfe.v4.ert import Ert
from pytimeloop.looptree.accesses import *
from pytimeloop.looptree.mapping_utilities import *


def gather_actions(looptree_results, mapping, workload, bindings):
Expand All @@ -16,6 +17,7 @@ def gather_actions(looptree_results, mapping, workload, bindings):

peer_reads, peer_writes = reads_and_writes_from_fill_by_peer(
looptree_results.fills_by_peer,
mapping,
workload
)
peer_reads = get_total_accesses(peer_reads)
Expand All @@ -33,8 +35,14 @@ def gather_actions(looptree_results, mapping, workload, bindings):
else:
writes[k] = v

ops = sum(get_sum_of_pw_qpolynomial(v)
for (tags, v) in looptree_results.ops.values()).to_python()
einsums_with_complete_mapping = get_einsums_with_complete_mappings(mapping['nodes'])
einsum_id_to_name = workload.einsum_id_to_name()

ops = sum(
get_sum_of_pw_qpolynomial(v)
for einsum_id, (tags, v) in looptree_results.ops.items()
if einsum_id_to_name[einsum_id] in einsums_with_complete_mapping
).to_python()

actions = {}
for (buf, tensor, einsum), counts in reads.items():
Expand Down
29 changes: 29 additions & 0 deletions pytimeloop/looptree/mapping_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
def get_paths(mapping):
cur_path = []
for node in mapping:
cur_path.append(node)
if node['type'] in ['pipeline', 'sequential']:
for child in node['branches']:
for subpath in get_paths(child):
yield cur_path + subpath
elif node['type'] == 'compute':
yield cur_path.copy()


def get_leaves(mapping):
for node in mapping:
if node['type'] in ['pipeline', 'sequential']:
for child in node['branches']:
yield from get_leaves(child)
elif node['type'] == 'compute':
yield node


def get_einsums_with_complete_mappings(mapping):
einsums_with_complete_mappings = set()
for compute_node in get_leaves(mapping):
if 'incomplete' not in compute_node:
einsums_with_complete_mappings.add(compute_node['einsum'])
if 'incomplete' in compute_node and not compute_node['incomplete']:
einsums_with_complete_mappings.add(compute_node['einsum'])
return einsums_with_complete_mappings

0 comments on commit ea1da56

Please sign in to comment.