Skip to content

Commit

Permalink
Merge remote-tracking branch 'gem/master' into connorjward/add-finat-gem
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Dec 4, 2024
2 parents 358aab4 + f4c4866 commit 4746c81
Show file tree
Hide file tree
Showing 14 changed files with 4,598 additions and 0 deletions.
2 changes: 2 additions & 0 deletions gem/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from gem.gem import * # noqa
from gem.optimise import select_expression # noqa
192 changes: 192 additions & 0 deletions gem/coffee.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
"""This module contains an implementation of the COFFEE optimisation
algorithm operating on a GEM representation.
This file is NOT for code generation as a COFFEE AST.
"""

from collections import OrderedDict
import itertools
import logging

import numpy

from gem.gem import IndexSum, one
from gem.optimise import make_sum, make_product
from gem.refactorise import Monomial
from gem.utils import groupby


__all__ = ['optimise_monomial_sum']


def monomial_sum_to_expression(monomial_sum):
"""Convert a monomial sum to a GEM expression.
:arg monomial_sum: an iterable of :class:`Monomial`s
:returns: GEM expression
"""
indexsums = [] # The result is summation of indexsums
# Group monomials according to their sum indices
groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices))
# Create IndexSum's from each monomial group
for _, monomials in groups:
sum_indices = monomials[0].sum_indices
products = [make_product(monomial.atomics + (monomial.rest,)) for monomial in monomials]
indexsums.append(IndexSum(make_sum(products), sum_indices))
return make_sum(indexsums)


def index_extent(factor, linear_indices):
"""Compute the product of the extents of linear indices of a GEM expression
:arg factor: GEM expression
:arg linear_indices: set of linear indices
:returns: product of extents of linear indices
"""
return numpy.prod([i.extent for i in factor.free_indices if i in linear_indices])


def find_optimal_atomics(monomials, linear_indices):
"""Find optimal atomic common subexpressions, which produce least number of
terms in the resultant IndexSum when factorised.
:arg monomials: A list of :class:`Monomial`s, all of which should have
the same sum indices
:arg linear_indices: tuple of linear indices
:returns: list of atomic GEM expressions
"""
atomics = tuple(OrderedDict.fromkeys(itertools.chain(*(monomial.atomics for monomial in monomials))))

def cost(solution):
extent = sum(map(lambda atomic: index_extent(atomic, linear_indices), solution))
# Prefer shorter solutions, but larger extents
return (len(solution), -extent)

optimal_solution = set(atomics) # pessimal but feasible solution
solution = set()

max_it = 1 << 12
it = iter(range(max_it))

def solve(idx):
while idx < len(monomials) and solution.intersection(monomials[idx].atomics):
idx += 1

if idx < len(monomials):
if len(solution) < len(optimal_solution):
for atomic in monomials[idx].atomics:
solution.add(atomic)
solve(idx + 1)
solution.remove(atomic)
else:
if cost(solution) < cost(optimal_solution):
optimal_solution.clear()
optimal_solution.update(solution)
next(it)

try:
solve(0)
except StopIteration:
logger = logging.getLogger('tsfc')
logger.warning("Solution to ILP problem may not be optimal: search "
"interrupted after examining %d solutions.", max_it)

return tuple(atomic for atomic in atomics if atomic in optimal_solution)


def factorise_atomics(monomials, optimal_atomics, linear_indices):
"""Group and factorise monomials using a list of atomics as common
subexpressions. Create new monomials for each group and optimise them recursively.
:arg monomials: an iterable of :class:`Monomial`s, all of which should have
the same sum indices
:arg optimal_atomics: list of tuples of atomics to be used as common subexpression
:arg linear_indices: tuple of linear indices
:returns: an iterable of :class:`Monomials`s after factorisation
"""
if not optimal_atomics or len(monomials) <= 1:
return monomials

# Group monomials with respect to each optimal atomic
def group_key(monomial):
for oa in optimal_atomics:
if oa in monomial.atomics:
return oa
assert False, "Expect at least one optimal atomic per monomial."
factor_group = groupby(monomials, key=group_key)

# We should not drop monomials
assert sum(len(ms) for _, ms in factor_group) == len(monomials)

sum_indices = next(iter(monomials)).sum_indices
new_monomials = []
for oa, monomials in factor_group:
# Create new MonomialSum for the factorised out terms
sub_monomials = []
for monomial in monomials:
atomics = list(monomial.atomics)
atomics.remove(oa) # remove common factor
sub_monomials.append(Monomial((), tuple(atomics), monomial.rest))
# Continue to factorise the remaining expression
sub_monomials = optimise_monomials(sub_monomials, linear_indices)
if len(sub_monomials) == 1:
# Factorised part is a product, we add back the common atomics then
# add to new MonomialSum directly rather than forming a product node
# Retaining the monomial structure enables applying associativity
# when forming GEM nodes later.
sub_monomial, = sub_monomials
new_monomials.append(
Monomial(sum_indices, (oa,) + sub_monomial.atomics, sub_monomial.rest))
else:
# Factorised part is a summation, we need to create a new GEM node
# and multiply with the common factor
node = monomial_sum_to_expression(sub_monomials)
# If the free indices of the new node intersect with linear indices,
# add to the new monomial as `atomic`, otherwise add as `rest`.
# Note: we might want to continue to factorise with the new atomics
# by running optimise_monoials twice.
if set(linear_indices) & set(node.free_indices):
new_monomials.append(Monomial(sum_indices, (oa, node), one))
else:
new_monomials.append(Monomial(sum_indices, (oa, ), node))
return new_monomials


def optimise_monomial_sum(monomial_sum, linear_indices):
"""Choose optimal common atomic subexpressions and factorise a
:class:`MonomialSum` object to create a GEM expression.
:arg monomial_sum: a :class:`MonomialSum` object
:arg linear_indices: tuple of linear indices
:returns: factorised GEM expression
"""
groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices))
new_monomials = []
for _, monomials in groups:
new_monomials.extend(optimise_monomials(monomials, linear_indices))
return monomial_sum_to_expression(new_monomials)


def optimise_monomials(monomials, linear_indices):
"""Choose optimal common atomic subexpressions and factorise an iterable
of monomials.
:arg monomials: a list of :class:`Monomial`s, all of which should have
the same sum indices
:arg linear_indices: tuple of linear indices
:returns: an iterable of factorised :class:`Monomials`s
"""
assert len(set(frozenset(m.sum_indices) for m in monomials)) <= 1, \
"All monomials required to have same sum indices for factorisation"

result = [m for m in monomials if not m.atomics] # skipped monomials
active_monomials = [m for m in monomials if m.atomics]
optimal_atomics = find_optimal_atomics(active_monomials, linear_indices)
result += factorise_atomics(active_monomials, optimal_atomics, linear_indices)
return result
197 changes: 197 additions & 0 deletions gem/flop_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
This file contains all the necessary functions to accurately count the
total number of floating point operations for a given script.
"""

import gem.gem as gem
import gem.impero as imp
from functools import singledispatch
import numpy
import math


@singledispatch
def statement(tree, temporaries):
raise NotImplementedError


@statement.register(imp.Block)
def statement_block(tree, temporaries):
flops = sum(statement(child, temporaries) for child in tree.children)
return flops


@statement.register(imp.For)
def statement_for(tree, temporaries):
extent = tree.index.extent
assert extent is not None
child, = tree.children
flops = statement(child, temporaries)
return flops * extent


@statement.register(imp.Initialise)
def statement_initialise(tree, temporaries):
return 0


@statement.register(imp.Accumulate)
def statement_accumulate(tree, temporaries):
flops = expression_flops(tree.indexsum.children[0], temporaries)
return flops + 1


@statement.register(imp.Return)
def statement_return(tree, temporaries):
flops = expression_flops(tree.expression, temporaries)
return flops + 1


@statement.register(imp.ReturnAccumulate)
def statement_returnaccumulate(tree, temporaries):
flops = expression_flops(tree.indexsum.children[0], temporaries)
return flops + 1


@statement.register(imp.Evaluate)
def statement_evaluate(tree, temporaries):
flops = expression_flops(tree.expression, temporaries, top=True)
return flops


@singledispatch
def flops(expr, temporaries):
raise NotImplementedError(f"Don't know how to count flops of {type(expr)}")


@flops.register(gem.Failure)
def flops_failure(expr, temporaries):
raise ValueError("Not expecting a Failure node")


@flops.register(gem.Variable)
@flops.register(gem.Identity)
@flops.register(gem.Delta)
@flops.register(gem.Zero)
@flops.register(gem.Literal)
@flops.register(gem.Index)
@flops.register(gem.VariableIndex)
def flops_zero(expr, temporaries):
# Initial set up of these Gem nodes are of 0 floating point operations.
return 0


@flops.register(gem.LogicalNot)
@flops.register(gem.LogicalAnd)
@flops.register(gem.LogicalOr)
@flops.register(gem.ListTensor)
def flops_zeroplus(expr, temporaries):
# These nodes contribute 0 floating point operations, but their children may not.
return 0 + sum(expression_flops(child, temporaries)
for child in expr.children)


@flops.register(gem.Product)
def flops_product(expr, temporaries):
# Multiplication by -1 is not a flop.
a, b = expr.children
if isinstance(a, gem.Literal) and a.value == -1:
return expression_flops(b, temporaries)
elif isinstance(b, gem.Literal) and b.value == -1:
return expression_flops(a, temporaries)
else:
return 1 + sum(expression_flops(child, temporaries)
for child in expr.children)


@flops.register(gem.Sum)
@flops.register(gem.Division)
@flops.register(gem.Comparison)
@flops.register(gem.MathFunction)
@flops.register(gem.MinValue)
@flops.register(gem.MaxValue)
def flops_oneplus(expr, temporaries):
return 1 + sum(expression_flops(child, temporaries)
for child in expr.children)


@flops.register(gem.Power)
def flops_power(expr, temporaries):
base, exponent = expr.children
base_flops = expression_flops(base, temporaries)
if isinstance(exponent, gem.Literal):
exponent = exponent.value
if exponent > 0 and exponent == math.floor(exponent):
return base_flops + int(math.ceil(math.log2(exponent)))
else:
return base_flops + 5 # heuristic
else:
return base_flops + 5 # heuristic


@flops.register(gem.Conditional)
def flops_conditional(expr, temporaries):
condition, then, else_ = (expression_flops(child, temporaries)
for child in expr.children)
return condition + max(then, else_)


@flops.register(gem.Indexed)
@flops.register(gem.FlexiblyIndexed)
def flops_indexed(expr, temporaries):
aggregate = sum(expression_flops(child, temporaries)
for child in expr.children)
# Average flops per entry
return aggregate / numpy.prod(expr.children[0].shape, dtype=int)


@flops.register(gem.IndexSum)
def flops_indexsum(expr, temporaries):
raise ValueError("Not expecting IndexSum")


@flops.register(gem.Inverse)
def flops_inverse(expr, temporaries):
n, _ = expr.shape
# 2n^3 + child flop count
return 2*n**3 + sum(expression_flops(child, temporaries)
for child in expr.children)


@flops.register(gem.Solve)
def flops_solve(expr, temporaries):
n, m = expr.shape
# 2mn + inversion cost of A + children flop count
return 2*n*m + 2*n**3 + sum(expression_flops(child, temporaries)
for child in expr.children)


@flops.register(gem.ComponentTensor)
def flops_componenttensor(expr, temporaries):
raise ValueError("Not expecting ComponentTensor")


def expression_flops(expression, temporaries, top=False):
"""An approximation to flops required for each expression.
:arg expression: GEM expression.
:arg temporaries: Expressions that are assigned to temporaries
:arg top: are we at the root?
:returns: flop count for the expression
"""
if not top and expression in temporaries:
return 0
else:
return flops(expression, temporaries)


def count_flops(impero_c):
"""An approximation to flops required for a scheduled impero_c tree.
:arg impero_c: a :class:`~.Impero_C` object.
:returns: approximate flop count for the tree.
"""
try:
return statement(impero_c.tree, set(impero_c.temporaries))
except (ValueError, NotImplementedError):
return 0
Loading

0 comments on commit 4746c81

Please sign in to comment.