Skip to content

Commit

Permalink
improve testing, comments and hessians
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jun 11, 2024
1 parent 10e85f4 commit c335823
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 56 deletions.
4 changes: 3 additions & 1 deletion include/clad/Differentiator/HessianModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ namespace clad {
DerivativeAndOverload
Merge(std::vector<clang::FunctionDecl*> secDerivFuncs,
llvm::SmallVector<size_t, 16> IndependentArgsSize,
size_t TotalIndependentArgsSize, std::string hessianFuncName);
size_t TotalIndependentArgsSize, const std::string& hessianFuncName,
clang::DeclContext* FD, clang::QualType hessianFuncType,
llvm::SmallVector<clang::QualType, 16> paramTypes);

public:
HessianModeVisitor(DerivativeBuilder& builder);
Expand Down
4 changes: 4 additions & 0 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
if (auto* FD = dyn_cast<FunctionDecl>(ND))
// Check if FD and functionType have the same signature.
if (utils::SameCanonicalType(FD->getType(), functionType))
// Make sure that it is not the case that FD is the forward
// declaration generated by Clad. It should be user defined custom
// derivative (either within the same translation unit or linked in
// from another translation unit).
if (FD->isDefined() || !m_DFC.IsDerivative(FD))
return FD;

Expand Down
49 changes: 25 additions & 24 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,26 @@ namespace clad {
}
}

llvm::SmallVector<QualType, 16> paramTypes(m_Function->getNumParams() + 1);
std::transform(m_Function->param_begin(), m_Function->param_end(),
std::begin(paramTypes),
[](const ParmVarDecl* PVD) { return PVD->getType(); });
paramTypes.back() = m_Context.getPointerType(m_Function->getReturnType());

const auto* originalFnProtoType =
cast<FunctionProtoType>(m_Function->getType());
QualType hessianFunctionType = m_Context.getFunctionType(
m_Context.VoidTy,
llvm::ArrayRef<QualType>(paramTypes.data(), paramTypes.size()),
// Cast to function pointer.
originalFnProtoType->getExtProtoInfo());

// Check if the function is already declared as a custom derivative.
auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl(
hessianFuncName, DC, hessianFunctionType))
return DerivativeAndOverload{customDerivative, nullptr};

// Ascertains the independent arguments and differentiates the function
// in forward and reverse mode by calling ProcessDiffRequest twice each
// iteration, storing each generated second derivative function
Expand Down Expand Up @@ -229,7 +249,8 @@ namespace clad {
}
}
return Merge(secondDerivativeColumns, IndependentArgsSize,
TotalIndependentArgsSize, hessianFuncName);
TotalIndependentArgsSize, hessianFuncName, DC,
hessianFunctionType, paramTypes);
}

// Combines all generated second derivative functions into a
Expand All @@ -239,7 +260,9 @@ namespace clad {
HessianModeVisitor::Merge(std::vector<FunctionDecl*> secDerivFuncs,
SmallVector<size_t, 16> IndependentArgsSize,
size_t TotalIndependentArgsSize,
std::string hessianFuncName) {
const std::string& hessianFuncName, DeclContext* DC,
QualType hessianFunctionType,
llvm::SmallVector<QualType, 16> paramTypes) {
DiffParams args;
std::copy(m_Function->param_begin(),
m_Function->param_end(),
Expand All @@ -248,28 +271,6 @@ namespace clad {
IdentifierInfo* II = &m_Context.Idents.get(hessianFuncName);
DeclarationNameInfo name(II, noLoc);

llvm::SmallVector<QualType, 16> paramTypes(m_Function->getNumParams() + 1);

std::transform(m_Function->param_begin(),
m_Function->param_end(),
std::begin(paramTypes),
[](const ParmVarDecl* PVD) { return PVD->getType(); });

paramTypes.back() = m_Context.getPointerType(m_Function->getReturnType());

auto originalFnProtoType = cast<FunctionProtoType>(m_Function->getType());
QualType hessianFunctionType = m_Context.getFunctionType(
m_Context.VoidTy,
llvm::ArrayRef<QualType>(paramTypes.data(), paramTypes.size()),
// Cast to function pointer.
originalFnProtoType->getExtProtoInfo());

// Check if the function is already declared as a custom derivative.
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl(
hessianFuncName, DC, hessianFunctionType))
return DerivativeAndOverload{customDerivative, nullptr};

// Create the gradient function declaration.
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope(),
Expand Down
70 changes: 39 additions & 31 deletions test/Hessian/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,6 @@
#include "clad/Differentiator/Differentiator.h"
#include <math.h>

namespace clad {
namespace custom_derivatives {
float f7_darg0(float x, float y) {
return cos(x);
}

float f7_darg1(float x, float y) {
return exp(y);
}

void f8_hessian(double x, double y, double *hessianMatrix) {
hessianMatrix[0] = 1.0;
hessianMatrix[1] = 1.0;
hessianMatrix[2] = 1.0;
hessianMatrix[3] = 1.0;
}
}
}

float f1(float x) {
return sin(x) + cos(x);
}
Expand Down Expand Up @@ -109,6 +90,29 @@ float f6(float x, float y) {
// CHECK-NEXT: f6_darg1_grad(x, y, hessianMatrix + {{2U|2UL}}, hessianMatrix + {{3U|3UL}});
// CHECK-NEXT: }

namespace clad {
namespace custom_derivatives {
float f7_darg0(float x, float y) {
return cos(x);
}

float f7_darg1(float x, float y) {
return exp(y);
}

void f7_darg1_grad(float x, float y, float *d_x, float *d_y) {
*d_y += exp(y);
}

void f8_hessian(float x, float y, float *hessianMatrix) {
hessianMatrix[0] = 1.0;
hessianMatrix[1] = 1.0;
hessianMatrix[2] = 1.0;
hessianMatrix[3] = 1.0;
}
}
}

float f7(float x, float y) {
return sin(x) + exp(y);
}
Expand All @@ -123,13 +127,26 @@ float f7(float x, float y) {
// CHECK-NEXT: return exp(y);
// CHECK-NEXT: }

// CHECK: void f7_darg1_grad(float x, float y, float *_d_x, float *_d_y);
// CHECK: void f7_darg1_grad(float x, float y, float *d_x, float *d_y) {
// CHECK-NEXT: *d_y += exp(y);
// CHECK-NEXT: }

// CHECK: void f7_hessian(float x, float y, float *hessianMatrix) {
// CHECK-NEXT: f7_darg0_grad(x, y, hessianMatrix + {{0U|0UL}}, hessianMatrix + {{1U|1UL}});
// CHECK-NEXT: f7_darg1_grad(x, y, hessianMatrix + {{2U|2UL}}, hessianMatrix + {{3U|3UL}});
// CHECK-NEXT: }

float f8(float x, float y) {
return (x*x + y*y)/2 + x*y;
}

// CHECK: void f8_hessian(float x, float y, float *hessianMatrix) {
// CHECK-NEXT: hessianMatrix[0] = 1.;
// CHECK-NEXT: hessianMatrix[1] = 1.;
// CHECK-NEXT: hessianMatrix[2] = 1.;
// CHECK-NEXT: hessianMatrix[3] = 1.;
// CHECK-NEXT: }

#define TEST1(F, x) { \
result[0] = 0; \
auto h = clad::hessian(F); \
Expand All @@ -155,6 +172,7 @@ int main() {
TEST1(f5, 3); // CHECK-EXEC: Result is = {3.84}
TEST2(f6, 3, 4); // CHECK-EXEC: Result is = {108.00, 145.65, 145.65, 97.76}
TEST2(f7, 3, 4); // CHECK-EXEC: Result is = {-0.14, 0.00, 0.00, 54.60}
TEST2(f8, 3, 4); // CHECK-EXEC: Result is = {1.00, 1.00, 1.00, 1.00}

// CHECK: float f1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
Expand Down Expand Up @@ -363,21 +381,11 @@ int main() {
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = 0;
// CHECK-NEXT: _r0 += 1 * clad::custom_derivatives::std::cos_pushforward(x, 1.F).pushforward;
// CHECK-NEXT: _r0 += 1 * clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, 1.F).pushforward;
// CHECK-NEXT: *_d_x += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void f7_darg1_grad(float x, float y, float *_d_x, float *_d_y) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = 0;
// CHECK-NEXT: _r0 += 1 * clad::custom_derivatives::std::exp_pushforward(y, 1.F).pushforward;
// CHECK-NEXT: *_d_y += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void sin_pushforward_pullback(float x, float d_x, ValueAndPushforward<float, float> _d_y, float *_d_x, float *_d_d_x) {
// CHECK-NEXT: float _t0;
// CHECK-NEXT: _t0 = ::std::cos(x);
Expand Down

0 comments on commit c335823

Please sign in to comment.