From 04bd9ff86398972b3a0018c30dad697c3edb5f49 Mon Sep 17 00:00:00 2001 From: Timo Betcke Date: Wed, 18 Dec 2019 23:53:46 +0000 Subject: [PATCH] Improved dense support for blocked operators. --- .../assembly/discrete_boundary_operator.py | 4 +-- bempp/api/linalg/direct_solvers.py | 25 ++++++++++++------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/bempp/api/assembly/discrete_boundary_operator.py b/bempp/api/assembly/discrete_boundary_operator.py index bbfb8053..cf08dcd8 100644 --- a/bempp/api/assembly/discrete_boundary_operator.py +++ b/bempp/api/assembly/discrete_boundary_operator.py @@ -89,7 +89,7 @@ def __init__(self, op1, op2): def _matvec(self, x): """Evaluate matvec.""" - return op1 @ x + op2 @ x + return self._op1 @ x + self._op2 @ x @property def A(self): @@ -117,7 +117,7 @@ def __init__(self, op1, op2): def _matvec(self, x): """Evaluate matvec.""" - return op1 @ (op2 @ x) + return self._op1 @ (self._op2 @ x) @property def A(self): diff --git a/bempp/api/linalg/direct_solvers.py b/bempp/api/linalg/direct_solvers.py index 443de407..128e51c4 100644 --- a/bempp/api/linalg/direct_solvers.py +++ b/bempp/api/linalg/direct_solvers.py @@ -38,17 +38,24 @@ def lu(A, b, lu_factor=None): from bempp.api import GridFunction, as_matrix from scipy.linalg import solve, lu_solve from bempp.api.assembly.blocked_operator import BlockedOperatorBase + from bempp.api.assembly.blocked_operator import projections_from_grid_functions_list + from bempp.api.assembly.blocked_operator import grid_function_list_from_coefficients if isinstance(A, BlockedOperatorBase): blocked = True - - - - if lu_factor is not None: - vec = b.projections(A.dual_to_range) - sol = lu_solve(lu_factor, vec) + vec = projections_from_grid_functions_list(b, A.dual_to_range_spaces) + if lu_factor is not None: + sol = lu_solve(lu_factor, vec) + else: + mat = A.weak_form().A + sol = solve(mat, vec) + return grid_function_list_from_coefficients(sol, A.domain_spaces) else: - mat = as_matrix(A.weak_form()) vec = b.projections(A.dual_to_range) - sol = solve(mat, vec) - return GridFunction(A.domain, coefficients=sol) + if lu_factor is not None: + sol = lu_solve(lu_factor, vec) + else: + mat = A.weak_form().A + sol = solve(mat, vec) + return GridFunction(A.domain, coefficients=sol) +