Skip to content

Commit

Permalink
Introduce MappedPointSet
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 10, 2025
1 parent 7f2dd72 commit 3f871a6
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 22 deletions.
2 changes: 1 addition & 1 deletion finat/fiat_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
'''
space_dimension = self._element.space_dimension()
value_size = np.prod(self._element.value_shape(), dtype=int)
fiat_result = self._element.tabulate(order, ps.points.reshape(-1, ps.points.shape[-1]), entity)
fiat_result = self._element.tabulate(order, ps.points, entity)
result = {}
# In almost all cases, we have
# self.space_dimension() == self._element.space_dimension()
Expand Down
59 changes: 55 additions & 4 deletions finat/point_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import gem
from gem.utils import cached_property
from FIAT.reference_element import make_affine_mapping


class AbstractPointSet(metaclass=ABCMeta):
Expand Down Expand Up @@ -111,8 +112,7 @@ def points(self):

@cached_property
def indices(self):
N, _ = self._points_expr.shape
return (gem.Index(extent=N),)
return tuple(gem.Index(extent=N) for N in self._points_expr.shape[:-1])

@cached_property
def expression(self):
Expand All @@ -129,7 +129,7 @@ def __init__(self, points):
:arg points: A vector of N points of shape (N, D) where D is the
dimension of each point."""
points = numpy.asarray(points)
assert len(points.shape) > 1
assert len(points.shape) == 2
self.points = points

@cached_property
Expand All @@ -138,7 +138,7 @@ def points(self):

@cached_property
def indices(self):
return tuple(gem.Index(extent=e) for e in self.points.shape[:-1])
return tuple(gem.Index(extent=N) for N in self.points.shape[:-1])

@cached_property
def expression(self):
Expand Down Expand Up @@ -200,3 +200,54 @@ def almost_equal(self, other, tolerance=1e-12):
len(self.factors) == len(other.factors) and \
all(s.almost_equal(o, tolerance=tolerance)
for s, o in zip(self.factors, other.factors))


class MappedPointSet(AbstractPointSet):

def __init__(self, cell, ps):
self.cell = cell
self.ps = ps

@cached_property
def transforms(self):
top = self.cell.topology
dim = self.ps.dimension
sd = self.cell.get_spatial_dimension()
A = numpy.zeros((len(top[dim]), sd, dim))
b = numpy.zeros((len(top[dim]), sd))
ref_verts = self.cell.construct_subelement(dim).vertices
for entity in sorted(top[dim]):
verts = self.cell.get_vertices_of_subcomplex(top[dim][entity])
A[entity], b[entity] = make_affine_mapping(ref_verts, verts)
return A, b

@cached_property
def points(self):
x = self.ps.points
A, b = self.transforms
pts = [numpy.add(numpy.dot(x, A[entity].T), b[entity])
for entity in range(len(A))]
return numpy.concatenate(pts)

@cached_property
def indices(self):
num_facets = len(self.cell.topology[self.ps.dimension])
return (gem.Index(extent=num_facets), *self.ps.indices)

@cached_property
def expression(self):
A, b = self.transforms
x = self.ps.expression
i, *p = self.indices
j, k = (gem.Index(extent=e) for e in A.shape[1:])

xpk = gem.Indexed(x, (*p, k))
Aijk = gem.Indexed(gem.Literal(A), (i, j, k))
bij = gem.Indexed(gem.Literal(b), (i, j))
return gem.Sum(gem.IndexSum(Aijk, xpk, (k,)), bij)

def almost_equal(self, other, tolerance=1e-12):
"""Approximate numerical equality of point sets"""
return type(self) is type(other) and \
self.cell == other.cell and \
self.ps.almost_equal(other.ps, tolerance=tolerance)
22 changes: 6 additions & 16 deletions finat/quadrature_element.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from finat.point_set import UnknownPointSet, PointSet
from finat.point_set import UnknownPointSet, MappedPointSet
from functools import reduce

import numpy
Expand Down Expand Up @@ -91,14 +91,7 @@ def space_dimension(self):
def _point_set(self):
ps = self._rule.point_set
sd = self.cell.get_spatial_dimension()
dim = ps.dimension
if dim != sd:
# Tile the quadrature rule on each subentity
entity_ids = self.entity_dofs()
pts = [self.cell.get_entity_transform(dim, entity)(ps.points)
for entity in entity_ids[dim]]
ps = PointSet(numpy.stack(pts, axis=0))
return ps
return ps if ps.dimension == sd else MappedPointSet(self.cell, ps)

@property
def index_shape(self):
Expand Down Expand Up @@ -147,16 +140,13 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):

# Return an outer product of identity matrices
multiindex = self.get_indices()
fid = ps.indices
if len(multiindex) > len(fid):
fid = (entity_id, *fid)
product = reduce(gem.Product, [gem.Delta(q, r)
for q, r in zip(ps.indices, multiindex[-len(ps.indices):])])
for q, r in zip(fid, multiindex)])

sd = self.cell.get_spatial_dimension()
if sd != ps.dimension:
data = numpy.zeros(self.index_shape[:-1], dtype=object)
data[...] = gem.Zero()
data[entity_id] = gem.Literal(1)
product = gem.Product(product, gem.Indexed(gem.ListTensor(data), multiindex[:1]))

return {(0,) * sd: gem.ComponentTensor(product, multiindex)}

def point_evaluation(self, order, refcoords, entity=None):
Expand Down
11 changes: 10 additions & 1 deletion finat/tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from gem.utils import cached_property

from finat.finiteelementbase import FiniteElementBase
from finat.point_set import PointSingleton, PointSet, TensorPointSet
from finat.point_set import PointSingleton, PointSet, TensorPointSet, MappedPointSet


class TensorProductElement(FiniteElementBase):
Expand Down Expand Up @@ -138,6 +138,15 @@ def _merge_evaluations(self, factor_results):
return result

def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
if isinstance(ps, MappedPointSet):
top = self.cell.topology
evals = [self.basis_evaluation(order, ps.ps, entity=(dim, entity),
coordinate_mapping=coordinate_mapping)
for dim in sorted(top)
for entity in sorted(top[dim])
if sum(dim) == ps.ps.dimension]
return {key: gem.ListTensor([e[key] for e in evals]) for key in evals[0]}

entities = self._factor_entity(entity)
entity_dim, _ = zip(*entities)

Expand Down
10 changes: 10 additions & 0 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,16 @@ def __new__(cls, i, j, dtype=None):
if isinstance(i, int) and isinstance(j, int):
return one if i == j else Zero()

if isinstance(i, int):
expr = numpy.full((j.extent), Zero(), dtype=object)
expr[i] = one
return Indexed(ListTensor(expr), (j,))

if isinstance(j, int):
expr = numpy.full((i.extent), Zero(), dtype=object)
expr[j] = one
return Indexed(ListTensor(expr), (i,))

self = super(Delta, cls).__new__(cls)
self.i = i
self.j = j
Expand Down

0 comments on commit 3f871a6

Please sign in to comment.