Skip to content

Commit

Permalink
Add primitive support for custom constructor pushforward functions
Browse files Browse the repository at this point in the history
This commit adds primitive support for custom pushforward functions for
constructors. Custom constructor pushforward function support will enable the
below features:
- Class differentiation support for classes whose constructor Clad
  cannot automatically differentiate. Now, We can enable differentiation of
  entire C++ standard library by providing custom derivatives.
- Remove the restriction of default-constructible for class types. This
  was a troublesome restriction. Now, the only restriction for class
  types is to have a sensible copy-constructor. That is, copy
  constructor should copy the class members and after copy-construction,
  both the objects should be equivalent, mathematically speaking.

Constructor pushforward functions differ from ordinary pushforward
functions in two important ways:
- Constructor pushforward functions initialize the primal class object
  and the corresponding derivative object. Ordinary member function
  pushforwards takes an already-existing primal class object and the
  corresponding derivative object as inputs.
- Constructor pushforward functions return a value even though
  constructor do not return anything. Constructor pushforward functions
  return initialized primal object and the derivative object. These are
  then used to initialize primal object and the derivative in the
  derivative function code.

How to write custom constructor pushforward functions
----------------------------------------

Let's see how to write custom pushforward function for a
constructor:

- Custom constructor pushforwards must have the name `constructor_pushforward`
- Custom constructor pushforwards must be defined in
  `::clad::custom_derivatives::class_functions` namespace.
- The parameters of the custom constructor pushforward must be:
  {`::clad::ConstructorPushforwardTag<Class>`, original params...,
    derivative params...}.

'original parameters...' and 'derivative parameters...' is same as what
we have for other pushforward functions. We will soon see why do we need
`::clad::ConstructorPushforwardTag<T>` for constructor custom pushforwards.

Let's see a basic example of how to write custom constructor
pushforward.

```cpp
class Coordinates {
  Coordinates(double px, double py, double pz) :
    x(px), y(py), z(pz) {}

  public:
  double x, y, z;
}

namespace clad {
namespace custom_derivatives {
namespace class_functions {
// custom constructor pushforward function
clad::ValueAndPushforward<Coordinates, Coordinates>
constructor_pushforward(clad::ConstructorPushforwardTag<Coordinates>, double x,
                        double y, double z, double d_x, double d_y, double d_z) {
  return {Coordinates(x, y, z), Coordinates(d_x, d_y, d_z) };
}
} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad

// custom constructor pushforward is used as follows:
// primal code
Constructor c(u, v, w);

// derivative code
clad::ValueAndPushforward<Coordinates, Coordinates> _t0 =
  constructor_pushforward(clad::ConstructorPushforwardTag<Coordinates>, u, v,
                          w, _d_u, _d_v, _d_w);
Coordinates _d_c = _t0.pushforward;
Coordinates c = _t0.value;
```

Now, let's see a bit advanced example based on `std::vector` constructor.

```cpp
namespace clad {
namespace custom_derivatives {
namespace class_functions {
// Custom pushforward for: vector(size_t n, const typename
::std::vector<T>::allocator_type alloc)
template <typename T>
clad::ValueAndPushforward<::std::vector<T>, ::std::vector<T>>
constructor_pushforward(
    ConstructorPushforwardTag<::std::vector<T>>, size_t n,
    const typename ::std::vector<T>::allocator_type alloc, size_t d_n,
    const typename ::std::vector<T>::allocator_type d_alloc) {
  ::std::vector<T> v(n, alloc);
  ::std::vector<T> d_v(n, 0, alloc);
  return {v, d_v};
}

// Custom pushfoward for: vector(size_t n, T val, const typename
::std::vector<T>::allocator_type alloc)
template <typename T>
clad::ValueAndPushforward<::std::vector<T>, ::std::vector<T>>
constructor_pushforward(
    ConstructorPushforwardTag<::std::vector<T>>, size_t n, T val,
    const typename ::std::vector<T>::allocator_type alloc, size_t d_n, T d_val,
    const typename ::std::vector<T>::allocator_type d_alloc) {
  ::std::vector<T> v(n, val, alloc);
  ::std::vector<T> d_v(n, d_val, alloc);
  return {v, d_v};
}
} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad

// The custom constructor pushforwards is used as follows:

// Primal code:
std::vector<double> v(10, u);

// Derivative code:
clad::ValueAndPushforward<std::vector<double>, std::vector<double>> _t0 =
  clad::custom_derivatives::class_functions::constructor_pushforward(
    clad::ConstructorPushforwardTag<std::vector<double> >(), 10, u,
    allocator_type(), 0, _d_u, allocator_type());
std::vector<double> d_v = _t0.pushforward;
std::vector<double> v = _t0.value;
```

Why `clad::ConstructorPushforwardTag<T>`?
------------------------

So, why do we need clad::ConstructorPushforwardTag<T>?

For a constructor that takes two parameters of types `size_t` and `double`,
the custom pushforward will have the following signature if we do not
include `::clad::ConstructorPushforwardTag<T>`:

```cpp
clad::ValueAndPushforward<Class, Class> constructor_pushforward(size_t n,
  double val, size_t d_n, double d_val);
```

Now, the question is: How to distinguish custom constructor pushforwards
for different classes?

```cpp
MyClassA a(3, 5.0);
MyClassB b(7, 9.0);
```

There is no way for overload resolution selector to distinguish
constructor_pushforward for classes `MyClassA` and `MyClassB`.

`clad::ConstructorPushforwardTag<T>` is used to identify the class for which
custom constructor pushforward is defined.

Please note that we cannot use the same strategy which we use for custom
member function pushforwards because member function pushforwards
always have parameters of the class type which are used for identifying
the class.

We also cannot simply ask users to define the pushforwards inside the
declaration context of the class because it may not always be feasible
to modify the source code of external libraries.

--------------------------------------------

Fixes #965
  • Loading branch information
parth-07 authored and vgvassilev committed Jul 2, 2024
1 parent 22b2590 commit 645d2b6
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 14 deletions.
11 changes: 11 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "Compatibility.h"
#include "VisitorBase.h"
#include "clang/AST/ExprCXX.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"
Expand Down Expand Up @@ -140,6 +141,16 @@ class BaseForwardModeVisitor
/// \return active switch case label after processing `stmt`
clang::SwitchCase* DeriveSwitchStmtBodyHelper(const clang::Stmt* stmt,
clang::SwitchCase* activeSC);

/// Tries to build custom derivative constructor pushforward call for the
/// given CXXConstructExpr.
///
/// \return A call expression if a suitable custom derivative is found;
/// Otherwise returns nullptr.
clang::Expr* BuildCustomDerivativeConstructorPFCall(
const clang::CXXConstructExpr* CE,
llvm::SmallVectorImpl<clang::Expr*>& clonedArgs,
llvm::SmallVectorImpl<clang::Expr*>& derivedArgs);
};
} // end namespace clad

Expand Down
20 changes: 20 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@ template <typename T, typename U> struct ValueAndPushforward {
T value;
U pushforward;
};

/// It is used to identify constructor custom pushforwards. For
/// constructor custom pushforward functions, we cannot use the same
/// strategy which we use for custom pushforward for member
/// functions. Member functions custom pushforward have the following
/// signature:
///
/// mem_fn_pushforward(ClassName *c, ..., ClassName *d_c, ...)
///
/// We use the first argument 'ClassName *c' to determine the class of member
/// function for which the pushforward is defined.
///
/// In the case of constructor pushforward, there are no objects of the class
/// type passed to the constructor. Therefore, we cannot simply use arguments
/// to determine the class. To solve this, 'ConstructorPushforwardTag<T>' is
/// used. A custom_derivative pushforward for constructor is required to have
/// 'ConstructorPushforwardTag<T>' as the first argument, where 'T' is the
/// class for which constructor pushforward is defined.
template <class T> class ConstructorPushforwardTag {};

namespace custom_derivatives {
#ifdef __CUDACC__
template <typename T>
Expand Down
87 changes: 86 additions & 1 deletion include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#ifndef CLAD_STL_BUILTINS_H
#define CLAD_STL_BUILTINS_H

#include <clad/Differentiator/BuiltinDerivatives.h>
#include <initializer_list>
#include <vector>
#include "clad/Differentiator/BuiltinDerivatives.h"

namespace clad {
namespace custom_derivatives {
Expand Down Expand Up @@ -43,6 +44,90 @@ end_pushforward(const ::std::initializer_list<T>* il,
const ::std::initializer_list<T>* d_il) {
return {il->end(), d_il->end()};
}

template <typename T>
clad::ValueAndPushforward<::std::vector<T>, ::std::vector<T>>
constructor_pushforward(
ConstructorPushforwardTag<::std::vector<T>>,
typename ::std::vector<T>::size_type n,
const typename ::std::vector<T>::allocator_type& alloc,
typename ::std::vector<T>::size_type d_n,
const typename ::std::vector<T>::allocator_type& d_alloc) {
::std::vector<T> v(n, alloc);
::std::vector<T> d_v(n, 0, alloc);
return {v, d_v};
}

template <typename T>
clad::ValueAndPushforward<::std::vector<T>, ::std::vector<T>>
constructor_pushforward(
ConstructorPushforwardTag<::std::vector<T>>,
typename ::std::vector<T>::size_type n, T val,
const typename ::std::vector<T>::allocator_type& alloc,
typename ::std::vector<T>::size_type d_n, T d_val,
const typename ::std::vector<T>::allocator_type& d_alloc) {
::std::vector<T> v(n, val, alloc);
::std::vector<T> d_v(n, d_val, alloc);
return {v, d_v};
}

template <typename T>
clad::ValueAndPushforward<::std::vector<T>, ::std::vector<T>>
constructor_pushforward(
ConstructorPushforwardTag<::std::vector<T>>,
::std::initializer_list<T> list,
const typename ::std::vector<T>::allocator_type& alloc,
::std::initializer_list<T> dlist,
const typename ::std::vector<T>::allocator_type& dalloc) {
::std::vector<T> v(list, alloc);
::std::vector<T> d_v(dlist, dalloc);
return {v, d_v};
}

template <typename T>
clad::ValueAndPushforward<::std::vector<T>, ::std::vector<T>>
constructor_pushforward(ConstructorPushforwardTag<::std::vector<T>>,
typename ::std::vector<T>::size_type n,
typename ::std::vector<T>::size_type d_n) {
::std::vector<T> v(n);
::std::vector<T> d_v(n, 0);
return {v, d_v};
}

template <typename T>
clad::ValueAndPushforward<::std::vector<T>, ::std::vector<T>>
constructor_pushforward(ConstructorPushforwardTag<::std::vector<T>>,
typename ::std::vector<T>::size_type n, T val,
typename ::std::vector<T>::size_type d_n, T d_val) {
::std::vector<T> v(n, val);
::std::vector<T> d_v(n, d_val);
return {v, d_v};
}

template <typename T>
clad::ValueAndPushforward<::std::vector<T>, ::std::vector<T>>
constructor_pushforward(ConstructorPushforwardTag<::std::vector<T>>,
::std::initializer_list<T> list,
::std::initializer_list<T> dlist) {
::std::vector<T> v(list);
::std::vector<T> d_v(dlist);
return {v, d_v};
}

template <typename T>
ValueAndPushforward<T&, T&>
operator_subscript_pushforward(::std::vector<T>* v, unsigned idx,
::std::vector<T>* d_v, unsigned d_idx) {
return {(*v)[idx], (*d_v)[idx]};
}

template <typename T>
ValueAndPushforward<const T&, const T&>
operator_subscript_pushforward(const ::std::vector<T>* v, unsigned idx,
const ::std::vector<T>* d_v, unsigned d_idx) {
return {(*v)[idx], (*d_v)[idx]};
}

} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad
Expand Down
11 changes: 10 additions & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ namespace clad {
#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/ParsedAttr.h"
#include "clang/Sema/Sema.h"

#include <array>
#include <stack>
#include <unordered_map>
Expand Down Expand Up @@ -600,6 +599,13 @@ namespace clad {
bool isDerived);

clang::QualType DetermineCladArrayValueType(clang::QualType T);

/// Returns clad::Identify template declaration.
clang::TemplateDecl* GetCladConstructorPushforwardTag();

/// Returns type clad::Identify<T>
clang::QualType GetCladConstructorPushforwardTagOfType(clang::QualType T);

public:
/// Rebuild a sequence of nested namespaces ending with DC.
clang::NamespaceDecl* RebuildEnclosingNamespaces(clang::DeclContext* DC);
Expand Down Expand Up @@ -644,6 +650,9 @@ namespace clad {
void ComputeEffectiveDOperands(StmtDiff& LDiff, StmtDiff& RDiff,
clang::Expr*& derivedL,
clang::Expr*& derivedR);

private:
clang::TemplateDecl* m_CladConstructorPushforwardTag = nullptr;
};
} // end namespace clad

Expand Down
67 changes: 57 additions & 10 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,26 @@ BaseForwardModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) {
clonedArgs.push_back(argDiff.getExpr());
derivedArgs.push_back(argDiff.getExpr_dx());
}

Expr* pushforwardCall =
BuildCustomDerivativeConstructorPFCall(CE, clonedArgs, derivedArgs);
if (pushforwardCall) {
auto valueAndPushforwardE = StoreAndRef(pushforwardCall);
Expr* valueE = utils::BuildMemberExpr(m_Sema, getCurrentScope(),
valueAndPushforwardE, "value");
Expr* pushforwardE = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), valueAndPushforwardE, "pushforward");
return StmtDiff(valueE, pushforwardE);
}

// Custom derivative not found. Create simple constructor calls based on the
// given arguments. For example, if the primal constructor call is
// 'C(a, b, c)' then we use the constructor call 'C(d_a, d_b, d_c)' for the
// derivative.
// FIXME: This is incorrect. It only works for very simple types such as
// std::complex. We should ideally treat a constructor like a function and
// thus differentiate its body, create a pushforward and use the pushforward
// in the derivative code instead of the original constructor.
Expr* clonedArgsE = nullptr;
Expr* derivedArgsE = nullptr;
// FIXME: Currently if the original initialisation expression is `{a, 1,
Expand All @@ -1996,17 +2016,14 @@ BaseForwardModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) {
if (CE->isListInitialization()) {
clonedArgsE = m_Sema.ActOnInitList(noLoc, clonedArgs, noLoc).get();
derivedArgsE = m_Sema.ActOnInitList(noLoc, derivedArgs, noLoc).get();
} else if (CE->getNumArgs() == 0) {
// ParenList is empty -- default initialisation.
// Passing empty parenList here will silently cause 'most vexing
// parse' issue.
return StmtDiff();
} else {
if (CE->getNumArgs() == 0) {
// ParenList is empty -- default initialisation.
// Passing empty parenList here will silently cause 'most vexing
// parse' issue.
return StmtDiff();
} else {
clonedArgsE = m_Sema.ActOnParenListExpr(noLoc, noLoc, clonedArgs).get();
derivedArgsE =
m_Sema.ActOnParenListExpr(noLoc, noLoc, derivedArgs).get();
}
clonedArgsE = m_Sema.ActOnParenListExpr(noLoc, noLoc, clonedArgs).get();
derivedArgsE = m_Sema.ActOnParenListExpr(noLoc, noLoc, derivedArgs).get();
}
} else {
clonedArgsE = clonedArgs[0];
Expand Down Expand Up @@ -2045,6 +2062,7 @@ StmtDiff BaseForwardModeVisitor::VisitCXXTemporaryObjectExpr(
clonedArgs.push_back(argDiff.getExpr());
derivedArgs.push_back(argDiff.getExpr_dx());
}

Expr* clonedTOE =
m_Sema
.ActOnCXXTypeConstructExpr(OpaquePtr<QualType>::make(TOE->getType()),
Expand Down Expand Up @@ -2183,4 +2201,33 @@ StmtDiff BaseForwardModeVisitor::VisitCXXStdInitializerListExpr(
const clang::CXXStdInitializerListExpr* ILE) {
return Visit(ILE->getSubExpr());
}

clang::Expr* BaseForwardModeVisitor::BuildCustomDerivativeConstructorPFCall(
const clang::CXXConstructExpr* CE,
llvm::SmallVectorImpl<clang::Expr*>& clonedArgs,
llvm::SmallVectorImpl<clang::Expr*>& derivedArgs) {
llvm::SmallVector<Expr*, 4> customPushforwardArgs;
QualType constructorPushforwardTagT = GetCladConstructorPushforwardTagOfType(
CE->getType().withoutLocalFastQualifiers());
// Builds clad::ConstructorPushforwardTag<T> declaration
Expr* constructorPushforwardTagArg =
m_Sema
.BuildCXXTypeConstructExpr(
m_Context.getTrivialTypeSourceInfo(constructorPushforwardTagT,
utils::GetValidSLoc(m_Sema)),
noLoc, MultiExprArg{}, noLoc, /*ListInitialization=*/false)
.get();
customPushforwardArgs.push_back(constructorPushforwardTagArg);
customPushforwardArgs.append(clonedArgs.begin(), clonedArgs.end());
customPushforwardArgs.append(derivedArgs.begin(), derivedArgs.end());
std::string customPushforwardName =
clad::utils::ComputeEffectiveFnName(CE->getConstructor()) +
GetPushForwardFunctionSuffix();
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
Expr* pushforwardCall = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforwardName, customPushforwardArgs, getCurrentScope(),
const_cast<DeclContext*>(CE->getConstructor()->getDeclContext()));
return pushforwardCall;
}
} // end namespace clad
3 changes: 3 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "clang/Basic/SourceLocation.h"
#include "clang/Sema/Lookup.h"
#include "llvm/ADT/SmallVector.h"
#include <clang/AST/DeclCXX.h>
#include "clad/Differentiator/Compatibility.h"

using namespace clang;
Expand Down Expand Up @@ -98,6 +99,8 @@ namespace clad {
case OverloadedOperatorKind::OO_Subscript:
return "operator_subscript";
default:
if (isa<CXXConstructorDecl>(FD))
return "constructor";
return FD->getNameAsString();
}
}
Expand Down
12 changes: 12 additions & 0 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,4 +811,16 @@ namespace clad {
return m_Sema.ActOnCallExpr(getCurrentScope(), pushDRE, noLoc, args, noLoc)
.get();
}

clang::TemplateDecl* VisitorBase::GetCladConstructorPushforwardTag() {
if (!m_CladConstructorPushforwardTag)
m_CladConstructorPushforwardTag =
LookupTemplateDeclInCladNamespace("ConstructorPushforwardTag");
return m_CladConstructorPushforwardTag;
}

clang::QualType
VisitorBase::GetCladConstructorPushforwardTagOfType(clang::QualType T) {
return InstantiateTemplate(GetCladConstructorPushforwardTag(), {T});
}
} // end namespace clad
Loading

0 comments on commit 645d2b6

Please sign in to comment.