From 1f4e2a3882e6845307f4895f804b34da50d80afa Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Mon, 22 Jul 2024 17:21:36 +0200 Subject: [PATCH 1/2] Try to improve literal types --- lib/Differentiator/BaseForwardModeVisitor.cpp | 13 ++++++++++--- test/FirstDerivative/CallArguments.C | 2 +- test/ForwardMode/Functors.C | 2 +- test/ForwardMode/Pointer.C | 4 ++-- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 7a70f3bcb..dbd8ca440 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1070,8 +1070,11 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { // If DRE is of type pointer, then the derivative is a null pointer. if (clonedDRE->getType()->isPointerType()) return StmtDiff(clonedDRE, nullptr); + QualType literalTy = utils::GetValueType(clonedDRE->getType()); + if (!literalTy->isRealType()) + literalTy = m_Context.IntTy; return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral( - m_Context.IntTy, m_Context, /*val=*/0)); + literalTy, m_Context, /*val=*/0)); } StmtDiff BaseForwardModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) { @@ -1374,8 +1377,12 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { } else if (opKind == UnaryOperatorKind::UO_Deref) { if (Expr* dx = diff.getExpr_dx()) return StmtDiff(op, BuildOp(opKind, dx)); - return StmtDiff(op, ConstantFolder::synthesizeLiteral( - m_Context.IntTy, m_Context, /*val=*/0)); + QualType literalTy = + utils::GetValueType(UnOp->getSubExpr()->getType()->getPointeeType()); + if (!literalTy->isRealType()) + literalTy = m_Context.IntTy; + return StmtDiff( + op, ConstantFolder::synthesizeLiteral(literalTy, m_Context, /*val=*/0)); } else if (opKind == UnaryOperatorKind::UO_AddrOf) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_LNot) { diff --git a/test/FirstDerivative/CallArguments.C b/test/FirstDerivative/CallArguments.C index efe05909e..077a85509 100644 --- a/test/FirstDerivative/CallArguments.C +++ b/test/FirstDerivative/CallArguments.C @@ -165,7 +165,7 @@ float f_call_inline_fxn(float *params, float const *obs, float const *xlArr) { // CHECK: float f_call_inline_fxn_darg0_0(float *params, const float *obs, const float *xlArr) { // CHECK-NEXT: clad::ValueAndPushforward _t0 = getBin_pushforward(0., 1., params[0], 1, 0., 0., 1.F, 0); -// CHECK-NEXT: const float _d_t116 = 0; +// CHECK-NEXT: const float _d_t116 = 0.F; // CHECK-NEXT: const float t116 = *(xlArr + _t0.value); // CHECK-NEXT: return _d_t116 * params[0] + t116 * 1.F; // CHECK-NEXT: } diff --git a/test/ForwardMode/Functors.C b/test/ForwardMode/Functors.C index fdfcab92b..85026942c 100644 --- a/test/ForwardMode/Functors.C +++ b/test/ForwardMode/Functors.C @@ -430,7 +430,7 @@ int main() { // CHECK-NEXT: double _d_i = 1; // CHECK-NEXT: double _d_jj = 0; // CHECK-NEXT: double _t0 = x * i; - // CHECK-NEXT: return (0 * i + x * _d_i) * jj + _t0 * _d_jj; + // CHECK-NEXT: return (0. * i + x * _d_i) * jj + _t0 * _d_jj; // CHECK-NEXT: } auto lambdaNNS = outer::inner::lambdaNNS; diff --git a/test/ForwardMode/Pointer.C b/test/ForwardMode/Pointer.C index 7d98cd2f8..9f724fbb2 100644 --- a/test/ForwardMode/Pointer.C +++ b/test/ForwardMode/Pointer.C @@ -198,7 +198,7 @@ double fn9(double* params, const double *constants) { } // CHECK: double fn9_darg0_0(double *params, const double *constants) { -// CHECK-NEXT: double _d_c0 = 0; +// CHECK-NEXT: double _d_c0 = 0.; // CHECK-NEXT: double c0 = *constants; // CHECK-NEXT: return 1. * c0 + params[0] * _d_c0; // CHECK-NEXT: } @@ -209,7 +209,7 @@ double fn10(double *params, const double *constants) { } // CHECK: double fn10_darg0_0(double *params, const double *constants) { -// CHECK-NEXT: double _d_c0 = 0; +// CHECK-NEXT: double _d_c0 = 0.; // CHECK-NEXT: double c0 = *(constants + 0); // CHECK-NEXT: return 1. * c0 + params[0] * _d_c0; // CHECK-NEXT: } From 97acf5fdfc81c3ba697436c8fe774a34cb70d046 Mon Sep 17 00:00:00 2001 From: kchristin Date: Sat, 3 Aug 2024 12:15:29 +0300 Subject: [PATCH 2/2] Get enum's underlying value type --- lib/Differentiator/CladUtils.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 3b6a379e0..d0ff345ed 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -372,6 +372,10 @@ namespace clad { else if (T->isArrayType()) valueType = T->getPointeeOrArrayElementType()->getCanonicalTypeInternal(); + else if (T->isEnumeralType()){ + if (const auto* ET = dyn_cast(T)) + valueType = ET->getDecl()->getIntegerType(); + } valueType.removeLocalConst(); return valueType; }