Skip to content

Commit

Permalink
Add poolstl::for_each_chunk
Browse files Browse the repository at this point in the history
  • Loading branch information
alugowski committed Dec 7, 2023
1 parent 379833c commit f973646
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ All in `std::` namespace.
### Other
* [`poolstl::iota_iter`](include/poolstl/iota_iter.hpp) - Iterate over integers. Same as iterating over output of [`std::iota`](https://en.cppreference.com/w/cpp/algorithm/iota) but without materializing anything. Iterator version of [`std::ranges::iota_view`](https://en.cppreference.com/w/cpp/ranges/iota_view).
* `poolstl::for_each_chunk` - Like `std::for_each`, but explicitly splits the input range into chunks then exposes the chunked parallelism. A user-specified chunk constructor is called for each parallel chunk then its output is passed to each loop iteration. Useful for workloads that need an expensive workspace that can be reused between iterations, but not simultaneously by all iterations in parallel.
## Usage
Expand Down
34 changes: 34 additions & 0 deletions include/poolstl/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -294,4 +294,38 @@ namespace std {
}
}

namespace poolstl {

template <class RandIt, class ChunkConstructor, class UnaryFunction>
void for_each_chunk(RandIt first, RandIt last, ChunkConstructor construct, UnaryFunction f) {
if (first == last) {
return;
}

auto chunk_data = construct();
for (; first != last; ++first) {
f(*first, chunk_data);
}
}

/**
* NOTE: Iterators are expected to be random access.
*
* Like `std::for_each`, but exposes the chunking. The `construct` method is called once per parallel chunk and
* its output is passed to `f`.
*
* Useful for cases where an expensive workspace can be shared between loop iterations
* but cannot be shared by all parallel iterations.
*/
template <class ExecPolicy, class RandIt, class ChunkConstructor, class UnaryFunction>
poolstl::internal::enable_if_par<ExecPolicy, void>
for_each_chunk(ExecPolicy&& policy, RandIt first, RandIt last, ChunkConstructor construct, UnaryFunction f) {
auto futures = poolstl::internal::parallel_chunk_for(std::forward<ExecPolicy>(policy), first, last,
[&construct, &f](RandIt chunk_first, RandIt chunk_last) {
for_each_chunk(chunk_first, chunk_last, construct, f);
});
poolstl::internal::get_futures(futures);
}
}

#endif
4 changes: 4 additions & 0 deletions include/poolstl/seq_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,8 @@ namespace std {
#endif
}

namespace poolstl {
POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF_VOID(poolstl, for_each_chunk)
}

#endif
32 changes: 32 additions & 0 deletions tests/poolstl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,38 @@ TEST_CASE("for_each_n", "[alg][algorithm]") {
}
}

TEST_CASE("for_each_chunk", "[alg][algorithm][poolstl]") {
std::atomic<int> sum{0};
std::atomic<int> num_chunks{0};
for (auto num_threads : test_thread_counts) {
ttp::task_thread_pool pool(num_threads);

for (auto num_iters : test_arr_sizes) {
auto v = iota_vector(num_iters);

for (auto is_sequential : {true, false}) {
num_chunks = 0;
sum = 0;
auto cc = [&]() { ++num_chunks; return 1; };
auto f = [&](auto, auto) { ++sum; };
if (is_sequential) {
poolstl::for_each_chunk(poolstl::par_if(false), v.cbegin(), v.cend(), cc, f);
REQUIRE(num_chunks == (v.empty() ? 0 : 1));
} else {
poolstl::for_each_chunk(poolstl::par.on(pool), v.cbegin(), v.cend(), cc, f);
if (num_threads != 0) {
REQUIRE(num_chunks <= std::min((int)v.size(), num_threads));
}
if (!v.empty()) {
REQUIRE(num_chunks > 0);
}
}
REQUIRE(sum == num_iters);
}
}
}
}

TEST_CASE("sort", "[alg][algorithm]") {
for (auto num_threads : test_thread_counts) {
ttp::task_thread_pool pool(num_threads);
Expand Down

0 comments on commit f973646

Please sign in to comment.