-
-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reorganize /opencl and add missing matrix_cl overloads #1364
Changes from all commits
f60a88d
077bec2
29146a0
798f661
912c05c
88ee0c4
a2464ac
b1cd81d
0b4d1ea
0e084a2
b36103f
de53951
72f88a1
40be0f0
675df84
4e5927d
522d83b
7a8f326
b416f74
af84a9e
aca4aca
3c57a14
d097a9e
d4d52ff
75c83f6
197977f
d69b112
9bc190e
fdd62f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#ifndef STAN_MATH_OPENCL_PRIM_CHOLESKY_DECOMPOSE_HPP | ||
#define STAN_MATH_OPENCL_PRIM_CHOLESKY_DECOMPOSE_HPP | ||
#ifdef STAN_OPENCL | ||
#include <stan/math/opencl/matrix_cl.hpp> | ||
#include <stan/math/opencl/cholesky_decompose.hpp> | ||
#include <stan/math/opencl/copy_triangular.hpp> | ||
#include <stan/math/prim/meta.hpp> | ||
#include <cl.hpp> | ||
#include <algorithm> | ||
#include <cmath> | ||
|
||
namespace stan { | ||
namespace math { | ||
/** | ||
* Returns the lower-triangular Cholesky factor (i.e., matrix | ||
* square root) of the specified square, symmetric matrix on the OpenCL device. | ||
* The return value \f$L\f$ will be a lower-traingular matrix such that the | ||
* original matrix \f$A\f$ is given by <p>\f$A = L \times L^T\f$. | ||
* @param A Input square matrix | ||
* @return Square root of matrix. | ||
* @throw std::domain_error if m is not a symmetric matrix or | ||
* if m is not positive definite (if m has more than 0 elements) | ||
*/ | ||
template <typename T, typename = require_floating_point_t<T>> | ||
inline matrix_cl<T> cholesky_decompose(matrix_cl<T>& A) { | ||
check_square("cholesky_decompose", "A", A); | ||
check_symmetric("cholesky_decompose", "A", A); | ||
matrix_cl<T> res = copy_cl(A); | ||
if (res.rows() == 0) { | ||
return res; | ||
} | ||
opencl::cholesky_decompose(res); | ||
// check_pos_definite on matrix_cl is check_nan + check_diagonal_zeros | ||
check_nan("cholesky_decompose (OpenCL)", "A", res); | ||
check_diagonal_zeros("cholesky_decompose (OpenCL)", "A", res); | ||
res.template zeros_strict_tri<matrix_cl_view::Upper>(); | ||
return res; | ||
} | ||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#ifndef STAN_MATH_OPENCL_PRIM_MDIVIDE_LEFT_TRI_LOW_HPP | ||
#define STAN_MATH_OPENCL_PRIM_MDIVIDE_LEFT_TRI_LOW_HPP | ||
#ifdef STAN_OPENCL | ||
#include <stan/math/prim/mat/err/check_square.hpp> | ||
#include <stan/math/prim/mat/err/check_multiplicable.hpp> | ||
#include <stan/math/opencl/matrix_cl.hpp> | ||
#include <stan/math/opencl/multiply.hpp> | ||
#include <stan/math/opencl/tri_inverse.hpp> | ||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Returns the solution of the system Ax=b when A is lower triangular. | ||
* @tparam T1 type of elements in A | ||
* @tparam T2 type of elements in b | ||
* @param A Triangular matrix. | ||
* @param b Right hand side matrix or vector. | ||
* @return x = A^-1 b, solution of the linear system. | ||
* @throws std::domain_error if A is not square or the rows of b don't | ||
* match the size of A. | ||
*/ | ||
template <typename T1, typename T2, | ||
typename = require_all_floating_point_t<T1, T2>> | ||
inline matrix_cl<return_type_t<T1, T2>> mdivide_left_tri_low( | ||
const matrix_cl<T1>& A, const matrix_cl<T2>& b) { | ||
check_square("mdivide_left_tri_low", "A", A); | ||
check_multiplicable("mdivide_left_tri_low", "A", A, "b", b); | ||
return tri_inverse<matrix_cl_view::Lower>(A) * b; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to force the Lower I had to extend the tri_inverse with a template. The other option would be to force a change of A before calling tri_inverse but that changes the view for A globally, which is bad. The third option would be to have the "forced" view as an argument to tri_inverse. I am also fine with that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the option you chose. |
||
} | ||
|
||
/** | ||
* Returns the solution of the system Ax=b when A is triangular and b=I. | ||
* @tparam T type of elements in A | ||
* @tparam R1 number of rows in A | ||
* @tparam C1 number of columns in A | ||
* @param A Triangular matrix. | ||
* @return x = A^-1 . | ||
* @throws std::domain_error if A is not square | ||
*/ | ||
template <typename T, typename = require_all_floating_point_t<T>> | ||
inline matrix_cl<T> mdivide_left_tri_low(const matrix_cl<T>& A) { | ||
check_square("mdivide_left_tri_low", "A", A); | ||
return tri_inverse<matrix_cl_view::Lower>(A); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#ifndef STAN_MATH_OPENCL_PRIM_MDIVIDE_RIGHT_TRI_LOW_HPP | ||
#define STAN_MATH_OPENCL_PRIM_MDIVIDE_RIGHT_TRI_LOW_HPP | ||
#ifdef STAN_OPENCL | ||
#include <stan/math/prim/mat/err/check_square.hpp> | ||
#include <stan/math/prim/mat/err/check_multiplicable.hpp> | ||
#include <stan/math/opencl/matrix_cl.hpp> | ||
#include <stan/math/opencl/multiply.hpp> | ||
#include <stan/math/opencl/tri_inverse.hpp> | ||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Returns the solution of the system Ax=b where A is a | ||
* lower triangular matrix. | ||
* @param A Matrix. | ||
* @param b Right hand side matrix or vector. | ||
* @return x = b * tri(A)^-1, solution of the linear system. | ||
* @throws std::domain_error if A is not square or the rows of b don't | ||
* match the size of A. | ||
*/ | ||
template <typename T1, typename T2, | ||
typename = require_all_floating_point_t<T1, T2>> | ||
inline matrix_cl<return_type_t<T1, T2>> mdivide_right_tri_low( | ||
const matrix_cl<T2>& b, const matrix_cl<T1>& A) { | ||
check_square("mdivide_right_tri_low (OpenCL)", "A", A); | ||
check_multiplicable("mdivide_right_tri_low (OpenCL)", "b", b, "A", A); | ||
return b * tri_inverse<matrix_cl_view::Lower>(A); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section was move to /prim. The actual implementation that was already in opencl:: was left in /opencl.