forked from FEniCS/fiat
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'gem/master' into connorjward/add-finat-gem
- Loading branch information
Showing
14 changed files
with
4,598 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from gem.gem import * # noqa | ||
from gem.optimise import select_expression # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
"""This module contains an implementation of the COFFEE optimisation | ||
algorithm operating on a GEM representation. | ||
This file is NOT for code generation as a COFFEE AST. | ||
""" | ||
|
||
from collections import OrderedDict | ||
import itertools | ||
import logging | ||
|
||
import numpy | ||
|
||
from gem.gem import IndexSum, one | ||
from gem.optimise import make_sum, make_product | ||
from gem.refactorise import Monomial | ||
from gem.utils import groupby | ||
|
||
|
||
__all__ = ['optimise_monomial_sum'] | ||
|
||
|
||
def monomial_sum_to_expression(monomial_sum): | ||
"""Convert a monomial sum to a GEM expression. | ||
:arg monomial_sum: an iterable of :class:`Monomial`s | ||
:returns: GEM expression | ||
""" | ||
indexsums = [] # The result is summation of indexsums | ||
# Group monomials according to their sum indices | ||
groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) | ||
# Create IndexSum's from each monomial group | ||
for _, monomials in groups: | ||
sum_indices = monomials[0].sum_indices | ||
products = [make_product(monomial.atomics + (monomial.rest,)) for monomial in monomials] | ||
indexsums.append(IndexSum(make_sum(products), sum_indices)) | ||
return make_sum(indexsums) | ||
|
||
|
||
def index_extent(factor, linear_indices): | ||
"""Compute the product of the extents of linear indices of a GEM expression | ||
:arg factor: GEM expression | ||
:arg linear_indices: set of linear indices | ||
:returns: product of extents of linear indices | ||
""" | ||
return numpy.prod([i.extent for i in factor.free_indices if i in linear_indices]) | ||
|
||
|
||
def find_optimal_atomics(monomials, linear_indices): | ||
"""Find optimal atomic common subexpressions, which produce least number of | ||
terms in the resultant IndexSum when factorised. | ||
:arg monomials: A list of :class:`Monomial`s, all of which should have | ||
the same sum indices | ||
:arg linear_indices: tuple of linear indices | ||
:returns: list of atomic GEM expressions | ||
""" | ||
atomics = tuple(OrderedDict.fromkeys(itertools.chain(*(monomial.atomics for monomial in monomials)))) | ||
|
||
def cost(solution): | ||
extent = sum(map(lambda atomic: index_extent(atomic, linear_indices), solution)) | ||
# Prefer shorter solutions, but larger extents | ||
return (len(solution), -extent) | ||
|
||
optimal_solution = set(atomics) # pessimal but feasible solution | ||
solution = set() | ||
|
||
max_it = 1 << 12 | ||
it = iter(range(max_it)) | ||
|
||
def solve(idx): | ||
while idx < len(monomials) and solution.intersection(monomials[idx].atomics): | ||
idx += 1 | ||
|
||
if idx < len(monomials): | ||
if len(solution) < len(optimal_solution): | ||
for atomic in monomials[idx].atomics: | ||
solution.add(atomic) | ||
solve(idx + 1) | ||
solution.remove(atomic) | ||
else: | ||
if cost(solution) < cost(optimal_solution): | ||
optimal_solution.clear() | ||
optimal_solution.update(solution) | ||
next(it) | ||
|
||
try: | ||
solve(0) | ||
except StopIteration: | ||
logger = logging.getLogger('tsfc') | ||
logger.warning("Solution to ILP problem may not be optimal: search " | ||
"interrupted after examining %d solutions.", max_it) | ||
|
||
return tuple(atomic for atomic in atomics if atomic in optimal_solution) | ||
|
||
|
||
def factorise_atomics(monomials, optimal_atomics, linear_indices): | ||
"""Group and factorise monomials using a list of atomics as common | ||
subexpressions. Create new monomials for each group and optimise them recursively. | ||
:arg monomials: an iterable of :class:`Monomial`s, all of which should have | ||
the same sum indices | ||
:arg optimal_atomics: list of tuples of atomics to be used as common subexpression | ||
:arg linear_indices: tuple of linear indices | ||
:returns: an iterable of :class:`Monomials`s after factorisation | ||
""" | ||
if not optimal_atomics or len(monomials) <= 1: | ||
return monomials | ||
|
||
# Group monomials with respect to each optimal atomic | ||
def group_key(monomial): | ||
for oa in optimal_atomics: | ||
if oa in monomial.atomics: | ||
return oa | ||
assert False, "Expect at least one optimal atomic per monomial." | ||
factor_group = groupby(monomials, key=group_key) | ||
|
||
# We should not drop monomials | ||
assert sum(len(ms) for _, ms in factor_group) == len(monomials) | ||
|
||
sum_indices = next(iter(monomials)).sum_indices | ||
new_monomials = [] | ||
for oa, monomials in factor_group: | ||
# Create new MonomialSum for the factorised out terms | ||
sub_monomials = [] | ||
for monomial in monomials: | ||
atomics = list(monomial.atomics) | ||
atomics.remove(oa) # remove common factor | ||
sub_monomials.append(Monomial((), tuple(atomics), monomial.rest)) | ||
# Continue to factorise the remaining expression | ||
sub_monomials = optimise_monomials(sub_monomials, linear_indices) | ||
if len(sub_monomials) == 1: | ||
# Factorised part is a product, we add back the common atomics then | ||
# add to new MonomialSum directly rather than forming a product node | ||
# Retaining the monomial structure enables applying associativity | ||
# when forming GEM nodes later. | ||
sub_monomial, = sub_monomials | ||
new_monomials.append( | ||
Monomial(sum_indices, (oa,) + sub_monomial.atomics, sub_monomial.rest)) | ||
else: | ||
# Factorised part is a summation, we need to create a new GEM node | ||
# and multiply with the common factor | ||
node = monomial_sum_to_expression(sub_monomials) | ||
# If the free indices of the new node intersect with linear indices, | ||
# add to the new monomial as `atomic`, otherwise add as `rest`. | ||
# Note: we might want to continue to factorise with the new atomics | ||
# by running optimise_monoials twice. | ||
if set(linear_indices) & set(node.free_indices): | ||
new_monomials.append(Monomial(sum_indices, (oa, node), one)) | ||
else: | ||
new_monomials.append(Monomial(sum_indices, (oa, ), node)) | ||
return new_monomials | ||
|
||
|
||
def optimise_monomial_sum(monomial_sum, linear_indices): | ||
"""Choose optimal common atomic subexpressions and factorise a | ||
:class:`MonomialSum` object to create a GEM expression. | ||
:arg monomial_sum: a :class:`MonomialSum` object | ||
:arg linear_indices: tuple of linear indices | ||
:returns: factorised GEM expression | ||
""" | ||
groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) | ||
new_monomials = [] | ||
for _, monomials in groups: | ||
new_monomials.extend(optimise_monomials(monomials, linear_indices)) | ||
return monomial_sum_to_expression(new_monomials) | ||
|
||
|
||
def optimise_monomials(monomials, linear_indices): | ||
"""Choose optimal common atomic subexpressions and factorise an iterable | ||
of monomials. | ||
:arg monomials: a list of :class:`Monomial`s, all of which should have | ||
the same sum indices | ||
:arg linear_indices: tuple of linear indices | ||
:returns: an iterable of factorised :class:`Monomials`s | ||
""" | ||
assert len(set(frozenset(m.sum_indices) for m in monomials)) <= 1, \ | ||
"All monomials required to have same sum indices for factorisation" | ||
|
||
result = [m for m in monomials if not m.atomics] # skipped monomials | ||
active_monomials = [m for m in monomials if m.atomics] | ||
optimal_atomics = find_optimal_atomics(active_monomials, linear_indices) | ||
result += factorise_atomics(active_monomials, optimal_atomics, linear_indices) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
""" | ||
This file contains all the necessary functions to accurately count the | ||
total number of floating point operations for a given script. | ||
""" | ||
|
||
import gem.gem as gem | ||
import gem.impero as imp | ||
from functools import singledispatch | ||
import numpy | ||
import math | ||
|
||
|
||
@singledispatch | ||
def statement(tree, temporaries): | ||
raise NotImplementedError | ||
|
||
|
||
@statement.register(imp.Block) | ||
def statement_block(tree, temporaries): | ||
flops = sum(statement(child, temporaries) for child in tree.children) | ||
return flops | ||
|
||
|
||
@statement.register(imp.For) | ||
def statement_for(tree, temporaries): | ||
extent = tree.index.extent | ||
assert extent is not None | ||
child, = tree.children | ||
flops = statement(child, temporaries) | ||
return flops * extent | ||
|
||
|
||
@statement.register(imp.Initialise) | ||
def statement_initialise(tree, temporaries): | ||
return 0 | ||
|
||
|
||
@statement.register(imp.Accumulate) | ||
def statement_accumulate(tree, temporaries): | ||
flops = expression_flops(tree.indexsum.children[0], temporaries) | ||
return flops + 1 | ||
|
||
|
||
@statement.register(imp.Return) | ||
def statement_return(tree, temporaries): | ||
flops = expression_flops(tree.expression, temporaries) | ||
return flops + 1 | ||
|
||
|
||
@statement.register(imp.ReturnAccumulate) | ||
def statement_returnaccumulate(tree, temporaries): | ||
flops = expression_flops(tree.indexsum.children[0], temporaries) | ||
return flops + 1 | ||
|
||
|
||
@statement.register(imp.Evaluate) | ||
def statement_evaluate(tree, temporaries): | ||
flops = expression_flops(tree.expression, temporaries, top=True) | ||
return flops | ||
|
||
|
||
@singledispatch | ||
def flops(expr, temporaries): | ||
raise NotImplementedError(f"Don't know how to count flops of {type(expr)}") | ||
|
||
|
||
@flops.register(gem.Failure) | ||
def flops_failure(expr, temporaries): | ||
raise ValueError("Not expecting a Failure node") | ||
|
||
|
||
@flops.register(gem.Variable) | ||
@flops.register(gem.Identity) | ||
@flops.register(gem.Delta) | ||
@flops.register(gem.Zero) | ||
@flops.register(gem.Literal) | ||
@flops.register(gem.Index) | ||
@flops.register(gem.VariableIndex) | ||
def flops_zero(expr, temporaries): | ||
# Initial set up of these Gem nodes are of 0 floating point operations. | ||
return 0 | ||
|
||
|
||
@flops.register(gem.LogicalNot) | ||
@flops.register(gem.LogicalAnd) | ||
@flops.register(gem.LogicalOr) | ||
@flops.register(gem.ListTensor) | ||
def flops_zeroplus(expr, temporaries): | ||
# These nodes contribute 0 floating point operations, but their children may not. | ||
return 0 + sum(expression_flops(child, temporaries) | ||
for child in expr.children) | ||
|
||
|
||
@flops.register(gem.Product) | ||
def flops_product(expr, temporaries): | ||
# Multiplication by -1 is not a flop. | ||
a, b = expr.children | ||
if isinstance(a, gem.Literal) and a.value == -1: | ||
return expression_flops(b, temporaries) | ||
elif isinstance(b, gem.Literal) and b.value == -1: | ||
return expression_flops(a, temporaries) | ||
else: | ||
return 1 + sum(expression_flops(child, temporaries) | ||
for child in expr.children) | ||
|
||
|
||
@flops.register(gem.Sum) | ||
@flops.register(gem.Division) | ||
@flops.register(gem.Comparison) | ||
@flops.register(gem.MathFunction) | ||
@flops.register(gem.MinValue) | ||
@flops.register(gem.MaxValue) | ||
def flops_oneplus(expr, temporaries): | ||
return 1 + sum(expression_flops(child, temporaries) | ||
for child in expr.children) | ||
|
||
|
||
@flops.register(gem.Power) | ||
def flops_power(expr, temporaries): | ||
base, exponent = expr.children | ||
base_flops = expression_flops(base, temporaries) | ||
if isinstance(exponent, gem.Literal): | ||
exponent = exponent.value | ||
if exponent > 0 and exponent == math.floor(exponent): | ||
return base_flops + int(math.ceil(math.log2(exponent))) | ||
else: | ||
return base_flops + 5 # heuristic | ||
else: | ||
return base_flops + 5 # heuristic | ||
|
||
|
||
@flops.register(gem.Conditional) | ||
def flops_conditional(expr, temporaries): | ||
condition, then, else_ = (expression_flops(child, temporaries) | ||
for child in expr.children) | ||
return condition + max(then, else_) | ||
|
||
|
||
@flops.register(gem.Indexed) | ||
@flops.register(gem.FlexiblyIndexed) | ||
def flops_indexed(expr, temporaries): | ||
aggregate = sum(expression_flops(child, temporaries) | ||
for child in expr.children) | ||
# Average flops per entry | ||
return aggregate / numpy.prod(expr.children[0].shape, dtype=int) | ||
|
||
|
||
@flops.register(gem.IndexSum) | ||
def flops_indexsum(expr, temporaries): | ||
raise ValueError("Not expecting IndexSum") | ||
|
||
|
||
@flops.register(gem.Inverse) | ||
def flops_inverse(expr, temporaries): | ||
n, _ = expr.shape | ||
# 2n^3 + child flop count | ||
return 2*n**3 + sum(expression_flops(child, temporaries) | ||
for child in expr.children) | ||
|
||
|
||
@flops.register(gem.Solve) | ||
def flops_solve(expr, temporaries): | ||
n, m = expr.shape | ||
# 2mn + inversion cost of A + children flop count | ||
return 2*n*m + 2*n**3 + sum(expression_flops(child, temporaries) | ||
for child in expr.children) | ||
|
||
|
||
@flops.register(gem.ComponentTensor) | ||
def flops_componenttensor(expr, temporaries): | ||
raise ValueError("Not expecting ComponentTensor") | ||
|
||
|
||
def expression_flops(expression, temporaries, top=False): | ||
"""An approximation to flops required for each expression. | ||
:arg expression: GEM expression. | ||
:arg temporaries: Expressions that are assigned to temporaries | ||
:arg top: are we at the root? | ||
:returns: flop count for the expression | ||
""" | ||
if not top and expression in temporaries: | ||
return 0 | ||
else: | ||
return flops(expression, temporaries) | ||
|
||
|
||
def count_flops(impero_c): | ||
"""An approximation to flops required for a scheduled impero_c tree. | ||
:arg impero_c: a :class:`~.Impero_C` object. | ||
:returns: approximate flop count for the tree. | ||
""" | ||
try: | ||
return statement(impero_c.tree, set(impero_c.temporaries)) | ||
except (ValueError, NotImplementedError): | ||
return 0 |
Oops, something went wrong.