From 1c9b121fb38d1a74be206a3ee0109f649b615f11 Mon Sep 17 00:00:00 2001 From: kchristin Date: Tue, 8 Oct 2024 13:13:43 +0300 Subject: [PATCH 1/6] Fix synth literal function for enums --- lib/Differentiator/ConstantFolder.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/lib/Differentiator/ConstantFolder.cpp b/lib/Differentiator/ConstantFolder.cpp index 6439c6eaa..8e7780ad7 100644 --- a/lib/Differentiator/ConstantFolder.cpp +++ b/lib/Differentiator/ConstantFolder.cpp @@ -7,6 +7,7 @@ //----------------------------------------------------------------------------// #include "ConstantFolder.h" +#include "clad/Differentiator/Compatibility.h" #include "clang/AST/ASTContext.h" @@ -141,7 +142,17 @@ namespace clad { // SourceLocation noLoc; Expr* Result = 0; QT = QT.getCanonicalType(); - if (QT->isPointerType()) { + if (QT->isEnumeralType()) { + llvm::APInt APVal(C.getIntWidth(QT), val, + QT->isSignedIntegerOrEnumerationType()); + Result = clad::synthesizeLiteral( + dyn_cast(QT)->getDecl()->getIntegerType(), C, APVal); + Expr* cast = ImplicitCastExpr::Create( + C, QT, clang::CastKind::CK_IntegralCast, Result, nullptr, + CLAD_COMPAT_ExprValueKind_R_or_PR_Value + CLAD_COMPAT_CLANG12_CastExpr_DefaultFPO); + Result = cast; + } else if (QT->isPointerType()) { Result = clad::synthesizeLiteral(QT, C); } else if (QT->isBooleanType()) { Result = clad::synthesizeLiteral(QT, C, (bool)val); From fd614ef551fe3645e6b1a48091478633b90e99de Mon Sep 17 00:00:00 2001 From: kchristin Date: Fri, 11 Oct 2024 12:10:18 +0300 Subject: [PATCH 2/6] Make cast of enum assignment static C++ cast and add test --- lib/Differentiator/ConstantFolder.cpp | 11 +- test/Gradient/Switch.C | 153 ++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 4 deletions(-) diff --git a/lib/Differentiator/ConstantFolder.cpp b/lib/Differentiator/ConstantFolder.cpp index 8e7780ad7..e75b918ad 100644 --- a/lib/Differentiator/ConstantFolder.cpp +++ b/lib/Differentiator/ConstantFolder.cpp @@ -147,10 +147,13 @@ namespace clad { QT->isSignedIntegerOrEnumerationType()); Result = clad::synthesizeLiteral( dyn_cast(QT)->getDecl()->getIntegerType(), C, APVal); - Expr* cast = ImplicitCastExpr::Create( - C, QT, clang::CastKind::CK_IntegralCast, Result, nullptr, - CLAD_COMPAT_ExprValueKind_R_or_PR_Value - CLAD_COMPAT_CLANG12_CastExpr_DefaultFPO); + SourceLocation noLoc; + Expr* cast = CXXStaticCastExpr::Create( + C, QT, CLAD_COMPAT_ExprValueKind_R_or_PR_Value, + clang::CastKind::CK_IntegralCast, Result, nullptr, + C.getTrivialTypeSourceInfo(QT, noLoc) + CLAD_COMPAT_CLANG12_CastExpr_DefaultFPO, + noLoc, noLoc, SourceRange()); Result = cast; } else if (QT->isPointerType()) { Result = clad::synthesizeLiteral(QT, C); diff --git a/test/Gradient/Switch.C b/test/Gradient/Switch.C index 6e18bc04d..c701c3c00 100644 --- a/test/Gradient/Switch.C +++ b/test/Gradient/Switch.C @@ -682,6 +682,146 @@ double fn7(double u, double v) { // CHECK-NEXT: } // CHECK-NEXT: } +enum Op { + Add, + Sub, + Mul, + Div +}; + +double fn24(double x, double y, Op op) { + double res = 0; + switch (op) { + case Add: + res = x + y; + break; + case Sub: + res = x - y; + break; + case Mul: + res = x * y; + break; + case Div: + res = x / y; + break; + } + return res; +} + +// CHECK: void fn24_grad_0_1(double x, double y, Op op, double *_d_x, double *_d_y) { +// CHECK-NEXT: Op _d_op = static_cast(0U); +// CHECK-NEXT: Op _cond0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double _d_res = 0.; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: _cond0 = op; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: { +// CHECK-NEXT: case Add: +// CHECK-NEXT: res = x + y; +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case Sub: +// CHECK-NEXT: res = x - y; +// CHECK-NEXT: _t2 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 2UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case Mul: +// CHECK-NEXT: res = x * y; +// CHECK-NEXT: _t3 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 3UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case Div: +// CHECK-NEXT: res = x / y; +// CHECK-NEXT: _t4 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 4UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t1, 5UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t1)) { +// CHECK-NEXT: case 5UL: +// CHECK-NEXT: ; +// CHECK-NEXT: case 4UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t4; +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d3 / y; +// CHECK-NEXT: double _r0 = _r_d3 * -(x / (y * y)); +// CHECK-NEXT: _d_y += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: if (Div == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 3UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t3; +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d2 * y; +// CHECK-NEXT: _d_y += x * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (Mul == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t2; +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d1; +// CHECK-NEXT: _d_y += -_r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: if (Sub == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d0; +// CHECK-NEXT: _d_y += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (Add == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT:} + #define TEST_2(F, x, y) \ { \ @@ -691,6 +831,14 @@ double fn7(double u, double v) { printf("{%.2f, %.2f}\n", result[0], result[1]); \ } +#define TEST_2_Op(F, x, y, op) \ +{ \ + result[0] = result[1] = 0; \ + auto d_##F = clad::gradient(F, "x, y"); \ + d_##F.execute(x, y, op, result, result + 1); \ + printf("{%.2f, %.2f}\n", result[0], result[1]); \ +} + int main() { double result[2] = {}; @@ -705,4 +853,9 @@ int main() { TEST_GRADIENT(fn6, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {5.00, 3.00} TEST_GRADIENT(fn7, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {3.00, 2.00} + + TEST_2_Op(fn24, 3, 5, Add); // CHECK-EXEC: {1.00, 1.00} + TEST_2_Op(fn24, 3, 5, Sub); // CHECK-EXEC: {1.00, -1.00} + TEST_2_Op(fn24, 3, 5, Mul); // CHECK-EXEC: {5.00, 3.00} + TEST_2_Op(fn24, 3, 5, Div); // CHECK-EXEC: {0.20, -0.12} } From 409b2d71d5203d0f8b90ff74c7ee449dcb9cba96 Mon Sep 17 00:00:00 2001 From: kchristin Date: Fri, 11 Oct 2024 12:19:52 +0300 Subject: [PATCH 3/6] Fix Arch compatibility --- test/Gradient/Switch.C | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Gradient/Switch.C b/test/Gradient/Switch.C index c701c3c00..fe8c4d0c8 100644 --- a/test/Gradient/Switch.C +++ b/test/Gradient/Switch.C @@ -712,7 +712,7 @@ double fn24(double x, double y, Op op) { // CHECK-NEXT: Op _d_op = static_cast(0U); // CHECK-NEXT: Op _cond0; // CHECK-NEXT: double _t0; -// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t1 = {}; // CHECK-NEXT: double _t2; // CHECK-NEXT: double _t3; // CHECK-NEXT: double _t4; From 9bedeafc7534241269068b6601187382b0857885 Mon Sep 17 00:00:00 2001 From: kchristin Date: Fri, 11 Oct 2024 12:29:33 +0300 Subject: [PATCH 4/6] Fix arch compatibility 2 --- test/Gradient/Switch.C | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/Gradient/Switch.C b/test/Gradient/Switch.C index fe8c4d0c8..7da7ad0fb 100644 --- a/test/Gradient/Switch.C +++ b/test/Gradient/Switch.C @@ -727,7 +727,7 @@ double fn24(double x, double y, Op op) { // CHECK-NEXT: _t0 = res; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: clad::push(_t1, 1UL); +// CHECK-NEXT: clad::push(_t1, {{1U|1UL}}); // CHECK-NEXT: break; // CHECK-NEXT: } // CHECK-NEXT: { @@ -736,7 +736,7 @@ double fn24(double x, double y, Op op) { // CHECK-NEXT: _t2 = res; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: clad::push(_t1, 2UL); +// CHECK-NEXT: clad::push(_t1, {{2U|2UL}}); // CHECK-NEXT: break; // CHECK-NEXT: } // CHECK-NEXT: { @@ -745,7 +745,7 @@ double fn24(double x, double y, Op op) { // CHECK-NEXT: _t3 = res; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: clad::push(_t1, 3UL); +// CHECK-NEXT: clad::push(_t1, {{3U|3UL}}); // CHECK-NEXT: break; // CHECK-NEXT: } // CHECK-NEXT: { @@ -754,10 +754,10 @@ double fn24(double x, double y, Op op) { // CHECK-NEXT: _t4 = res; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: clad::push(_t1, 4UL); +// CHECK-NEXT: clad::push(_t1, {{4U|4UL}}); // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: clad::push(_t1, 5UL); +// CHECK-NEXT: clad::push(_t1, {{5U|5UL}}); // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; From 3326e431bf4f99a4b82b17e34f93dfc24d37818e Mon Sep 17 00:00:00 2001 From: kchristin Date: Fri, 11 Oct 2024 12:43:55 +0300 Subject: [PATCH 5/6] Fix arch compatibility 3 --- test/Gradient/Switch.C | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/Gradient/Switch.C b/test/Gradient/Switch.C index 7da7ad0fb..98a176807 100644 --- a/test/Gradient/Switch.C +++ b/test/Gradient/Switch.C @@ -763,9 +763,9 @@ double fn24(double x, double y, Op op) { // CHECK-NEXT: _d_res += 1; // CHECK-NEXT: { // CHECK-NEXT: switch (clad::pop(_t1)) { -// CHECK-NEXT: case 5UL: +// CHECK-NEXT: case {{5U|5UL}}: // CHECK-NEXT: ; -// CHECK-NEXT: case 4UL: +// CHECK-NEXT: case {{4U|4UL}}: // CHECK-NEXT: ; // CHECK-NEXT: { // CHECK-NEXT: { @@ -779,7 +779,7 @@ double fn24(double x, double y, Op op) { // CHECK-NEXT: if (Div == _cond0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: case 3UL: +// CHECK-NEXT: case {{3U|3UL}}: // CHECK-NEXT: ; // CHECK-NEXT: { // CHECK-NEXT: { @@ -792,7 +792,7 @@ double fn24(double x, double y, Op op) { // CHECK-NEXT: if (Mul == _cond0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: case 2UL: +// CHECK-NEXT: case {{2U|2UL}}: // CHECK-NEXT: ; // CHECK-NEXT: { // CHECK-NEXT: { @@ -805,7 +805,7 @@ double fn24(double x, double y, Op op) { // CHECK-NEXT: if (Sub == _cond0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: case 1UL: +// CHECK-NEXT: case {{1U|1UL}}: // CHECK-NEXT: ; // CHECK-NEXT: { // CHECK-NEXT: { From b3cae4abba593397ebb9ea9623d45d9953fb788c Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Sat, 12 Oct 2024 21:26:45 +0300 Subject: [PATCH 6/6] Update lib/Differentiator/ConstantFolder.cpp --- lib/Differentiator/ConstantFolder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Differentiator/ConstantFolder.cpp b/lib/Differentiator/ConstantFolder.cpp index e75b918ad..900e87a90 100644 --- a/lib/Differentiator/ConstantFolder.cpp +++ b/lib/Differentiator/ConstantFolder.cpp @@ -150,7 +150,7 @@ namespace clad { SourceLocation noLoc; Expr* cast = CXXStaticCastExpr::Create( C, QT, CLAD_COMPAT_ExprValueKind_R_or_PR_Value, - clang::CastKind::CK_IntegralCast, Result, nullptr, + clang::CastKind::CK_IntegralCast, Result, /*CXXCastPath=*/nullptr, C.getTrivialTypeSourceInfo(QT, noLoc) CLAD_COMPAT_CLANG12_CastExpr_DefaultFPO, noLoc, noLoc, SourceRange());