Skip to content

Commit

Permalink
Vector-mode support for top level custom derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jun 14, 2024
1 parent e955d84 commit 900f21d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
8 changes: 8 additions & 0 deletions lib/Differentiator/VectorForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD,
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl(
derivedFnName, DC, vectorDiffFunctionType)) {
// Set m_Derivative for creating the overload.
m_Derivative = customDerivative;
FunctionDecl* gradientOverloadFD = CreateVectorModeOverload();
return DerivativeAndOverload{customDerivative, gradientOverloadFD};
}

m_Sema.CurContext = DC;
DeclWithContext result = m_Builder.cloneFunction(
m_DiffReq.Function, *this, DC, loc, name, vectorDiffFunctionType);
Expand Down
23 changes: 23 additions & 0 deletions test/ForwardMode/VectorMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,24 @@ double f8(int n, const double* arr) {
// CHECK-NEXT: }
// CHECK-NEXT: }

namespace clad {
namespace custom_derivatives{
void f9_dvec(double x, double y, double *d_x, double *d_y) {
*d_x += 1;
*d_y += 1;
}
}
}

double f9(double x, double y) {
return x + y;
}

// CHECK: void f9_dvec(double x, double y, double *d_x, double *d_y) {
// CHECK-NEXT: *d_x += 1;
// CHECK-NEXT: *d_y += 1;
// CHECK-NEXT: }

#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -338,6 +356,11 @@ int main() {
f8_dvec.execute(3, arr2, darr2_ref);
printf("Result is = {%.2f, %.2f, %.2f}\n", darr2[0], darr2[1], darr2[2]); // CHECK-EXEC: Result is = {1.00, 1.00, 1.00}

auto f9_dvec = clad::differentiate<clad::opts::vector_mode>(f9);
double dx = 0, dy = 0;
f9_dvec.execute(1, 2, &dx, &dy);
printf("Result is = {%.2f, %.2f}\n", dx, dy); // CHECK-EXEC: Result is = {1.00, 1.00}

// CHECK: clad::ValueAndPushforward<double, clad::array<double> > square_vector_pushforward(const double &x, const clad::array<double> &_d_x) {
// CHECK-NEXT: unsigned long indepVarCount = _d_x.size();
// CHECK-NEXT: clad::array<double> _d_vector_z(clad::array<double>(indepVarCount, _d_x * x + x * _d_x));
Expand Down

0 comments on commit 900f21d

Please sign in to comment.