diff --git a/README.md b/README.md index 5f3e1a9..636449f 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ All in `std::` namespace. ### Other * [`poolstl::iota_iter`](include/poolstl/iota_iter.hpp) - Iterate over integers. Same as iterating over output of [`std::iota`](https://en.cppreference.com/w/cpp/algorithm/iota) but without materializing anything. Iterator version of [`std::ranges::iota_view`](https://en.cppreference.com/w/cpp/ranges/iota_view). +* `poolstl::for_each_chunk` - Like `std::for_each`, but explicitly splits the input range into chunks then exposes the chunked parallelism. A user-specified chunk constructor is called for each parallel chunk then its output is passed to each loop iteration. Useful for workloads that need an expensive workspace that can be reused between iterations, but not simultaneously by all iterations in parallel. ## Usage diff --git a/include/poolstl/algorithm b/include/poolstl/algorithm index 5bb1bd3..d4db75e 100644 --- a/include/poolstl/algorithm +++ b/include/poolstl/algorithm @@ -294,4 +294,38 @@ namespace std { } } +namespace poolstl { + + template + void for_each_chunk(RandIt first, RandIt last, ChunkConstructor construct, UnaryFunction f) { + if (first == last) { + return; + } + + auto chunk_data = construct(); + for (; first != last; ++first) { + f(*first, chunk_data); + } + } + + /** + * NOTE: Iterators are expected to be random access. + * + * Like `std::for_each`, but exposes the chunking. The `construct` method is called once per parallel chunk and + * its output is passed to `f`. + * + * Useful for cases where an expensive workspace can be shared between loop iterations + * but cannot be shared by all parallel iterations. + */ + template + poolstl::internal::enable_if_par + for_each_chunk(ExecPolicy&& policy, RandIt first, RandIt last, ChunkConstructor construct, UnaryFunction f) { + auto futures = poolstl::internal::parallel_chunk_for(std::forward(policy), first, last, + [&construct, &f](RandIt chunk_first, RandIt chunk_last) { + for_each_chunk(chunk_first, chunk_last, construct, f); + }); + poolstl::internal::get_futures(futures); + } +} + #endif diff --git a/include/poolstl/seq_fwd.hpp b/include/poolstl/seq_fwd.hpp index 58e9053..b08a5ca 100644 --- a/include/poolstl/seq_fwd.hpp +++ b/include/poolstl/seq_fwd.hpp @@ -92,4 +92,8 @@ namespace std { #endif } +namespace poolstl { + POOLSTL_DEFINE_BOTH_SEQ_FWD_AND_PAR_IF_VOID(poolstl, for_each_chunk) +} + #endif diff --git a/tests/poolstl_test.cpp b/tests/poolstl_test.cpp index f8cb13d..92caf7a 100644 --- a/tests/poolstl_test.cpp +++ b/tests/poolstl_test.cpp @@ -242,6 +242,38 @@ TEST_CASE("for_each_n", "[alg][algorithm]") { } } +TEST_CASE("for_each_chunk", "[alg][algorithm][poolstl]") { + std::atomic sum{0}; + std::atomic num_chunks{0}; + for (auto num_threads : test_thread_counts) { + ttp::task_thread_pool pool(num_threads); + + for (auto num_iters : test_arr_sizes) { + auto v = iota_vector(num_iters); + + for (auto is_sequential : {true, false}) { + num_chunks = 0; + sum = 0; + auto cc = [&]() { ++num_chunks; return 1; }; + auto f = [&](auto, auto) { ++sum; }; + if (is_sequential) { + poolstl::for_each_chunk(poolstl::par_if(false), v.cbegin(), v.cend(), cc, f); + REQUIRE(num_chunks == (v.empty() ? 0 : 1)); + } else { + poolstl::for_each_chunk(poolstl::par.on(pool), v.cbegin(), v.cend(), cc, f); + if (num_threads != 0) { + REQUIRE(num_chunks <= std::min((int)v.size(), num_threads)); + } + if (!v.empty()) { + REQUIRE(num_chunks > 0); + } + } + REQUIRE(sum == num_iters); + } + } + } +} + TEST_CASE("sort", "[alg][algorithm]") { for (auto num_threads : test_thread_counts) { ttp::task_thread_pool pool(num_threads);