diff --git a/src/polysolve/linear/CuSolverDN.cuh b/src/polysolve/linear/CuSolverDN.cuh index de0f9db..8250c35 100644 --- a/src/polysolve/linear/CuSolverDN.cuh +++ b/src/polysolve/linear/CuSolverDN.cuh @@ -42,6 +42,8 @@ namespace polysolve::linear // Factorize system matrix (dense, preferred) virtual void factorize_dense(const Eigen::MatrixXd &A) override; + bool is_dense() const override { return true; } + // Solve the linear system Ax = b virtual void solve(const Ref b, Ref x) override; diff --git a/src/polysolve/linear/EigenSolver.hpp b/src/polysolve/linear/EigenSolver.hpp index df7ccea..048b706 100644 --- a/src/polysolve/linear/EigenSolver.hpp +++ b/src/polysolve/linear/EigenSolver.hpp @@ -97,6 +97,8 @@ namespace polysolve::linear // Constructor requires a solver name used for finding parameters in the json file passed to set_parameters EigenDenseSolver(const std::string &name) { m_Name = name; } + bool is_dense() const override { return true; } + public: // Get info on the last solve step virtual void get_info(json ¶ms) const override; diff --git a/src/polysolve/linear/FEMSolver.cpp b/src/polysolve/linear/FEMSolver.cpp index 2da5fd7..5d1f8f8 100644 --- a/src/polysolve/linear/FEMSolver.cpp +++ b/src/polysolve/linear/FEMSolver.cpp @@ -103,6 +103,7 @@ namespace polysolve::linear const bool remove_zero_cols, const bool skip_last_cols) { + assert(!solver.is_dense()); // Let Γ be the set of Dirichlet dofs. // To implement nonzero Dirichlet boundary conditions, we seek to replace // the linear system Au = f with a new system Ãx = g, where diff --git a/src/polysolve/linear/Solver.hpp b/src/polysolve/linear/Solver.hpp index 5415191..a02c7f2 100644 --- a/src/polysolve/linear/Solver.hpp +++ b/src/polysolve/linear/Solver.hpp @@ -91,6 +91,9 @@ namespace polysolve::linear // Factorize system matrix of a dense matrix virtual void factorize_dense(const Eigen::MatrixXd &A) {} + // If solver uses dense matrices + virtual bool is_dense() const { return false; } + // // @brief { Solve the linear system Ax = b } // diff --git a/src/polysolve/nonlinear/BFGS.cpp b/src/polysolve/nonlinear/BFGS.cpp index 660e928..02bae26 100644 --- a/src/polysolve/nonlinear/BFGS.cpp +++ b/src/polysolve/nonlinear/BFGS.cpp @@ -12,6 +12,7 @@ namespace polysolve::nonlinear : Superclass(solver_params, characteristic_length, logger) { linear_solver = polysolve::linear::Solver::create(linear_solver_params, logger); + assert(linear_solver->is_dense()); } std::string BFGS::descent_strategy_name(int descent_strategy) const diff --git a/src/polysolve/nonlinear/DenseNewton.cpp b/src/polysolve/nonlinear/DenseNewton.cpp index c34038a..a9207a6 100644 --- a/src/polysolve/nonlinear/DenseNewton.cpp +++ b/src/polysolve/nonlinear/DenseNewton.cpp @@ -14,6 +14,7 @@ namespace polysolve::nonlinear logger) { linear_solver = polysolve::linear::Solver::create(linear_solver_params, logger); + assert(linear_solver->is_dense()); } double DenseNewton::solve_linear_system(Problem &objFunc, diff --git a/src/polysolve/nonlinear/SparseNewton.cpp b/src/polysolve/nonlinear/SparseNewton.cpp index 3c9c955..7cc7fa4 100644 --- a/src/polysolve/nonlinear/SparseNewton.cpp +++ b/src/polysolve/nonlinear/SparseNewton.cpp @@ -16,6 +16,7 @@ namespace polysolve::nonlinear logger) { linear_solver = polysolve::linear::Solver::create(linear_solver_params, logger); + assert(!linear_solver->is_dense()); } double SparseNewton::solve_linear_system(Problem &objFunc, diff --git a/tests/test_linear_solver.cpp b/tests/test_linear_solver.cpp index 065f290..7e449b1 100644 --- a/tests/test_linear_solver.cpp +++ b/tests/test_linear_solver.cpp @@ -106,6 +106,8 @@ TEST_CASE("all", "[solver]") Eigen::SparseMatrix A; const bool ok = loadMarket(A, path + "/A_2.mat"); REQUIRE(ok); + json solver_info; + Eigen::MatrixXd A_dense(A); auto solvers = Solver::available_solvers(); for (const auto &s : solvers) @@ -130,13 +132,22 @@ TEST_CASE("all", "[solver]") Eigen::VectorXd x(b.size()); x.setZero(); - solver->analyze_pattern(A, A.rows()); - solver->factorize(A); + if (solver->is_dense()) + { + solver->analyze_pattern_dense(A, A.rows()); + solver->factorize_dense(A); + } + else + { + solver->analyze_pattern(A, A.rows()); + solver->factorize(A); + } + solver->solve(b, x); REQUIRE(solver->name() == s); - // solver->get_info(solver_info); + solver->get_info(solver_info); // std::cout<<"Solver error: "<