Skip to content

Commit

Permalink
added a fct to wrap scalar values for derivs
Browse files Browse the repository at this point in the history
  • Loading branch information
Konrad1991 committed Jun 4, 2024
1 parent 77db7cd commit b160944
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
3 changes: 2 additions & 1 deletion include/etr_bits/Allocation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "Allocation/Colon.hpp"
#include "Allocation/Matrix.hpp"
#include "Allocation/Rep.hpp"
#include "Allocation/ScalarForDeriv.hpp"
#include "Allocation/Vector.hpp"

#endif
#endif
19 changes: 19 additions & 0 deletions include/etr_bits/Allocation/ScalarForDeriv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef SCALAR_DERIV_ETR_H
#define SCALAR_DERIV_ETR_H

#include "../Core.hpp"

namespace etr {

template <int Idx, typename AV, typename T>
requires std::is_arithmetic_v<T>
inline auto scalarDeriv(AV &av, T s) {
av.varConstants[Idx].resize(1);
av.varConstants[Idx][0] = s;
Vec<T, VarPointer<decltype(av), Idx, -1>, ConstantTypeTrait> ret(av);
return ret;
}

} // namespace etr

#endif
4 changes: 2 additions & 2 deletions tests/Derivatives_Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ int main() {
// NOTE: test minus
{
std::cout << "test minus" << std::endl;
AllVars<2, 0, 0, 2> av(0, 0);
AllVars<2, 0, 0, 3> av(0, 0);
Vec<double, VarPointer<decltype(av), 0, 0>, VariableTypeTrait> a(av);
Vec<double, VarPointer<decltype(av), 1, 0>, VariableTypeTrait> b(av);

Expand Down Expand Up @@ -69,7 +69,7 @@ int main() {

std::cout << "\n"
<< "a = a / b" << std::endl;
a = a / b;
a = a / scalarDeriv<3>(av, 3.14);
print(a, av);
print(b, av);
print(get_derivs(a));
Expand Down

0 comments on commit b160944

Please sign in to comment.