From ecb895915c3dd6770b79e9c1b04a3fe6166fbe2a Mon Sep 17 00:00:00 2001 From: Mihail Mihov Date: Wed, 31 Jul 2024 17:38:01 +0300 Subject: [PATCH] (wip) CladFunction `constexpr` --- include/clad/Differentiator/Differentiator.h | 157 ++++++++++++------- include/clad/Differentiator/FunctionTraits.h | 46 +++--- 2 files changed, 118 insertions(+), 85 deletions(-) diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 4a089c095..9fbee4b35 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -117,7 +117,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { template ::type = true> - CUDA_HOST_DEVICE return_type_t + constexpr CUDA_HOST_DEVICE return_type_t execute_with_default_args(list, F f, list, Args&&... args) { return f(static_cast(args)..., static_cast(nullptr)...); @@ -126,7 +126,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { template ::type = true> - return_type_t execute_with_default_args(list, F f, + constexpr return_type_t execute_with_default_args(list, F f, list, Args&&... args) { return f(static_cast(args)...); @@ -136,7 +136,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { template ::type = true> - CUDA_HOST_DEVICE auto + constexpr CUDA_HOST_DEVICE auto execute_with_default_args(list, ReturnType C::*f, Obj&& obj, list, Args&&... args) -> return_type_t { @@ -147,7 +147,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { template ::type = true> - auto execute_with_default_args(list, ReturnType C::*f, Obj&& obj, + constexpr auto execute_with_default_args(list, ReturnType C::*f, Obj&& obj, list, Args&&... args) -> return_type_t { return (static_cast(obj).*f)(static_cast(args)...); @@ -169,53 +169,60 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { private: CladFunctionType m_Function; - char* m_Code; + char const* m_Code; FunctorType *m_Functor = nullptr; public: - CUDA_HOST_DEVICE CladFunction(CladFunctionType f, + constexpr CUDA_HOST_DEVICE CladFunction(CladFunctionType f, const char* code, FunctorType* functor = nullptr) - : m_Functor(functor) { - assert(f && "Must pass a non-0 argument."); - if (size_t length = GetLength(code)) { - m_Function = f; - char* temp = (char*)malloc(length + 1); - m_Code = temp; - while ((*temp++ = *code++)); - } else { - // clad did not place the derivative in this object. This can happen - // upon error of if clad was disabled. Diagnose. - printf("clad failed to place the generated derivative in the object\n"); - printf("Make sure calls to clad are within a #pragma clad ON region\n"); - - // Invalidate the placeholders. - m_Function = nullptr; - m_Code = nullptr; - } + : m_Function(f), m_Code(""), m_Functor(functor) { + if !consteval { + assert(f && "Must pass a non-0 argument."); + if (size_t length = GetLength(code)) { + char* temp = (char*)malloc(length + 1); + m_Code = temp; + while ((*temp++ = *code++)); + } else { + // clad did not place the derivative in this object. This can happen + // upon error of if clad was disabled. Diagnose. + printf("clad failed to place the generated derivative in the object\n"); + printf("Make sure calls to clad are within a #pragma clad ON region\n"); + + // Invalidate the placeholders. + m_Function = nullptr; + m_Code = nullptr; + } + } } /// Constructor overload for initializing `m_Functor` when functor /// is passed by reference. - CUDA_HOST_DEVICE - CladFunction(CladFunctionType f, const char* code, FunctorType& functor) + constexpr CUDA_HOST_DEVICE CladFunction(CladFunctionType f, + const char* code, + FunctorType& functor) : CladFunction(f, code, &functor) {}; // Intentionally leak m_Code, otherwise we have to link against c++ runtime, // i.e -lstdc++. //~CladFunction() { /*free(m_Code);*/ } - CladFunctionType getFunctionPtr() { return m_Function; } + constexpr CladFunctionType getFunctionPtr() { return m_Function; } template typename std::enable_if::value, return_type_t>::type - execute(Args&&... args) CUDA_HOST_DEVICE { - if (!m_Function) { - printf("CladFunction is invalid\n"); - return static_cast>(return_type_t()); - } - // here static_cast is used to achieve perfect forwarding - return execute_helper(m_Function, static_cast(args)...); + constexpr execute(Args&&... args) CUDA_HOST_DEVICE { + if consteval { + if (m_Function) + return execute_helper(m_Function, static_cast(args)...); + } else { + if (!m_Function) { + printf("CladFunction is invalid\n"); + return static_cast>(return_type_t()); + } + // here static_cast is used to achieve perfect forwarding + return execute_helper(m_Function, static_cast(args)...); + } } /// `Execute` overload to be used when derived function type cannot be @@ -226,7 +233,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { template typename std::enable_if::value, return_type_t>::type - execute(Args&&... args) CUDA_HOST_DEVICE { + constexpr execute(Args&&... args) CUDA_HOST_DEVICE { return static_cast>(0); } @@ -261,8 +268,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { private: /// Helper function for executing non-member derived functions. template - CUDA_HOST_DEVICE return_type_t - execute_helper(Fn f, Args&&... args) { + constexpr CUDA_HOST_DEVICE + return_type_t execute_helper(Fn f, Args&&... args) { // `static_cast` is required here for perfect forwarding. return execute_with_default_args( DropArgs_t{}, f, @@ -280,6 +287,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { class = typename std::enable_if< std::is_same::type, C>::value>::type, class... Args> + constexpr return_type_t execute_helper(ReturnType C::*f, Obj&& obj, Args&&... args) { // `static_cast` is required here for perfect forwarding. @@ -292,11 +300,14 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { /// will be used and derived function will be called through the object /// saved in `CladFunction`. template + constexpr return_type_t execute_helper(ReturnType C::*f, Args&&... args) { - assert(m_Functor && - "No default object set, explicitly pass an object to " - "CladFunction::execute"); + if !consteval { + assert(m_Functor && + "No default object set, explicitly pass an object to " + "CladFunction::execute"); + } // `static_cast` is required here for perfect forwarding. return execute_with_default_args( DropArgs_t{}, f, *m_Functor, @@ -305,6 +316,19 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { } }; + template + constexpr auto create_lambda_with_args(list) { + return [](ArgTys...) -> RetTy { return RetTy{}; }; + } + + template + constexpr auto create_default_derived_lambda() { + using RetTy = typename function_traits::return_type; + using ArgTys = typename function_traits::argument_types; + auto lambda = create_lambda_with_args(ArgTys{}); + return lambda; + } + // This is the function which will be instantiated with the concrete arguments // After that our AD library will have all the needed information. For eg: // which is the differentiated function, which is the argument with respect @@ -331,12 +355,21 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { !std::is_class>::value>::type> CladFunction> __attribute__(( annotate("D"))) - differentiate(F fn, ArgSpec args = "", + constexpr differentiate(F fn, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(fn && "Must pass in a non-0 argument"); - return CladFunction>(derivedFn, - code); + if consteval { + // If we are in a consteval context, clad hasn't yet replaced the nullptr above with the actual derivative, so we add a default to prevent a compiler error. + if(!derivedFn) + derivedFn = create_default_derived_lambda(); + if(fn && derivedFn) + return CladFunction>(derivedFn, + code); + } else { + assert(fn && "Must pass in a non-0 argument"); + return CladFunction>(derivedFn, + code); + } } /// Specialization for differentiating functors. @@ -349,7 +382,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { !clad::HasOption(GetBitmaskedOpts(BitMaskedOpts...), opts::vector_mode) && std::is_class>::value>::type> - CladFunction> __attribute__(( + constexpr CladFunction> __attribute__(( annotate("D"))) differentiate(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), @@ -372,12 +405,14 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { clad::HasOption(GetBitmaskedOpts(BitMaskedOpts...), opts::vector_mode) && !std::is_class>::value>::type> - CladFunction, true> __attribute__(( + constexpr CladFunction, true> __attribute__(( annotate("D"))) differentiate(F fn, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(fn && "Must pass in a non-0 argument"); + if !consteval { + assert(fn && "Must pass in a non-0 argument"); + } return CladFunction, true>( derivedFn, code); } @@ -393,12 +428,14 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { typename F, typename DerivedFnType = GradientDerivedFnTraits_t, typename = typename std::enable_if< !std::is_class>::value>::type> - CladFunction, true> __attribute__(( + constexpr CladFunction, true> __attribute__(( annotate("G"))) CUDA_HOST_DEVICE gradient(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); + if !consteval { + assert(f && "Must pass in a non-0 argument"); + } return CladFunction, true>( derivedFn /* will be replaced by gradient*/, code); } @@ -410,7 +447,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { typename F, typename DerivedFnType = GradientDerivedFnTraits_t, typename = typename std::enable_if< std::is_class>::value>::type> - CladFunction, true> __attribute__(( + constexpr CladFunction, true> __attribute__(( annotate("G"))) CUDA_HOST_DEVICE gradient(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), @@ -430,12 +467,14 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { typename F, typename DerivedFnType = HessianDerivedFnTraits_t, typename = typename std::enable_if< !std::is_class>::value>::type> - CladFunction> __attribute__(( + constexpr CladFunction> __attribute__(( annotate("H"))) hessian(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); + if !consteval { + assert(f && "Must pass in a non-0 argument"); + } return CladFunction>( derivedFn /* will be replaced by hessian*/, code); } @@ -447,7 +486,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { typename F, typename DerivedFnType = HessianDerivedFnTraits_t, typename = typename std::enable_if< std::is_class>::value>::type> - CladFunction> __attribute__(( + constexpr CladFunction> __attribute__(( annotate("H"))) hessian(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), @@ -467,12 +506,14 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { typename F, typename DerivedFnType = JacobianDerivedFnTraits_t, typename = typename std::enable_if< !std::is_class>::value>::type> - CladFunction> __attribute__(( + constexpr CladFunction> __attribute__(( annotate("J"))) jacobian(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); + if !consteval { + assert(f && "Must pass in a non-0 argument"); + } return CladFunction>( derivedFn /* will be replaced by Jacobian*/, code); } @@ -484,7 +525,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { typename F, typename DerivedFnType = JacobianDerivedFnTraits_t, typename = typename std::enable_if< std::is_class>::value>::type> - CladFunction> __attribute__(( + constexpr CladFunction> __attribute__(( annotate("J"))) jacobian(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), @@ -495,11 +536,13 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { template > - CladFunction __attribute__((annotate("E"))) + constexpr CladFunction __attribute__((annotate("E"))) estimate_error(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); + if !consteval { + assert(f && "Must pass in a non-0 argument"); + } return CladFunction< DerivedFnType>(derivedFn /* will be replaced by estimation code*/, code); diff --git a/include/clad/Differentiator/FunctionTraits.h b/include/clad/Differentiator/FunctionTraits.h index c15eeb270..b63298ff6 100644 --- a/include/clad/Differentiator/FunctionTraits.h +++ b/include/clad/Differentiator/FunctionTraits.h @@ -13,8 +13,7 @@ namespace clad { /// then removes any associated pointer. Resulting type is provided /// as member typedef `type`. template struct remove_reference_and_pointer { - using type = typename std::remove_pointer< - typename std::remove_reference::type>::type; + using type = std::remove_pointer_t>; }; /// Helper type for remove_reference_and_pointer. @@ -31,8 +30,8 @@ namespace clad { template struct has_call_operator< C, - typename std::enable_if<( - sizeof(&remove_reference_and_pointer_t::operator()) > 0)>::type> + std::enable_if_t<( + sizeof(&remove_reference_and_pointer_t::operator()) > 0)>> : std::true_type {}; /// Placeholder type for denoting no function type exists @@ -742,7 +741,7 @@ namespace clad { // OutputVecParamType is used to deduce the type of derivative arguments // for vector forward mode. template struct OutputVecParamType { - using type = array_ref::type>; + using type = array_ref>; }; template @@ -762,19 +761,14 @@ namespace clad { /// Specialization for free function pointer type template - struct ExtractDerivedFnTraitsForwMode< - F*, - typename std::enable_if::value>::type> { + struct ExtractDerivedFnTraitsForwMode>> { using type = remove_reference_and_pointer_t*; }; /// Specialization for member function pointer type template - struct ExtractDerivedFnTraitsForwMode< - F, - typename std::enable_if< - std::is_member_function_pointer::value>::type> { - using type = typename std::decay::type; + struct ExtractDerivedFnTraitsForwMode>> { + using type = std::decay_t; }; /// Specialization for class types @@ -782,21 +776,17 @@ namespace clad { /// member typedef `type` same as the type of the call operator, otherwise /// defines member typedef `type` as the type of `NoFunction*`. template - struct ExtractDerivedFnTraitsForwMode< - F, - typename std::enable_if< - std::is_class>::value && - has_call_operator::value>::type> { - using ClassType = - typename std::decay>::type; + struct ExtractDerivedFnTraitsForwMode> && has_call_operator::value> + > { + using ClassType = std::decay_t>; using type = decltype(&ClassType::operator()); }; + template - struct ExtractDerivedFnTraitsForwMode< - F, - typename std::enable_if< - std::is_class>::value && - !has_call_operator::value>::type> { + struct ExtractDerivedFnTraitsForwMode> && !has_call_operator::value> + > { using type = NoFunction*; }; @@ -834,7 +824,7 @@ namespace clad { template struct ExtractFunctorTraits< F*, - typename std::enable_if::value>::type> { + std::enable_if_t>> { using type = NoObject; }; @@ -848,8 +838,8 @@ namespace clad { template struct ExtractFunctorTraits< F, - typename std::enable_if< - std::is_class>::value>::type> { + std::enable_if_t< + std::is_class_v>>> { using type = remove_reference_and_pointer_t; }; } // namespace clad