Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add primitive support for custom constructor pushforward functions
This commit adds primitive support for custom pushforward functions for constructors. Custom constructor pushforward function support will enable the below features: - Class differentiation support for classes whose constructor Clad cannot automatically differentiate. Now, We can enable differentiation of entire C++ standard library by providing custom derivatives. - Remove the restriction of default-constructible for class types. This was a troublesome restriction. Now, the only restriction for class types is to have a sensible copy-constructor. That is, copy constructor should copy the class members and after copy-construction, both the objects should be equivalent, mathematically speaking. Constructor pushforward functions differ from ordinary pushforward functions in two important ways: - Constructor pushforward functions initialize the primal class object and the corresponding derivative object. Ordinary member function pushforwards takes an already-existing primal class object and the corresponding derivative object as inputs. - Constructor pushforward functions return a value even though constructor do not return anything. Constructor pushforward functions return initialized primal object and the derivative object. These are then used to initialize primal object and the derivative in the derivative function code. How to write custom constructor pushforward functions ---------------------------------------- Let's see how to write custom pushforward function for a constructor: - Custom constructor pushforwards must have the name `constructor_pushforward` - Custom constructor pushforwards must be defined in `::clad::custom_derivatives::class_functions` namespace. - The parameters of the custom constructor pushforward must be: {`::clad::ConstructorPushforwardTag<Class>`, original params..., derivative params...}. 'original parameters...' and 'derivative parameters...' is same as what we have for other pushforward functions. We will soon see why do we need `::clad::ConstructorPushforwardTag<T>` for constructor custom pushforwards. Let's see a basic example of how to write custom constructor pushforward. ```cpp class Coordinates { Coordinates(double px, double py, double pz) : x(px), y(py), z(pz) {} public: double x, y, z; } namespace clad { namespace custom_derivatives { namespace class_functions { // custom constructor pushforward function clad::ValueAndPushforward<Coordinates, Coordinates> constructor_pushforward(clad::ConstructorPushforwardTag<Coordinates>, double x, double y, double z, double d_x, double d_y, double d_z) { return {Coordinates(x, y, z), Coordinates(d_x, d_y, d_z) }; } } // namespace class_functions } // namespace custom_derivatives } // namespace clad // custom constructor pushforward is used as follows: // primal code Constructor c(u, v, w); // derivative code clad::ValueAndPushforward<Coordinates, Coordinates> _t0 = constructor_pushforward(clad::ConstructorPushforwardTag<Coordinates>, u, v, w, _d_u, _d_v, _d_w); Coordinates _d_c = _t0.pushforward; Coordinates c = _t0.value; ``` Now, let's see a bit advanced example based on `std::vector` constructor. ```cpp namespace clad { namespace custom_derivatives { namespace class_functions { // Custom pushforward for: vector(size_t n, const typename ::std::vector<T>::allocator_type alloc) template <typename T> clad::ValueAndPushforward<::std::vector<T>, ::std::vector<T>> constructor_pushforward( ConstructorPushforwardTag<::std::vector<T>>, size_t n, const typename ::std::vector<T>::allocator_type alloc, size_t d_n, const typename ::std::vector<T>::allocator_type d_alloc) { ::std::vector<T> v(n, alloc); ::std::vector<T> d_v(n, 0, alloc); return {v, d_v}; } // Custom pushfoward for: vector(size_t n, T val, const typename ::std::vector<T>::allocator_type alloc) template <typename T> clad::ValueAndPushforward<::std::vector<T>, ::std::vector<T>> constructor_pushforward( ConstructorPushforwardTag<::std::vector<T>>, size_t n, T val, const typename ::std::vector<T>::allocator_type alloc, size_t d_n, T d_val, const typename ::std::vector<T>::allocator_type d_alloc) { ::std::vector<T> v(n, val, alloc); ::std::vector<T> d_v(n, d_val, alloc); return {v, d_v}; } } // namespace class_functions } // namespace custom_derivatives } // namespace clad // The custom constructor pushforwards is used as follows: // Primal code: std::vector<double> v(10, u); // Derivative code: clad::ValueAndPushforward<std::vector<double>, std::vector<double>> _t0 = clad::custom_derivatives::class_functions::constructor_pushforward( clad::ConstructorPushforwardTag<std::vector<double> >(), 10, u, allocator_type(), 0, _d_u, allocator_type()); std::vector<double> d_v = _t0.pushforward; std::vector<double> v = _t0.value; ``` Why `clad::ConstructorPushforwardTag<T>`? ------------------------ So, why do we need clad::ConstructorPushforwardTag<T>? For a constructor that takes two parameters of types `size_t` and `double`, the custom pushforward will have the following signature if we do not include `::clad::ConstructorPushforwardTag<T>`: ```cpp clad::ValueAndPushforward<Class, Class> constructor_pushforward(size_t n, double val, size_t d_n, double d_val); ``` Now, the question is: How to distinguish custom constructor pushforwards for different classes? ```cpp MyClassA a(3, 5.0); MyClassB b(7, 9.0); ``` There is no way for overload resolution selector to distinguish constructor_pushforward for classes `MyClassA` and `MyClassB`. `clad::ConstructorPushforwardTag<T>` is used to identify the class for which custom constructor pushforward is defined. Please note that we cannot use the same strategy which we use for custom member function pushforwards because member function pushforwards always have parameters of the class type which are used for identifying the class. We also cannot simply ask users to define the pushforwards inside the declaration context of the class because it may not always be feasible to modify the source code of external libraries. -------------------------------------------- Fixes #965
- Loading branch information