From f86eedee99509b2eda4d16d5660ca8ff051c6d55 Mon Sep 17 00:00:00 2001 From: Christina Koutsou <74819775+kchristin22@users.noreply.github.com> Date: Sat, 12 Oct 2024 22:00:07 +0300 Subject: [PATCH] Fix synthesizing literals function for enums (#1113) --- lib/Differentiator/ConstantFolder.cpp | 16 ++- test/Gradient/Switch.C | 153 ++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 1 deletion(-) diff --git a/lib/Differentiator/ConstantFolder.cpp b/lib/Differentiator/ConstantFolder.cpp index 6439c6eaa..900e87a90 100644 --- a/lib/Differentiator/ConstantFolder.cpp +++ b/lib/Differentiator/ConstantFolder.cpp @@ -7,6 +7,7 @@ //----------------------------------------------------------------------------// #include "ConstantFolder.h" +#include "clad/Differentiator/Compatibility.h" #include "clang/AST/ASTContext.h" @@ -141,7 +142,20 @@ namespace clad { // SourceLocation noLoc; Expr* Result = 0; QT = QT.getCanonicalType(); - if (QT->isPointerType()) { + if (QT->isEnumeralType()) { + llvm::APInt APVal(C.getIntWidth(QT), val, + QT->isSignedIntegerOrEnumerationType()); + Result = clad::synthesizeLiteral( + dyn_cast(QT)->getDecl()->getIntegerType(), C, APVal); + SourceLocation noLoc; + Expr* cast = CXXStaticCastExpr::Create( + C, QT, CLAD_COMPAT_ExprValueKind_R_or_PR_Value, + clang::CastKind::CK_IntegralCast, Result, /*CXXCastPath=*/nullptr, + C.getTrivialTypeSourceInfo(QT, noLoc) + CLAD_COMPAT_CLANG12_CastExpr_DefaultFPO, + noLoc, noLoc, SourceRange()); + Result = cast; + } else if (QT->isPointerType()) { Result = clad::synthesizeLiteral(QT, C); } else if (QT->isBooleanType()) { Result = clad::synthesizeLiteral(QT, C, (bool)val); diff --git a/test/Gradient/Switch.C b/test/Gradient/Switch.C index 6e18bc04d..98a176807 100644 --- a/test/Gradient/Switch.C +++ b/test/Gradient/Switch.C @@ -682,6 +682,146 @@ double fn7(double u, double v) { // CHECK-NEXT: } // CHECK-NEXT: } +enum Op { + Add, + Sub, + Mul, + Div +}; + +double fn24(double x, double y, Op op) { + double res = 0; + switch (op) { + case Add: + res = x + y; + break; + case Sub: + res = x - y; + break; + case Mul: + res = x * y; + break; + case Div: + res = x / y; + break; + } + return res; +} + +// CHECK: void fn24_grad_0_1(double x, double y, Op op, double *_d_x, double *_d_y) { +// CHECK-NEXT: Op _d_op = static_cast(0U); +// CHECK-NEXT: Op _cond0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double _d_res = 0.; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: _cond0 = op; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: { +// CHECK-NEXT: case Add: +// CHECK-NEXT: res = x + y; +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, {{1U|1UL}}); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case Sub: +// CHECK-NEXT: res = x - y; +// CHECK-NEXT: _t2 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, {{2U|2UL}}); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case Mul: +// CHECK-NEXT: res = x * y; +// CHECK-NEXT: _t3 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, {{3U|3UL}}); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case Div: +// CHECK-NEXT: res = x / y; +// CHECK-NEXT: _t4 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, {{4U|4UL}}); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t1, {{5U|5UL}}); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t1)) { +// CHECK-NEXT: case {{5U|5UL}}: +// CHECK-NEXT: ; +// CHECK-NEXT: case {{4U|4UL}}: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t4; +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d3 / y; +// CHECK-NEXT: double _r0 = _r_d3 * -(x / (y * y)); +// CHECK-NEXT: _d_y += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: if (Div == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case {{3U|3UL}}: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t3; +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d2 * y; +// CHECK-NEXT: _d_y += x * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (Mul == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case {{2U|2UL}}: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t2; +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d1; +// CHECK-NEXT: _d_y += -_r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: if (Sub == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case {{1U|1UL}}: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_x += _r_d0; +// CHECK-NEXT: _d_y += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (Add == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT:} + #define TEST_2(F, x, y) \ { \ @@ -691,6 +831,14 @@ double fn7(double u, double v) { printf("{%.2f, %.2f}\n", result[0], result[1]); \ } +#define TEST_2_Op(F, x, y, op) \ +{ \ + result[0] = result[1] = 0; \ + auto d_##F = clad::gradient(F, "x, y"); \ + d_##F.execute(x, y, op, result, result + 1); \ + printf("{%.2f, %.2f}\n", result[0], result[1]); \ +} + int main() { double result[2] = {}; @@ -705,4 +853,9 @@ int main() { TEST_GRADIENT(fn6, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {5.00, 3.00} TEST_GRADIENT(fn7, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {3.00, 2.00} + + TEST_2_Op(fn24, 3, 5, Add); // CHECK-EXEC: {1.00, 1.00} + TEST_2_Op(fn24, 3, 5, Sub); // CHECK-EXEC: {1.00, -1.00} + TEST_2_Op(fn24, 3, 5, Mul); // CHECK-EXEC: {5.00, 3.00} + TEST_2_Op(fn24, 3, 5, Div); // CHECK-EXEC: {0.20, -0.12} }