From 28dea375dd2ddd26e4aea6801db52d4bc4ce94eb Mon Sep 17 00:00:00 2001 From: Mihail Mihov Date: Wed, 17 Jul 2024 12:54:53 +0300 Subject: [PATCH] Add support for non-differentiable attribute in reverse mode fixes #717 --- lib/Differentiator/ReverseModeVisitor.cpp | 73 +++++++-- test/Gradient/NonDifferentiable.C | 187 ++++++++++++++++++++++ test/Gradient/NonDifferentiableError.C | 51 ++++++ 3 files changed, 295 insertions(+), 16 deletions(-) create mode 100644 test/Gradient/NonDifferentiable.C create mode 100644 test/Gradient/NonDifferentiableError.C diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 05186fcdf..551acddd2 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1420,6 +1420,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Clone(CE)); } + // If the function is non_differentiable, return zero derivative. + if (clad::utils::hasNonDifferentiableAttribute(CE)) { + // Calling the function without computing derivatives + llvm::SmallVector ClonedArgs; + for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) + ClonedArgs.push_back(Clone(CE->getArg(i))); + + SourceLocation validLoc = clad::utils::GetValidSLoc(m_Sema); + Expr* Call = m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), + validLoc, ClonedArgs, validLoc) + .get(); + // Creating a zero derivative + auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, + /*val=*/0); + + // Returning the function call and zero derivative + return StmtDiff(Call, zero); + } + auto NArgs = FD->getNumParams(); // If the function has no args and is not a member function call then we // assume that it is not related to independent variables and does not @@ -2061,6 +2081,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } else if (opCode == UnaryOperatorKind::UO_Deref) { diff = Visit(E); Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr()); + + // If we have a pointer to a member expression, which is + // non-differentiable, we just return a clone of the original expression. + if (auto* ME = dyn_cast(diff.getExpr())) + if (clad::utils::hasNonDifferentiableAttribute(ME->getMemberDecl())) + return {cloneE}; + Expr* diff_dx = diff.getExpr_dx(); bool specialDThisCase = false; Expr* derivedE = nullptr; @@ -2655,9 +2682,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If `VD` is a reference to a non-local variable then also there's no // need to call `Visit` since non-local variables are not differentiated. if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) { - Expr* derivedE = BuildDeclRef(VDDerived); - if (isInitializedByNewExpr) - derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE); + Expr* derivedE = nullptr; + + if (!clad::utils::hasNonDifferentiableAttribute(VD)) { + derivedE = BuildDeclRef(VDDerived); + if (isInitializedByNewExpr) + derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE); + } + if (VD->getInit()) { if (isa(VD->getInit())) initDiff = Visit(VD->getInit()); @@ -2689,6 +2721,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(assignToZero, direction::reverse); } } + VarDecl* VDClone = nullptr; Expr* derivedVDE = nullptr; if (VDDerived) @@ -2815,19 +2848,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (declsBegin != DS->decls().end() && isa(*declsBegin)) { auto* VD = dyn_cast(*declsBegin); QualType QT = VD->getType(); - if (!QT->isPointerType()) { - auto* typeDecl = QT->getAsCXXRecordDecl(); - // We should also simply copy the original lambda. The differentiation - // of lambdas is happening in the `VisitCallExpr`. For now, only the - // declarations with lambda expressions without captures are supported. - isLambda = typeDecl && typeDecl->isLambda(); - if (isLambda) { - for (auto* D : DS->decls()) - if (auto* VD = dyn_cast(D)) - decls.push_back(VD); - Stmt* DSClone = BuildDeclStmt(decls); - return StmtDiff(DSClone, nullptr); - } + if (QT->isPointerType()) + QT = QT->getPointeeType(); + + auto* typeDecl = QT->getAsCXXRecordDecl(); + // We should also simply copy the original lambda. The differentiation + // of lambdas is happening in the `VisitCallExpr`. For now, only the + // declarations with lambda expressions without captures are supported. + isLambda = typeDecl && typeDecl->isLambda(); + if (isLambda || + (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl))) { + for (auto* D : DS->decls()) + if (auto* VD = dyn_cast(D)) + decls.push_back(VD); + Stmt* DSClone = BuildDeclStmt(decls); + return StmtDiff(DSClone, nullptr); } } @@ -2839,6 +2874,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (auto* D : DS->decls()) { if (auto* VD = dyn_cast(D)) { DeclDiff VDDiff; + if (!isLambda) VDDiff = DifferentiateVarDecl(VD); @@ -3014,6 +3050,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, "CXXMethodDecl nodes not supported yet!"); MemberExpr* clonedME = utils::BuildMemberExpr( m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName()); + if (clad::utils::hasNonDifferentiableAttribute(ME)) { + auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, + /*val=*/0); + return {clonedME, zero}; + } if (!baseDiff.getExpr_dx()) return {clonedME, nullptr}; MemberExpr* derivedME = utils::BuildMemberExpr( diff --git a/test/Gradient/NonDifferentiable.C b/test/Gradient/NonDifferentiable.C new file mode 100644 index 000000000..f86441581 --- /dev/null +++ b/test/Gradient/NonDifferentiable.C @@ -0,0 +1,187 @@ +// RUN: %cladclang %s -I%S/../../include -oNonDifferentiable.out 2>&1 | %filecheck %s +// RUN: ./NonDifferentiable.out | %filecheck_exec %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#define non_differentiable __attribute__((annotate("another_attribute"), annotate("non_differentiable"))) + +#include "clad/Differentiator/Differentiator.h" + +class SimpleFunctions1 { +public: + SimpleFunctions1() noexcept : x(0), y(0), x_pointer(&x), y_pointer(&y) {} + SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y), x_pointer(&x), y_pointer(&y) {} + double x; + non_differentiable double y; + double* x_pointer; + non_differentiable double* y_pointer; + double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; } + non_differentiable double mem_fn_2(double i, double j) { return i * j; } + double mem_fn_3(double i, double j) { return mem_fn_1(i, j) + i * j; } + double mem_fn_4(double i, double j) { return mem_fn_2(i, j) + i * j; } + double mem_fn_5(double i, double j) { return mem_fn_2(i, j) * mem_fn_1(i, j) * i; } + SimpleFunctions1 operator+(const SimpleFunctions1& other) const { + return SimpleFunctions1(x + other.x, y + other.y); + } +}; + +double fn_s1_mem_fn(double i, double j) { + SimpleFunctions1 obj(2, 3); + return obj.mem_fn_1(i, j) + i * j; +} + +double fn_s1_field(double i, double j) { + SimpleFunctions1 obj(2, 3); + return obj.x * obj.y + i * j; +} + +double fn_s1_field_pointer(double i, double j) { + SimpleFunctions1 obj(2, 3); + return (*obj.x_pointer) * (*obj.y_pointer) + i * j; +} + +double fn_s1_operator(double i, double j) { + SimpleFunctions1 obj1(2, 3); + SimpleFunctions1 obj2(3, 5); + return (obj1 + obj2).mem_fn_1(i, j); +} + +class non_differentiable SimpleFunctions2 { +public: + SimpleFunctions2() noexcept : x(0), y(0) {} + SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} + double x; + double y; + double mem_fn(double i, double j) { return (x + y) * i + i * j * j; } + SimpleFunctions2 operator+(const SimpleFunctions2& other) const { + return SimpleFunctions2(x + other.x, y + other.y); + } +}; + +double fn_s2_mem_fn(double i, double j) { + SimpleFunctions2 obj(2, 3); + return obj.mem_fn(i, j) + i * j; +} + +double fn_s2_field(double i, double j) { + SimpleFunctions2 *obj0, obj(2, 3); + return obj.x * obj.y + i * j; +} + +double fn_s2_operator(double i, double j) { + SimpleFunctions2 obj1(2, 3); + SimpleFunctions2 obj2(3, 5); + return (obj1 + obj2).mem_fn(i, j); +} + +double fn_non_diff_var(double i, double j) { + non_differentiable double k = i * i * j; + return k; +} + +#define INIT_EXPR(classname) \ + classname expr_1(2, 3); \ + classname expr_2(3, 5); + +#define TEST_CLASS(classname, name, i, j) \ + auto d_##name = clad::gradient(&classname::name); \ + double result_##name[2] = {}; \ + d_##name.execute(expr_1, i, j, &result_##name[0], &result_##name[1]); \ + printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]); + +#define TEST_FUNC(name, i, j) \ + auto d_##name = clad::gradient(&name); \ + double result_##name[2] = {}; \ + d_##name.execute(i, j, &result_##name[0], &result_##name[1]); \ + printf("%.2f %.2f\n\n", result_##name[0], result_##name[1]); + +int main() { + // FIXME: The parts of this test that are commented out are currently not working, due to bugs + // not related to the implementation of the non-differentiable attribute. + INIT_EXPR(SimpleFunctions1); + + /*TEST_CLASS(SimpleFunctions1, mem_fn_1, 3, 5)*/ + + /*TEST_CLASS(SimpleFunctions1, mem_fn_3, 3, 5)*/ + + /*TEST_CLASS(SimpleFunctions1, mem_fn_4, 3, 5)*/ + + /*TEST_CLASS(SimpleFunctions1, mem_fn_5, 3, 5)*/ + + TEST_FUNC(fn_s1_mem_fn, 3, 5) // CHECK-EXEC: 35.00 33.00 + + TEST_FUNC(fn_s1_field, 3, 5) // CHECK-EXEC: 5.00 3.00 + + TEST_FUNC(fn_s1_field_pointer, 3, 5) // CHECK-EXEC: 5.00 3.00 + + /*TEST_FUNC(fn_s1_operator, 3, 5)*/ + + TEST_FUNC(fn_s2_mem_fn, 3, 5) // CHECK-EXEC: 5.00 3.00 + + /*TEST_FUNC(fn_s2_field, 3, 5)*/ + + /*TEST_FUNC(fn_s2_operator, 3, 5)*/ + + TEST_FUNC(fn_non_diff_var, 3, 5) // CHECK-EXEC: 0.00 0.00 + + // CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j); + + // CHECK: void fn_s1_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { + // CHECK-NEXT: SimpleFunctions1 _d_obj({}); + // CHECK-NEXT: SimpleFunctions1 _t0; + // CHECK-NEXT: SimpleFunctions1 obj(2, 3); + // CHECK-NEXT: _t0 = obj; + // CHECK-NEXT: { + // CHECK-NEXT: double _r0 = 0; + // CHECK-NEXT: double _r1 = 0; + // CHECK-NEXT: _t0.mem_fn_1_pullback(i, j, 1, &_d_obj, &_r0, &_r1); + // CHECK-NEXT: *_d_i += _r0; + // CHECK-NEXT: *_d_j += _r1; + // CHECK-NEXT: *_d_i += 1 * j; + // CHECK-NEXT: *_d_j += i * 1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // CHECK: void fn_s1_field_grad(double i, double j, double *_d_i, double *_d_j) { + // CHECK-NEXT: SimpleFunctions1 _d_obj({}); + // CHECK-NEXT: SimpleFunctions1 obj(2, 3); + // CHECK-NEXT: { + // CHECK-NEXT: _d_obj.x += 1 * obj.y; + // CHECK-NEXT: *_d_i += 1 * j; + // CHECK-NEXT: *_d_j += i * 1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // CHECK: void fn_s1_field_pointer_grad(double i, double j, double *_d_i, double *_d_j) { + // CHECK-NEXT: SimpleFunctions1 _d_obj({}); + // CHECK-NEXT: SimpleFunctions1 obj(2, 3); + // CHECK-NEXT: { + // CHECK-NEXT: *_d_obj.x_pointer += 1 * (*obj.y_pointer); + // CHECK-NEXT: *_d_i += 1 * j; + // CHECK-NEXT: *_d_j += i * 1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // CHECK: void fn_s2_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { + // CHECK-NEXT: SimpleFunctions2 obj(2, 3); + // CHECK-NEXT: { + // CHECK-NEXT: *_d_i += 1 * j; + // CHECK-NEXT: *_d_j += i * 1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // CHECK: void fn_non_diff_var_grad(double i, double j, double *_d_i, double *_d_j) { + // CHECK-NEXT: double _d_k = 0; + // CHECK-NEXT: double k = i * i * j; + // CHECK-NEXT: _d_k += 1; + // CHECK-NEXT: } + + // CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { + // CHECK-NEXT: { + // CHECK-NEXT: (*_d_this).x += _d_y * i; + // CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y; + // CHECK-NEXT: *_d_i += _d_y * j * j; + // CHECK-NEXT: *_d_j += i * _d_y * j; + // CHECK-NEXT: *_d_j += i * j * _d_y; + // CHECK-NEXT: } + // CHECK-NEXT: } +} diff --git a/test/Gradient/NonDifferentiableError.C b/test/Gradient/NonDifferentiableError.C new file mode 100644 index 000000000..501c16268 --- /dev/null +++ b/test/Gradient/NonDifferentiableError.C @@ -0,0 +1,51 @@ +// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1 + +#define non_differentiable __attribute__((annotate("non_differentiable"))) + +#include "clad/Differentiator/Differentiator.h" + +extern "C" int printf(const char* fmt, ...); + +class non_differentiable SimpleFunctions2 { +public: + SimpleFunctions2() noexcept : x(0), y(0) {} + SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} + double x; + double y; + double mem_fn(double i, double j) { return (x + y) * i + i * j * j; } // expected-error {{attempted differentiation of method 'mem_fn' in class 'SimpleFunctions2', which is marked as non-differentiable}} +}; + +namespace clad { + namespace custom_derivatives { + void fn_s2_mem_fn_pullback(double i, double j, double _d_y, double* _d_i, double* _d_j) { + *_d_i = 1.5; + *_d_j = 2.5; + } + } // namespace custom_derivatives +} // namespace clad + +non_differentiable double fn_s2_mem_fn(double i, double j) { + SimpleFunctions2 obj(2, 3); + return obj.mem_fn(i, j) + i * j; +} + +#define INIT_EXPR(classname) \ + classname expr_1(2, 3); \ + classname expr_2(3, 5); + +#define TEST_CLASS(classname, name, i, j) \ + auto d_##name = clad::differentiate(&classname::name, "i"); \ + printf("%.2f\n", d_##name.execute(expr_1, i, j)); \ + printf("%.2f\n", d_##name.execute(expr_2, i, j)); \ + printf("\n"); + +#define TEST_FUNC(name, i, j) \ + auto d_##name = clad::differentiate(&name, "i"); \ + printf("%.2f\n", d_##name.execute(i, j)); \ + printf("\n"); + +int main() { + INIT_EXPR(SimpleFunctions2); + TEST_CLASS(SimpleFunctions2, mem_fn, 3, 5); + TEST_FUNC(fn_s2_mem_fn, 3, 5); // expected-error {{attempted differentiation of function 'fn_s2_mem_fn', which is marked as non-differentiable}} +}