diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 068d2170d..6f356dd05 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -93,6 +93,11 @@ class BaseForwardModeVisitor VisitUnaryExprOrTypeTraitExpr(const clang::UnaryExprOrTypeTraitExpr* UE); StmtDiff VisitPseudoObjectExpr(const clang::PseudoObjectExpr* POE); + virtual clang::QualType + GetPushForwardDerivativeType(clang::QualType ParamType); + virtual std::string GetPushForwardFunctionSuffix(); + virtual DiffMode GetPushForwardMode(); + protected: /// Helper function for differentiating the switch statement body. /// diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index c1015db68..75a500244 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -72,8 +72,9 @@ namespace clad { private: friend class VisitorBase; friend class BaseForwardModeVisitor; - friend class ForwardModeVisitor; + friend class PushForwardModeVisitor; friend class VectorForwardModeVisitor; + friend class VectorPushForwardModeVisitor; friend class ReverseModeVisitor; friend class HessianModeVisitor; friend class JacobianModeVisitor; diff --git a/include/clad/Differentiator/DiffMode.h b/include/clad/Differentiator/DiffMode.h index a03e77e49..919d22c80 100644 --- a/include/clad/Differentiator/DiffMode.h +++ b/include/clad/Differentiator/DiffMode.h @@ -8,6 +8,7 @@ enum class DiffMode { vector_forward_mode, experimental_pushforward, experimental_pullback, + experimental_vector_pushforward, reverse, hessian, jacobian, diff --git a/include/clad/Differentiator/ForwardModeVisitor.h b/include/clad/Differentiator/PushForwardModeVisitor.h similarity index 75% rename from include/clad/Differentiator/ForwardModeVisitor.h rename to include/clad/Differentiator/PushForwardModeVisitor.h index 244285ff2..6b8f79c3d 100644 --- a/include/clad/Differentiator/ForwardModeVisitor.h +++ b/include/clad/Differentiator/PushForwardModeVisitor.h @@ -10,17 +10,19 @@ #include "BaseForwardModeVisitor.h" namespace clad { - /// A visitor for processing the function code in forward mode. - /// Used to compute derivatives by clad::differentiate. -class ForwardModeVisitor : public BaseForwardModeVisitor { +/// A visitor for processing the function code in forward mode. +/// Used to compute derivatives by clad::differentiate. +class PushForwardModeVisitor : virtual public BaseForwardModeVisitor { public: - ForwardModeVisitor(DerivativeBuilder& builder); - ~ForwardModeVisitor(); + PushForwardModeVisitor(DerivativeBuilder& builder); + ~PushForwardModeVisitor(); DerivativeAndOverload DerivePushforward(const clang::FunctionDecl* FD, const DiffRequest& request); + virtual void ExecuteInsideFunctionBlock(); + /// Returns the return type for the pushforward function of the function /// `m_Function`. /// \note `m_Function` field should be set before using this function. diff --git a/include/clad/Differentiator/VectorForwardModeVisitor.h b/include/clad/Differentiator/VectorForwardModeVisitor.h index a5a21d13a..544cc699a 100644 --- a/include/clad/Differentiator/VectorForwardModeVisitor.h +++ b/include/clad/Differentiator/VectorForwardModeVisitor.h @@ -8,7 +8,7 @@ namespace clad { /// A visitor for processing the function code in vector forward mode. /// Used to compute derivatives by clad::vector_forward_differentiate. -class VectorForwardModeVisitor : public BaseForwardModeVisitor { +class VectorForwardModeVisitor : virtual public BaseForwardModeVisitor { private: llvm::SmallVector m_IndependentVars; /// Map used to keep track of parameter variables w.r.t which the @@ -78,6 +78,14 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor { StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; // Decl is not Stmt, so it cannot be visited directly. VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD) override; + + clang::QualType + GetPushForwardDerivativeType(clang::QualType ParamType) override; + std::string GetPushForwardFunctionSuffix() override; + DiffMode GetPushForwardMode() override; + + // Function for setting the independent variables for vector mode. + void SetIndependentVarsExpr(clang::Expr* IndVarsExpr); }; } // end namespace clad diff --git a/include/clad/Differentiator/VectorPushForwardModeVisitor.h b/include/clad/Differentiator/VectorPushForwardModeVisitor.h new file mode 100644 index 000000000..995911695 --- /dev/null +++ b/include/clad/Differentiator/VectorPushForwardModeVisitor.h @@ -0,0 +1,21 @@ +#ifndef CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H +#define CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H + +#include "PushForwardModeVisitor.h" +#include "VectorForwardModeVisitor.h" + +namespace clad { +class VectorPushForwardModeVisitor : public PushForwardModeVisitor, + public VectorForwardModeVisitor { + +public: + VectorPushForwardModeVisitor(DerivativeBuilder& builder); + ~VectorPushForwardModeVisitor(); + + void ExecuteInsideFunctionBlock() override; + + StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; +}; +} // end namespace clad + +#endif // CLAD_DIFFERENTIATOR_VECTORPUSHFORWARDMODEVISITOR_H diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index f092adcd2..36b0d122b 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -914,6 +914,19 @@ Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( return OverloadedFn; } +QualType +BaseForwardModeVisitor::GetPushForwardDerivativeType(QualType ParamType) { + return ParamType; +} + +std::string BaseForwardModeVisitor::GetPushForwardFunctionSuffix() { + return "_pushforward"; +} + +DiffMode BaseForwardModeVisitor::GetPushForwardMode() { + return DiffMode::experimental_pushforward; +} + StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { const FunctionDecl* FD = CE->getDirectCallee(); if (!FD) { @@ -1008,10 +1021,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { CallArgs.push_back(argDiff.getExpr()); if (BaseForwardModeVisitor::IsDifferentiableType(arg->getType())) { Expr* dArg = argDiff.getExpr_dx(); - QualType CallArgTy = CallArgs.back()->getType(); - assert((!dArg || m_Context.hasSameType(CallArgTy, dArg->getType())) && - "Type mismatch, we might fail to instantiate a pullback"); - (void)CallArgTy; // FIXME: What happens when dArg is nullptr? diffArgs.push_back(dArg); } @@ -1034,14 +1043,13 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // Try to find a user-defined overloaded derivative. std::string customPushforward = - clad::utils::ComputeEffectiveFnName(FD) + "_pushforward"; + clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix(); Expr* callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPushforward, customDerivativeArgs, getCurrentScope(), const_cast(FD->getDeclContext())); // Check if it is a recursive call. - if (!callDiff && (FD == m_Function) && - m_Mode == DiffMode::experimental_pushforward) { + if (!callDiff && (FD == m_Function) && m_Mode == GetPushForwardMode()) { // The differentiated function is called recursively. Expr* derivativeRef = m_Sema @@ -1060,7 +1068,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // derive the called function. DiffRequest pushforwardFnRequest; pushforwardFnRequest.Function = FD; - pushforwardFnRequest.Mode = DiffMode::experimental_pushforward; + pushforwardFnRequest.Mode = GetPushForwardMode(); pushforwardFnRequest.BaseFunctionName = FD->getNameAsString(); // pushforwardFnRequest.RequestedDerivativeOrder = m_DerivativeOrder; // Silence diag outputs in nested derivation process. diff --git a/lib/Differentiator/CMakeLists.txt b/lib/Differentiator/CMakeLists.txt index f5eddb2c6..2bb0c831a 100644 --- a/lib/Differentiator/CMakeLists.txt +++ b/lib/Differentiator/CMakeLists.txt @@ -27,14 +27,15 @@ add_llvm_library(cladDifferentiator DiffPlanner.cpp ErrorEstimator.cpp EstimationModel.cpp - ForwardModeVisitor.cpp HessianModeVisitor.cpp JacobianModeVisitor.cpp MultiplexExternalRMVSource.cpp + PushForwardModeVisitor.cpp ReverseModeForwPassVisitor.cpp ReverseModeVisitor.cpp StmtClone.cpp VectorForwardModeVisitor.cpp + VectorPushForwardModeVisitor.cpp Version.cpp VisitorBase.cpp ${version_inc} diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index c07e70057..23417dfda 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -18,13 +18,14 @@ #include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" -#include "clad/Differentiator/ForwardModeVisitor.h" #include "clad/Differentiator/HessianModeVisitor.h" #include "clad/Differentiator/JacobianModeVisitor.h" +#include "clad/Differentiator/PushForwardModeVisitor.h" #include "clad/Differentiator/ReverseModeForwPassVisitor.h" #include "clad/Differentiator/ReverseModeVisitor.h" #include "clad/Differentiator/StmtClone.h" #include "clad/Differentiator/VectorForwardModeVisitor.h" +#include "clad/Differentiator/VectorPushForwardModeVisitor.h" #include @@ -209,11 +210,14 @@ namespace clad { BaseForwardModeVisitor V(*this); result = V.Derive(FD, request); } else if (request.Mode == DiffMode::experimental_pushforward) { - ForwardModeVisitor V(*this); + PushForwardModeVisitor V(*this); result = V.DerivePushforward(FD, request); } else if (request.Mode == DiffMode::vector_forward_mode) { VectorForwardModeVisitor V(*this); result = V.DeriveVectorMode(FD, request); + } else if (request.Mode == DiffMode::experimental_vector_pushforward) { + VectorPushForwardModeVisitor V(*this); + result = V.DerivePushforward(FD, request); } else if (request.Mode == DiffMode::reverse) { ReverseModeVisitor V(*this); result = V.Derive(FD, request); diff --git a/lib/Differentiator/ForwardModeVisitor.cpp b/lib/Differentiator/ForwardModeVisitor.cpp deleted file mode 100644 index d702dcbfc..000000000 --- a/lib/Differentiator/ForwardModeVisitor.cpp +++ /dev/null @@ -1,185 +0,0 @@ -//--------------------------------------------------------------------*- C++ -*- -// clad - the C++ Clang-based Automatic Differentiator -// version: $Id: ClangPlugin.cpp 7 2013-06-01 22:48:03Z v.g.vassilev@gmail.com $ -// author: Vassil Vassilev -//------------------------------------------------------------------------------ - -#include "clad/Differentiator/ForwardModeVisitor.h" -#include "clad/Differentiator/BaseForwardModeVisitor.h" - -#include "clad/Differentiator/CladUtils.h" - -#include "llvm/Support/SaveAndRestore.h" - -using namespace clang; - -namespace clad { -ForwardModeVisitor::ForwardModeVisitor(DerivativeBuilder& builder) - : BaseForwardModeVisitor(builder) {} - -ForwardModeVisitor::~ForwardModeVisitor() {} - -clang::QualType ForwardModeVisitor::ComputePushforwardFnReturnType() { - assert(m_Mode == DiffMode::experimental_pushforward); - QualType originalFnRT = m_Function->getReturnType(); - if (originalFnRT->isVoidType()) - return m_Context.VoidTy; - TemplateDecl* valueAndPushforward = - LookupTemplateDeclInCladNamespace("ValueAndPushforward"); - assert(valueAndPushforward && - "clad::ValueAndPushforward template not found!!"); - QualType RT = - InstantiateTemplate(valueAndPushforward, {originalFnRT, originalFnRT}); - return RT; -} - - DerivativeAndOverload - ForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, - const DiffRequest& request) { - m_Function = FD; - m_Functor = request.Functor; - m_DerivativeOrder = request.CurrentDerivativeOrder; - m_Mode = DiffMode::experimental_pushforward; - assert(!m_DerivativeInFlight && - "Doesn't support recursive diff. Use DiffPlan."); - m_DerivativeInFlight = true; - - auto originalFnEffectiveName = utils::ComputeEffectiveFnName(m_Function); - - IdentifierInfo* derivedFnII = - &m_Context.Idents.get(originalFnEffectiveName + "_pushforward"); - DeclarationNameInfo derivedFnName(derivedFnII, noLoc); - llvm::SmallVector paramTypes, derivedParamTypes; - - // If we are differentiating an instance member function then - // create a parameter type for the parameter that will represent the - // derivative of `this` pointer with respect to the independent parameter. - if (auto MD = dyn_cast(FD)) { - if (MD->isInstance()) { - QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); - derivedParamTypes.push_back(thisType); - } - } - - for (auto* PVD : m_Function->parameters()) { - paramTypes.push_back(PVD->getType()); - - if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) - derivedParamTypes.push_back(PVD->getType()); - } - - paramTypes.insert(paramTypes.end(), derivedParamTypes.begin(), - derivedParamTypes.end()); - - auto originalFnType = dyn_cast(m_Function->getType()); - QualType returnType = ComputePushforwardFnReturnType(); - QualType derivedFnType = - m_Context.getFunctionType(returnType, paramTypes, - originalFnType->getExtProtoInfo()); - llvm::SaveAndRestore saveContext(m_Sema.CurContext); - llvm::SaveAndRestore saveScope(m_CurScope); - DeclContext* DC = const_cast(m_Function->getDeclContext()); - m_Sema.CurContext = DC; - - DeclWithContext cloneFunctionResult = m_Builder.cloneFunction( - m_Function, *this, DC, noLoc, derivedFnName, derivedFnType); - m_Derivative = cloneFunctionResult.first; - - llvm::SmallVector params; - llvm::SmallVector derivedParams; - beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | - Scope::DeclScope); - m_Sema.PushFunctionScope(); - m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); - - // If we are differentiating an instance member function then - // create a parameter for representing derivative of - // `this` pointer with respect to the independent parameter. - if (auto MFD = dyn_cast(FD)) { - if (MFD->isInstance()) { - auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MFD); - IdentifierInfo* derivedPVDII = CreateUniqueIdentifier("_d_this"); - auto derivedPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext, - derivedPVDII, thisType); - m_Sema.PushOnScopeChains(derivedPVD, getCurrentScope(), - /*AddToContext=*/false); - derivedParams.push_back(derivedPVD); - m_ThisExprDerivative = BuildDeclRef(derivedPVD); - } - } - - std::size_t numParamsOriginalFn = m_Function->getNumParams(); - for (std::size_t i = 0; i < numParamsOriginalFn; ++i) { - auto PVD = m_Function->getParamDecl(i); - // Some of the special member functions created implicitly by compilers - // have missing parameter identifier. - bool identifierMissing = false; - IdentifierInfo* PVDII = PVD->getIdentifier(); - if (!PVDII || PVDII->getLength() == 0) { - PVDII = CreateUniqueIdentifier("param"); - identifierMissing = true; - } - auto newPVD = CloneParmVarDecl(PVD, PVDII, - /*pushOnScopeChains=*/true, - /*cloneDefaultArg=*/false); - params.push_back(newPVD); - - if (identifierMissing) - m_DeclReplacements[PVD] = newPVD; - - if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) - continue; - auto derivedPVDName = "_d_" + std::string(PVDII->getName()); - IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName); - auto derivedPVD = CloneParmVarDecl(PVD, derivedPVDII, - /*pushOnScopeChains=*/true, - /*cloneDefaultArg=*/false); - derivedParams.push_back(derivedPVD); - m_Variables[newPVD] = BuildDeclRef(derivedPVD); - } - - params.insert(params.end(), derivedParams.begin(), derivedParams.end()); - m_Derivative->setParams(params); - m_Derivative->setBody(nullptr); - - beginScope(Scope::FnScope | Scope::DeclScope); - m_DerivativeFnScope = getCurrentScope(); - beginBlock(); - - Stmt* bodyDiff = Visit(FD->getBody()).getStmt(); - CompoundStmt* CS = cast(bodyDiff); - for (Stmt* S : CS->body()) - addToCurrentBlock(S); - - Stmt* derivativeBody = endBlock(); - m_Derivative->setBody(derivativeBody); - - endScope(); // Function body scope - m_Sema.PopFunctionScopeInfo(); - m_Sema.PopDeclContext(); - endScope(); // Function decl scope - - m_DerivativeInFlight = false; - return DerivativeAndOverload{cloneFunctionResult.first}; - } - - StmtDiff ForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { - //If there is no return value, we must not attempt to differentiate - if (!RS->getRetValue()) - return nullptr; - - StmtDiff retValDiff = Visit(RS->getRetValue()); - llvm::SmallVector returnValues = {retValDiff.getExpr(), - retValDiff.getExpr_dx()}; - SourceLocation fakeInitLoc = utils::GetValidSLoc(m_Sema); - // This can instantiate as part of the move or copy initialization and - // needs a fake source location. - Expr* initList = - m_Sema.ActOnInitList(fakeInitLoc, returnValues, noLoc).get(); - - SourceLocation fakeRetLoc = utils::GetValidSLoc(m_Sema); - Stmt* returnStmt = - m_Sema.ActOnReturnStmt(fakeRetLoc, initList, getCurrentScope()).get(); - return StmtDiff(returnStmt); - } -} // end namespace clad diff --git a/lib/Differentiator/PushForwardModeVisitor.cpp b/lib/Differentiator/PushForwardModeVisitor.cpp new file mode 100644 index 000000000..daac08235 --- /dev/null +++ b/lib/Differentiator/PushForwardModeVisitor.cpp @@ -0,0 +1,188 @@ +//--------------------------------------------------------------------*- C++ -*- +// clad - the C++ Clang-based Automatic Differentiator +// version: $Id: ClangPlugin.cpp 7 2013-06-01 22:48:03Z v.g.vassilev@gmail.com $ +// author: Vassil Vassilev +//------------------------------------------------------------------------------ + +#include "clad/Differentiator/PushForwardModeVisitor.h" +#include "clad/Differentiator/BaseForwardModeVisitor.h" + +#include "clad/Differentiator/CladUtils.h" + +#include "llvm/Support/SaveAndRestore.h" + +using namespace clang; + +namespace clad { +PushForwardModeVisitor::PushForwardModeVisitor(DerivativeBuilder& builder) + : BaseForwardModeVisitor(builder) {} + +PushForwardModeVisitor::~PushForwardModeVisitor() {} + +clang::QualType PushForwardModeVisitor::ComputePushforwardFnReturnType() { + assert(m_Mode == GetPushForwardMode()); + QualType originalFnRT = m_Function->getReturnType(); + if (originalFnRT->isVoidType()) + return m_Context.VoidTy; + TemplateDecl* valueAndPushforward = + LookupTemplateDeclInCladNamespace("ValueAndPushforward"); + assert(valueAndPushforward && + "clad::ValueAndPushforward template not found!!"); + QualType RT = InstantiateTemplate( + valueAndPushforward, + {originalFnRT, GetPushForwardDerivativeType(originalFnRT)}); + return RT; +} + +DerivativeAndOverload +PushForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, + const DiffRequest& request) { + m_Function = FD; + m_Functor = request.Functor; + m_DerivativeOrder = request.CurrentDerivativeOrder; + m_Mode = GetPushForwardMode(); + assert(!m_DerivativeInFlight && + "Doesn't support recursive diff. Use DiffPlan."); + m_DerivativeInFlight = true; + + auto originalFnEffectiveName = utils::ComputeEffectiveFnName(m_Function); + + IdentifierInfo* derivedFnII = &m_Context.Idents.get( + originalFnEffectiveName + GetPushForwardFunctionSuffix()); + DeclarationNameInfo derivedFnName(derivedFnII, noLoc); + llvm::SmallVector paramTypes, derivedParamTypes; + + // If we are differentiating an instance member function then + // create a parameter type for the parameter that will represent the + // derivative of `this` pointer with respect to the independent parameter. + if (auto MD = dyn_cast(FD)) { + if (MD->isInstance()) { + QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); + derivedParamTypes.push_back(thisType); + } + } + + for (auto* PVD : m_Function->parameters()) { + paramTypes.push_back(PVD->getType()); + + if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) + derivedParamTypes.push_back(GetPushForwardDerivativeType(PVD->getType())); + } + + paramTypes.insert(paramTypes.end(), derivedParamTypes.begin(), + derivedParamTypes.end()); + + auto originalFnType = dyn_cast(m_Function->getType()); + QualType returnType = ComputePushforwardFnReturnType(); + QualType derivedFnType = m_Context.getFunctionType( + returnType, paramTypes, originalFnType->getExtProtoInfo()); + llvm::SaveAndRestore saveContext(m_Sema.CurContext); + llvm::SaveAndRestore saveScope(m_CurScope); + DeclContext* DC = const_cast(m_Function->getDeclContext()); + m_Sema.CurContext = DC; + + DeclWithContext cloneFunctionResult = m_Builder.cloneFunction( + m_Function, *this, DC, noLoc, derivedFnName, derivedFnType); + m_Derivative = cloneFunctionResult.first; + + llvm::SmallVector params; + llvm::SmallVector derivedParams; + beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | + Scope::DeclScope); + m_Sema.PushFunctionScope(); + m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); + + // If we are differentiating an instance member function then + // create a parameter for representing derivative of + // `this` pointer with respect to the independent parameter. + if (auto MFD = dyn_cast(FD)) { + if (MFD->isInstance()) { + auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MFD); + IdentifierInfo* derivedPVDII = CreateUniqueIdentifier("_d_this"); + auto derivedPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext, + derivedPVDII, thisType); + m_Sema.PushOnScopeChains(derivedPVD, getCurrentScope(), + /*AddToContext=*/false); + derivedParams.push_back(derivedPVD); + m_ThisExprDerivative = BuildDeclRef(derivedPVD); + } + } + + std::size_t numParamsOriginalFn = m_Function->getNumParams(); + for (std::size_t i = 0; i < numParamsOriginalFn; ++i) { + auto PVD = m_Function->getParamDecl(i); + // Some of the special member functions created implicitly by compilers + // have missing parameter identifier. + bool identifierMissing = false; + IdentifierInfo* PVDII = PVD->getIdentifier(); + if (!PVDII || PVDII->getLength() == 0) { + PVDII = CreateUniqueIdentifier("param"); + identifierMissing = true; + } + auto newPVD = CloneParmVarDecl(PVD, PVDII, + /*pushOnScopeChains=*/true, + /*cloneDefaultArg=*/false); + params.push_back(newPVD); + + if (identifierMissing) + m_DeclReplacements[PVD] = newPVD; + + if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) + continue; + auto derivedPVDName = "_d_" + std::string(PVDII->getName()); + IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName); + auto derivedPVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, derivedPVDII, + GetPushForwardDerivativeType(PVD->getType()), PVD->getStorageClass()); + derivedParams.push_back(derivedPVD); + m_Variables[newPVD] = BuildDeclRef(derivedPVD); + } + + params.insert(params.end(), derivedParams.begin(), derivedParams.end()); + m_Derivative->setParams(params); + m_Derivative->setBody(nullptr); + + beginScope(Scope::FnScope | Scope::DeclScope); + m_DerivativeFnScope = getCurrentScope(); + beginBlock(); + + ExecuteInsideFunctionBlock(); + + Stmt* derivativeBody = endBlock(); + m_Derivative->setBody(derivativeBody); + + endScope(); // Function body scope + m_Sema.PopFunctionScopeInfo(); + m_Sema.PopDeclContext(); + endScope(); // Function decl scope + + m_DerivativeInFlight = false; + return DerivativeAndOverload{cloneFunctionResult.first}; +} + +void PushForwardModeVisitor::ExecuteInsideFunctionBlock() { + Stmt* bodyDiff = Visit(m_Function->getBody()).getStmt(); + CompoundStmt* CS = cast(bodyDiff); + for (Stmt* S : CS->body()) + addToCurrentBlock(S); +} + +StmtDiff PushForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { + // If there is no return value, we must not attempt to differentiate + if (!RS->getRetValue()) + return nullptr; + + StmtDiff retValDiff = Visit(RS->getRetValue()); + llvm::SmallVector returnValues = {retValDiff.getExpr(), + retValDiff.getExpr_dx()}; + SourceLocation fakeInitLoc = utils::GetValidSLoc(m_Sema); + // This can instantiate as part of the move or copy initialization and + // needs a fake source location. + Expr* initList = m_Sema.ActOnInitList(fakeInitLoc, returnValues, noLoc).get(); + + SourceLocation fakeRetLoc = utils::GetValidSLoc(m_Sema); + Stmt* returnStmt = + m_Sema.ActOnReturnStmt(fakeRetLoc, initList, getCurrentScope()).get(); + return StmtDiff(returnStmt); +} +} // end namespace clad diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index 9e38bc6a8..87f268176 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -15,6 +15,43 @@ VectorForwardModeVisitor::VectorForwardModeVisitor(DerivativeBuilder& builder) VectorForwardModeVisitor::~VectorForwardModeVisitor() {} +std::string VectorForwardModeVisitor::GetPushForwardFunctionSuffix() { + return "_vector_pushforward"; +} + +DiffMode VectorForwardModeVisitor::GetPushForwardMode() { + return DiffMode::experimental_vector_pushforward; +} + +QualType +VectorForwardModeVisitor::GetPushForwardDerivativeType(QualType ParamType) { + QualType valueType = utils::GetValueType(ParamType); + QualType resType; + if (utils::isArrayOrPointerType(ParamType)) { + // If the parameter is a pointer or an array, then the derivative will be a + // reference to the matrix. + resType = GetCladMatrixOfType(valueType); + resType = m_Context.getLValueReferenceType(resType); + } else { + // If the parameter is not a pointer or an array, then the derivative will + // be a clad array. + resType = GetCladArrayOfType(valueType); + + // Add const qualifier if the parameter is const. + if (ParamType.getNonReferenceType().isConstQualified()) + resType.addConst(); + + // Add reference qualifier if the parameter is a reference. + if (ParamType->isReferenceType()) + resType = m_Context.getLValueReferenceType(resType); + } + return resType; +} + +void VectorForwardModeVisitor::SetIndependentVarsExpr(Expr* IndVarCountExpr) { + m_IndVarCountExpr = IndVarCountExpr; +} + DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, const DiffRequest& request) { diff --git a/lib/Differentiator/VectorPushForwardModeVisitor.cpp b/lib/Differentiator/VectorPushForwardModeVisitor.cpp new file mode 100644 index 000000000..67e22d082 --- /dev/null +++ b/lib/Differentiator/VectorPushForwardModeVisitor.cpp @@ -0,0 +1,35 @@ +#include "clad/Differentiator/VectorPushForwardModeVisitor.h" + +#include "clad/Differentiator/CladUtils.h" + +#include "llvm/Support/SaveAndRestore.h" + +using namespace clang; + +namespace clad { +VectorPushForwardModeVisitor::VectorPushForwardModeVisitor( + DerivativeBuilder& builder) + : BaseForwardModeVisitor(builder), PushForwardModeVisitor(builder), + VectorForwardModeVisitor(builder) {} + +VectorPushForwardModeVisitor::~VectorPushForwardModeVisitor() {} + +void VectorPushForwardModeVisitor::ExecuteInsideFunctionBlock() { + Expr* indVarCountExpr = nullptr; + auto* totalIndVars = + BuildVarDecl(m_Context.UnsignedLongTy, "indepVarCount", indVarCountExpr); + addToCurrentBlock(BuildDeclStmt(totalIndVars)); + SetIndependentVarsExpr(BuildDeclRef(totalIndVars)); + + Stmt* bodyDiff = Visit(m_Function->getBody()).getStmt(); + CompoundStmt* CS = cast(bodyDiff); + for (Stmt* S : CS->body()) + addToCurrentBlock(S); +} + +StmtDiff +VectorPushForwardModeVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { + return PushForwardModeVisitor::VisitReturnStmt(RS); +} + +} // end namespace clad diff --git a/test/ForwardMode/VectorMode.C b/test/ForwardMode/VectorMode.C index 08fdfad2b..9a8f2d2d5 100644 --- a/test/ForwardMode/VectorMode.C +++ b/test/ForwardMode/VectorMode.C @@ -135,6 +135,7 @@ double f5(double x, double y, double z) { // CHECK-NEXT: *_d_z = _d_vector_return[2UL]; // CHECK-NEXT: return; // CHECK-NEXT: } +// CHECK-NEXT: } // x, y // CHECK: void f5_dvec_0_1(double x, double y, double z, double *_d_x, double *_d_y) { @@ -148,6 +149,7 @@ double f5(double x, double y, double z) { // CHECK-NEXT: *_d_y = _d_vector_return[1UL]; // CHECK-NEXT: return; // CHECK-NEXT: } +// CHECK-NEXT: } // x, z // CHECK: void f5_dvec_0_2(double x, double y, double z, double *_d_x, double *_d_z) { @@ -161,6 +163,7 @@ double f5(double x, double y, double z) { // CHECK-NEXT: *_d_z = _d_vector_return[1UL]; // CHECK-NEXT: return; // CHECK-NEXT: } +// CHECK-NEXT: } // y, z // CHECK: void f5_dvec_1_2(double x, double y, double z, double *_d_y, double *_d_z) { @@ -174,6 +177,7 @@ double f5(double x, double y, double z) { // CHECK-NEXT: *_d_z = _d_vector_return[1UL]; // CHECK-NEXT: return; // CHECK-NEXT: } +// CHECK-NEXT: } // z // CHECK: void f5_dvec_2(double x, double y, double z, double *_d_z) { @@ -186,7 +190,79 @@ double f5(double x, double y, double z) { // CHECK-NEXT: *_d_z = _d_vector_return[0UL]; // CHECK-NEXT: return; // CHECK-NEXT: } +// CHECK-NEXT: } +double square(const double& x) { + double z = x*x; + return z; +} + +// CHECK: clad::ValueAndPushforward > square_vector_pushforward(const double &x, const clad::array &_d_x) { +// CHECK-NEXT: unsigned long indepVarCount; +// CHECK-NEXT: clad::array _d_vector_z(clad::array(indepVarCount, _d_x * x + x * _d_x)); +// CHECK-NEXT: double z = x * x; +// CHECK-NEXT: return {z, _d_vector_z}; +// CHECK-NEXT: } + +double f6(double x, double y) { + return square(x) + square(y); +} + +// CHECK: void f6_dvec(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: unsigned long indepVarCount = 2UL; +// CHECK-NEXT: clad::array _d_vector_x = clad::one_hot_vector(indepVarCount, 0UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::one_hot_vector(indepVarCount, 1UL); +// CHECK-NEXT: clad::ValueAndPushforward > _t0 = square_vector_pushforward(x, _d_vector_x); +// CHECK-NEXT: clad::ValueAndPushforward > _t1 = square_vector_pushforward(y, _d_vector_y); +// CHECK-NEXT: { +// CHECK-NEXT: clad::array _d_vector_return(clad::array(indepVarCount, _t0.pushforward + _t1.pushforward)); +// CHECK-NEXT: *_d_x = _d_vector_return[0UL]; +// CHECK-NEXT: *_d_y = _d_vector_return[1UL]; +// CHECK-NEXT: return; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double weighted_array_squared_sum(const double* arr, double w, int n) { + double sum = 0; + for (int i = 0; i < n; ++i) { + sum += w * square(arr[i]); + } + return sum; +} + +// CHECK: clad::ValueAndPushforward > weighted_array_squared_sum_vector_pushforward(const double *arr, double w, int n, clad::matrix &_d_arr, clad::array _d_w, clad::array _d_n) { +// CHECK-NEXT: unsigned long indepVarCount; +// CHECK-NEXT: clad::array _d_vector_sum(clad::array(indepVarCount, 0)); +// CHECK-NEXT: double sum = 0; +// CHECK-NEXT: { +// CHECK-NEXT: clad::array _d_vector_i(clad::array(indepVarCount, 0)); +// CHECK-NEXT: for (int i = 0; i < n; ++i) { +// CHECK-NEXT: clad::ValueAndPushforward > _t0 = square_vector_pushforward(arr[i], _d_arr[i]); +// CHECK-NEXT: double &_t1 = _t0.value; +// CHECK-NEXT: _d_vector_sum += _d_w * _t1 + w * _t0.pushforward; +// CHECK-NEXT: sum += w * _t1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return {sum, _d_vector_sum}; +// CHECK-NEXT: } + +double f7(const double* arr, double w, int n) { + return weighted_array_squared_sum(arr, w, n); +} + +// CHECK: void f7_dvec_0_1(const double *arr, double w, int n, clad::array_ref _d_arr, double *_d_w) { +// CHECK-NEXT: unsigned long indepVarCount = _d_arr.size() + 1UL; +// CHECK-NEXT: clad::matrix _d_vector_arr = clad::identity_matrix(_d_arr.size(), indepVarCount, 0UL); +// CHECK-NEXT: clad::array _d_vector_w = clad::one_hot_vector(indepVarCount, _d_arr.size()); +// CHECK-NEXT: clad::array _d_vector_n = clad::zero_vector(indepVarCount); +// CHECK-NEXT: clad::ValueAndPushforward > _t0 = weighted_array_squared_sum_vector_pushforward(arr, w, n, _d_vector_arr, _d_vector_w, _d_vector_n); +// CHECK-NEXT: { +// CHECK-NEXT: clad::array _d_vector_return(clad::array(indepVarCount, _t0.pushforward)); +// CHECK-NEXT: _d_arr = _d_vector_return.slice(0UL, _d_arr.size()); +// CHECK-NEXT: *_d_w = _d_vector_return[_d_arr.size()]; +// CHECK-NEXT: return; +// CHECK-NEXT: } +// CHECK-NEXT: } #define TEST(F, x, y) \ { \ @@ -228,4 +304,18 @@ int main() { auto f_dvec_z = clad::differentiate(f5, "z"); f_dvec_z.execute(1, 2, 3, &result[0]); printf("Result is = {%.2f}\n", result[0]); // CHECK-EXEC: Result is = {3.00} + + // Testing derivatives of function calls. + auto f6_dvec = clad::differentiate(f6); + f6_dvec.execute(1, 2, &result[0], &result[1]); + printf("Result is = {%.2f, %.2f}\n", result[0], result[1]); // CHECK-EXEC: Result is = {2.00, 4.00} + + // Testing derivatives of function calls with array parameters. + auto f7_dvec = clad::differentiate(f7, "arr,w"); + double arr[3] = {1, 2, 3}; + double w = 2, dw = 0; + double darr[3] = {0, 0, 0}; + clad::array_ref darr_ref(darr, 3); + f7_dvec.execute(arr, 2, 3, darr_ref, &dw); + printf("Result is = {%.2f, %.2f, %.2f, %.2f}\n", darr[0], darr[1], darr[2], dw); // CHECK-EXEC: Result is = {4.00, 8.00, 12.00, 14.00} }