diff --git a/stan/math/fwd/fun/log_add_exp.hpp b/stan/math/fwd/fun/log_add_exp.hpp index 26fcff8068c..f6c0827b1a0 100644 --- a/stan/math/fwd/fun/log_add_exp.hpp +++ b/stan/math/fwd/fun/log_add_exp.hpp @@ -49,36 +49,36 @@ inline fvar log_add_exp(double x1, const fvar& x2) { // Overload for matrices of fvar template -inline fvar log_add_exp(const Eigen::Matrix, -1, -1>& a, const Eigen::Matrix, -1, -1>& b) { - - using fvar_mat_type = Eigen::Matrix, -1, -1>; - fvar_mat_type result(a.rows(), a.cols()); +inline fvar log_add_exp(const Eigen::Matrix, -1, -1>& a, + const Eigen::Matrix, -1, -1>& b) { + using fvar_mat_type = Eigen::Matrix, -1, -1>; + fvar_mat_type result(a.rows(), a.cols()); - // Check for empty inputs - if (a.size() == 0 || b.size() == 0) { + // Check for empty inputs + if (a.size() == 0 || b.size() == 0) { throw std::invalid_argument("Input containers must not be empty."); - } + } - // Check for NaN - if (a.array().isNaN().any() || b.array().isNaN().any()) { + // Check for NaN + if (a.array().isNaN().any() || b.array().isNaN().any()) { result.setConstant(fvar(std::numeric_limits::quiet_NaN())); return result; - } + } - // Check for infinity - if (a.array().isInf().any() || b.array().isInf().any()) { + // Check for infinity + if (a.array().isInf().any() || b.array().isInf().any()) { result.setConstant(fvar(std::numeric_limits::quiet_NaN())); return result; - } + } - // Apply the log_add_exp operation directly - for (int i = 0; i < a.rows(); ++i) { + // Apply the log_add_exp operation directly + for (int i = 0; i < a.rows(); ++i) { for (int j = 0; j < a.cols(); ++j) { - result(i, j) = stan::math::log_add_exp(a(i, j), b(i, j)); - } + result(i, j) = stan::math::log_add_exp(a(i, j), b(i, j)); } + } - return result; // Return the result matrix + return result; // Return the result matrix } // Specialization for nested fvar types diff --git a/stan/math/prim/fun/log_add_exp.hpp b/stan/math/prim/fun/log_add_exp.hpp index 82aa9a83fff..7117d876179 100644 --- a/stan/math/prim/fun/log_add_exp.hpp +++ b/stan/math/prim/fun/log_add_exp.hpp @@ -26,22 +26,21 @@ namespace math { * @param b the second variable */ -template * = nullptr, - require_all_stan_scalar_t* = nullptr> +template * = nullptr, + require_all_stan_scalar_t* = nullptr> inline return_type_t log_add_exp(const T2& a, const T1& b) { - if (a == NEGATIVE_INFTY) { + if (a == NEGATIVE_INFTY) { return b; - } - if (b == NEGATIVE_INFTY) { + } + if (b == NEGATIVE_INFTY) { return a; - } - if (a == INFTY || b == INFTY) { + } + if (a == INFTY || b == INFTY) { return INFTY; - } + } - const double max_val = std::max(a, b); - return max_val + std::log(std::exp(a - max_val) + std::exp(b - max_val)); + const double max_val = std::max(a, b); + return max_val + std::log(std::exp(a - max_val) + std::exp(b - max_val)); } /** @@ -100,9 +99,8 @@ inline auto log_add_exp(const T& a, const T& b) { */ template * = nullptr> inline auto log_add_exp(const T1& a, const T2& b) { - return apply_scalar_binary( - a, b, [](const auto& c, const auto& d) { return log_sum_exp(c, d); } - ); + return apply_scalar_binary( + a, b, [](const auto& c, const auto& d) { return log_sum_exp(c, d); }); } } // namespace math diff --git a/stan/math/rev/fun/log_add_exp.hpp b/stan/math/rev/fun/log_add_exp.hpp index be870405e69..c0663e064a2 100644 --- a/stan/math/rev/fun/log_add_exp.hpp +++ b/stan/math/rev/fun/log_add_exp.hpp @@ -19,31 +19,31 @@ namespace math { namespace internal { class log_add_exp_vv_vari : public op_vv_vari { - public: - log_add_exp_vv_vari(vari* avi, vari* bvi) - : op_vv_vari(log_add_exp(avi->val_, bvi->val_), avi, bvi) {} - void chain() { - double exp_a = std::exp(avi_->val_); - double exp_b = std::exp(bvi_->val_); - double sum_exp = exp_a + exp_b; + public: + log_add_exp_vv_vari(vari* avi, vari* bvi) + : op_vv_vari(log_add_exp(avi->val_, bvi->val_), avi, bvi) {} + void chain() { + double exp_a = std::exp(avi_->val_); + double exp_b = std::exp(bvi_->val_); + double sum_exp = exp_a + exp_b; - avi_->adj_ += adj_ * (exp_a / sum_exp); - bvi_->adj_ += adj_ * (exp_b / sum_exp); - } + avi_->adj_ += adj_ * (exp_a / sum_exp); + bvi_->adj_ += adj_ * (exp_b / sum_exp); + } }; class log_add_exp_vd_vari : public op_vd_vari { - public: - log_add_exp_vd_vari(vari* avi, double b) - : op_vd_vari(log_add_exp(avi->val_, b), avi, b) {} - void chain() { - if (val_ == NEGATIVE_INFTY) { - avi_->adj_ += adj_; - } else { - double exp_a = std::exp(avi_->val_); - avi_->adj_ += adj_ * (exp_a / (exp_a + std::exp(bd_))); - } + public: + log_add_exp_vd_vari(vari* avi, double b) + : op_vd_vari(log_add_exp(avi->val_, b), avi, b) {} + void chain() { + if (val_ == NEGATIVE_INFTY) { + avi_->adj_ += adj_; + } else { + double exp_a = std::exp(avi_->val_); + avi_->adj_ += adj_ * (exp_a / (exp_a + std::exp(bd_))); } + } }; } // namespace internal @@ -52,21 +52,21 @@ class log_add_exp_vd_vari : public op_vd_vari { * Returns the element-wise log sum of exponentials. */ inline var log_add_exp(const var& a, const var& b) { - return var(new internal::log_add_exp_vv_vari(a.vi_, b.vi_)); + return var(new internal::log_add_exp_vv_vari(a.vi_, b.vi_)); } /** * Returns the log sum of exponentials. */ inline var log_add_exp(const var& a, double b) { - return var(new internal::log_add_exp_vd_vari(a.vi_, b)); + return var(new internal::log_add_exp_vd_vari(a.vi_, b)); } /** * Returns the element-wise log sum of exponentials. */ inline var log_add_exp(double a, const var& b) { - return var(new internal::log_add_exp_vd_vari(b.vi_, a)); + return var(new internal::log_add_exp_vd_vari(b.vi_, a)); } /** @@ -78,9 +78,8 @@ inline var log_add_exp(double a, const var& b) { */ template * = nullptr> inline T log_add_exp(const T& x, const T& y) { - return apply_scalar_binary( - x, y, [](const auto& a, const auto& b) { return log_add_exp(a, b); } - ); + return apply_scalar_binary( + x, y, [](const auto& a, const auto& b) { return log_add_exp(a, b); }); } } // namespace math diff --git a/test/unit/math/mix/fun/log_add_exp_test.cpp b/test/unit/math/mix/fun/log_add_exp_test.cpp index 6db2f3f8067..d311aa40034 100644 --- a/test/unit/math/mix/fun/log_add_exp_test.cpp +++ b/test/unit/math/mix/fun/log_add_exp_test.cpp @@ -4,68 +4,68 @@ #include TEST(MathMixMatFun, logAddExp) { - auto f = [](const auto& x, const auto& y) { + auto f = [](const auto& x, const auto& y) { return stan::math::log_add_exp(x, y); - }; - // Test with finite values - Eigen::VectorXd x1(2); - x1 << 2.0, 1.0; - Eigen::VectorXd y1(2); - y1 << 3.0, 2.0; - stan::test::expect_ad(f, x1, y1); + }; + // Test with finite values + Eigen::VectorXd x1(2); + x1 << 2.0, 1.0; + Eigen::VectorXd y1(2); + y1 << 3.0, 2.0; + stan::test::expect_ad(f, x1, y1); - // Test with negative infinity + // Test with negative infinity - stan::test::expect_ad(f, stan::math::NEGATIVE_INFTY, 1.0); - stan::test::expect_ad(f, 1.0, stan::math::NEGATIVE_INFTY); + stan::test::expect_ad(f, stan::math::NEGATIVE_INFTY, 1.0); + stan::test::expect_ad(f, 1.0, stan::math::NEGATIVE_INFTY); - // Test with infinity - stan::test::expect_ad(f, stan::math::INFTY, stan::math::INFTY); + // Test with infinity + stan::test::expect_ad(f, stan::math::INFTY, stan::math::INFTY); } TEST(MathMixMatFun, log_add_exp_elementwise_values) { - auto f = [](const auto& x, const auto& y) { + auto f = [](const auto& x, const auto& y) { return stan::math::log_add_exp(x, y); - }; + }; - Eigen::VectorXd x1(2); - x1 << 2.0, 1.0; - Eigen::VectorXd y1(2); - y1 << 3.0, 2.0; - stan::test::expect_ad(f, x1, y1); + Eigen::VectorXd x1(2); + x1 << 2.0, 1.0; + Eigen::VectorXd y1(2); + y1 << 3.0, 2.0; + stan::test::expect_ad(f, x1, y1); - Eigen::VectorXd x2(2); - x2 << 0.5, -1.0; - Eigen::VectorXd y2(2); - y2 << 1.0, 2.0; - stan::test::expect_ad(f, x2, y2); + Eigen::VectorXd x2(2); + x2 << 0.5, -1.0; + Eigen::VectorXd y2(2); + y2 << 1.0, 2.0; + stan::test::expect_ad(f, x2, y2); - // Test with infinity - Eigen::VectorXd x3(2); - x3 << std::numeric_limits::infinity(), 1.0; - Eigen::VectorXd y3(2); - y3 << 2.0, std::numeric_limits::infinity(); - stan::test::expect_ad(f, x3, y3); + // Test with infinity + Eigen::VectorXd x3(2); + x3 << std::numeric_limits::infinity(), 1.0; + Eigen::VectorXd y3(2); + y3 << 2.0, std::numeric_limits::infinity(); + stan::test::expect_ad(f, x3, y3); } TEST(MathMixMatFun, log_add_exp_edge_cases) { - auto f = [](const auto& x, const auto& y) { + auto f = [](const auto& x, const auto& y) { return stan::math::log_add_exp(x, y); - }; + }; - stan::test::expect_ad(f, stan::math::NEGATIVE_INFTY, 1.0); - stan::test::expect_ad(f, 1.0, stan::math::NEGATIVE_INFTY); - stan::test::expect_ad(f, stan::math::INFTY, stan::math::INFTY); + stan::test::expect_ad(f, stan::math::NEGATIVE_INFTY, 1.0); + stan::test::expect_ad(f, 1.0, stan::math::NEGATIVE_INFTY); + stan::test::expect_ad(f, stan::math::INFTY, stan::math::INFTY); } TEST(MathMixMatFun, log_add_exp_mismatched_sizes) { - auto f = [](const auto& x, const auto& y) { + auto f = [](const auto& x, const auto& y) { return stan::math::log_add_exp(x, y); - }; + }; - std::vector x{1.0, 2.0}; - std::vector y{1.0, 2.0, 3.0}; + std::vector x{1.0, 2.0}; + std::vector y{1.0, 2.0, 3.0}; - stan::test::expect_ad(f, x, y); - stan::test::expect_ad(f, y, x); + stan::test::expect_ad(f, x, y); + stan::test::expect_ad(f, y, x); }