Skip to content

Commit

Permalink
Add parallel_for prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Oct 16, 2024
1 parent c73a158 commit 56e4faa
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,9 @@ template <class... Args> void resize_pullback(Args... /*args*/) {}
template <typename S> void fence_pushforward(const S& s, const S& /*d_s*/) {
::Kokkos::fence(s);
}
template <typename... Args> void fence_pullback(Args...) { ::Kokkos::fence(); }

/// Parallel for
/// Parallel for (forward mode)
template <class... PolicyParams, class FunctorType> // range policy
void parallel_for_pushforward(
const ::std::string& str,
Expand All @@ -528,7 +529,6 @@ void parallel_for_pushforward(
functor.operator_call_pushforward(i, &d_functor, 0);
});
}

// This structure is used to dispatch parallel for pushforward calls based on
// the rank and the work tag of the MDPolicy
template <class Policy, class FunctorType, class T, int Rank>
Expand Down Expand Up @@ -649,7 +649,6 @@ struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 6> {
});
}
};

template <class PolicyP, class... PolicyParams,
class FunctorType> // multi-dimensional policy
void parallel_for_pushforward(
Expand All @@ -666,7 +665,6 @@ void parallel_for_pushforward(
functor,
d_functor);
}

// This structure is used to dispatch parallel for pushforward calls based on
// the work tag of other types of policies
template <class Policy, class FunctorType, class T>
Expand All @@ -690,7 +688,6 @@ struct diff_parallel_for_OP_call_dispatch<Policy, FunctorType, void> {
});
}
};

// This structure is used to dispatch parallel for pushforward calls for
// integral policies
template <class Policy, class FunctorType, bool IsInt>
Expand All @@ -713,7 +710,6 @@ struct diff_parallel_for_int_call_dispatch<Policy, FunctorType, true> {
});
}
};

template <class Policy, class FunctorType> // other policy type
void parallel_for_pushforward(const ::std::string& str, const Policy& policy,
const FunctorType& functor,
Expand All @@ -726,15 +722,13 @@ void parallel_for_pushforward(const ::std::string& str, const Policy& policy,
functor,
d_functor);
}

template <class Policy, class FunctorType> // anonymous loop
void parallel_for_pushforward(const Policy& policy, const FunctorType& functor,
const Policy& d_policy,
const FunctorType& d_functor) {
parallel_for_pushforward(::std::string("anonymous_parallel_for"), policy,
functor, ::std::string(""), d_policy, d_functor);
}

template <class Policy, class FunctorType> // anonymous loop
void parallel_for_pushforward(
const Policy& policy, const FunctorType& functor,
Expand All @@ -745,15 +739,13 @@ void parallel_for_pushforward(
parallel_for_pushforward(::std::string("anonymous_parallel_for"), policy,
functor, ::std::string(""), d_policy, d_functor);
}

template <typename Policy, class FunctorType> // alternative signature
void parallel_for_pushforward(Policy policy, const FunctorType& functor,
const ::std::string& str, Policy d_policy,
const FunctorType& d_functor,
const ::std::string& d_str) {
parallel_for_pushforward(str, policy, functor, d_str, d_policy, d_functor);
}

template <typename Policy, class FunctorType> // alternative signature
void parallel_for_pushforward(
const Policy& policy, const FunctorType& functor, const ::std::string& str,
Expand All @@ -765,10 +757,16 @@ void parallel_for_pushforward(
parallel_for_pushforward(str, policy, functor, d_str, d_policy, d_functor);
}

/// Parallel reduce
/// Parallel for (reverse mode)
template <typename F>
void parallel_for_pullback(const size_t work_count, const F& functor,
size_t* d_work_count, F* d_functor) {
// TODO: implement parallel_for pullbacks
}

/// Parallel reduce (forward mode)
// TODO: ADD SUPORT FOR MULTIPLE REDUCED ARGUMENTS
// TODO: ADD SUPPORT FOR UNNAMED LOOPS

// This structure is used to dispatch parallel reduce pushforward calls for
// multidimentional policies
template <class Policy, class FunctorType, class Reduced, class WT, int Rank>
Expand Down Expand Up @@ -1022,7 +1020,6 @@ struct diff_parallel_reduce_MDP_dispatch<
d_res);
}
};

// This structure is used to dispatch parallel reduce pushforward calls for
// integral policies
template <class Policy, class FunctorType, class Reduced, bool isInt>
Expand All @@ -1048,7 +1045,6 @@ struct diff_parallel_reduce_int_dispatch<Policy, FunctorType, Reduced, true> {
res, d_res);
}
};

template <class Policy, class FunctorType,
class Reduced> // generally, this is matched
void parallel_reduce_pushforward(const ::std::string& str, const Policy& policy,
Expand Down

0 comments on commit 56e4faa

Please sign in to comment.