Skip to content

Commit

Permalink
Fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Sep 12, 2024
1 parent 37e9e47 commit 18ebe7a
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 51 deletions.
14 changes: 7 additions & 7 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,13 @@ struct DiffRequest {

// Define the hash function for DiffRequest.
template <> struct std::hash<clad::DiffRequest> {
std::size_t operator()(const clad::DiffRequest& DR) const {
// Use the function pointer as the hash of the DiffRequest, it
// is sufficient to break a reasonable number of collisions.
if (DR.Function->getPreviousDecl())
return std::hash<const void*>{}(DR.Function->getPreviousDecl());
return std::hash<const void*>{}(DR.Function);
}
std::size_t operator()(const clad::DiffRequest& DR) const {
// Use the function pointer as the hash of the DiffRequest, it
// is sufficient to break a reasonable number of collisions.
if (DR.Function->getPreviousDecl())
return std::hash<const void*>{}(DR.Function->getPreviousDecl());
return std::hash<const void*>{}(DR.Function);
}
};

#endif
9 changes: 5 additions & 4 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,11 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
return_type_t<CladFunctionType>
execute_helper(ReturnType C::*f, Obj&& obj, Args&&... args) {
// `static_cast` is required here for perfect forwarding.
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), decltype(f)>{}, f, static_cast<Obj>(obj),
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{},
static_cast<Args>(args)...);
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), decltype(f)>{}, f,
static_cast<Obj>(obj),
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{},
static_cast<Args>(args)...);
}
/// If user have not passed object explicitly, then this specialization
/// will be used and derived function will be called through the object
Expand Down
37 changes: 18 additions & 19 deletions include/clad/Differentiator/FunctionTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,85 +240,84 @@ namespace clad {
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args...) & noexcept> {
struct function_traits<ReturnType (C::*)(Args...)& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args..., ...) & noexcept> {
struct function_traits<ReturnType (C::*)(Args..., ...)& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args...) const & noexcept> {
struct function_traits<ReturnType (C::*)(Args...) const& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args..., ...) const & noexcept> {
struct function_traits<ReturnType (C::*)(Args..., ...) const& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args...) volatile & noexcept> {
struct function_traits<ReturnType (C::*)(Args...) volatile& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args..., ...) volatile & noexcept> {
struct function_traits<ReturnType (C::*)(Args..., ...) volatile& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args...) const volatile & noexcept> {
struct function_traits<ReturnType (C::*)(Args...) const volatile& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args..., ...) const volatile &
noexcept> {
struct function_traits<ReturnType (C::*)(Args..., ...)
const volatile& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args...) && noexcept> {
struct function_traits<ReturnType (C::*)(Args...)&& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args..., ...) && noexcept> {
struct function_traits<ReturnType (C::*)(Args..., ...)&& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args...) const && noexcept> {
struct function_traits<ReturnType (C::*)(Args...) const&& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args..., ...) const && noexcept> {
struct function_traits<ReturnType (C::*)(Args..., ...) const&& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args...) volatile && noexcept> {
struct function_traits<ReturnType (C::*)(Args...) volatile&& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args..., ...) volatile && noexcept> {
struct function_traits<ReturnType (C::*)(Args..., ...) volatile&& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args...) const volatile &&
noexcept> {
struct function_traits<ReturnType (C::*)(Args...) const volatile&& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
template <class ReturnType, class C, class... Args>
struct function_traits<ReturnType (C::*)(Args..., ...) const volatile &&
noexcept> {
struct function_traits<ReturnType (C::*)(Args..., ...)
const volatile&& noexcept> {
using return_type = ReturnType;
using argument_types = list<Args...>;
};
Expand Down
29 changes: 18 additions & 11 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 2> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
"_diff_" + str, policy,
[&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0);
});
}
Expand All @@ -213,7 +214,8 @@ struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 3> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
"_diff_" + str, policy,
[&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0);
});
}
Expand All @@ -235,7 +237,8 @@ struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 4> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
"_diff_" + str, policy,
[&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0, 0);
});
}
Expand All @@ -257,7 +260,8 @@ struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 5> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
"_diff_" + str, policy,
[&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0, 0, 0);
});
}
Expand All @@ -278,11 +282,12 @@ template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 6> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(
args..., &d_functor, 0, 0, 0, 0, 0, 0);
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0, 0, 0,
0);
});
}
};

Expand Down Expand Up @@ -321,7 +326,8 @@ struct diff_parallel_for_OP_call_dispatch<Policy, FunctorType, void> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
"_diff_" + str, policy,
[&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, {});
});
}
Expand All @@ -344,7 +350,8 @@ struct diff_parallel_for_int_call_dispatch<Policy, FunctorType, true> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](const int i) {
"_diff_" + str, policy,
[&functor, &d_functor](const int i) {
functor.operator_call_pushforward(i, &d_functor, 0);
});
}
Expand Down
9 changes: 4 additions & 5 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,11 @@ namespace clad {
i < e && paramIdx < FD->getNumParams(); ++i) {
const auto* param = DVI[i].param;
while (paramIdx < FD->getNumParams() &&
FD->getParamDecl(paramIdx) != param) {
++paramIdx;
}
FD->getParamDecl(paramIdx) != param)
++paramIdx;
if (paramIdx != FD->getNumParams())
// Update the parameter to point to the definition parameter.
DVI[i].param = Function->getParamDecl(paramIdx);
// Update the parameter to point to the definition parameter.
DVI[i].param = Function->getParamDecl(paramIdx);
}
return;
}
Expand Down
7 changes: 4 additions & 3 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4070,9 +4070,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

Expr* ReverseModeVisitor::BreakContStmtHandler::
CreateCFTapeSizeExprForCurrentCase() {
return m_RMV.BuildOp(BinaryOperatorKind::BO_NE, m_ControlFlowTape->Size(),
ConstantFolder::synthesizeLiteral(
m_RMV.m_Context.IntTy, m_RMV.m_Context, 0));
return m_RMV.BuildOp(
BinaryOperatorKind::BO_NE, m_ControlFlowTape->Size(),
ConstantFolder::synthesizeLiteral(m_RMV.m_Context.IntTy,
m_RMV.m_Context, /*val=*/0));
}

void ReverseModeVisitor::BreakContStmtHandler::UpdateForwAndRevBlocks(
Expand Down
3 changes: 1 addition & 2 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,8 @@ namespace clad {
(isa<CXXOperatorCallExpr>(ENoCasts) &&
cast<CXXOperatorCallExpr>(ENoCasts)->getNumArgs() == 2) ||
isa<ConditionalOperator>(ENoCasts) ||
isa<CXXBindTemporaryExpr>(ENoCasts)) {
isa<CXXBindTemporaryExpr>(ENoCasts))
return m_Sema.ActOnParenExpr(E->getBeginLoc(), E->getEndLoc(), E).get();
}
return E;
}

Expand Down

0 comments on commit 18ebe7a

Please sign in to comment.