Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Johnson committed Jan 13, 2020
1 parent 1fce5ac commit f3b3286
Show file tree
Hide file tree
Showing 8 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion stan/math/fwd/fun/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace math {
*/
template <typename T, require_t<is_fvar<scalar_type_t<T>>>...>
inline auto log_softmax(const T& x) {
return apply_vector_unary<T>::apply(x, [&](auto& alpha) {
return apply_vector_unary<T>::apply(x, [&](const auto& alpha) {
using T_fvar = value_type_t<decltype(alpha)>;
using T_fvar_inner = typename T_fvar::Scalar;

Expand Down
2 changes: 1 addition & 1 deletion stan/math/fwd/fun/log_sum_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ inline fvar<T> log_sum_exp(const fvar<T>& x1, double x2) {
*/
template <typename T, require_t<is_fvar<scalar_type_t<T>>>...>
inline auto log_sum_exp(const T& x) {
return apply_vector_unary<T>::reduce(x, [&](auto& v) {
return apply_vector_unary<T>::reduce(x, [&](const auto& v) {
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
using mat_type = Eigen::Matrix<T_fvar_inner, -1, -1>;
mat_type vals = v.val();
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace math {
*/
template <typename T, require_t<std::is_arithmetic<scalar_type_t<T>>>...>
inline auto log_softmax(const T& x) {
return apply_vector_unary<T>::apply(x, [&](auto& v) {
return apply_vector_unary<T>::apply(x, [&](const auto& v) {
check_nonzero_size("log_softmax", "v", v);
return (v.array() - log_sum_exp(v)).matrix();
});
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/log_sum_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ inline return_type_t<T1, T2> log_sum_exp(const T2& a, const T1& b) {
*/
template <typename T, require_t<std::is_arithmetic<scalar_type_t<T>>>...>
inline auto log_sum_exp(const T& x) {
return apply_vector_unary<T>::reduce(x, [&](auto& v) {
return apply_vector_unary<T>::reduce(x, [&](const auto& v) {
if (v.size() == 0) {
return NEGATIVE_INFTY;
}
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/vectorize/apply_vector_unary.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef STAN_MATH_PRIM_VECTORIZE_APPLY_VECTOR_UNARY_HPP
#define STAN_MATH_PRIM_VECTORIZE_APPLY_VECTOR_UNARY_HPP

#include <stan/math/prim/mat/fun/Eigen.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>
#include <vector>

Expand Down
2 changes: 1 addition & 1 deletion stan/math/rev/core/matrix_vari.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class op_matrix_vari : public vari {

public:
template <typename T, require_eigen_vt<is_var, T>...>
op_matrix_vari(double f, T&& vs) : vari(f), size_(vs.size()) {
op_matrix_vari(double f, const T& vs) : vari(f), size_(vs.size()) {
vis_ = ChainableStack::instance_->memalloc_.alloc_array<vari*>(size_);
Eigen::Map<Eigen::Matrix<vari*, -1, -1>>(vis_, vs.rows(), vs.cols())
= vs.vi();
Expand Down
2 changes: 1 addition & 1 deletion stan/math/rev/fun/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class log_softmax_elt_vari : public vari {
*/
template <typename T, require_t<is_var<scalar_type_t<T>>>...>
inline auto log_softmax(const T& x) {
return apply_vector_unary<T>::apply(x, [&](auto& alpha) {
return apply_vector_unary<T>::apply(x, [&](const auto& alpha) {
const int a_size = alpha.size();

check_nonzero_size("log_softmax", "alpha", alpha);
Expand Down
6 changes: 3 additions & 3 deletions stan/math/rev/fun/log_sum_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ namespace internal {
class log_sum_exp_matrix_vari : public op_matrix_vari {
public:
template <typename T>
explicit log_sum_exp_matrix_vari(T&& x)
: op_matrix_vari(log_sum_exp(x.val()), std::forward<T>(x)) {}
explicit log_sum_exp_matrix_vari(const T& x)
: op_matrix_vari(log_sum_exp(x.val()), x) {}
void chain() {
Eigen::Map<vector_vi> vis_map(vis_, size_);
vis_map.adj().array() += adj_ * (vis_map.val().array() - val_).exp();
Expand All @@ -82,7 +82,7 @@ class log_sum_exp_matrix_vari : public op_matrix_vari {
*/
template <typename T, require_t<is_var<scalar_type_t<T>>>...>
inline auto log_sum_exp(const T& x) {
return apply_vector_unary<T>::reduce(x, [&](auto& v) {
return apply_vector_unary<T>::reduce(x, [&](const auto& v) {
return var(new internal::log_sum_exp_matrix_vari(v));
});
}
Expand Down

0 comments on commit f3b3286

Please sign in to comment.