From 4ac8ca56de520d0e0e3b95a9f4bdb99e6dfd858f Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Wed, 3 Jul 2024 11:32:47 +0200 Subject: [PATCH] Fix the derivative of string literals in forward mode This commit makes Clad set the derivative of string literals to an empty string in the forward mode. Differentiating string literals used to produce integer zero literals previously. --- lib/Differentiator/BaseForwardModeVisitor.cpp | 7 ++-- test/FirstDerivative/Variables.C | 33 +++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 4dd6fd8f1..e068afd05 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1662,10 +1662,9 @@ BaseForwardModeVisitor::VisitCharacterLiteral(const CharacterLiteral* CL) { } StmtDiff BaseForwardModeVisitor::VisitStringLiteral(const StringLiteral* SL) { - llvm::APInt zero(m_Context.getIntWidth(m_Context.IntTy), /*value*/ 0); - auto* constant0 = - IntegerLiteral::Create(m_Context, zero, m_Context.IntTy, noLoc); - return StmtDiff(Clone(SL), constant0); + return StmtDiff(Clone(SL), StringLiteral::Create( + m_Context, "", SL->getKind(), SL->isPascal(), + SL->getType(), utils::GetValidSLoc(m_Sema))); } StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) { diff --git a/test/FirstDerivative/Variables.C b/test/FirstDerivative/Variables.C index 1bfb93527..9d3af4f08 100644 --- a/test/FirstDerivative/Variables.C +++ b/test/FirstDerivative/Variables.C @@ -4,6 +4,7 @@ #include "clad/Differentiator/Differentiator.h" #include +#include double f_x(double x) { double t0 = x; @@ -86,12 +87,44 @@ double f_sin(double x, double y) { // CHECK-NEXT: return _d_xt + _d_yt; // CHECK-NEXT: } +double f_string(double x) { + const char *s = "string literal"; + return x; +} + +// CHECK: double f_string_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: const char *_d_s = ""; +// CHECK-NEXT: const char *s = "string literal"; +// CHECK-NEXT: return _d_x; +// CHECK-NEXT: } + +namespace clad { +namespace custom_derivatives { +clad::ValueAndPushforward string_test_pushforward(double x, const char s[], double _d_x, const char *_d_s) { + return {0, 0}; +} +}} +double string_test(double x, const char s[]) { + return 1; +} +double f_string_call(double x) { + return string_test(x, "string literal"); +} + +// CHECK: double f_string_call_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: clad::ValueAndPushforward _t0 = clad::custom_derivatives::string_test_pushforward(x, "string literal", _d_x, ""); +// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: } int main() { clad::differentiate(f_x, 0); clad::differentiate(f_ops1, 0); clad::differentiate(f_ops2, 0); clad::differentiate(f_sin, 0); + clad::differentiate(f_string, 0); + clad::differentiate(f_string_call, 0); }