Skip to content

Commit

Permalink
Update doc, additional simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed Oct 9, 2023
1 parent 535dadb commit e769630
Showing 1 changed file with 52 additions and 6 deletions.
58 changes: 52 additions & 6 deletions stan/math/prim/fun/grad_pFq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,54 @@ inline auto binarysign(const T& x) {
* input arguments:
* \f$ _pF_q(a_1,...,a_p;b_1,...,b_q;z) \f$
*
* Where:
* \f$
* \frac{\partial }{\partial a_1} =
* \sum_{k=0}^{\infty}{
* \frac
* {\psi\left(k+a_1\right)\left(\prod_{j=1}^p\left(a_j\right)_k\right)z^k}
* {k!\prod_{j=1}^q\left(b_j\right)_k}}
* - \psi\left(a_1\right)_pF_q(a_1,...,a_p;b_1,...,b_q;z)
* \f$
* \f$
* \frac{\partial }{\partial b_1} =
* \psi\left(b_1\right)_pF_q(a_1,...,a_p;b_1,...,b_q;z) -
* \sum_{k=0}^{\infty}{
* \frac
* {\psi\left(k+b_1\right)\left(\prod_{j=1}^p\left(a_j\right)_k\right)z^k}
* {k!\prod_{j=1}^q\left(b_j\right)_k}}
* \f$
*
* \f$
* \frac{\partial }{\partial z} =
* \frac{\prod_{j=1}^{p}(a_j)}{\prod_{j=1}^{q} (b_j)}\
* * _pF_q(a_1+1,...,a_p+1;b_1+1,...,b_q+1;z)
* \f$
*
* Noting the the recurrence relation for the digamma function:
* \f$ \psi(x + 1) = \psi(x) + \frac{1}{x} \f$, and as such the presence of the
* digamma function in both operands of the subtraction, this then becomes a
* scaling factor and can be removed. The gradients for the function w.r.t a & b
* then simplify to:
* \f$
* \frac{\partial }{\partial a_1} =
* \sum_{k=1}^{\infty}{
* \frac
* {\left(1 + \sum_{m=0}^{k-1}\frac{1}{m+a_1}\right)
* * (\prod_{j=1}^p\left(a_j\right)_k\right)z^k}
* {k!\prod_{j=1}^q\left(b_j\right)_k}}
* - _pF_q(a_1,...,a_p;b_1,...,b_q;z)
* \f$
* \f$
* \frac{\partial }{\partial b_1} =
* _pF_q(a_1,...,a_p;b_1,...,b_q;z) -
* \sum_{k=1}^{\infty}{
* \frac
* {\left(1 + \sum_{m=0}^{k-1}\frac{1}{m+b_1}\right)
* * \left(\prod_{j=1}^p\left(a_j\right)_k\right)z^k}
* {k!\prod_{j=1}^q\left(b_j\right)_k}}
* \f$
*
* @tparam CalcA Boolean for whether to calculate derivatives wrt to 'a'
* @tparam CalcB Boolean for whether to calculate derivatives wrt to 'b'
* @tparam CalcZ Boolean for whether to calculate derivatives wrt to 'z'
Expand Down Expand Up @@ -57,8 +105,8 @@ auto grad_pFq(const TpFq& pfq_val, const Ta& a, const Tb& b, const Tz& z,
Tz log_z = log(abs(z));
int z_sign = internal::binarysign(z);

Ta_Array digamma_a = select(a_k == 0.0, 0.0, inv(a_k));
Tb_Array digamma_b = select(b_k == 0.0, 0.0, inv(b_k));
Ta_Array digamma_a = Ta_Array::Ones(a.size());
Tb_Array digamma_b = Tb_Array::Ones(b.size());

std::tuple<promote_scalar_t<T_Rtn, plain_type_t<Ta>>,
promote_scalar_t<T_Rtn, plain_type_t<Tb>>, T_Rtn>
Expand Down Expand Up @@ -106,12 +154,10 @@ auto grad_pFq(const TpFq& pfq_val, const Ta& a, const Tb& b, const Tz& z,
}

if (CalcA) {
std::get<0>(ret_tuple).array()
-= select(a_array == 0.0, 0.0, pfq_val / a_array);
std::get<0>(ret_tuple).array() -= pfq_val;
}
if (CalcB) {
std::get<1>(ret_tuple).array()
+= select(b_array == 0.0, 0.0, pfq_val / b_array);
std::get<1>(ret_tuple).array() += pfq_val;
}
}
if (CalcZ) {
Expand Down

0 comments on commit e769630

Please sign in to comment.