From 8963b8cadfc6864c666c0f7c5e99820029428370 Mon Sep 17 00:00:00 2001 From: Mihail Mihov Date: Wed, 17 Jul 2024 15:23:34 +0300 Subject: [PATCH 1/2] Add tests for temporary expressions in reverse mode --- test/Gradient/TemporaryExpr.C | 55 +++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 test/Gradient/TemporaryExpr.C diff --git a/test/Gradient/TemporaryExpr.C b/test/Gradient/TemporaryExpr.C new file mode 100644 index 000000000..84598c265 --- /dev/null +++ b/test/Gradient/TemporaryExpr.C @@ -0,0 +1,55 @@ +// RUN: %cladclang %s -fno-exceptions -I%S/../../include -oTemporaryExpr.out 2>&1 | %filecheck %s +// RUN: ./TemporaryExpr.out | %filecheck_exec %s + +//CHECK-NOT: {{.*error|warning|note:.*}} +#include "clad/Differentiator/Differentiator.h" + +class SimpleFunctions { +public: + SimpleFunctions(double p_x = 0, double p_y = 0) : x(p_x), y(p_y) {} + double x, y; + double mem_fn(double i, double j) { return (x + y) * i + i * j; } + + SimpleFunctions operator*(SimpleFunctions& rhs) { + return SimpleFunctions(this->x * rhs.x, this->y * rhs.y); + } +}; + +double fn1(double i, double j) { + SimpleFunctions sf(3, 5); + + return sf.mem_fn(i, j); +} + +double fn2(double i, double j) { + SimpleFunctions sf(3 * i, 5 * j); + + return sf.mem_fn(i, j); +} + +double fn3(double i, double j) { + SimpleFunctions sf1 (3, 5); + SimpleFunctions sf2 (i, j); + + SimpleFunctions r = sf1 * sf2; + return r.mem_fn(i, j); + + /*return (sf1 * sf2).mem_fn(i, j);*/ +} + +int main() { + double result1[2] = {}; + auto fn1_grad = clad::gradient(fn1); + fn1_grad.execute(4, 5, &result1[0], &result1[1]); + printf("%f %f\n", result1[0], result1[1]); // CHECK-EXEC: 13.000000 4.000000 + + double result2[2] = {}; + auto fn2_grad = clad::gradient(fn2); + fn2_grad.execute(4, 5, &result2[0], &result2[1]); + printf("%f %f\n", result2[0], result2[1]); // CHECK-EXEC: 54.000000 24.000000 + + double result3[2] = {}; + auto fn3_grad = clad::gradient(fn3); + fn3_grad.execute(4, 5, &result3[0], &result3[1]); + printf("%f %f\n", result3[0], result3[1]); // CHECK-EXEC: 54.000000 24.000000 +} From 41a920e7b104c421eabed9ad86ef052d445fdd01 Mon Sep 17 00:00:00 2001 From: Mihail Mihov Date: Wed, 17 Jul 2024 15:23:41 +0300 Subject: [PATCH 2/2] Fix temporary expressions in reverse mode fixes #917 --- lib/Differentiator/ReverseModeVisitor.cpp | 36 +++++++++++++++++------ 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6394ee9dd..cd68edb51 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1369,6 +1369,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Clone(CE)); } + auto* CEModified = dyn_cast(Clone(CE)); + auto NArgs = FD->getNumParams(); // If the function has no args and is not a member function call then we // assume that it is not related to independent variables and does not @@ -1618,8 +1620,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, baseOriginalE = OCE->getArg(0); baseDiff = Visit(baseOriginalE); - Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); - baseDiff.updateStmt(baseDiffStore); + + if (auto* ME = dyn_cast(CEModified->getCallee())) + ME->setBase(baseDiff.getExpr()); + Expr* baseDerivative = baseDiff.getExpr_dx(); if (!baseDerivative->getType()->isPointerType()) baseDerivative = @@ -1878,13 +1882,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); return StmtDiff(resValue, nullptr, resAdjoint); } // Recreate the original call expression. + call = m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, + .ActOnCallExpr(getCurrentScope(), CEModified->getCallee(), Loc, CallArgs, Loc) .get(); - return StmtDiff(call); - return {}; + return StmtDiff(call); } Expr* ReverseModeVisitor::GetMultiArgCentralDiffCall( @@ -3747,8 +3751,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const clang::MaterializeTemporaryExpr* MTE) { // `MaterializeTemporaryExpr` node will be created automatically if it is // required by `ActOn`/`Build` Sema functions. - StmtDiff MTEDiff = Visit(clad_compat::GetSubExpr(MTE), dfdx()); - return MTEDiff; + if (dfdx()) { + StmtDiff MTEDiff = Visit(clad_compat::GetSubExpr(MTE), dfdx()); + return MTEDiff; + } + + Expr* MTEStore = + GlobalStoreAndRef(Clone(clad_compat::GetSubExpr(MTE)), "_t", + /*force=*/true); + + auto* MTEStoreDRE = dyn_cast(MTEStore); + DeclDiff MTEDerived = + DifferentiateVarDecl(dyn_cast(MTEStoreDRE->getDecl())); + addToCurrentBlock(BuildDeclStmt(MTEDerived.getDecl_dx())); + + return StmtDiff{MTEStore, BuildDeclRef(MTEDerived.getDecl_dx())}; } StmtDiff ReverseModeVisitor::VisitSubstNonTypeTemplateParmExpr( @@ -3928,8 +3945,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_DiffReq.Mode == DiffMode::experimental_pullback && !m_DiffReq->getReturnType()->isVoidType()) { IdentifierInfo* pullbackParamII = CreateUniqueIdentifier("_d_y"); - QualType pullbackType = - derivativeFnType->getParamType(m_DiffReq->getNumParams()); + /*QualType pullbackType =*/ + /* derivativeFnType->getParamType(m_DiffReq->getNumParams());*/ + QualType pullbackType = m_Context.DoubleTy; ParmVarDecl* pullbackPVD = utils::BuildParmVarDecl( m_Sema, m_Derivative, pullbackParamII, pullbackType); paramDerivatives.insert(paramDerivatives.begin(), pullbackPVD);