Skip to content

Commit

Permalink
Fix support for non-type template params
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jan 15, 2024
1 parent 40d8bec commit 551a423
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 5 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class BaseForwardModeVisitor
StmtDiff
VisitUnaryExprOrTypeTraitExpr(const clang::UnaryExprOrTypeTraitExpr* UE);
StmtDiff VisitPseudoObjectExpr(const clang::PseudoObjectExpr* POE);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);

virtual clang::QualType
GetPushForwardDerivativeType(clang::QualType ParamType);
Expand Down
2 changes: 2 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ namespace clad {
VisitMaterializeTemporaryExpr(const clang::MaterializeTemporaryExpr* MTE);
StmtDiff VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE);
VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);

/// A helper method to differentiate a single Stmt in the reverse mode.
/// Internally, calls Visit(S, expr). Its result is wrapped into a
Expand Down
4 changes: 2 additions & 2 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,14 @@ namespace clad {
/// \n Variable declaration cannot be added to code directly, instead we
/// have to build a declaration staement.
/// \param[in] D The declaration to build a declaration statement from.
/// \returns The declration statement expression corresponding to the input
/// \returns The declaration statement expression corresponding to the input
/// variable declaration.
clang::DeclStmt* BuildDeclStmt(clang::Decl* D);
/// Wraps a set of declarations in a DeclStmt.
/// \n This function is useful to wrap multiple variable declarations in one
/// single declaration statement.
/// \param[in] D The declarations to build a declaration statement from.
/// \returns The declration statemetn expression corresponding to the input
/// \returns The declaration statement expression corresponding to the input
/// variable declaration.
clang::DeclStmt* BuildDeclStmt(llvm::MutableArrayRef<clang::Decl*> DS);

Expand Down
7 changes: 6 additions & 1 deletion lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2081,4 +2081,9 @@ StmtDiff BaseForwardModeVisitor::VisitPseudoObjectExpr(
return {Clone(POE),
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0)};
}
} // end namespace clad

StmtDiff BaseForwardModeVisitor::VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP) {
return Visit(NTTP->getReplacement());
}
} // end namespace clad
8 changes: 6 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,8 +739,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) {
diag(
DiagnosticsEngine::Warning,
S->getBeginLoc(),
DiagnosticsEngine::Warning, S->getBeginLoc(),
"attempted to differentiate unsupported statement, no changes applied");
// Unknown stmt, just clone it.
return StmtDiff(Clone(S));
Expand Down Expand Up @@ -3499,6 +3498,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return MTEDiff;
}

StmtDiff ReverseModeVisitor::VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP) {
return Visit(NTTP->getReplacement());
}

QualType ReverseModeVisitor::GetParameterDerivativeType(QualType yType,
QualType xType) {

Expand Down
21 changes: 21 additions & 0 deletions test/FirstDerivative/BasicArithmeticMulDiv.C
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,23 @@ double m_10(double x, bool flag) {
// CHECK-NEXT: return flag ? (((_d_x = _d_x * 2 + x * 0) , (x *= 2)) , (_d_x * x + x * _d_x)) : (((_d_x += 0) , (x += 1)) , (_d_x * x + x * _d_x));
// CHECK-NEXT: }

template<size_t N>
double m_11(double x) {
const size_t maxN = 53;
const size_t m = maxN < N ? maxN : N;
return x*m;
}

// CHECK: double m_11_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: const size_t _d_maxN = 0;
// CHECK-NEXT: const size_t maxN = 53;
// CHECK-NEXT: bool _t0 = maxN < 64UL;
// CHECK-NEXT: const size_t _d_m = _t0 ? _d_maxN : 0UL;
// CHECK-NEXT: const size_t m = _t0 ? maxN : 64UL;
// CHECK-NEXT: return _d_x * m + x * _d_m;
// CHECK-NEXT: }

int d_1(int x) {
int y = 4;
return y / y; // == 0
Expand Down Expand Up @@ -190,6 +207,7 @@ double m_7_darg0(double x);
double m_8_darg0(double x);
double m_9_darg0(double x);
double m_10_darg0(double x, bool flag);
double m_11_darg0(double x);
int d_1_darg0(int x);
int d_2_darg0(int x);
int d_3_darg0(int x);
Expand Down Expand Up @@ -230,6 +248,9 @@ int main () {
printf("Result is = %f\n", m_10_darg0(1, true)); // CHECK-EXEC: Result is = 8
printf("Result is = %f\n", m_10_darg0(1, false)); // CHECK-EXEC: Result is = 4

clad::differentiate(m_11<64>, 0);
printf("Result is = %f\n", m_11_darg0(1)); // CHECK-EXEC: Result is = 53

clad::differentiate(d_1, 0);
printf("Result is = %d\n", d_1_darg0(1)); // CHECK-EXEC: Result is = 0

Expand Down
26 changes: 26 additions & 0 deletions test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,27 @@ void fn_increment_in_return_grad(double i, double j, clad::array_ref<double> _d_
// CHECK-NEXT: * _d_i += _d_temp;
// CHECK-NEXT: }

template<size_t N>
double fn_template_non_type(double x) {
const size_t maxN = 53;
const size_t m = maxN < N ? maxN : N;
return x*m;
}

// CHECK: void fn_template_non_type_grad(double x, clad::array_ref<double> _d_x) {
// CHECK-NEXT: size_t _d_maxN = 0;
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: size_t _d_m = 0;
// CHECK-NEXT: const size_t maxN = 53;
// CHECK-NEXT: _cond0 = maxN < 15UL;
// CHECK-NEXT: const size_t m = _cond0 ? maxN : 15UL;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: * _d_x += 1 * m;
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: _d_maxN += _d_m;
// CHECK-NEXT: }

#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -812,4 +833,9 @@ int main() {
TEST_GRADIENT(fn_global_var_use, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {7.00, 0.00}

TEST(fn_increment_in_return, 3, 2); // CHECK-EXEC: Result is = {7.00, 0.00}

auto fn_template_non_type_dx = clad::gradient(fn_template_non_type<15>);
double x = 5, dx = 0;
fn_template_non_type_dx.execute(x, &dx);
printf("Result is = %.2f\n", dx); // CHECK-EXEC: Result is = 15.00
}

0 comments on commit 551a423

Please sign in to comment.