Skip to content

Commit

Permalink
Add option for transposed solve
Browse files Browse the repository at this point in the history
  • Loading branch information
octave-user committed Apr 28, 2024
1 parent e114bf8 commit 62355de
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 18 deletions.
23 changes: 17 additions & 6 deletions inst/numerical_tests_01.tst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## numerical_tests.tst:01
%!test
%! if (~isempty(which("pastix")))
%! for i=1:2
%! for i=1:4
%! for j=1:2
%! A = [1 0 0 0 0
%! 0 3 0 0 0
Expand All @@ -26,12 +26,23 @@
%! opts.number_of_threads = int32(4);
%! opts.check_solution = true;
%! switch i
%! case 1
%! x = pastix(A, b, opts);
%! case 2
%! x = pastix(pastix(A, opts), b);
%! case {1, 2}
%! trans = 0;
%! case {3, 4}
%! trans = 2;
%! endswitch
%! switch i
%! case {1, 3}
%! x = pastix(A, b, opts, trans);
%! case {2, 4}
%! x = pastix(pastix(A, opts), b, trans);
%! endswitch
%! switch (trans)
%! case 0
%! f = max(norm(A * x - b, "cols") ./ norm(A * x + b, "cols"));
%! case 2
%! f = max(norm(A.' * x - b, "cols") ./ norm(A.' * x + b, "cols"));
%! endswitch
%! f = max(norm(A * x - b, "cols") ./ norm(A * x + b, "cols"));
%! assert(f <= eps^0.8);
%! endfor
%! endfor
Expand Down
18 changes: 15 additions & 3 deletions src/pardiso.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class PardisoObject : public octave_base_value {
virtual ~PardisoObject(void);
virtual size_t byte_size() const;
virtual dim_vector dims() const;
bool solve(DenseMatrixType& b, DenseMatrixType& x) const;
bool solve(DenseMatrixType& b, DenseMatrixType& x, long long sys) const;
static bool get_options(const octave_value& ovOptions, PardisoObject::Options& options);
virtual bool is_constant(void) const{ return true; }
virtual bool is_defined(void) const{ return true; }
Expand Down Expand Up @@ -447,7 +447,7 @@ PardisoObject<T>::~PardisoObject()
}

template <typename T>
bool PardisoObject<T>::solve(DenseMatrixType& b, DenseMatrixType& x) const {
bool PardisoObject<T>::solve(DenseMatrixType& b, DenseMatrixType& x, long long sys) const {
if (b.rows() != n) {
error_with_id("pardiso:solve", "pardiso: rows(b)=%Ld must be equal to rows(A)=%Ld", static_cast<long long>(b.rows()), n);
return false;
Expand All @@ -456,8 +456,14 @@ bool PardisoObject<T>::solve(DenseMatrixType& b, DenseMatrixType& x) const {
assert(b.rows() == x.rows());
assert(b.columns() == x.columns());

const auto save_sys = iparm[11];

iparm[11] = sys;

long long ierror = pardiso(b.fortran_vec(), x.fortran_vec(), b.columns());

iparm[11] = save_sys;

if (ierror != 0LL) {
error_with_id("pardiso:solve", "pardiso solve failed with status %Ld", ierror);
return false;
Expand Down Expand Up @@ -623,8 +629,14 @@ octave_value_list PardisoObject<T>::eval(const octave_value_list& args, int narg
bOwnPardiso = true;
}

long long sys = 0LL;

if (args.length() > iarg) {
sys = args(iarg++).long_value();
}

if (bHaveRightHandSide) {
if (pPardiso->solve(b, x)) {
if (pPardiso->solve(b, x, sys)) {
retval.append(x);
}

Expand Down
34 changes: 30 additions & 4 deletions src/pastix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class PastixObject : public octave_base_value {
virtual ~PastixObject(void);
virtual size_t byte_size() const;
virtual dim_vector dims() const;
bool solve(DenseMatrixType& b, DenseMatrixType& x) const;
bool solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans_t sys) const;
static bool get_options(const octave_value& ovOptions, PastixObject::Options& options);
virtual bool is_constant(void) const{ return true; }
virtual bool is_defined(void) const{ return true; }
Expand Down Expand Up @@ -423,7 +423,7 @@ void PastixObject<T>::cleanup()
}

template <typename T>
bool PastixObject<T>::solve(DenseMatrixType& b, DenseMatrixType& x) const {
bool PastixObject<T>::solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans_t sys) const {
if (b.rows() != ncols) {
error_with_id("pastix:solve", "pastix: rows(b)=%ld must be equal to rows(A)=%ld", long(b.rows()), long(ncols));
return false;
Expand All @@ -433,12 +433,18 @@ bool PastixObject<T>::solve(DenseMatrixType& b, DenseMatrixType& x) const {

x = b;

const auto save_sys = iparm[IPARM_TRANSPOSE_SOLVE];

iparm[IPARM_TRANSPOSE_SOLVE] = sys;

int rc = pastix_task_solve(pastix_data,
x.rows(),
x.columns(),
x.fortran_vec(),
x.rows());

iparm[IPARM_TRANSPOSE_SOLVE] = save_sys;

if (PASTIX_SUCCESS != rc) {
error_with_id("pastix:solve", "pastix_task_solve failed with status %d", rc);
return false;
Expand All @@ -458,6 +464,8 @@ bool PastixObject<T>::solve(DenseMatrixType& b, DenseMatrixType& x) const {
}

if (!bZeroVec) {
iparm[IPARM_TRANSPOSE_SOLVE] = sys;

// Avoid division zero by zero in PaStiX
rc = pastix_task_refine(pastix_data,
spm.n,
Expand All @@ -467,6 +475,8 @@ bool PastixObject<T>::solve(DenseMatrixType& b, DenseMatrixType& x) const {
x.fortran_vec() + j * x.rows(),
x.rows());

iparm[IPARM_TRANSPOSE_SOLVE] = save_sys;

if (PASTIX_SUCCESS != rc) {
error_with_id("pastix:solve", "pastix_task_refine failed with status %d", rc);
return false;
Expand All @@ -476,7 +486,8 @@ bool PastixObject<T>::solve(DenseMatrixType& b, DenseMatrixType& x) const {
OCTAVE_QUIT;
}

if (options.check_solution) {
if (options.check_solution && sys == PastixNoTrans) {
// FIXME: Is there any transposed version of spmCheckAxb?
rc = spmCheckAxb(dparm[DPARM_EPSILON_REFINEMENT],
b.columns(),
&spm,
Expand Down Expand Up @@ -792,8 +803,23 @@ octave_value_list PastixObject<T>::eval(const octave_value_list& args, int nargo
#endif
}

pastix_trans_t sys = PastixNoTrans;

if (args.length() > iarg) {
switch (args(iarg++).long_value()) {
case 0:
sys = PastixNoTrans;
break;
case 2:
sys = PastixTrans;
break;
default:
sys = static_cast<pastix_trans_t>(-1);
}
}

if (bHaveRightHandSide) {
if (pPastix->solve(b, x)) {
if (pPastix->solve(b, x, sys)) {
retval.append(x);
}

Expand Down
16 changes: 11 additions & 5 deletions src/umfpack.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2019(-2023) Reinhard <[email protected]>
// Copyright (C) 2019(-2024) Reinhard <[email protected]>

// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
Expand Down Expand Up @@ -150,7 +150,7 @@ class UmfpackObject : public octave_base_value {
UmfpackObject();
explicit UmfpackObject(const SparseMatrixType& A, const Options& options);
virtual ~UmfpackObject(void);
DenseMatrixType solve(const DenseMatrixType& b);
DenseMatrixType solve(const DenseMatrixType& b, SuiteSparse_long sys);
virtual bool is_constant(void) const{ return true; }
virtual bool is_defined(void) const{ return true; }
virtual dim_vector dims (void) const { return oMat.dims(); }
Expand Down Expand Up @@ -245,7 +245,7 @@ UmfpackObject<T>::~UmfpackObject()
}

template <typename T>
typename UmfpackObject<T>::DenseMatrixType UmfpackObject<T>::solve(const DenseMatrixType& b)
typename UmfpackObject<T>::DenseMatrixType UmfpackObject<T>::solve(const DenseMatrixType& b, SuiteSparse_long sys)
{
DenseMatrixType x(b.rows(), b.columns());

Expand All @@ -255,7 +255,7 @@ typename UmfpackObject<T>::DenseMatrixType UmfpackObject<T>::solve(const DenseMa
const T* const bp = b.data();

for (octave_idx_type j = 0; j < b.columns(); ++j) {
auto status = oMat.umfpack_solve(UMFPACK_A,
auto status = oMat.umfpack_solve(sys,
xp + j * n,
bp + j * n,
Numeric,
Expand Down Expand Up @@ -411,6 +411,12 @@ octave_value_list UmfpackObject<T>::eval(const octave_value_list& args, int narg
}
}

SuiteSparse_long sys = UMFPACK_A;

if (args.length() > iarg) {
sys = args(iarg++).long_value();
}

try {
if (bHaveMatrix) {
pUmfpack = new UmfpackObjectType{A, options};
Expand All @@ -419,7 +425,7 @@ octave_value_list UmfpackObject<T>::eval(const octave_value_list& args, int narg
}

if (bHaveRightHandSide) {
retval.append(pUmfpack->solve(b));
retval.append(pUmfpack->solve(b, sys));

if (bOwnUmfpack) {
delete pUmfpack;
Expand Down

0 comments on commit 62355de

Please sign in to comment.