Skip to content

Commit

Permalink
Add parallel_for prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Oct 15, 2024
1 parent 8fd700c commit 25c180a
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,9 @@ template <class... Args> void resize_pullback(Args... /*args*/) {}
template <typename S> void fence_pushforward(const S& s, const S& /*d_s*/) {
::Kokkos::fence(s);
}
template <typename... Args> void fence_pullback(Args...) { ::Kokkos::fence(); }

/// Parallel for
/// Parallel for (forward mode)
template <class... PolicyParams, class FunctorType> // range policy
void parallel_for_pushforward(
const ::std::string& str,
Expand All @@ -541,7 +542,6 @@ void parallel_for_pushforward(
functor.operator_call_pushforward(i, &d_functor, 0);
});
}

// This structure is used to dispatch parallel for pushforward calls based on
// the rank and the work tag of the MDPolicy
template <class Policy, class FunctorType, class T, int Rank>
Expand Down Expand Up @@ -662,7 +662,6 @@ struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 6> {
});
}
};

template <class PolicyP, class... PolicyParams,
class FunctorType> // multi-dimensional policy
void parallel_for_pushforward(
Expand All @@ -679,7 +678,6 @@ void parallel_for_pushforward(
functor,
d_functor);
}

// This structure is used to dispatch parallel for pushforward calls based on
// the work tag of other types of policies
template <class Policy, class FunctorType, class T>
Expand All @@ -703,7 +701,6 @@ struct diff_parallel_for_OP_call_dispatch<Policy, FunctorType, void> {
});
}
};

// This structure is used to dispatch parallel for pushforward calls for
// integral policies
template <class Policy, class FunctorType, bool IsInt>
Expand All @@ -726,7 +723,6 @@ struct diff_parallel_for_int_call_dispatch<Policy, FunctorType, true> {
});
}
};

template <class Policy, class FunctorType> // other policy type
void parallel_for_pushforward(const ::std::string& str, const Policy& policy,
const FunctorType& functor,
Expand All @@ -739,15 +735,13 @@ void parallel_for_pushforward(const ::std::string& str, const Policy& policy,
functor,
d_functor);
}

template <class Policy, class FunctorType> // anonymous loop
void parallel_for_pushforward(const Policy& policy, const FunctorType& functor,
const Policy& d_policy,
const FunctorType& d_functor) {
parallel_for_pushforward(::std::string("anonymous_parallel_for"), policy,
functor, ::std::string(""), d_policy, d_functor);
}

template <class Policy, class FunctorType> // anonymous loop
void parallel_for_pushforward(
const Policy& policy, const FunctorType& functor,
Expand All @@ -758,15 +752,13 @@ void parallel_for_pushforward(
parallel_for_pushforward(::std::string("anonymous_parallel_for"), policy,
functor, ::std::string(""), d_policy, d_functor);
}

template <typename Policy, class FunctorType> // alternative signature
void parallel_for_pushforward(Policy policy, const FunctorType& functor,
const ::std::string& str, Policy d_policy,
const FunctorType& d_functor,
const ::std::string& d_str) {
parallel_for_pushforward(str, policy, functor, d_str, d_policy, d_functor);
}

template <typename Policy, class FunctorType> // alternative signature
void parallel_for_pushforward(
const Policy& policy, const FunctorType& functor, const ::std::string& str,
Expand All @@ -778,10 +770,16 @@ void parallel_for_pushforward(
parallel_for_pushforward(str, policy, functor, d_str, d_policy, d_functor);
}

/// Parallel reduce
/// Parallel for (reverse mode)
template <typename F>
void parallel_for_pullback(const size_t work_count, const F& functor,
size_t* d_work_count, F* d_functor) {
// TODO: implement parallel_for pullbacks
}

/// Parallel reduce (forward mode)
// TODO: ADD SUPORT FOR MULTIPLE REDUCED ARGUMENTS
// TODO: ADD SUPPORT FOR UNNAMED LOOPS

// This structure is used to dispatch parallel reduce pushforward calls for
// multidimentional policies
template <class Policy, class FunctorType, class Reduced, class WT, int Rank>
Expand Down Expand Up @@ -1035,7 +1033,6 @@ struct diff_parallel_reduce_MDP_dispatch<
d_res);
}
};

// This structure is used to dispatch parallel reduce pushforward calls for
// integral policies
template <class Policy, class FunctorType, class Reduced, bool isInt>
Expand All @@ -1061,7 +1058,6 @@ struct diff_parallel_reduce_int_dispatch<Policy, FunctorType, Reduced, true> {
res, d_res);
}
};

template <class Policy, class FunctorType,
class Reduced> // generally, this is matched
void parallel_reduce_pushforward(const ::std::string& str, const Policy& policy,
Expand Down

0 comments on commit 25c180a

Please sign in to comment.