Skip to content

Commit

Permalink
further tries
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Dec 3, 2024
1 parent 6e45d88 commit 1ebc777
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 40 deletions.
3 changes: 3 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ struct DiffRequest {
/// order derivatives.
const clang::CXXRecordDecl* Functor = nullptr;

/// If we're creating a pullback for a lambda function, we need to put the diffed function into a call operator of another lambda class. Therefore, we need to used a pre-created function declaration to put our derived code into. Only applicable if working with lambdas in the reverse mode.
clang::FunctionDecl* LambdaPreCreatedDerivativeTarget = nullptr;

/// Stores differentiation parameters information. Stored information
/// includes info on indices range for array parameters, and nested data
/// member information for record (class) type parameters.
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ namespace clad {
clang::CXXRecordDecl*
diffLambdaCXXRecordDecl(const clang::CXXRecordDecl* Original);
clang::CXXMethodDecl*
DifferentiateCallOperatorIfLambda(const clang::CXXRecordDecl* RD);
DifferentiateCallOperatorIfLambda(const clang::CXXRecordDecl* RD, clang::FunctionDecl* preCreatedDerivativeDecl);

public:
ReverseModeVisitor(DerivativeBuilder& builder, const DiffRequest& request);
Expand All @@ -386,6 +386,7 @@ namespace clad {
/// y" will give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'.
DerivativeAndOverload Derive();
DerivativeAndOverload DerivePullback();
DerivativeAndOverload DerivePullback(clang::FunctionDecl* preDerivative);
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitLambdaExpr(const clang::LambdaExpr* LE);
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request);
V.AddExternalSource(*m_ErrorEstHandler.back());
}
result = V.DerivePullback();
result = V.DerivePullback(request.LambdaPreCreatedDerivativeTarget);
if (!m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
Expand Down
99 changes: 61 additions & 38 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

DerivativeAndOverload ReverseModeVisitor::DerivePullback() {
return DerivePullback(nullptr);
}

DerivativeAndOverload ReverseModeVisitor::DerivePullback(clang::FunctionDecl* preDerivative) {
const clang::FunctionDecl* FD = m_DiffReq.Function;
// FIXME: Duplication of external source here is a workaround
// for the two 'Derive's being different functions.
Expand Down Expand Up @@ -408,18 +412,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Sema.CurContext = const_cast<DeclContext*>(m_DiffReq->getDeclContext());

SourceLocation validLoc{m_DiffReq->getLocation()};
DeclWithContext fnBuildRes =
m_Builder.cloneFunction(m_DiffReq.Function, *this, m_Sema.CurContext,
validLoc, DNI, pullbackFnType);
m_Derivative = fnBuildRes.first;
if (!preDerivative) {
DeclWithContext fnBuildRes =
m_Builder.cloneFunction(m_DiffReq.Function, *this, m_Sema.CurContext,
validLoc, DNI, pullbackFnType);
m_Derivative = fnBuildRes.first;
} else {
m_Derivative = preDerivative;
}

if (m_ExternalSource)
m_ExternalSource->ActBeforeCreatingDerivedFnScope();

beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative);
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); // ASK ABOUT IT

if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnScope();
Expand Down Expand Up @@ -490,7 +498,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Sema.PopDeclContext();
endScope(); // Function decl scope

return DerivativeAndOverload{fnBuildRes.first, nullptr};
return DerivativeAndOverload{m_Derivative, nullptr};
}

void ReverseModeVisitor::DifferentiateWithClad() {
Expand Down Expand Up @@ -1452,7 +1460,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

CXXMethodDecl* ReverseModeVisitor::DifferentiateCallOperatorIfLambda(
const clang::CXXRecordDecl* RD) {
const clang::CXXRecordDecl* RD, FunctionDecl* preCreatedDerivativeDecl) {
if (RD) {
CXXRecordDecl* constructedType = RD->getDefinition();
bool isLambda = constructedType->isLambda();
Expand All @@ -1466,6 +1474,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
req.Function = cxxMethod;
req.Mode = DiffMode::experimental_pullback;
req.BaseFunctionName = utils::ComputeEffectiveFnName(cxxMethod);
req.LambdaPreCreatedDerivativeTarget = preCreatedDerivativeDecl;
// Silence diag outputs in nested derivation process.
req.VerboseDiags = false;

Expand Down Expand Up @@ -1513,49 +1522,65 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
for (auto* Method : Original->methods()) {
if (CXXMethodDecl* OriginalOpCall = dyn_cast<CXXMethodDecl>(Method)) {
if (OriginalOpCall->getOverloadedOperator() == OO_Call) {
auto* diffedOpCall = DifferentiateCallOperatorIfLambda(Original);
if (diffedOpCall) {
diffedOpCall->setAccess(OriginalOpCall->getAccess());
// First, we create an operator copied from the original lambda, which we then differentiate
{

}

// if (diffedOpCall) {
// diffedOpCall->setAccess(OriginalOpCall->getAccess());
// Cloned->addDecl(diffedOpCall);

// TRY CLONING THE OPERATOR FIRST, THEN DIFFERENTIATING IT TO MAINTAIN A CONSISTENT CONTEXT

DiffParams args{};
std::copy(OriginalOpCall->param_begin(), OriginalOpCall->param_end(), std::back_inserter(args));
auto paramTypes = ComputeParamTypes(args);
const auto* originalFnType =
dyn_cast<FunctionProtoType>(OriginalOpCall->getType());
QualType pullbackFnType = m_Context.getFunctionType(
m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo());

CXXMethodDecl* ClonedOpCall = CXXMethodDecl::Create(
m_Context, Cloned, diffedOpCall->getBeginLoc(),
m_Context, Cloned, OriginalOpCall->getBeginLoc(),
OriginalOpCall->getNameInfo(),
diffedOpCall
->getType(), // Function type (return type + parameters)
diffedOpCall->getTypeSourceInfo(),
diffedOpCall->getStorageClass()
CLAD_COMPAT_FunctionDecl_UsesFPIntrin_Param(diffedOpCall),
diffedOpCall->isInlineSpecified(), // Inline specifier
pullbackFnType, // Function type (return type + parameters)
OriginalOpCall->getTypeSourceInfo(),
OriginalOpCall->getStorageClass()
CLAD_COMPAT_FunctionDecl_UsesFPIntrin_Param(OriginalOpCall),
OriginalOpCall->isInlineSpecified(), // Inline specifier
clad_compat::Function_GetConstexprKind(
diffedOpCall), // Constexpr specifier
diffedOpCall->getEndLoc() //,
OriginalOpCall), // Constexpr specifier
OriginalOpCall->getEndLoc() //,
// diffedOpCall->getTrailingRequiresClause()
);

llvm::SmallVector<clang::ParmVarDecl*, 8> params;
for (unsigned i = 0; i < diffedOpCall->param_size(); ++i) {
ParmVarDecl* p = diffedOpCall->getParamDecl(i);
ParmVarDecl* NewParam = ParmVarDecl::Create(
m_Context, ClonedOpCall, p->getBeginLoc(), p->getLocation(),
p->getIdentifier(), p->getType(), p->getTypeSourceInfo(),
p->getStorageClass(), p->getDefaultArg());
params.push_back(NewParam);
}
ClonedOpCall->setParams(params);
// Cloned->addDecl(ClonedOpCall); // do we need this?

auto* diffedOpCall = DifferentiateCallOperatorIfLambda(Original, ClonedOpCall);

// llvm::SmallVector<clang::ParmVarDecl*, 8> params;
// for (unsigned i = 0; i < diffedOpCall->param_size(); ++i) {
// ParmVarDecl* p = diffedOpCall->getParamDecl(i);
// ParmVarDecl* NewParam = ParmVarDecl::Create(
// m_Context, ClonedOpCall, p->getBeginLoc(), p->getLocation(),
// p->getIdentifier(), p->getType(), p->getTypeSourceInfo(),
// p->getStorageClass(), p->getDefaultArg());
// params.push_back(NewParam);
// }
// ClonedOpCall->setParams(params);

// Copy the method body if it exists
if (diffedOpCall->hasBody()) {
Stmt* body = diffedOpCall->getBody();
Stmt* ClonedBody = Clone(body);
ClonedOpCall->setBody(ClonedBody);
}
// if (diffedOpCall->hasBody()) {
// Stmt* body = diffedOpCall->getBody();
// Stmt* ClonedBody = Clone(body);
// ClonedOpCall->setBody(ClonedBody);
// }

ClonedOpCall->setAccess(OriginalOpCall->getAccess());
Cloned->addDecl(ClonedOpCall);

break; // we get into an infinite loop otherwise
}
// }
}
}
}
Expand Down Expand Up @@ -3436,7 +3461,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// If the DeclStmt is not empty, check the first declaration in case it is a
// lambda function. This case it is treated differently.
bool isLambda = false;
const auto* declsBegin = DS->decls().begin();
if (declsBegin != DS->decls().end() && isa<VarDecl>(*declsBegin)) {
auto* VD = dyn_cast<VarDecl>(*declsBegin);
Expand All @@ -3445,7 +3469,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
QT = QT->getPointeeType();

auto* typeDecl = QT->getAsCXXRecordDecl();
isLambda = typeDecl && typeDecl->isLambda();
if (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl)) {
for (auto* D : DS->decls())
if (auto* VD = dyn_cast<VarDecl>(D))
Expand Down

0 comments on commit 1ebc777

Please sign in to comment.