-
-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalised unary vector function framework #1558
Merged
rok-cesnovar
merged 36 commits into
stan-dev:develop
from
andrjohns:feature/vec_gen_design
Jan 24, 2020
Merged
Changes from 34 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
ca89f78
Initial implementation
de964ba
Merge branch 'develop' into feature/vec_gen_design
6b54c52
Add forwarding, rev & fwd versions
eebfd95
Merge branch 'develop' into feature/vec_gen_design
2b085d1
Add autodiff tests, remove arr versions
ac3d5e3
Nested testing
77f3a59
Fix tests, update doc
b3a1132
Tidy doc
27d1b1f
Merge branch 'develop' into feature/vec_gen_design
a4b83a7
Cpplint
d97c963
Tidy missing doc
461f0b0
log_softmax doc errors
cd0a362
Merge branch 'develop' into feature/vec_gen_design
9fa5124
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot cba3a30
Fix failing test
e5a6b28
Merge develop
5a934fe
Revert head replacement
b7b2171
Merge branch 'develop' into feature/vec_gen_design
8ebea40
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot a73be6d
Merge commit '426ad8fe5a2858b9d367aade1b25a631ac5e97e8' into merge_af…
rok-cesnovar 8b5cc7f
Merge commit 'd7eb73884e5fad18eaf323760e4625317e1c4c91' into merge_af…
rok-cesnovar df34056
Merge commit '2b2f7ddff32c12e1e253a6179bf81c1845962306' into merge_af…
rok-cesnovar 8a7017a
Merge commit '731b5f8cf6566db4f13a06851d56cc9e54029146' into merge_af…
rok-cesnovar 8214c93
Merge branch 'develop' into merge_after_flatten
rok-cesnovar d776eac
merge conflicts fix
rok-cesnovar 09c4004
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot a0eb3df
fix header guard
rok-cesnovar 4e83afb
remove include
rok-cesnovar 2e4f6b1
Merge branch 'develop' into feature/vec_gen_design
febffbe
Address review comments
67e23b8
Merge branch 'develop' into feature/vec_gen_design
477cf9f
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot 02afdb9
Fix merge error
8d8539a
cpplint
1fce5ac
Merge branch 'develop' into feature/vec_gen_design
f3b3286
Address comments
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -3,6 +3,7 @@ | |||||
|
||||||
#include <stan/math/prim/mat/fun/Eigen.hpp> | ||||||
#include <stan/math/prim/scal/fun/constants.hpp> | ||||||
#include <stan/math/prim/vectorize/apply_vector_unary.hpp> | ||||||
#include <cmath> | ||||||
#include <vector> | ||||||
|
||||||
|
@@ -12,31 +13,30 @@ namespace math { | |||||
/** | ||||||
* Return the log of the sum of the exponentiated values of the specified | ||||||
* matrix of values. The matrix may be a full matrix, a vector, | ||||||
* or a row vector. | ||||||
* a row vector, or a container of these. | ||||||
* | ||||||
* The function is defined as follows to prevent overflow in exponential | ||||||
* calculations. | ||||||
* | ||||||
* \f$\log \sum_{n=1}^N \exp(x_n) = \max(x) + \log \sum_{n=1}^N \exp(x_n - | ||||||
* \max(x))\f$. | ||||||
* | ||||||
* @tparam R number of rows, can be Eigen::Dynamic | ||||||
* @tparam C number of columns, can be Eigen::Dynamic | ||||||
* | ||||||
* @param[in] x Matrix of specified values | ||||||
* @tparam T Type of input vector or matrix. | ||||||
* @param[in] x Matrix of specified values. | ||||||
* @return The log of the sum of the exponentiated vector values. | ||||||
*/ | ||||||
template <int R, int C> | ||||||
double log_sum_exp(const Eigen::Matrix<double, R, C>& x) { | ||||||
if (x.size() == 0) { | ||||||
return NEGATIVE_INFTY; | ||||||
} | ||||||
|
||||||
const double max = x.maxCoeff(); | ||||||
if (!std::isfinite(max)) { | ||||||
return max; | ||||||
} | ||||||
return max + std::log((x.array() - max).exp().sum()); | ||||||
template <typename T, require_t<std::is_arithmetic<scalar_type_t<T>>>...> | ||||||
inline auto log_sum_exp(const T& x) { | ||||||
return apply_vector_unary<T>::reduce(x, [&](auto& v) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
With no perfect forwarding this can be const. Same for all other lambdas you introduced. |
||||||
if (v.size() == 0) { | ||||||
return NEGATIVE_INFTY; | ||||||
} | ||||||
const double max = v.maxCoeff(); | ||||||
if (!std::isfinite(max)) { | ||||||
return max; | ||||||
} | ||||||
return max + std::log((v.array() - max).exp().sum()); | ||||||
}); | ||||||
} | ||||||
|
||||||
} // namespace math | ||||||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed a redundant definition here (and in the rev) header. Originally, there was a definition for
log_sum_exp(const fvar<T>& x1, double x2)
andlog_sum_exp(double x1, const fvar<T>& x2)
, but we can just have one definition and change the order of arguments as needed.