Skip to content

Commit

Permalink
FacetSplitPC: use global permutation to construct virtual submatrix (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck authored Feb 21, 2024
1 parent 10c35cd commit c2ab0b5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 17 deletions.
25 changes: 10 additions & 15 deletions firedrake/preconditioners/facet_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def restrict(ele, restriction_domain):
W = FunctionSpace(V.mesh(), MixedElement([restrict(V.ufl_element(), d) for d in ("interior", "facet")]))
assert W.dim() == V.dim(), "Dimensions of the original and decomposed spaces do not match"

mixed_operator = a(sum(TestFunctions(W)), sum(TrialFunctions(W)), coefficients={})
mixed_operator = a(sum(TestFunctions(W)), sum(TrialFunctions(W)))
mixed_bcs = tuple(bc.reconstruct(V=W[-1], g=0) for bc in bcs)

self.perm = None
Expand Down Expand Up @@ -112,7 +112,9 @@ def _permute_nullspace(nsp):
mixed_opmat.setNearNullSpace(_permute_nullspace(P.getNearNullSpace()))
mixed_opmat.setTransposeNullSpace(_permute_nullspace(P.getTransposeNullSpace()))
elif self.perm:
self._permute_op = partial(PETSc.Mat().createSubMatrixVirtual, P, self.iperm, self.iperm)
global_indices = V.dof_dset.lgmap.apply(self.iperm.indices)
self._global_iperm = PETSc.IS().createGeneral(global_indices, comm=P.getComm())
self._permute_op = partial(PETSc.Mat().createSubMatrixVirtual, P, self._global_iperm, self._global_iperm)
mixed_opmat = self._permute_op()
else:
mixed_opmat = P
Expand Down Expand Up @@ -246,25 +248,18 @@ def get_permutation_map(V, W):
val = numpy.arange(offset, offset + Wsub.dof_count, dtype=PETSc.IntType)
wdats.append(Wsub.make_dat(val=val))
offset += Wsub.dof_dset.layout_vec.sizes[0]

sizes = [Wsub.finat_element.space_dimension() * Wsub.value_size for Wsub in W]
wdat = op2.MixedDat(wdats)
size = sum(Wsub.finat_element.space_dimension() * Wsub.value_size for Wsub in W)
eperm = numpy.concatenate([restricted_dofs(Wsub.finat_element, V.finat_element) for Wsub in W])
pmap = PermutedMap(V.cell_node_map(), eperm)

kernel_code = f"""
void permutation(PetscInt *restrict x,
const PetscInt *restrict xi,
const PetscInt *restrict xf){{
for(PetscInt i=0; i<{sizes[0]}; i++) x[i] = xi[i];
for(PetscInt i=0; i<{sizes[1]}; i++) x[i+{sizes[0]}] = xf[i];
return;
}}
"""
void permutation(PetscInt *restrict v, const PetscInt *restrict w) {{
for (PetscInt i=0; i<{size}; i++) v[i] = w[i];
}}"""
kernel = op2.Kernel(kernel_code, "permutation", requires_zeroed_output_arguments=False)
op2.par_loop(kernel, V.mesh().cell_set,
vdat(op2.WRITE, pmap),
wdats[0](op2.READ, W[0].cell_node_map()),
wdats[1](op2.READ, W[1].cell_node_map()))
vdat(op2.WRITE, pmap), wdat(op2.READ, W.cell_node_map()))

own = V.dof_dset.layout_vec.sizes[0]
return perm[:own]
5 changes: 3 additions & 2 deletions tests/regression/test_facet_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,6 @@ def test_facet_split(quadrilateral, pc_type):


@pytest.mark.parallel
def test_facet_split_parallel():
assert run_facet_split(True, "lu", refine=3) < 1E-10
@pytest.mark.parametrize("pc_type", ["lu", "jacobi"])
def test_facet_split_parallel(pc_type):
assert run_facet_split(True, pc_type, refine=3) < 1E-10

0 comments on commit c2ab0b5

Please sign in to comment.