Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix assignment for nullptr var_value<matrix> and for assigning expressions #2978

Merged
merged 14 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions stan/math/rev/core/arena_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ class arena_matrix : public Eigen::Map<MatrixType> {
Base::operator=(a);
return *this;
}
/**
* Forces hard copying matrices into an arena matrix
* @tparam T Any type assignable to `Base`
* @param x the values to write to `this`
*/
template <typename T>
void hard_copy(const T& x) {
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
Base::operator=(x);
}
};

} // namespace math
Expand Down
64 changes: 57 additions & 7 deletions stan/math/rev/core/var.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@ class var_value<T, internal::require_matrix_var_value<T>> {
reverse_pass_callback(
[this_vi = this->vi_, other_vi = other.vi_]() mutable {
other_vi->adj_ += this_vi->adj_;
// Reset the adjoint for `this` to replicate SoA before assignment
this_vi->adj_.setZero();
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
});
}

Expand Down Expand Up @@ -1020,9 +1022,10 @@ class var_value<T, internal::require_matrix_var_value<T>> {
* @param other the value to assign
* @return this
*/
template <typename S, require_assignable_t<value_type, S>* = nullptr,
require_all_plain_type_t<T, S>* = nullptr,
require_not_same_t<plain_type_t<T>, plain_type_t<S>>* = nullptr>
template <typename S, typename T_ = T,
require_assignable_t<value_type, S>* = nullptr,
require_all_plain_type_t<T_, S>* = nullptr,
require_not_same_t<plain_type_t<T_>, plain_type_t<S>>* = nullptr>
inline var_value<T>& operator=(const var_value<S>& other) {
static_assert(
EIGEN_PREDICATE_SAME_MATRIX_SIZE(T, S),
Expand All @@ -1032,16 +1035,63 @@ class var_value<T, internal::require_matrix_var_value<T>> {
}

/**
* Assignment of another var value, when either this or the other one does not
* Assignment of another var value, when the `this` does not
* contain a plain type.
* @tparam S type of the value in the `var_value` to assing
* @tparam S type of the value in the `var_value` to assign
* @param other the value to assign
* @return this
*/
template <typename S, typename T_ = T,
require_assignable_t<value_type, S>* = nullptr,
require_any_not_plain_type_t<T_, S>* = nullptr>
require_not_plain_type_t<S>* = nullptr,
require_plain_type_t<T_>* = nullptr>
inline var_value<T>& operator=(const var_value<S>& other) {
// If vi_ is nullptr then the var needs initialized via copy constructor
if (!(this->vi_)) {
*this = var_value<T>(other);
return *this;
}
arena_t<plain_type_t<T>> prev_val(vi_->val_.rows(), vi_->val_.cols());
prev_val.hard_copy(vi_->val_);
vi_->val_.hard_copy(other.val());
// no need to change any adjoints - these are just zeros before the reverse
// pass

reverse_pass_callback(
[this_vi = this->vi_, other_vi = other.vi_, prev_val]() mutable {
this_vi->val_.hard_copy(prev_val);

// we have no way of detecting aliasing between this->vi_->adj_ and
// other.vi_->adj_, so we must copy adjoint before reseting to zero

// we can reuse prev_val instead of allocating a new matrix
prev_val.hard_copy(this_vi->adj_);
this_vi->adj_.setZero();
other_vi->adj_ += prev_val;
});
return *this;
}
/**
* Assignment of another var value, when either both `this` or other does not
* contain a plain type.
* @tparam S type of the value in the `var_value` to assign
* @param other the value to assign
* @return this
*/
template <typename S, typename T_ = T,
require_assignable_t<value_type, S>* = nullptr,
require_any_not_plain_type_t<T_, S>* = nullptr,
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
require_not_plain_type_t<T_>* = nullptr>
inline var_value<T>& operator=(const var_value<S>& other) {
// If vi_ is nullptr then the var needs initialized via copy constructor
if (!(this->vi_)) {
[]() STAN_COLD_PATH {
throw std::domain_error(
"var_value<matrix>::operator=(var_value<expression>):"
" Internal Bug! Please report this with an example"
" of your model to the Stan math github repository.");
}();
}
arena_t<plain_type_t<T>> prev_val = vi_->val_;
vi_->val_ = other.val();
// no need to change any adjoints - these are just zeros before the reverse
Expand All @@ -1055,7 +1105,7 @@ class var_value<T, internal::require_matrix_var_value<T>> {
// other.vi_->adj_, so we must copy adjoint before reseting to zero

// we can reuse prev_val instead of allocating a new matrix
prev_val = this_vi->adj_;
prev_val.hard_copy(this_vi->adj_);
this_vi->adj_.setZero();
other_vi->adj_ += prev_val;
});
Expand Down
47 changes: 47 additions & 0 deletions test/unit/math/rev/core/var_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,3 +910,50 @@ TEST_F(AgradRev, matrix_compile_time_conversions) {
EXPECT_MATRIX_FLOAT_EQ(colvec.val(), rowvec.val());
EXPECT_MATRIX_FLOAT_EQ(x11.val(), rowvec.val());
}

TEST_F(AgradRev, assign_nan) {
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
using stan::math::var_value;
using var_vector = var_value<Eigen::Matrix<double, -1, 1>>;
using stan::math::var;
Eigen::VectorXd x_val(10);
for (int i = 0; i < 10; ++i) {
x_val(i) = i + 0.1;
}
var_vector x(x_val);
var_vector y = var_vector(Eigen::Matrix<double, -1, 1>::Constant(
10, std::numeric_limits<double>::quiet_NaN()));
y = stan::math::head(x, 10);
var sigma = 1.0;
var lp = stan::math::normal_lpdf<false>(y, 0, sigma);
lp.grad();
Eigen::VectorXd x_ans_adj(10);
for (int i = 0; i < 10; ++i) {
x_ans_adj(i) = -(i + 0.1);
}
EXPECT_MATRIX_EQ(x.adj(), x_ans_adj);
Eigen::VectorXd y_ans_adj = Eigen::VectorXd::Zero(10);
EXPECT_MATRIX_EQ(y_ans_adj, y.adj());
}

TEST_F(AgradRev, assign_nullptr_vari) {
using stan::math::var_value;
using var_vector = var_value<Eigen::Matrix<double, -1, 1>>;
using stan::math::var;
Eigen::VectorXd x_val(10);
for (int i = 0; i < 10; ++i) {
x_val(i) = i + 0.1;
}
var_vector x(x_val);
var_vector y;
y = stan::math::head(x, 10);
var sigma = 1.0;
var lp = stan::math::normal_lpdf<false>(y, 0, sigma);
lp.grad();
Eigen::VectorXd x_ans_adj(10);
for (int i = 0; i < 10; ++i) {
x_ans_adj(i) = -(i + 0.1);
}
EXPECT_MATRIX_EQ(x.adj(), x_ans_adj);
Eigen::VectorXd y_ans_adj = Eigen::VectorXd::Zero(10);
EXPECT_MATRIX_EQ(y_ans_adj, y.adj());
}