-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcompute-complexity.py
125 lines (91 loc) · 4.14 KB
/
compute-complexity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import form
import firedrake
import tsfc
from gem import gem
from gem import impero
from functools import reduce, singledispatch
import sympy
from firedrake import COMM_WORLD, ExtrudedMesh, UnitCubeMesh, UnitSquareMesh, assemble
def expression(expr, temporaries, indices, top=False):
"""Walks an impero tree and computes the complexity polynomial
:arg expr: Impero expression
:arg temporaries: subexpressions for which temporaries exist.
:arg indices: dictionary mapping gem indices to names.
:arg top: ignore the temporary for the root node.
:returns: Sympy polynomial.
"""
if not top and expr in temporaries:
return 0
else:
return _expression(expr, temporaries, indices)
@singledispatch
def _expression(expr, temporaries, indices):
raise AssertionError("cannot compute complexity polynomial for %s" % type(expr))
@_expression.register(gem.Product)
@_expression.register(gem.Sum)
@_expression.register(gem.Division)
def _expression_binary(expr, temporaries, indices):
return expression(expr.children[0], temporaries, indices) + expression(expr.children[1], temporaries, indices) + 1
@_expression.register(gem.Constant)
@_expression.register(gem.Variable)
@_expression.register(impero.Initialise)
def _expression_noop(expr, temporaries, indices):
return 0
@_expression.register(gem.Indexed)
@_expression.register(gem.FlexiblyIndexed)
def _expression_indexed(expr, temporaries, indices):
return expression(expr.children[0], temporaries, indices)
@_expression.register(impero.Block)
def _expression_block(expr, temporaries, indices):
return reduce(lambda x, y: x+y, map(lambda e: expression(e, temporaries, indices), expr.children))
@_expression.register(impero.Evaluate)
def _expression_evaluate(expr, temporaries, indices):
return expression(expr.expression, temporaries, indices, top=True)
@_expression.register(impero.For)
def _expression_for(expr, temporaries, indices):
return indices[expr.index] * expression(expr.children[0], temporaries, indices)
@_expression.register(impero.Accumulate)
@_expression.register(impero.ReturnAccumulate)
def _expression_accumulate(expr, temporaries, indices):
return expression(expr.indexsum.children[0], temporaries, indices)
@_expression.register(gem.MathFunction)
def _expression_block(expr, temporaries, indices):
if expr.name == 'abs':
return expression(expr.children[0], temporaries, indices) + 1
raise AssertionError("cannot compute complexity polynomial for %s" % type(expr))
class IndexDict(dict):
"""Index dictionary which invents index names as they are required."""
def __init__(self, *args):
super().__init__(*args)
self.count = 0
def __getitem__(self, key):
try:
return super().__getitem__(key)
except KeyError:
self[key] = sympy.symbols("i%d" % self.count)
self.count += 1
return self[key]
def complexity(form, parameters, action=False):
if action:
coef = firedrake.Function(form.arguments()[0].function_space())
form = firedrake.action(form, coef)
impero_kernel, index_names = tsfc.driver.compile_form(form, parameters=parameters)[0]
indices = IndexDict({idx: sympy.symbols(name) for idx, name in index_names})
expr = expression(impero_kernel.tree, impero_kernel.temporaries, indices, top=True)
p1 = sympy.symbols("p") + 1
'''Currently assume p+1 quad points in each direction.'''
return expr.subs([(i, p1) for i in indices.values()]).expand()
m = ExtrudedMesh(UnitSquareMesh(2, 2, quadrilateral=True), 2)
mass = form.mass(m, 6)
poisson = form.poisson(m, 6)
hyperelasticity = form.hyperelasticity(m, 6)
curl_curl = form.curl_curl(m, 6)
parameters = firedrake.parameters['form_compiler'].copy()
parameters['return_impero'] = True
parameters['mode'] = 'spectral'
for mode, action in (("assembly", False), ("action", True)):
print(mode)
print(" mass: ", complexity(mass, parameters, action))
print(" laplacian: ",complexity(poisson, parameters, action))
print(" hyperelasticity:", complexity(hyperelasticity, parameters, action))
print(" curl_curl:", complexity(curl_curl, parameters, action))