Skip to content

Commit

Permalink
Fix support for non-type template params
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jan 15, 2024
1 parent 40d8bec commit 479aa27
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 4 deletions.
4 changes: 2 additions & 2 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,14 @@ namespace clad {
/// \n Variable declaration cannot be added to code directly, instead we
/// have to build a declaration staement.
/// \param[in] D The declaration to build a declaration statement from.
/// \returns The declration statement expression corresponding to the input
/// \returns The declaration statement expression corresponding to the input
/// variable declaration.
clang::DeclStmt* BuildDeclStmt(clang::Decl* D);
/// Wraps a set of declarations in a DeclStmt.
/// \n This function is useful to wrap multiple variable declarations in one
/// single declaration statement.
/// \param[in] D The declarations to build a declaration statement from.
/// \returns The declration statemetn expression corresponding to the input
/// \returns The declaration statement expression corresponding to the input
/// variable declaration.
clang::DeclStmt* BuildDeclStmt(llvm::MutableArrayRef<clang::Decl*> DS);

Expand Down
6 changes: 5 additions & 1 deletion lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,13 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
}

StmtDiff BaseForwardModeVisitor::VisitStmt(const Stmt* S) {
// if S is a SubstNonTypeTemplateParmExpr, we need to visit its replacement
// expression instead.
if (const auto* SNTTPE = dyn_cast<SubstNonTypeTemplateParmExpr>(S))
return Visit(SNTTPE->getReplacement());
// Unknown stmt, just clone it and return a warning.
diag(DiagnosticsEngine::Warning, S->getBeginLoc(),
"attempted to differentiate unsupported statement, no changes applied");
// Unknown stmt, just clone it.
return StmtDiff(Clone(S));
}

Expand Down
6 changes: 5 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,11 +738,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}
StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) {
// if S is a SubstNonTypeTemplateParmExpr, we need to visit its replacement
// expression instead.
if (const auto* SNTTPE = dyn_cast<SubstNonTypeTemplateParmExpr>(S))
return Visit(SNTTPE->getReplacement());
// Unknown stmt, just clone it and return a warning.
diag(
DiagnosticsEngine::Warning,
S->getBeginLoc(),
"attempted to differentiate unsupported statement, no changes applied");
// Unknown stmt, just clone it.
return StmtDiff(Clone(S));
}

Expand Down
21 changes: 21 additions & 0 deletions test/FirstDerivative/BasicArithmeticMulDiv.C
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,23 @@ double m_10(double x, bool flag) {
// CHECK-NEXT: return flag ? (((_d_x = _d_x * 2 + x * 0) , (x *= 2)) , (_d_x * x + x * _d_x)) : (((_d_x += 0) , (x += 1)) , (_d_x * x + x * _d_x));
// CHECK-NEXT: }

template<size_t N>
double m_11(double x) {
const size_t maxN = 53;
const size_t m = maxN < N ? maxN : N;
return x*m;
}

// CHECK: double m_11_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: const size_t _d_maxN = 0;
// CHECK-NEXT: const size_t maxN = 53;
// CHECK-NEXT: bool _t0 = maxN < 64UL;
// CHECK-NEXT: const size_t _d_m = _t0 ? _d_maxN : 0UL;
// CHECK-NEXT: const size_t m = _t0 ? maxN : 64UL;
// CHECK-NEXT: return _d_x * m + x * _d_m;
// CHECK-NEXT: }

int d_1(int x) {
int y = 4;
return y / y; // == 0
Expand Down Expand Up @@ -190,6 +207,7 @@ double m_7_darg0(double x);
double m_8_darg0(double x);
double m_9_darg0(double x);
double m_10_darg0(double x, bool flag);
double m_11_darg0(double x);
int d_1_darg0(int x);
int d_2_darg0(int x);
int d_3_darg0(int x);
Expand Down Expand Up @@ -230,6 +248,9 @@ int main () {
printf("Result is = %f\n", m_10_darg0(1, true)); // CHECK-EXEC: Result is = 8
printf("Result is = %f\n", m_10_darg0(1, false)); // CHECK-EXEC: Result is = 4

clad::differentiate(m_11<64>, 0);
printf("Result is = %f\n", m_11_darg0(1)); // CHECK-EXEC: Result is = 53

clad::differentiate(d_1, 0);
printf("Result is = %d\n", d_1_darg0(1)); // CHECK-EXEC: Result is = 0

Expand Down
26 changes: 26 additions & 0 deletions test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,27 @@ void fn_increment_in_return_grad(double i, double j, clad::array_ref<double> _d_
// CHECK-NEXT: * _d_i += _d_temp;
// CHECK-NEXT: }

template<size_t N>
double fn_template_non_type(double x) {
const size_t maxN = 53;
const size_t m = maxN < N ? maxN : N;
return x*m;
}

// CHECK: void fn_template_non_type_grad(double x, clad::array_ref<double> _d_x) {
// CHECK-NEXT: size_t _d_maxN = 0;
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: size_t _d_m = 0;
// CHECK-NEXT: const size_t maxN = 53;
// CHECK-NEXT: _cond0 = maxN < 15UL;
// CHECK-NEXT: const size_t m = _cond0 ? maxN : 15UL;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: * _d_x += 1 * m;
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: _d_maxN += _d_m;
// CHECK-NEXT: }

#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -812,4 +833,9 @@ int main() {
TEST_GRADIENT(fn_global_var_use, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {7.00, 0.00}

TEST(fn_increment_in_return, 3, 2); // CHECK-EXEC: Result is = {7.00, 0.00}

auto fn_template_non_type_dx = clad::gradient(fn_template_non_type<15>);
double x = 5, dx = 0;
fn_template_non_type_dx.execute(x, &dx);
printf("Result is = %.2f\n", dx); // CHECK-EXEC: Result is = 15.00
}

0 comments on commit 479aa27

Please sign in to comment.