Skip to content

Commit

Permalink
super constructor for ExpansionSet
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Oct 29, 2023
1 parent f48ea4f commit daaab1e
Showing 1 changed file with 17 additions and 27 deletions.
44 changes: 17 additions & 27 deletions FIAT/expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,21 @@ def __new__(cls, ref_el, *args, **kwargs):
else:
raise Exception("Unknown reference element type.")

def _tabulate_duffy(self, degree, pts):
raise NotImplementedError
def __init__(self, ref_el):
self.ref_el = ref_el
dim = ref_el.get_spatial_dimension()
self.base_ref_el = reference_element.default_simplex(dim)
v1 = ref_el.get_vertices()
v2 = self.base_ref_el.get_vertices()
self.A, self.b = reference_element.make_affine_mapping(v1, v2)
self.mapping = lambda x: numpy.dot(self.A, x) + self.b
self.scale = numpy.sqrt(numpy.linalg.det(self.A))
self._dmats_cache = {}

def _tabulate_duffy(self, n, pts):
raise NotImplementedError()

def make_dmats(self, degree):
if not hasattr(self, "_dmats_cache"):
self._dmats_cache = {}
cache = self._dmats_cache
key = degree
try:
Expand All @@ -259,8 +268,7 @@ class PointExpansionSet(ExpansionSet):
def __init__(self, ref_el):
if ref_el.get_spatial_dimension() != 0:
raise ValueError("Must have a point")
self.ref_el = ref_el
self.base_ref_el = reference_element.Point()
super(PointExpansionSet, self).__init__(ref_el)

def get_num_members(self, n):
return 1
Expand All @@ -286,13 +294,7 @@ class LineExpansionSet(ExpansionSet):
def __init__(self, ref_el):
if ref_el.get_spatial_dimension() != 1:
raise Exception("Must have a line")
self.ref_el = ref_el
self.base_ref_el = reference_element.DefaultLine()
v1 = ref_el.get_vertices()
v2 = self.base_ref_el.get_vertices()
self.A, self.b = reference_element.make_affine_mapping(v1, v2)
self.mapping = lambda x: numpy.dot(self.A, x) + self.b
self.scale = numpy.sqrt(numpy.linalg.det(self.A))
super(LineExpansionSet, self).__init__(ref_el)

def get_num_members(self, n):
return n + 1
Expand Down Expand Up @@ -351,13 +353,7 @@ class TriangleExpansionSet(ExpansionSet):
def __init__(self, ref_el):
if ref_el.get_spatial_dimension() != 2:
raise Exception("Must have a triangle")
self.ref_el = ref_el
self.base_ref_el = reference_element.DefaultTriangle()
v1 = ref_el.get_vertices()
v2 = self.base_ref_el.get_vertices()
self.A, self.b = reference_element.make_affine_mapping(v1, v2)
self.mapping = lambda x: numpy.dot(self.A, x) + self.b
# self.scale = numpy.sqrt(numpy.linalg.det(self.A))
super(TriangleExpansionSet, self).__init__(ref_el)

def get_num_members(self, n):
return (n + 1) * (n + 2) // 2
Expand Down Expand Up @@ -462,13 +458,7 @@ class TetrahedronExpansionSet(ExpansionSet):
def __init__(self, ref_el):
if ref_el.get_spatial_dimension() != 3:
raise Exception("Must be a tetrahedron")
self.ref_el = ref_el
self.base_ref_el = reference_element.DefaultTetrahedron()
v1 = ref_el.get_vertices()
v2 = self.base_ref_el.get_vertices()
self.A, self.b = reference_element.make_affine_mapping(v1, v2)
self.mapping = lambda x: numpy.dot(self.A, x) + self.b
self.scale = numpy.sqrt(numpy.linalg.det(self.A))
super(TetrahedronExpansionSet, self).__init__(ref_el)

def get_num_members(self, n):
return (n + 1) * (n + 2) * (n + 3) // 6
Expand Down

0 comments on commit daaab1e

Please sign in to comment.