Skip to content

Commit

Permalink
Merge pull request #2961 from stan-dev/grad_pfq-2
Browse files Browse the repository at this point in the history
Re-Implementation of Hypergeometric PFQ gradient function
  • Loading branch information
syclik authored Apr 13, 2024
2 parents 1f94ed3 + ecc713b commit de0c1a7
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 683 deletions.
54 changes: 32 additions & 22 deletions stan/math/fwd/fun/hypergeometric_pFq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,42 @@ namespace math {
* @return Generalized hypergeometric function
*/
template <typename Ta, typename Tb, typename Tz,
require_all_matrix_t<Ta, Tb>* = nullptr,
require_return_type_t<is_fvar, Ta, Tb, Tz>* = nullptr>
inline return_type_t<Ta, Tb, Tz> hypergeometric_pFq(const Ta& a, const Tb& b,
const Tz& z) {
using fvar_t = return_type_t<Ta, Tb, Tz>;
ref_type_t<Ta> a_ref = a;
ref_type_t<Tb> b_ref = b;
auto grad_tuple = grad_pFq(a_ref, b_ref, z);

typename fvar_t::Scalar grad = 0;

if (!is_constant<Ta>::value) {
grad += dot_product(forward_as<promote_scalar_t<fvar_t, Ta>>(a_ref).d(),
std::get<0>(grad_tuple));
typename FvarT = return_type_t<Ta, Tb, Tz>,
bool grad_a = !is_constant<Ta>::value,
bool grad_b = !is_constant<Tb>::value,
bool grad_z = !is_constant<Tz>::value,
require_all_vector_t<Ta, Tb>* = nullptr,
require_fvar_t<FvarT>* = nullptr>
inline FvarT hypergeometric_pFq(const Ta& a, const Tb& b, const Tz& z) {
using PartialsT = partials_type_t<FvarT>;
using ARefT = ref_type_t<Ta>;
using BRefT = ref_type_t<Tb>;

ARefT a_ref = a;
BRefT b_ref = b;
auto&& a_val = value_of(a_ref);
auto&& b_val = value_of(b_ref);
auto&& z_val = value_of(z);
PartialsT pfq_val = hypergeometric_pFq(a_val, b_val, z_val);
auto grad_tuple
= grad_pFq<grad_a, grad_b, grad_z>(pfq_val, a_val, b_val, z_val);

FvarT rtn = FvarT(pfq_val, 0.0);

if (grad_a) {
rtn.d_ += dot_product(forward_as<promote_scalar_t<FvarT, ARefT>>(a_ref).d(),
std::get<0>(grad_tuple));
}
if (!is_constant<Tb>::value) {
grad += dot_product(forward_as<promote_scalar_t<fvar_t, Tb>>(b_ref).d(),
std::get<1>(grad_tuple));
if (grad_b) {
rtn.d_ += dot_product(forward_as<promote_scalar_t<FvarT, BRefT>>(b_ref).d(),
std::get<1>(grad_tuple));
}
if (!is_constant<Tz>::value) {
grad += forward_as<promote_scalar_t<fvar_t, Tz>>(z).d_
* std::get<2>(grad_tuple);
if (grad_z) {
rtn.d_ += forward_as<promote_scalar_t<FvarT, Tz>>(z).d_
* std::get<2>(grad_tuple);
}

return fvar_t(
hypergeometric_pFq(value_of(a_ref), value_of(b_ref), value_of(z)), grad);
return rtn;
}

} // namespace math
Expand Down
Loading

0 comments on commit de0c1a7

Please sign in to comment.