From 900f21d1bc8df81dabe70fcea45138ad8080153c Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Fri, 14 Jun 2024 21:17:05 +0200 Subject: [PATCH] Vector-mode support for top level custom derivatives --- .../VectorForwardModeVisitor.cpp | 8 +++++++ test/ForwardMode/VectorMode.C | 23 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index f5efbc3dd..cd9730567 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -111,6 +111,14 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, // FIXME: We should not use const_cast to get the decl context here. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* DC = const_cast(m_DiffReq->getDeclContext()); + if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( + derivedFnName, DC, vectorDiffFunctionType)) { + // Set m_Derivative for creating the overload. + m_Derivative = customDerivative; + FunctionDecl* gradientOverloadFD = CreateVectorModeOverload(); + return DerivativeAndOverload{customDerivative, gradientOverloadFD}; + } + m_Sema.CurContext = DC; DeclWithContext result = m_Builder.cloneFunction( m_DiffReq.Function, *this, DC, loc, name, vectorDiffFunctionType); diff --git a/test/ForwardMode/VectorMode.C b/test/ForwardMode/VectorMode.C index 9ee1e10de..bf504dc14 100644 --- a/test/ForwardMode/VectorMode.C +++ b/test/ForwardMode/VectorMode.C @@ -275,6 +275,24 @@ double f8(int n, const double* arr) { // CHECK-NEXT: } // CHECK-NEXT: } +namespace clad { + namespace custom_derivatives{ + void f9_dvec(double x, double y, double *d_x, double *d_y) { + *d_x += 1; + *d_y += 1; + } + } +} + +double f9(double x, double y) { + return x + y; +} + +// CHECK: void f9_dvec(double x, double y, double *d_x, double *d_y) { +// CHECK-NEXT: *d_x += 1; +// CHECK-NEXT: *d_y += 1; +// CHECK-NEXT: } + #define TEST(F, x, y) \ { \ result[0] = 0; \ @@ -338,6 +356,11 @@ int main() { f8_dvec.execute(3, arr2, darr2_ref); printf("Result is = {%.2f, %.2f, %.2f}\n", darr2[0], darr2[1], darr2[2]); // CHECK-EXEC: Result is = {1.00, 1.00, 1.00} + auto f9_dvec = clad::differentiate(f9); + double dx = 0, dy = 0; + f9_dvec.execute(1, 2, &dx, &dy); + printf("Result is = {%.2f, %.2f}\n", dx, dy); // CHECK-EXEC: Result is = {1.00, 1.00} + // CHECK: clad::ValueAndPushforward > square_vector_pushforward(const double &x, const clad::array &_d_x) { // CHECK-NEXT: unsigned long indepVarCount = _d_x.size(); // CHECK-NEXT: clad::array _d_vector_z(clad::array(indepVarCount, _d_x * x + x * _d_x));