From f381914dc946793425f67097cddcb0e0da882e8b Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Mon, 9 Dec 2024 18:12:18 +0100 Subject: [PATCH] Combine bit update loops into a single loop --- .../detail/rmat_rectangular_generator.cuh | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh b/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh index 9ad7c68f87..24207ba6db 100644 --- a/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh +++ b/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh @@ -151,15 +151,16 @@ RAFT_KERNEL rmat_gen_kernel(IdxT* out, raft::random::PCGenerator gen{r.seed, r.base_subsequence + idx, 0}; auto min_scale = min(r_scale, c_scale); IdxT i = 0; - for (; i < min_scale; ++i) { - gen_and_update_bits(src_id, dst_id, a, a + b, a + b + c, r_scale, c_scale, i, gen); - } - for (; i < r_scale; ++i) { - gen_and_update_bits(src_id, dst_id, a + b, a + b, ProbT(1), r_scale, c_scale, i, gen); - } - for (; i < c_scale; ++i) { - gen_and_update_bits(src_id, dst_id, a + c, ProbT(1), ProbT(1), r_scale, c_scale, i, gen); + // Whether we have more rows than columns. + const bool more_rows = r_scale > c_scale; + + for (; i < max_scale; ++i) { + ProbT A = (i < min_scale) ? a : (more_rows ? a + b : a + c); + ProbT AB = (i < min_scale) ? a + b : (more_rows ? a + b : ProbT(1)); + ProbT ABC = (i < min_scale) ? a + b + c : ProbT(1); + gen_and_update_bits(src_id, dst_id, A, AB, ABC, r_scale, c_scale, i, gen); } + store_ids(out, out_src, out_dst, src_id, dst_id, idx, n_edges); }