diff --git a/pytimeloop/looptree/accesses.py b/pytimeloop/looptree/accesses.py index dae895f..9482ab2 100644 --- a/pytimeloop/looptree/accesses.py +++ b/pytimeloop/looptree/accesses.py @@ -3,6 +3,7 @@ import islpy as isl from pytimeloop.isl.singular import get_sum_of_pw_qpolynomial +from pytimeloop.looptree.mapping_utilities import * def get_total_accesses(accesses: Mapping): @@ -27,9 +28,13 @@ def reads_and_writes_from_fill_by_parent(fills: Mapping, mapping, workload): parent_buffers = get_parent_buffers(mapping, workload) + einsums_with_complete_mappings = get_einsums_with_complete_mappings(mapping) + for (buffer_id, dspace_id, einsum_id), (tags, fill) in fills.items(): dspace_name = dspace_id_to_name[dspace_id] einsum_name = einsum_id_to_name[einsum_id] + if einsum_name not in einsums_with_complete_mappings: + continue parent_buffer = parent_buffers[(buffer_id, dspace_name, einsum_name)] if parent_buffer is not None: if dspace_id in workload.tensors_written_by_einsum(einsum_id): @@ -42,16 +47,22 @@ def reads_and_writes_from_fill_by_parent(fills: Mapping, mapping, workload): return reads, writes -def reads_and_writes_from_fill_by_peer(fills: Mapping, workload): +def reads_and_writes_from_fill_by_peer(fills: Mapping, mapping, workload): + mapping = mapping['nodes'] dspace_id_to_name = workload.data_space_id_to_name() einsum_id_to_name = workload.einsum_id_to_name() reads = {} writes = {} + einsums_with_complete_mappings = get_einsums_with_complete_mappings(mapping) + for (buffer_id, dspace_id, einsum_id), (tags, fill) in fills.items(): einsum_name = einsum_id_to_name[einsum_id] dspace_name = dspace_id_to_name[dspace_id] + if einsum_name not in einsums_with_complete_mappings: + continue + reads[(buffer_id, dspace_name, einsum_name)] = fill writes[(buffer_id, dspace_name, einsum_name)] = fill @@ -87,16 +98,4 @@ def get_parent_buffers(mapping, workload): if dspace in dspace_to_top_buffer: parent_buffers[key] = dspace_to_top_buffer[dspace] - return parent_buffers - - -def get_paths(mapping): - cur_path = [] - for node in mapping: - cur_path.append(node) - if node['type'] in ['pipeline', 'sequential']: - for child in node['branches']: - for subpath in get_paths(child): - yield cur_path + subpath - elif node['type'] == 'compute': - yield cur_path.copy() \ No newline at end of file + return parent_buffers \ No newline at end of file diff --git a/pytimeloop/looptree/energy.py b/pytimeloop/looptree/energy.py index b76d75e..68efeec 100644 --- a/pytimeloop/looptree/energy.py +++ b/pytimeloop/looptree/energy.py @@ -4,6 +4,7 @@ from pytimeloop.isl.singular import get_sum_of_pw_qpolynomial from pytimeloop.timeloopfe.v4.ert import Ert from pytimeloop.looptree.accesses import * +from pytimeloop.looptree.mapping_utilities import * def gather_actions(looptree_results, mapping, workload, bindings): @@ -16,6 +17,7 @@ def gather_actions(looptree_results, mapping, workload, bindings): peer_reads, peer_writes = reads_and_writes_from_fill_by_peer( looptree_results.fills_by_peer, + mapping, workload ) peer_reads = get_total_accesses(peer_reads) @@ -33,8 +35,14 @@ def gather_actions(looptree_results, mapping, workload, bindings): else: writes[k] = v - ops = sum(get_sum_of_pw_qpolynomial(v) - for (tags, v) in looptree_results.ops.values()).to_python() + einsums_with_complete_mapping = get_einsums_with_complete_mappings(mapping['nodes']) + einsum_id_to_name = workload.einsum_id_to_name() + + ops = sum( + get_sum_of_pw_qpolynomial(v) + for einsum_id, (tags, v) in looptree_results.ops.items() + if einsum_id_to_name[einsum_id] in einsums_with_complete_mapping + ).to_python() actions = {} for (buf, tensor, einsum), counts in reads.items(): diff --git a/pytimeloop/looptree/mapping_utilities.py b/pytimeloop/looptree/mapping_utilities.py new file mode 100644 index 0000000..5a8edab --- /dev/null +++ b/pytimeloop/looptree/mapping_utilities.py @@ -0,0 +1,29 @@ +def get_paths(mapping): + cur_path = [] + for node in mapping: + cur_path.append(node) + if node['type'] in ['pipeline', 'sequential']: + for child in node['branches']: + for subpath in get_paths(child): + yield cur_path + subpath + elif node['type'] == 'compute': + yield cur_path.copy() + + +def get_leaves(mapping): + for node in mapping: + if node['type'] in ['pipeline', 'sequential']: + for child in node['branches']: + yield from get_leaves(child) + elif node['type'] == 'compute': + yield node + + +def get_einsums_with_complete_mappings(mapping): + einsums_with_complete_mappings = set() + for compute_node in get_leaves(mapping): + if 'incomplete' not in compute_node: + einsums_with_complete_mappings.add(compute_node['einsum']) + if 'incomplete' in compute_node and not compute_node['incomplete']: + einsums_with_complete_mappings.add(compute_node['einsum']) + return einsums_with_complete_mappings