Skip to content

Commit

Permalink
Try to improve literal types
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak authored and vgvassilev committed Aug 2, 2024
1 parent 7d1e26c commit 61982a1
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
13 changes: 10 additions & 3 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,8 +1072,11 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
// If DRE is of type pointer, then the derivative is a null pointer.
if (clonedDRE->getType()->isPointerType())
return StmtDiff(clonedDRE, nullptr);
QualType literalTy = utils::GetValueType(clonedDRE->getType());
if (!literalTy->isRealType())
literalTy = m_Context.IntTy;
return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0));
literalTy, m_Context, /*val=*/0));
}

StmtDiff BaseForwardModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
Expand Down Expand Up @@ -1376,8 +1379,12 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
} else if (opKind == UnaryOperatorKind::UO_Deref) {
if (Expr* dx = diff.getExpr_dx())
return StmtDiff(op, BuildOp(opKind, dx));
return StmtDiff(op, ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0));
QualType literalTy =
utils::GetValueType(UnOp->getSubExpr()->getType()->getPointeeType());
if (!literalTy->isRealType())
literalTy = m_Context.IntTy;
return StmtDiff(
op, ConstantFolder::synthesizeLiteral(literalTy, m_Context, /*val=*/0));
} else if (opKind == UnaryOperatorKind::UO_AddrOf) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_LNot) {
Expand Down
2 changes: 1 addition & 1 deletion test/FirstDerivative/CallArguments.C
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ float f_call_inline_fxn(float *params, float const *obs, float const *xlArr) {

// CHECK: float f_call_inline_fxn_darg0_0(float *params, const float *obs, const float *xlArr) {
// CHECK-NEXT: clad::ValueAndPushforward<unsigned int, unsigned int> _t0 = getBin_pushforward(0., 1., params[0], 1, 0., 0., 1.F, 0);
// CHECK-NEXT: const float _d_t116 = 0;
// CHECK-NEXT: const float _d_t116 = 0.F;
// CHECK-NEXT: const float t116 = *(xlArr + _t0.value);
// CHECK-NEXT: return _d_t116 * params[0] + t116 * 1.F;
// CHECK-NEXT: }
Expand Down
2 changes: 1 addition & 1 deletion test/ForwardMode/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ int main() {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_jj = 0;
// CHECK-NEXT: double _t0 = x * i;
// CHECK-NEXT: return (0 * i + x * _d_i) * jj + _t0 * _d_jj;
// CHECK-NEXT: return (0. * i + x * _d_i) * jj + _t0 * _d_jj;
// CHECK-NEXT: }

auto lambdaNNS = outer::inner::lambdaNNS;
Expand Down
4 changes: 2 additions & 2 deletions test/ForwardMode/Pointer.C
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ double fn9(double* params, const double *constants) {
}

// CHECK: double fn9_darg0_0(double *params, const double *constants) {
// CHECK-NEXT: double _d_c0 = 0;
// CHECK-NEXT: double _d_c0 = 0.;
// CHECK-NEXT: double c0 = *constants;
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }
Expand All @@ -209,7 +209,7 @@ double fn10(double *params, const double *constants) {
}

// CHECK: double fn10_darg0_0(double *params, const double *constants) {
// CHECK-NEXT: double _d_c0 = 0;
// CHECK-NEXT: double _d_c0 = 0.;
// CHECK-NEXT: double c0 = *(constants + 0);
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }
Expand Down

0 comments on commit 61982a1

Please sign in to comment.