-
-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow arena_matrix to use move semantics #2928
Changes from 30 commits
00a9f43
056e297
3a6c11e
fa8a464
3f07884
e5e226b
d66673b
53dc399
fcca84d
c13855f
2f42a0a
4189367
b297df5
5b76fc8
5fbaf55
1bfd431
119099d
90b23ab
9ce5c50
1871c64
42cef50
0ec5253
d6b892d
984bdf8
a85e786
1b0504a
4f926cf
0d34e03
3193ad4
b37d163
36e0bd3
eb6276c
8acdb6d
11da0dd
8a1f9f0
c29860e
0707438
8b96c45
0a168c3
fec3689
6b8ae15
c83cbfc
f594b2c
d1feb19
33f0825
5c5dfc6
b2af1cd
b0815c4
515d621
a604fa4
6a71cfb
a2124c1
1e906d9
de7d11e
d89610e
eee02d8
c5f983a
86a3e83
1f25ef7
c9f76db
7a5a009
df305d6
08d8a22
34cf554
a3a88a5
f748825
04124da
9f759e1
045073f
e73651b
7a9601d
d45dff2
91ea4c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
#include <stan/math/prim/fun/Eigen.hpp> | ||
#include <stan/math/rev/core/chainable_alloc.hpp> | ||
#include <stan/math/rev/core/chainablestack.hpp> | ||
|
||
#include <stan/math/rev/core/chainable_object.hpp> | ||
namespace stan { | ||
namespace math { | ||
|
||
|
@@ -54,7 +54,7 @@ class arena_matrix : public Eigen::Map<MatrixType> { | |
size) {} | ||
|
||
/** | ||
* Constructs `arena_matrix` from an expression. | ||
* Constructs `arena_matrix` from an expression | ||
* @param other expression | ||
*/ | ||
template <typename T, require_eigen_t<T>* = nullptr> | ||
|
@@ -73,6 +73,50 @@ class arena_matrix : public Eigen::Map<MatrixType> { | |
*this = other; | ||
} | ||
|
||
/** | ||
* Constructs `arena_matrix` from an expression, then send it to either the | ||
* object stack or memory arena. | ||
* @tparam T A type that inherits from Eigen::DenseBase that is not an | ||
* `arena_matrix`. | ||
* @param other expression | ||
* @note When T is both an rvalue and a plain type, the expression is moved to | ||
* the object stack. However when T is an lvalue, or an rvalue that is not a | ||
* plain type, the expression is copied to the memory arena. | ||
*/ | ||
template <typename T, require_eigen_t<T>* = nullptr, | ||
require_not_arena_matrix_t<T>* = nullptr> | ||
arena_matrix(T&& other) // NOLINT | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to split this to two constructors using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes that's much cleaner. Though I still kept the immediately evaluated lambda as that keeps the constructor empty which I find kind of nice |
||
: Base::Map([](auto&& x) { | ||
using base_map_t = | ||
typename stan::math::arena_matrix<MatrixType>::Base; | ||
using T_t = std::decay_t<T>; | ||
if (std::is_rvalue_reference<decltype(x)>::value | ||
&& is_plain_type<T_t>::value) { | ||
// Note: plain_type_t here does nothing since T_t is plain type | ||
auto other | ||
= make_chainable_ptr(plain_type_t<MatrixType>(std::move(x))); | ||
// other has it's rows and cols swapped already if it needed that | ||
return base_map_t(&(other->coeffRef(0)), other->rows(), | ||
other->cols()); | ||
} else { | ||
base_map_t map( | ||
ChainableStack::instance_->memalloc_.alloc_array<Scalar>( | ||
x.size()), | ||
(RowsAtCompileTime == 1 && T_t::ColsAtCompileTime == 1) | ||
|| (ColsAtCompileTime == 1 | ||
&& T_t::RowsAtCompileTime == 1) | ||
? x.cols() | ||
: x.rows(), | ||
(RowsAtCompileTime == 1 && T_t::ColsAtCompileTime == 1) | ||
|| (ColsAtCompileTime == 1 | ||
&& T_t::RowsAtCompileTime == 1) | ||
? x.rows() | ||
: x.cols()); | ||
map = x; | ||
return map; | ||
} | ||
}(std::forward<T>(other))) {} | ||
|
||
/** | ||
* Constructs `arena_matrix` from an expression. This makes an assumption that | ||
* any other `Eigen::Map` also contains memory allocated in the arena. | ||
|
@@ -110,23 +154,32 @@ class arena_matrix : public Eigen::Map<MatrixType> { | |
* @param a expression to evaluate into this | ||
* @return `*this` | ||
*/ | ||
template <typename T> | ||
arena_matrix& operator=(const T& a) { | ||
// do we need to transpose? | ||
if ((RowsAtCompileTime == 1 && T::ColsAtCompileTime == 1) | ||
|| (ColsAtCompileTime == 1 && T::RowsAtCompileTime == 1)) { | ||
// placement new changes what data map points to - there is no allocation | ||
new (this) Base( | ||
ChainableStack::instance_->memalloc_.alloc_array<Scalar>(a.size()), | ||
a.cols(), a.rows()); | ||
|
||
template <typename T, require_not_arena_matrix_t<T>* = nullptr> | ||
arena_matrix& operator=(T&& a) { | ||
using T_t = std::decay_t<T>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be |
||
if (std::is_rvalue_reference<T&&>::value && is_plain_type<T_t>::value) { | ||
// Note: plain_type_t here does nothing since T_t is plain type | ||
auto other = make_chainable_ptr(plain_type_t<MatrixType>(std::move(a))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the |
||
new (this) Base(&(other->coeffRef(0)), other->rows(), other->cols()); | ||
return *this; | ||
} else { | ||
new (this) Base( | ||
ChainableStack::instance_->memalloc_.alloc_array<Scalar>(a.size()), | ||
a.rows(), a.cols()); | ||
// do we need to transpose? | ||
if ((RowsAtCompileTime == 1 && T_t::ColsAtCompileTime == 1) | ||
|| (ColsAtCompileTime == 1 && T_t::RowsAtCompileTime == 1)) { | ||
// placement new changes what data map points to - there is no | ||
// allocation | ||
new (this) Base( | ||
ChainableStack::instance_->memalloc_.alloc_array<Scalar>(a.size()), | ||
a.cols(), a.rows()); | ||
|
||
} else { | ||
new (this) Base( | ||
ChainableStack::instance_->memalloc_.alloc_array<Scalar>(a.size()), | ||
a.rows(), a.cols()); | ||
} | ||
Base::operator=(a); | ||
return *this; | ||
} | ||
Base::operator=(a); | ||
return *this; | ||
} | ||
/** | ||
* Forces hard copying matrices into an arena matrix | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.