From 8da21c26dfd67346d4d58205542834be8511050b Mon Sep 17 00:00:00 2001 From: Marco Barbone Date: Thu, 12 Sep 2024 11:09:29 -0400 Subject: [PATCH] vectorized --- src/spreadinterp.cpp | 65 ++++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/src/spreadinterp.cpp b/src/spreadinterp.cpp index 1a2660cbc..1707dac41 100644 --- a/src/spreadinterp.cpp +++ b/src/spreadinterp.cpp @@ -1889,13 +1889,24 @@ void bin_sort_singlethread_vector( static constexpr auto simd_size = simd_type::size; static constexpr auto alignment = arch_t::alignment(); - constexpr auto to_array = [](const auto &vec) constexpr noexcept { + static constexpr auto to_array = [](const auto &vec) constexpr noexcept { using T = decltype(std::decay_t()); alignas(alignment) std::array array{}; vec.store_aligned(array.data()); return array; }; + static constexpr auto has_duplicates = [](const auto &vec) constexpr noexcept { + using T = decltype(std::decay_t()); + for (auto i = 0; i < simd_size; i++) { + const auto rotated = xsimd::rotl(vec, (sizeof(typename T::value_type) * 8) * i); + if ((rotated == vec).mask() != 0) { + return true; + } + } + return false; + }; + const auto isky = (N2 > 1), iskz = (N3 > 1); // ky,kz avail? (cannot access if not) // here the +1 is needed to allow round-off error causing i1=N1/bin_size_x, // for kx near +pi, ie foldrescale gives N1 (exact arith would be 0 to N1-1). @@ -1916,8 +1927,6 @@ void bin_sort_singlethread_vector( // count how many pts in each bin alignas(alignment) std::vector> counts(nbins + simd_size, 0); - alignas(alignment) std::vector> ref_counts(nbins + simd_size, - 0); const auto simd_M = M & (-simd_size); // round down to simd_size multiple UBIGINT i{}; for (i = 0; i < simd_M; i += simd_size) { @@ -1931,14 +1940,17 @@ void bin_sort_singlethread_vector( iskz ? xsimd::to_int(fold_rescale(simd_type::load_unaligned(kz + i), N3) * inv_bin_size_z_vec) : zero; - const auto bin = i1 + nbins1 * (i2 + nbins2 * i3); - const auto bin_array = to_array(bin); - for (int j = 0; j < simd_size; j++) { - ++ref_counts[bin_array[j]]; + const auto bin = i1 + nbins1 * (i2 + nbins2 * i3); + if (has_duplicates(bin)) { + const auto bin_array = to_array(bin); + for (int j = 0; j < simd_size; j++) { + ++counts[bin_array[j]]; + } + } else { + const auto bins = int_simd_type::gather(counts.data(), bin); + const auto incr_bins = bins + 1; + incr_bins.scatter(counts.data(), bin); } - const auto bins = int_simd_type::gather(counts.data(), bin); - const auto incr_bins = bins + 1; - incr_bins.scatter(counts.data(), bin); } for (; i < M; i++) { @@ -1948,16 +1960,6 @@ void bin_sort_singlethread_vector( const auto i3 = iskz ? BIGINT(fold_rescale(kz[i], N3) * inv_bin_size_z) : 0; const auto bin = i1 + nbins1 * (i2 + nbins2 * i3); ++counts[bin]; - ++ref_counts[bin]; - } - - for (i = 0; i < nbins; i++) { - if (counts[i] != ref_counts[i]) { - std::cerr << "Error: bin count mismatch at bin " << i - << " counts[i] = " << counts[i] << " ref_counts[i] = " << ref_counts[i] - << std::endl; - std::abort(); - } } // compute the offsets directly in the counts array (no offset array) @@ -1979,16 +1981,19 @@ void bin_sort_singlethread_vector( iskz ? xsimd::to_int(fold_rescale(simd_type::load_unaligned(kz + i), N3) * inv_bin_size_z_vec) : zero; - const auto bin = i1 + nbins1 * (i2 + nbins2 * i3); - // const auto bins = decltype(bin)::gather(counts.data(), bin); - // const auto ret_elems = decltype(bins)::gather(ret, bins) + (increment+i); - // ret_elems.scatter(ret, bins); - // const auto inc_bins = bins+1; - // inc_bins.scatter(counts.data(), bin); - const auto bin_array = to_array(to_int(bin)); - for (int j = 0; j < simd_size; j++) { - ret[counts[bin_array[j]]] = j + i; - counts[bin_array[j]]++; + const auto bin = i1 + nbins1 * (i2 + nbins2 * i3); + const auto bins = decltype(bin)::gather(counts.data(), bin); + if (has_duplicates(bin) || has_duplicates(bins)) { + const auto bin_array = to_array(to_int(bin)); + for (int j = 0; j < simd_size; j++) { + ret[counts[bin_array[j]]] = j + i; + counts[bin_array[j]]++; + } + } else { + const auto incr_bins = bins + 1; + incr_bins.scatter(counts.data(), bin); + const auto result = increment + i; + result.scatter(ret, bins); } } for (; i < M; i++) {