Skip to content

Commit

Permalink
(wip) CladFunction constexpr
Browse files Browse the repository at this point in the history
  • Loading branch information
MihailMihov committed Aug 6, 2024
1 parent 8a9c9e2 commit ecb8959
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 85 deletions.
157 changes: 100 additions & 57 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
template <bool EnablePadding, class... Rest, class F, class... Args,
class... fArgTypes,
typename std::enable_if<EnablePadding, bool>::type = true>
CUDA_HOST_DEVICE return_type_t<F>
constexpr CUDA_HOST_DEVICE return_type_t<F>
execute_with_default_args(list<Rest...>, F f, list<fArgTypes...>,
Args&&... args) {
return f(static_cast<Args>(args)..., static_cast<Rest>(nullptr)...);
Expand All @@ -126,7 +126,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
template <bool EnablePadding, class... Rest, class F, class... Args,
class... fArgTypes,
typename std::enable_if<!EnablePadding, bool>::type = true>
return_type_t<F> execute_with_default_args(list<Rest...>, F f,
constexpr return_type_t<F> execute_with_default_args(list<Rest...>, F f,
list<fArgTypes...>,
Args&&... args) {
return f(static_cast<Args>(args)...);
Expand All @@ -136,7 +136,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
template <bool EnablePadding, class... Rest, class ReturnType, class C,
class Obj, class... Args, class... fArgTypes,
typename std::enable_if<EnablePadding, bool>::type = true>
CUDA_HOST_DEVICE auto
constexpr CUDA_HOST_DEVICE auto
execute_with_default_args(list<Rest...>, ReturnType C::*f, Obj&& obj,
list<fArgTypes...>, Args&&... args)
-> return_type_t<decltype(f)> {
Expand All @@ -147,7 +147,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
template <bool EnablePadding, class... Rest, class ReturnType, class C,
class Obj, class... Args, class... fArgTypes,
typename std::enable_if<!EnablePadding, bool>::type = true>
auto execute_with_default_args(list<Rest...>, ReturnType C::*f, Obj&& obj,
constexpr auto execute_with_default_args(list<Rest...>, ReturnType C::*f, Obj&& obj,
list<fArgTypes...>, Args&&... args)
-> return_type_t<decltype(f)> {
return (static_cast<Obj>(obj).*f)(static_cast<Args>(args)...);
Expand All @@ -169,53 +169,60 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {

private:
CladFunctionType m_Function;
char* m_Code;
char const* m_Code;
FunctorType *m_Functor = nullptr;

public:
CUDA_HOST_DEVICE CladFunction(CladFunctionType f,
constexpr CUDA_HOST_DEVICE CladFunction(CladFunctionType f,
const char* code,
FunctorType* functor = nullptr)
: m_Functor(functor) {
assert(f && "Must pass a non-0 argument.");
if (size_t length = GetLength(code)) {
m_Function = f;
char* temp = (char*)malloc(length + 1);
m_Code = temp;
while ((*temp++ = *code++));
} else {
// clad did not place the derivative in this object. This can happen
// upon error of if clad was disabled. Diagnose.
printf("clad failed to place the generated derivative in the object\n");
printf("Make sure calls to clad are within a #pragma clad ON region\n");

// Invalidate the placeholders.
m_Function = nullptr;
m_Code = nullptr;
}
: m_Function(f), m_Code(""), m_Functor(functor) {
if !consteval {
assert(f && "Must pass a non-0 argument.");
if (size_t length = GetLength(code)) {
char* temp = (char*)malloc(length + 1);
m_Code = temp;
while ((*temp++ = *code++));
} else {
// clad did not place the derivative in this object. This can happen
// upon error of if clad was disabled. Diagnose.
printf("clad failed to place the generated derivative in the object\n");
printf("Make sure calls to clad are within a #pragma clad ON region\n");

// Invalidate the placeholders.
m_Function = nullptr;
m_Code = nullptr;
}
}
}
/// Constructor overload for initializing `m_Functor` when functor
/// is passed by reference.
CUDA_HOST_DEVICE
CladFunction(CladFunctionType f, const char* code, FunctorType& functor)
constexpr CUDA_HOST_DEVICE CladFunction(CladFunctionType f,
const char* code,
FunctorType& functor)
: CladFunction(f, code, &functor) {};

// Intentionally leak m_Code, otherwise we have to link against c++ runtime,
// i.e -lstdc++.
//~CladFunction() { /*free(m_Code);*/ }

CladFunctionType getFunctionPtr() { return m_Function; }
constexpr CladFunctionType getFunctionPtr() { return m_Function; }

template <typename... Args, class FnType = CladFunctionType>
typename std::enable_if<!std::is_same<FnType, NoFunction*>::value,
return_type_t<F>>::type
execute(Args&&... args) CUDA_HOST_DEVICE {
if (!m_Function) {
printf("CladFunction is invalid\n");
return static_cast<return_type_t<F>>(return_type_t<F>());
}
// here static_cast is used to achieve perfect forwarding
return execute_helper(m_Function, static_cast<Args>(args)...);
constexpr execute(Args&&... args) CUDA_HOST_DEVICE {
if consteval {
if (m_Function)
return execute_helper(m_Function, static_cast<Args>(args)...);
} else {
if (!m_Function) {
printf("CladFunction is invalid\n");
return static_cast<return_type_t<F>>(return_type_t<F>());
}
// here static_cast is used to achieve perfect forwarding
return execute_helper(m_Function, static_cast<Args>(args)...);
}
}

/// `Execute` overload to be used when derived function type cannot be
Expand All @@ -226,7 +233,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
template <typename... Args, class FnType = CladFunctionType>
typename std::enable_if<std::is_same<FnType, NoFunction*>::value,
return_type_t<F>>::type
execute(Args&&... args) CUDA_HOST_DEVICE {
constexpr execute(Args&&... args) CUDA_HOST_DEVICE {
return static_cast<return_type_t<F>>(0);
}

Expand Down Expand Up @@ -261,8 +268,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
private:
/// Helper function for executing non-member derived functions.
template <class Fn, class... Args>
CUDA_HOST_DEVICE return_type_t<CladFunctionType>
execute_helper(Fn f, Args&&... args) {
constexpr CUDA_HOST_DEVICE
return_type_t<CladFunctionType> execute_helper(Fn f, Args&&... args) {
// `static_cast` is required here for perfect forwarding.
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), F>{}, f,
Expand All @@ -280,6 +287,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
class = typename std::enable_if<
std::is_same<typename std::decay<Obj>::type, C>::value>::type,
class... Args>
constexpr
return_type_t<CladFunctionType>
execute_helper(ReturnType C::*f, Obj&& obj, Args&&... args) {
// `static_cast` is required here for perfect forwarding.
Expand All @@ -292,11 +300,14 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// will be used and derived function will be called through the object
/// saved in `CladFunction`.
template <class ReturnType, class C, class... Args>
constexpr
return_type_t<CladFunctionType> execute_helper(ReturnType C::*f,
Args&&... args) {
assert(m_Functor &&
"No default object set, explicitly pass an object to "
"CladFunction::execute");
if !consteval {
assert(m_Functor &&
"No default object set, explicitly pass an object to "
"CladFunction::execute");
}
// `static_cast` is required here for perfect forwarding.
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), decltype(f)>{}, f, *m_Functor,
Expand All @@ -305,6 +316,19 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
}
};

template<typename RetTy, typename... ArgTys>
constexpr auto create_lambda_with_args(list<ArgTys...>) {
return [](ArgTys...) -> RetTy { return RetTy{}; };
}

template<typename DerivedFnType>
constexpr auto create_default_derived_lambda() {
using RetTy = typename function_traits<DerivedFnType>::return_type;
using ArgTys = typename function_traits<DerivedFnType>::argument_types;
auto lambda = create_lambda_with_args<RetTy>(ArgTys{});
return lambda;
}

// This is the function which will be instantiated with the concrete arguments
// After that our AD library will have all the needed information. For eg:
// which is the differentiated function, which is the argument with respect
Expand All @@ -331,12 +355,21 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("D")))
differentiate(F fn, ArgSpec args = "",
constexpr differentiate(F fn, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(fn && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code);
if consteval {
// If we are in a consteval context, clad hasn't yet replaced the nullptr above with the actual derivative, so we add a default to prevent a compiler error.
if(!derivedFn)
derivedFn = create_default_derived_lambda<DerivedFnType>();
if(fn && derivedFn)
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code);
} else {
assert(fn && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code);
}
}

/// Specialization for differentiating functors.
Expand All @@ -349,7 +382,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
!clad::HasOption(GetBitmaskedOpts(BitMaskedOpts...),
opts::vector_mode) &&
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
constexpr CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("D")))
differentiate(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
Expand All @@ -372,12 +405,14 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
clad::HasOption(GetBitmaskedOpts(BitMaskedOpts...),
opts::vector_mode) &&
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
constexpr CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
annotate("D")))
differentiate(F fn, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(fn && "Must pass in a non-0 argument");
if !consteval {
assert(fn && "Must pass in a non-0 argument");
}
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn, code);
}
Expand All @@ -393,12 +428,14 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
typename F, typename DerivedFnType = GradientDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
constexpr CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
annotate("G"))) CUDA_HOST_DEVICE
gradient(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
if !consteval {
assert(f && "Must pass in a non-0 argument");
}
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code);
}
Expand All @@ -410,7 +447,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
typename F, typename DerivedFnType = GradientDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
constexpr CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
annotate("G"))) CUDA_HOST_DEVICE
gradient(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
Expand All @@ -430,12 +467,14 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
typename F, typename DerivedFnType = HessianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
constexpr CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("H")))
hessian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
if !consteval {
assert(f && "Must pass in a non-0 argument");
}
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by hessian*/, code);
}
Expand All @@ -447,7 +486,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
typename F, typename DerivedFnType = HessianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
constexpr CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("H")))
hessian(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
Expand All @@ -467,12 +506,14 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
typename F, typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
!std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
constexpr CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("J")))
jacobian(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
if !consteval {
assert(f && "Must pass in a non-0 argument");
}
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(
derivedFn /* will be replaced by Jacobian*/, code);
}
Expand All @@ -484,7 +525,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
typename F, typename DerivedFnType = JacobianDerivedFnTraits_t<F>,
typename = typename std::enable_if<
std::is_class<remove_reference_and_pointer_t<F>>::value>::type>
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
constexpr CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>> __attribute__((
annotate("J")))
jacobian(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
Expand All @@ -495,11 +536,13 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {

template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = GradientDerivedEstFnTraits_t<F>>
CladFunction<DerivedFnType> __attribute__((annotate("E")))
constexpr CladFunction<DerivedFnType> __attribute__((annotate("E")))
estimate_error(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
if !consteval {
assert(f && "Must pass in a non-0 argument");
}
return CladFunction<
DerivedFnType>(derivedFn /* will be replaced by estimation code*/,
code);
Expand Down
Loading

0 comments on commit ecb8959

Please sign in to comment.