From ec76b704ced2bb7471b25b51a7c24273374700d9 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Sun, 3 Dec 2023 11:57:12 +0530 Subject: [PATCH] Add support for call expressions in vector forward mode AD (#638) --- .../Differentiator/BaseForwardModeVisitor.h | 15 ++ .../clad/Differentiator/DerivativeBuilder.h | 3 +- include/clad/Differentiator/DiffMode.h | 1 + .../clad/Differentiator/ForwardModeVisitor.h | 33 ---- .../Differentiator/PushForwardModeVisitor.h | 25 +++ .../Differentiator/VectorForwardModeVisitor.h | 8 + .../VectorPushForwardModeVisitor.h | 20 ++ lib/Differentiator/BaseForwardModeVisitor.cpp | 175 ++++++++++++++++- lib/Differentiator/CMakeLists.txt | 3 +- lib/Differentiator/DerivativeBuilder.cpp | 8 +- lib/Differentiator/ForwardModeVisitor.cpp | 185 ------------------ lib/Differentiator/PushForwardModeVisitor.cpp | 38 ++++ .../VectorForwardModeVisitor.cpp | 37 ++++ .../VectorPushForwardModeVisitor.cpp | 68 +++++++ test/FirstDerivative/FunctionCalls.C | 23 +++ test/ForwardMode/VectorMode.C | 136 +++++++++++++ 16 files changed, 548 insertions(+), 230 deletions(-) delete mode 100644 include/clad/Differentiator/ForwardModeVisitor.h create mode 100644 include/clad/Differentiator/PushForwardModeVisitor.h create mode 100644 include/clad/Differentiator/VectorPushForwardModeVisitor.h delete mode 100644 lib/Differentiator/ForwardModeVisitor.cpp create mode 100644 lib/Differentiator/PushForwardModeVisitor.cpp create mode 100644 lib/Differentiator/VectorPushForwardModeVisitor.cpp diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 068d2170d..370097b42 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -37,6 +37,16 @@ class BaseForwardModeVisitor DerivativeAndOverload Derive(const clang::FunctionDecl* FD, const DiffRequest& request); + DerivativeAndOverload DerivePushforward(const clang::FunctionDecl* FD, + const DiffRequest& request); + + /// Returns the return type for the pushforward function of the function + /// `m_Function`. + /// \note `m_Function` field should be set before using this function. + clang::QualType ComputePushforwardFnReturnType(); + + virtual void ExecuteInsidePushforwardFunctionBlock(); + static bool IsDifferentiableType(clang::QualType T); virtual StmtDiff @@ -93,6 +103,11 @@ class BaseForwardModeVisitor VisitUnaryExprOrTypeTraitExpr(const clang::UnaryExprOrTypeTraitExpr* UE); StmtDiff VisitPseudoObjectExpr(const clang::PseudoObjectExpr* POE); + virtual clang::QualType + GetPushForwardDerivativeType(clang::QualType ParamType); + virtual std::string GetPushForwardFunctionSuffix(); + virtual DiffMode GetPushForwardMode(); + protected: /// Helper function for differentiating the switch statement body. /// diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index c1015db68..75a500244 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -72,8 +72,9 @@ namespace clad { private: friend class VisitorBase; friend class BaseForwardModeVisitor; - friend class ForwardModeVisitor; + friend class PushForwardModeVisitor; friend class VectorForwardModeVisitor; + friend class VectorPushForwardModeVisitor; friend class ReverseModeVisitor; friend class HessianModeVisitor; friend class JacobianModeVisitor; diff --git a/include/clad/Differentiator/DiffMode.h b/include/clad/Differentiator/DiffMode.h index a03e77e49..919d22c80 100644 --- a/include/clad/Differentiator/DiffMode.h +++ b/include/clad/Differentiator/DiffMode.h @@ -8,6 +8,7 @@ enum class DiffMode { vector_forward_mode, experimental_pushforward, experimental_pullback, + experimental_vector_pushforward, reverse, hessian, jacobian, diff --git a/include/clad/Differentiator/ForwardModeVisitor.h b/include/clad/Differentiator/ForwardModeVisitor.h deleted file mode 100644 index 244285ff2..000000000 --- a/include/clad/Differentiator/ForwardModeVisitor.h +++ /dev/null @@ -1,33 +0,0 @@ -//--------------------------------------------------------------------*- C++ -*- -// clad - the C++ Clang-based Automatic Differentiator -// version: $Id: ClangPlugin.cpp 7 2013-06-01 22:48:03Z v.g.vassilev@gmail.com $ -// author: Vassil Vassilev -//------------------------------------------------------------------------------ - -#ifndef CLAD_FORWARD_MODE_VISITOR_H -#define CLAD_FORWARD_MODE_VISITOR_H - -#include "BaseForwardModeVisitor.h" - -namespace clad { - /// A visitor for processing the function code in forward mode. - /// Used to compute derivatives by clad::differentiate. -class ForwardModeVisitor : public BaseForwardModeVisitor { - -public: - ForwardModeVisitor(DerivativeBuilder& builder); - ~ForwardModeVisitor(); - - DerivativeAndOverload DerivePushforward(const clang::FunctionDecl* FD, - const DiffRequest& request); - - /// Returns the return type for the pushforward function of the function - /// `m_Function`. - /// \note `m_Function` field should be set before using this function. - clang::QualType ComputePushforwardFnReturnType(); - - StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; -}; -} // end namespace clad - -#endif // CLAD_FORWARD_MODE_VISITOR_H diff --git a/include/clad/Differentiator/PushForwardModeVisitor.h b/include/clad/Differentiator/PushForwardModeVisitor.h new file mode 100644 index 000000000..82347ab2f --- /dev/null +++ b/include/clad/Differentiator/PushForwardModeVisitor.h @@ -0,0 +1,25 @@ +//--------------------------------------------------------------------*- C++ -*- +// clad - the C++ Clang-based Automatic Differentiator +// version: $Id: ClangPlugin.cpp 7 2013-06-01 22:48:03Z v.g.vassilev@gmail.com $ +// author: Vassil Vassilev +//------------------------------------------------------------------------------ + +#ifndef CLAD_DIFFERENTIATOR_PUSHFORWARDMODEVISITOR_H +#define CLAD_DIFFERENTIATOR_PUSHFORWARDMODEVISITOR_H + +#include "BaseForwardModeVisitor.h" + +namespace clad { +/// A visitor for processing the function code in forward mode. +/// Used to compute derivatives by clad::differentiate. +class PushForwardModeVisitor : public BaseForwardModeVisitor { + +public: + PushForwardModeVisitor(DerivativeBuilder& builder); + ~PushForwardModeVisitor() override; + + StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; +}; +} // end namespace clad + +#endif // CLAD_DIFFERENTIATOR_PUSHFORWARDMODEVISITOR_H diff --git a/include/clad/Differentiator/VectorForwardModeVisitor.h b/include/clad/Differentiator/VectorForwardModeVisitor.h index a5a21d13a..2a34e8fc1 100644 --- a/include/clad/Differentiator/VectorForwardModeVisitor.h +++ b/include/clad/Differentiator/VectorForwardModeVisitor.h @@ -78,6 +78,14 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor { StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; // Decl is not Stmt, so it cannot be visited directly. VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD) override; + + clang::QualType + GetPushForwardDerivativeType(clang::QualType ParamType) override; + std::string GetPushForwardFunctionSuffix() override; + DiffMode GetPushForwardMode() override; + + // Function for setting the independent variables for vector mode. + void SetIndependentVarsExpr(clang::Expr* IndVarCountExpr); }; } // end namespace clad diff --git a/include/clad/Differentiator/VectorPushForwardModeVisitor.h b/include/clad/Differentiator/VectorPushForwardModeVisitor.h new file mode 100644 index 000000000..eabeefae0 --- /dev/null +++ b/include/clad/Differentiator/VectorPushForwardModeVisitor.h @@ -0,0 +1,20 @@ +#ifndef CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H +#define CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H + +#include "PushForwardModeVisitor.h" +#include "VectorForwardModeVisitor.h" + +namespace clad { +class VectorPushForwardModeVisitor : public VectorForwardModeVisitor { + +public: + VectorPushForwardModeVisitor(DerivativeBuilder& builder); + ~VectorPushForwardModeVisitor() override; + + void ExecuteInsidePushforwardFunctionBlock() override; + + StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; +}; +} // end namespace clad + +#endif // CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index cfae69aca..986b61bdf 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -359,6 +359,157 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, /*OverloadFunctionDecl=*/nullptr}; } +clang::QualType BaseForwardModeVisitor::ComputePushforwardFnReturnType() { + assert(m_Mode == GetPushForwardMode()); + QualType originalFnRT = m_Function->getReturnType(); + if (originalFnRT->isVoidType()) + return m_Context.VoidTy; + TemplateDecl* valueAndPushforward = + LookupTemplateDeclInCladNamespace("ValueAndPushforward"); + assert(valueAndPushforward && + "clad::ValueAndPushforward template not found!!"); + QualType RT = InstantiateTemplate( + valueAndPushforward, + {originalFnRT, GetPushForwardDerivativeType(originalFnRT)}); + return RT; +} + +void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() { + Stmt* bodyDiff = Visit(m_Function->getBody()).getStmt(); + auto* CS = cast(bodyDiff); + for (Stmt* S : CS->body()) + addToCurrentBlock(S); +} + +DerivativeAndOverload +BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, + const DiffRequest& request) { + m_Function = FD; + m_Functor = request.Functor; + m_DerivativeOrder = request.CurrentDerivativeOrder; + m_Mode = GetPushForwardMode(); + assert(!m_DerivativeInFlight && + "Doesn't support recursive diff. Use DiffPlan."); + m_DerivativeInFlight = true; + + auto originalFnEffectiveName = utils::ComputeEffectiveFnName(m_Function); + + IdentifierInfo* derivedFnII = &m_Context.Idents.get( + originalFnEffectiveName + GetPushForwardFunctionSuffix()); + DeclarationNameInfo derivedFnName(derivedFnII, noLoc); + llvm::SmallVector paramTypes; + llvm::SmallVector derivedParamTypes; + + // If we are differentiating an instance member function then + // create a parameter type for the parameter that will represent the + // derivative of `this` pointer with respect to the independent parameter. + if (const auto* MD = dyn_cast(FD)) { + if (MD->isInstance()) { + QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); + derivedParamTypes.push_back(thisType); + } + } + + for (auto* PVD : m_Function->parameters()) { + paramTypes.push_back(PVD->getType()); + + if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) + derivedParamTypes.push_back(GetPushForwardDerivativeType(PVD->getType())); + } + + paramTypes.insert(paramTypes.end(), derivedParamTypes.begin(), + derivedParamTypes.end()); + + const auto* originalFnType = + dyn_cast(m_Function->getType()); + QualType returnType = ComputePushforwardFnReturnType(); + QualType derivedFnType = m_Context.getFunctionType( + returnType, paramTypes, originalFnType->getExtProtoInfo()); + llvm::SaveAndRestore saveContext(m_Sema.CurContext); + llvm::SaveAndRestore saveScope(m_CurScope); + auto* DC = const_cast(m_Function->getDeclContext()); + m_Sema.CurContext = DC; + + DeclWithContext cloneFunctionResult = m_Builder.cloneFunction( + m_Function, *this, DC, noLoc, derivedFnName, derivedFnType); + m_Derivative = cloneFunctionResult.first; + + llvm::SmallVector params; + llvm::SmallVector derivedParams; + beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | + Scope::DeclScope); + m_Sema.PushFunctionScope(); + m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); + + // If we are differentiating an instance member function then + // create a parameter for representing derivative of + // `this` pointer with respect to the independent parameter. + if (const auto* MFD = dyn_cast(FD)) { + if (MFD->isInstance()) { + auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MFD); + IdentifierInfo* derivedPVDII = CreateUniqueIdentifier("_d_this"); + auto* derivedPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext, + derivedPVDII, thisType); + m_Sema.PushOnScopeChains(derivedPVD, getCurrentScope(), + /*AddToContext=*/false); + derivedParams.push_back(derivedPVD); + m_ThisExprDerivative = BuildDeclRef(derivedPVD); + } + } + + std::size_t numParamsOriginalFn = m_Function->getNumParams(); + for (std::size_t i = 0; i < numParamsOriginalFn; ++i) { + const auto* PVD = m_Function->getParamDecl(i); + // Some of the special member functions created implicitly by compilers + // have missing parameter identifier. + bool identifierMissing = false; + IdentifierInfo* PVDII = PVD->getIdentifier(); + if (!PVDII || PVDII->getLength() == 0) { + PVDII = CreateUniqueIdentifier("param"); + identifierMissing = true; + } + auto* newPVD = CloneParmVarDecl(PVD, PVDII, + /*pushOnScopeChains=*/true, + /*cloneDefaultArg=*/false); + params.push_back(newPVD); + + if (identifierMissing) + m_DeclReplacements[PVD] = newPVD; + + if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) + continue; + auto derivedPVDName = "_d_" + std::string(PVDII->getName()); + IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName); + auto* derivedPVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, derivedPVDII, + GetPushForwardDerivativeType(PVD->getType()), PVD->getStorageClass()); + derivedParams.push_back(derivedPVD); + m_Variables[newPVD] = BuildDeclRef(derivedPVD); + } + + params.insert(params.end(), derivedParams.begin(), derivedParams.end()); + m_Derivative->setParams(params); + m_Derivative->setBody(nullptr); + + beginScope(Scope::FnScope | Scope::DeclScope); + m_DerivativeFnScope = getCurrentScope(); + beginBlock(); + + // execute the functor inside the function body. + ExecuteInsidePushforwardFunctionBlock(); + + Stmt* derivativeBody = endBlock(); + m_Derivative->setBody(derivativeBody); + + endScope(); // Function body scope + m_Sema.PopFunctionScopeInfo(); + m_Sema.PopDeclContext(); + endScope(); // Function decl scope + + m_DerivativeInFlight = false; + return DerivativeAndOverload{cloneFunctionResult.first}; +} + StmtDiff BaseForwardModeVisitor::VisitStmt(const Stmt* S) { diag(DiagnosticsEngine::Warning, S->getBeginLoc(), "attempted to differentiate unsupported statement, no changes applied"); @@ -913,6 +1064,19 @@ Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( return OverloadedFn; } +QualType +BaseForwardModeVisitor::GetPushForwardDerivativeType(QualType ParamType) { + return ParamType; +} + +std::string BaseForwardModeVisitor::GetPushForwardFunctionSuffix() { + return "_pushforward"; +} + +DiffMode BaseForwardModeVisitor::GetPushForwardMode() { + return DiffMode::experimental_pushforward; +} + StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { const FunctionDecl* FD = CE->getDirectCallee(); if (!FD) { @@ -1007,10 +1171,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { CallArgs.push_back(argDiff.getExpr()); if (BaseForwardModeVisitor::IsDifferentiableType(arg->getType())) { Expr* dArg = argDiff.getExpr_dx(); - QualType CallArgTy = CallArgs.back()->getType(); - assert((!dArg || m_Context.hasSameType(CallArgTy, dArg->getType())) && - "Type mismatch, we might fail to instantiate a pullback"); - (void)CallArgTy; // FIXME: What happens when dArg is nullptr? diffArgs.push_back(dArg); } @@ -1033,14 +1193,13 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // Try to find a user-defined overloaded derivative. std::string customPushforward = - clad::utils::ComputeEffectiveFnName(FD) + "_pushforward"; + clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix(); Expr* callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPushforward, customDerivativeArgs, getCurrentScope(), const_cast(FD->getDeclContext())); // Check if it is a recursive call. - if (!callDiff && (FD == m_Function) && - m_Mode == DiffMode::experimental_pushforward) { + if (!callDiff && (FD == m_Function) && m_Mode == GetPushForwardMode()) { // The differentiated function is called recursively. Expr* derivativeRef = m_Sema @@ -1089,7 +1248,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // derive the called function. DiffRequest pushforwardFnRequest; pushforwardFnRequest.Function = FD; - pushforwardFnRequest.Mode = DiffMode::experimental_pushforward; + pushforwardFnRequest.Mode = GetPushForwardMode(); pushforwardFnRequest.BaseFunctionName = FD->getNameAsString(); // pushforwardFnRequest.RequestedDerivativeOrder = m_DerivativeOrder; // Silence diag outputs in nested derivation process. diff --git a/lib/Differentiator/CMakeLists.txt b/lib/Differentiator/CMakeLists.txt index d778be998..5cd2fc103 100644 --- a/lib/Differentiator/CMakeLists.txt +++ b/lib/Differentiator/CMakeLists.txt @@ -27,15 +27,16 @@ add_llvm_library(cladDifferentiator DiffPlanner.cpp ErrorEstimator.cpp EstimationModel.cpp - ForwardModeVisitor.cpp HessianModeVisitor.cpp JacobianModeVisitor.cpp MultiplexExternalRMVSource.cpp + PushForwardModeVisitor.cpp ReverseModeForwPassVisitor.cpp ReverseModeVisitor.cpp TBRAnalyzer.cpp StmtClone.cpp VectorForwardModeVisitor.cpp + VectorPushForwardModeVisitor.cpp Version.cpp VisitorBase.cpp ${version_inc} diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index c07e70057..23417dfda 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -18,13 +18,14 @@ #include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" -#include "clad/Differentiator/ForwardModeVisitor.h" #include "clad/Differentiator/HessianModeVisitor.h" #include "clad/Differentiator/JacobianModeVisitor.h" +#include "clad/Differentiator/PushForwardModeVisitor.h" #include "clad/Differentiator/ReverseModeForwPassVisitor.h" #include "clad/Differentiator/ReverseModeVisitor.h" #include "clad/Differentiator/StmtClone.h" #include "clad/Differentiator/VectorForwardModeVisitor.h" +#include "clad/Differentiator/VectorPushForwardModeVisitor.h" #include @@ -209,11 +210,14 @@ namespace clad { BaseForwardModeVisitor V(*this); result = V.Derive(FD, request); } else if (request.Mode == DiffMode::experimental_pushforward) { - ForwardModeVisitor V(*this); + PushForwardModeVisitor V(*this); result = V.DerivePushforward(FD, request); } else if (request.Mode == DiffMode::vector_forward_mode) { VectorForwardModeVisitor V(*this); result = V.DeriveVectorMode(FD, request); + } else if (request.Mode == DiffMode::experimental_vector_pushforward) { + VectorPushForwardModeVisitor V(*this); + result = V.DerivePushforward(FD, request); } else if (request.Mode == DiffMode::reverse) { ReverseModeVisitor V(*this); result = V.Derive(FD, request); diff --git a/lib/Differentiator/ForwardModeVisitor.cpp b/lib/Differentiator/ForwardModeVisitor.cpp deleted file mode 100644 index d702dcbfc..000000000 --- a/lib/Differentiator/ForwardModeVisitor.cpp +++ /dev/null @@ -1,185 +0,0 @@ -//--------------------------------------------------------------------*- C++ -*- -// clad - the C++ Clang-based Automatic Differentiator -// version: $Id: ClangPlugin.cpp 7 2013-06-01 22:48:03Z v.g.vassilev@gmail.com $ -// author: Vassil Vassilev -//------------------------------------------------------------------------------ - -#include "clad/Differentiator/ForwardModeVisitor.h" -#include "clad/Differentiator/BaseForwardModeVisitor.h" - -#include "clad/Differentiator/CladUtils.h" - -#include "llvm/Support/SaveAndRestore.h" - -using namespace clang; - -namespace clad { -ForwardModeVisitor::ForwardModeVisitor(DerivativeBuilder& builder) - : BaseForwardModeVisitor(builder) {} - -ForwardModeVisitor::~ForwardModeVisitor() {} - -clang::QualType ForwardModeVisitor::ComputePushforwardFnReturnType() { - assert(m_Mode == DiffMode::experimental_pushforward); - QualType originalFnRT = m_Function->getReturnType(); - if (originalFnRT->isVoidType()) - return m_Context.VoidTy; - TemplateDecl* valueAndPushforward = - LookupTemplateDeclInCladNamespace("ValueAndPushforward"); - assert(valueAndPushforward && - "clad::ValueAndPushforward template not found!!"); - QualType RT = - InstantiateTemplate(valueAndPushforward, {originalFnRT, originalFnRT}); - return RT; -} - - DerivativeAndOverload - ForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, - const DiffRequest& request) { - m_Function = FD; - m_Functor = request.Functor; - m_DerivativeOrder = request.CurrentDerivativeOrder; - m_Mode = DiffMode::experimental_pushforward; - assert(!m_DerivativeInFlight && - "Doesn't support recursive diff. Use DiffPlan."); - m_DerivativeInFlight = true; - - auto originalFnEffectiveName = utils::ComputeEffectiveFnName(m_Function); - - IdentifierInfo* derivedFnII = - &m_Context.Idents.get(originalFnEffectiveName + "_pushforward"); - DeclarationNameInfo derivedFnName(derivedFnII, noLoc); - llvm::SmallVector paramTypes, derivedParamTypes; - - // If we are differentiating an instance member function then - // create a parameter type for the parameter that will represent the - // derivative of `this` pointer with respect to the independent parameter. - if (auto MD = dyn_cast(FD)) { - if (MD->isInstance()) { - QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); - derivedParamTypes.push_back(thisType); - } - } - - for (auto* PVD : m_Function->parameters()) { - paramTypes.push_back(PVD->getType()); - - if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) - derivedParamTypes.push_back(PVD->getType()); - } - - paramTypes.insert(paramTypes.end(), derivedParamTypes.begin(), - derivedParamTypes.end()); - - auto originalFnType = dyn_cast(m_Function->getType()); - QualType returnType = ComputePushforwardFnReturnType(); - QualType derivedFnType = - m_Context.getFunctionType(returnType, paramTypes, - originalFnType->getExtProtoInfo()); - llvm::SaveAndRestore saveContext(m_Sema.CurContext); - llvm::SaveAndRestore saveScope(m_CurScope); - DeclContext* DC = const_cast(m_Function->getDeclContext()); - m_Sema.CurContext = DC; - - DeclWithContext cloneFunctionResult = m_Builder.cloneFunction( - m_Function, *this, DC, noLoc, derivedFnName, derivedFnType); - m_Derivative = cloneFunctionResult.first; - - llvm::SmallVector params; - llvm::SmallVector derivedParams; - beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | - Scope::DeclScope); - m_Sema.PushFunctionScope(); - m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); - - // If we are differentiating an instance member function then - // create a parameter for representing derivative of - // `this` pointer with respect to the independent parameter. - if (auto MFD = dyn_cast(FD)) { - if (MFD->isInstance()) { - auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MFD); - IdentifierInfo* derivedPVDII = CreateUniqueIdentifier("_d_this"); - auto derivedPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext, - derivedPVDII, thisType); - m_Sema.PushOnScopeChains(derivedPVD, getCurrentScope(), - /*AddToContext=*/false); - derivedParams.push_back(derivedPVD); - m_ThisExprDerivative = BuildDeclRef(derivedPVD); - } - } - - std::size_t numParamsOriginalFn = m_Function->getNumParams(); - for (std::size_t i = 0; i < numParamsOriginalFn; ++i) { - auto PVD = m_Function->getParamDecl(i); - // Some of the special member functions created implicitly by compilers - // have missing parameter identifier. - bool identifierMissing = false; - IdentifierInfo* PVDII = PVD->getIdentifier(); - if (!PVDII || PVDII->getLength() == 0) { - PVDII = CreateUniqueIdentifier("param"); - identifierMissing = true; - } - auto newPVD = CloneParmVarDecl(PVD, PVDII, - /*pushOnScopeChains=*/true, - /*cloneDefaultArg=*/false); - params.push_back(newPVD); - - if (identifierMissing) - m_DeclReplacements[PVD] = newPVD; - - if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) - continue; - auto derivedPVDName = "_d_" + std::string(PVDII->getName()); - IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName); - auto derivedPVD = CloneParmVarDecl(PVD, derivedPVDII, - /*pushOnScopeChains=*/true, - /*cloneDefaultArg=*/false); - derivedParams.push_back(derivedPVD); - m_Variables[newPVD] = BuildDeclRef(derivedPVD); - } - - params.insert(params.end(), derivedParams.begin(), derivedParams.end()); - m_Derivative->setParams(params); - m_Derivative->setBody(nullptr); - - beginScope(Scope::FnScope | Scope::DeclScope); - m_DerivativeFnScope = getCurrentScope(); - beginBlock(); - - Stmt* bodyDiff = Visit(FD->getBody()).getStmt(); - CompoundStmt* CS = cast(bodyDiff); - for (Stmt* S : CS->body()) - addToCurrentBlock(S); - - Stmt* derivativeBody = endBlock(); - m_Derivative->setBody(derivativeBody); - - endScope(); // Function body scope - m_Sema.PopFunctionScopeInfo(); - m_Sema.PopDeclContext(); - endScope(); // Function decl scope - - m_DerivativeInFlight = false; - return DerivativeAndOverload{cloneFunctionResult.first}; - } - - StmtDiff ForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { - //If there is no return value, we must not attempt to differentiate - if (!RS->getRetValue()) - return nullptr; - - StmtDiff retValDiff = Visit(RS->getRetValue()); - llvm::SmallVector returnValues = {retValDiff.getExpr(), - retValDiff.getExpr_dx()}; - SourceLocation fakeInitLoc = utils::GetValidSLoc(m_Sema); - // This can instantiate as part of the move or copy initialization and - // needs a fake source location. - Expr* initList = - m_Sema.ActOnInitList(fakeInitLoc, returnValues, noLoc).get(); - - SourceLocation fakeRetLoc = utils::GetValidSLoc(m_Sema); - Stmt* returnStmt = - m_Sema.ActOnReturnStmt(fakeRetLoc, initList, getCurrentScope()).get(); - return StmtDiff(returnStmt); - } -} // end namespace clad diff --git a/lib/Differentiator/PushForwardModeVisitor.cpp b/lib/Differentiator/PushForwardModeVisitor.cpp new file mode 100644 index 000000000..b44779149 --- /dev/null +++ b/lib/Differentiator/PushForwardModeVisitor.cpp @@ -0,0 +1,38 @@ +//--------------------------------------------------------------------*- C++ -*- +// clad - the C++ Clang-based Automatic Differentiator +// version: $Id: ClangPlugin.cpp 7 2013-06-01 22:48:03Z v.g.vassilev@gmail.com $ +// author: Vassil Vassilev +//------------------------------------------------------------------------------ + +#include "clad/Differentiator/PushForwardModeVisitor.h" +#include "clad/Differentiator/BaseForwardModeVisitor.h" + +#include "clad/Differentiator/CladUtils.h" + +#include "llvm/Support/SaveAndRestore.h" + +using namespace clang; + +namespace clad { +PushForwardModeVisitor::PushForwardModeVisitor(DerivativeBuilder& builder) + : BaseForwardModeVisitor(builder) {} + +PushForwardModeVisitor::~PushForwardModeVisitor() = default; + +StmtDiff PushForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { + // If there is no return value, we must not attempt to differentiate + if (!RS->getRetValue()) + return nullptr; + + StmtDiff retValDiff = Visit(RS->getRetValue()); + llvm::SmallVector returnValues = {retValDiff.getExpr(), + retValDiff.getExpr_dx()}; + // This can instantiate as part of the move or copy initialization and + // needs a fake source location. + SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema); + Expr* initList = m_Sema.ActOnInitList(fakeLoc, returnValues, noLoc).get(); + Stmt* returnStmt = + m_Sema.ActOnReturnStmt(fakeLoc, initList, getCurrentScope()).get(); + return StmtDiff(returnStmt); +} +} // end namespace clad diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index 9e38bc6a8..87f268176 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -15,6 +15,43 @@ VectorForwardModeVisitor::VectorForwardModeVisitor(DerivativeBuilder& builder) VectorForwardModeVisitor::~VectorForwardModeVisitor() {} +std::string VectorForwardModeVisitor::GetPushForwardFunctionSuffix() { + return "_vector_pushforward"; +} + +DiffMode VectorForwardModeVisitor::GetPushForwardMode() { + return DiffMode::experimental_vector_pushforward; +} + +QualType +VectorForwardModeVisitor::GetPushForwardDerivativeType(QualType ParamType) { + QualType valueType = utils::GetValueType(ParamType); + QualType resType; + if (utils::isArrayOrPointerType(ParamType)) { + // If the parameter is a pointer or an array, then the derivative will be a + // reference to the matrix. + resType = GetCladMatrixOfType(valueType); + resType = m_Context.getLValueReferenceType(resType); + } else { + // If the parameter is not a pointer or an array, then the derivative will + // be a clad array. + resType = GetCladArrayOfType(valueType); + + // Add const qualifier if the parameter is const. + if (ParamType.getNonReferenceType().isConstQualified()) + resType.addConst(); + + // Add reference qualifier if the parameter is a reference. + if (ParamType->isReferenceType()) + resType = m_Context.getLValueReferenceType(resType); + } + return resType; +} + +void VectorForwardModeVisitor::SetIndependentVarsExpr(Expr* IndVarCountExpr) { + m_IndVarCountExpr = IndVarCountExpr; +} + DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, const DiffRequest& request) { diff --git a/lib/Differentiator/VectorPushForwardModeVisitor.cpp b/lib/Differentiator/VectorPushForwardModeVisitor.cpp new file mode 100644 index 000000000..bfd8da43e --- /dev/null +++ b/lib/Differentiator/VectorPushForwardModeVisitor.cpp @@ -0,0 +1,68 @@ +#include "clad/Differentiator/VectorPushForwardModeVisitor.h" + +#include "ConstantFolder.h" +#include "clad/Differentiator/CladUtils.h" + +#include "llvm/Support/SaveAndRestore.h" + +using namespace clang; + +namespace clad { +VectorPushForwardModeVisitor::VectorPushForwardModeVisitor( + DerivativeBuilder& builder) + : VectorForwardModeVisitor(builder) {} + +VectorPushForwardModeVisitor::~VectorPushForwardModeVisitor() = default; + +void VectorPushForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() { + // Extract the last parameter of the m_Derivative function. + // This parameter will either be a clad array or a matrix. + // If it's a clad array, use it's size, or if it's a clad matrix + // use the size of 0th element of the matrix. + ParmVarDecl* lastParam = + m_Derivative->getParamDecl(m_Derivative->getNumParams() - 1); + QualType lastParamType = utils::GetValueType(lastParam->getType()); + Expr* lastParamExpr = BuildDeclRef(lastParam); + Expr* lastParamSizeExpr = nullptr; + if (isCladArrayType(lastParamType)) { + lastParamSizeExpr = BuildArrayRefSizeExpr(lastParamExpr); + } else { + auto* zero = + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); + Expr* arraySubscriptExpr = + m_Sema + .ActOnArraySubscriptExpr(getCurrentScope(), lastParamExpr, + lastParamExpr->getExprLoc(), zero, noLoc) + .get(); + lastParamSizeExpr = BuildArrayRefSizeExpr(arraySubscriptExpr); + } + + // Create a variable to store the total number of independent variables. + Expr* indVarCountExpr = lastParamSizeExpr; + auto* totalIndVars = + BuildVarDecl(m_Context.UnsignedLongTy, "indepVarCount", indVarCountExpr); + addToCurrentBlock(BuildDeclStmt(totalIndVars)); + SetIndependentVarsExpr(BuildDeclRef(totalIndVars)); + + BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock(); +} + +StmtDiff +VectorPushForwardModeVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { + // If there is no return value, we must not attempt to differentiate + if (!RS->getRetValue()) + return nullptr; + + StmtDiff retValDiff = Visit(RS->getRetValue()); + llvm::SmallVector returnValues = {retValDiff.getExpr(), + retValDiff.getExpr_dx()}; + // This can instantiate as part of the move or copy initialization and + // needs a fake source location. + SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema); + Expr* initList = m_Sema.ActOnInitList(fakeLoc, returnValues, noLoc).get(); + Stmt* returnStmt = + m_Sema.ActOnReturnStmt(fakeLoc, initList, getCurrentScope()).get(); + return StmtDiff(returnStmt); +} + +} // end namespace clad diff --git a/test/FirstDerivative/FunctionCalls.C b/test/FirstDerivative/FunctionCalls.C index be8006bd7..953ccc6cf 100644 --- a/test/FirstDerivative/FunctionCalls.C +++ b/test/FirstDerivative/FunctionCalls.C @@ -144,6 +144,28 @@ double test_7(double i, double j) { // CHECK-NEXT: return _d_res; // CHECK-NEXT: } +enum E {A, B, C}; +double func_with_enum(double x, E e) { + return x*x; +} + +double test_8(double x) { + E e; + return func_with_enum(x, e); +} + +// CHECK: clad::ValueAndPushforward func_with_enum_pushforward(double x, E e, double _d_x) { +// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x}; +// CHECK-NEXT: } + +// CHECK: double test_8_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: E _d_e; +// CHECK-NEXT: E e; +// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward _t0 = func_with_enum_pushforward(x, e, _d_x); +// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: } + int main () { clad::differentiate(test_1, 0); clad::differentiate(test_2, 0); @@ -152,6 +174,7 @@ int main () { clad::differentiate(test_5, 0); clad::differentiate(test_6, "x"); clad::differentiate(test_7, "i"); + clad::differentiate(test_8, "x"); return 0; } diff --git a/test/ForwardMode/VectorMode.C b/test/ForwardMode/VectorMode.C index aacdb37a8..17be014da 100644 --- a/test/ForwardMode/VectorMode.C +++ b/test/ForwardMode/VectorMode.C @@ -137,6 +137,7 @@ double f5(double x, double y, double z) { // CHECK-NEXT: *_d_z = _d_vector_return[2UL]; // CHECK-NEXT: return; // CHECK-NEXT: } +// CHECK-NEXT: } // x, y // CHECK: void f5_dvec_0_1(double x, double y, double z, double *_d_x, double *_d_y) { @@ -150,6 +151,7 @@ double f5(double x, double y, double z) { // CHECK-NEXT: *_d_y = _d_vector_return[1UL]; // CHECK-NEXT: return; // CHECK-NEXT: } +// CHECK-NEXT: } // x, z // CHECK: void f5_dvec_0_2(double x, double y, double z, double *_d_x, double *_d_z) { @@ -163,6 +165,7 @@ double f5(double x, double y, double z) { // CHECK-NEXT: *_d_z = _d_vector_return[1UL]; // CHECK-NEXT: return; // CHECK-NEXT: } +// CHECK-NEXT: } // y, z // CHECK: void f5_dvec_1_2(double x, double y, double z, double *_d_y, double *_d_z) { @@ -176,6 +179,7 @@ double f5(double x, double y, double z) { // CHECK-NEXT: *_d_z = _d_vector_return[1UL]; // CHECK-NEXT: return; // CHECK-NEXT: } +// CHECK-NEXT: } // z // CHECK: void f5_dvec_2(double x, double y, double z, double *_d_z) { @@ -188,7 +192,117 @@ double f5(double x, double y, double z) { // CHECK-NEXT: *_d_z = _d_vector_return[0UL]; // CHECK-NEXT: return; // CHECK-NEXT: } +// CHECK-NEXT: } + +double square(const double& x) { + double z = x*x; + return z; +} + +// CHECK: clad::ValueAndPushforward > square_vector_pushforward(const double &x, const clad::array &_d_x) { +// CHECK-NEXT: unsigned long indepVarCount = _d_x.size(); +// CHECK-NEXT: clad::array _d_vector_z(clad::array(indepVarCount, _d_x * x + x * _d_x)); +// CHECK-NEXT: double z = x * x; +// CHECK-NEXT: return {z, _d_vector_z}; +// CHECK-NEXT: } +double f6(double x, double y) { + return square(x) + square(y); +} + +// CHECK: void f6_dvec(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: unsigned long indepVarCount = 2UL; +// CHECK-NEXT: clad::array _d_vector_x = clad::one_hot_vector(indepVarCount, 0UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::one_hot_vector(indepVarCount, 1UL); +// CHECK-NEXT: clad::ValueAndPushforward > _t0 = square_vector_pushforward(x, _d_vector_x); +// CHECK-NEXT: clad::ValueAndPushforward > _t1 = square_vector_pushforward(y, _d_vector_y); +// CHECK-NEXT: { +// CHECK-NEXT: clad::array _d_vector_return(clad::array(indepVarCount, _t0.pushforward + _t1.pushforward)); +// CHECK-NEXT: *_d_x = _d_vector_return[0UL]; +// CHECK-NEXT: *_d_y = _d_vector_return[1UL]; +// CHECK-NEXT: return; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double weighted_array_squared_sum(const double* arr, double w, int n) { + double sum = 0; + for (int i = 0; i < n; ++i) { + sum += w * square(arr[i]); + } + return sum; +} + +// CHECK: clad::ValueAndPushforward > weighted_array_squared_sum_vector_pushforward(const double *arr, double w, int n, clad::matrix &_d_arr, clad::array _d_w, clad::array _d_n) { +// CHECK-NEXT: unsigned long indepVarCount = _d_n.size(); +// CHECK-NEXT: clad::array _d_vector_sum(clad::array(indepVarCount, 0)); +// CHECK-NEXT: double sum = 0; +// CHECK-NEXT: { +// CHECK-NEXT: clad::array _d_vector_i(clad::array(indepVarCount, 0)); +// CHECK-NEXT: for (int i = 0; i < n; ++i) { +// CHECK-NEXT: clad::ValueAndPushforward > _t0 = square_vector_pushforward(arr[i], _d_arr[i]); +// CHECK-NEXT: double &_t1 = _t0.value; +// CHECK-NEXT: _d_vector_sum += _d_w * _t1 + w * _t0.pushforward; +// CHECK-NEXT: sum += w * _t1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return {sum, _d_vector_sum}; +// CHECK-NEXT: } + +double f7(const double* arr, double w, int n) { + return weighted_array_squared_sum(arr, w, n); +} + +// CHECK: void f7_dvec_0_1(const double *arr, double w, int n, clad::array_ref _d_arr, double *_d_w) { +// CHECK-NEXT: unsigned long indepVarCount = _d_arr.size() + 1UL; +// CHECK-NEXT: clad::matrix _d_vector_arr = clad::identity_matrix(_d_arr.size(), indepVarCount, 0UL); +// CHECK-NEXT: clad::array _d_vector_w = clad::one_hot_vector(indepVarCount, _d_arr.size()); +// CHECK-NEXT: clad::array _d_vector_n = clad::zero_vector(indepVarCount); +// CHECK-NEXT: clad::ValueAndPushforward > _t0 = weighted_array_squared_sum_vector_pushforward(arr, w, n, _d_vector_arr, _d_vector_w, _d_vector_n); +// CHECK-NEXT: { +// CHECK-NEXT: clad::array _d_vector_return(clad::array(indepVarCount, _t0.pushforward)); +// CHECK-NEXT: _d_arr = _d_vector_return.slice(0UL, _d_arr.size()); +// CHECK-NEXT: *_d_w = _d_vector_return[_d_arr.size()]; +// CHECK-NEXT: return; +// CHECK-NEXT: } +// CHECK-NEXT: } + +void sum_ref(double& res, int n, const double* arr) { + for(int i=0; i &_d_res, clad::array _d_n, clad::matrix &_d_arr) { +// CHECK-NEXT: unsigned long indepVarCount = _d_arr[0].size(); +// CHECK-NEXT: { +// CHECK-NEXT: clad::array _d_vector_i(clad::array(indepVarCount, 0)); +// CHECK-NEXT: for (int i = 0; i < n; ++i) { +// CHECK-NEXT: _d_res += _d_arr[i]; +// CHECK-NEXT: res += arr[i]; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double f8(int n, const double* arr) { + double res = 0; + sum_ref(res, n, arr); + return res; +} + +// CHECK: void f8_dvec_1(int n, const double *arr, clad::array_ref _d_arr) { +// CHECK-NEXT: unsigned long indepVarCount = _d_arr.size(); +// CHECK-NEXT: clad::array _d_vector_n = clad::zero_vector(indepVarCount); +// CHECK-NEXT: clad::matrix _d_vector_arr = clad::identity_matrix(_d_arr.size(), indepVarCount, 0UL); +// CHECK-NEXT: clad::array _d_vector_res(clad::array(indepVarCount, 0)); +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: sum_ref_vector_pushforward(res, n, arr, _d_vector_res, _d_vector_n, _d_vector_arr); +// CHECK-NEXT: { +// CHECK-NEXT: clad::array _d_vector_return(clad::array(indepVarCount, _d_vector_res)); +// CHECK-NEXT: _d_arr = _d_vector_return.slice(0UL, _d_arr.size()); +// CHECK-NEXT: return; +// CHECK-NEXT: } +// CHECK-NEXT: } #define TEST(F, x, y) \ { \ @@ -230,4 +344,26 @@ int main() { auto f_dvec_z = clad::differentiate(f5, "z"); f_dvec_z.execute(1, 2, 3, &result[0]); printf("Result is = {%.2f}\n", result[0]); // CHECK-EXEC: Result is = {3.00} + + // Testing derivatives of function calls. + auto f6_dvec = clad::differentiate(f6); + f6_dvec.execute(1, 2, &result[0], &result[1]); + printf("Result is = {%.2f, %.2f}\n", result[0], result[1]); // CHECK-EXEC: Result is = {2.00, 4.00} + + // Testing derivatives of function calls with array parameters. + auto f7_dvec = clad::differentiate(f7, "arr,w"); + double arr[3] = {1, 2, 3}; + double w = 2, dw = 0; + double darr[3] = {0, 0, 0}; + clad::array_ref darr_ref(darr, 3); + f7_dvec.execute(arr, 2, 3, darr_ref, &dw); + printf("Result is = {%.2f, %.2f, %.2f, %.2f}\n", darr[0], darr[1], darr[2], dw); // CHECK-EXEC: Result is = {4.00, 8.00, 12.00, 14.00} + + // Testing derivatives of function calls with array and reference parameters. + auto f8_dvec = clad::differentiate(f8, "arr"); + double arr2[3] = {1, 2, 3}; + double darr2[3] = {0, 0, 0}; + clad::array_ref darr2_ref(darr2, 3); + f8_dvec.execute(3, arr2, darr2_ref); + printf("Result is = {%.2f, %.2f, %.2f}\n", darr2[0], darr2[1], darr2[2]); // CHECK-EXEC: Result is = {1.00, 1.00, 1.00} }