From aaefb0efd3240e63459a517d53d666a6b9b0b9cd Mon Sep 17 00:00:00 2001 From: lucas Date: Mon, 20 May 2024 20:52:59 +0200 Subject: [PATCH 1/3] * add reset py method, * fix resetting at each decoding run for bplsd, also if BP converges --- cpp_test/TestLsd.cpp | 20 ++++++++++++++++--- src_cpp/lsd.hpp | 12 +++++++++++ src_python/ldpc/bplsd_decoder/__init__.pyi | 11 ++++++++++ .../ldpc/bplsd_decoder/_bplsd_decoder.pxd | 1 + .../ldpc/bplsd_decoder/_bplsd_decoder.pyx | 12 ++++++++++- 5 files changed, 52 insertions(+), 4 deletions(-) diff --git a/cpp_test/TestLsd.cpp b/cpp_test/TestLsd.cpp index b9f3193..bef2c21 100644 --- a/cpp_test/TestLsd.cpp +++ b/cpp_test/TestLsd.cpp @@ -448,10 +448,10 @@ TEST(LsdDecoder, test_cluster_stats) { auto lsd = LsdDecoder(pcm, ldpc::osd::OsdMethod::EXHAUSTIVE, 0); lsd.set_do_stats(true); auto syndrome = std::vector({1, 1, 0, 0, 0}); - lsd.statistics.error = std::vector(pcm.n, 1); lsd.statistics.syndrome = std::vector(pcm.m, 1); - lsd.statistics.compare_recover = std::vector(pcm.n, 0); auto decoding = lsd.lsd_decode(syndrome, bp.log_prob_ratios, 1, true); + lsd.statistics.compare_recover = std::vector(pcm.n, 0); + lsd.statistics.error = std::vector(pcm.n, 1); auto stats = lsd.statistics; std::cout << stats.toString() << std::endl; @@ -464,7 +464,7 @@ TEST(LsdDecoder, test_cluster_stats) { ASSERT_TRUE(stats.global_timestep_bit_history[0].size() == 2); ASSERT_TRUE(stats.global_timestep_bit_history[0][0].size() == 1); ASSERT_TRUE(stats.global_timestep_bit_history[0][1].size() == 2); - ASSERT_TRUE(stats.global_timestep_bit_history[1].size() == 0); + ASSERT_TRUE(stats.global_timestep_bit_history[1].empty()); ASSERT_TRUE(stats.elapsed_time > 0.0); ASSERT_TRUE(stats.individual_cluster_stats[0].active == false); ASSERT_TRUE(stats.individual_cluster_stats[0].got_inactive_in_timestep == 0); @@ -476,6 +476,20 @@ TEST(LsdDecoder, test_cluster_stats) { ASSERT_TRUE(stats.error.size() == pcm.n); ASSERT_TRUE(stats.syndrome.size() == pcm.n); ASSERT_TRUE(stats.compare_recover.size() == pcm.n); + + // now reset + lsd.reset_cluster_stats(); + stats = lsd.statistics; + ASSERT_TRUE(lsd.get_do_stats()); + ASSERT_TRUE(stats.lsd_method = ldpc::osd::OsdMethod::COMBINATION_SWEEP); + ASSERT_TRUE(stats.lsd_order == 0); + ASSERT_TRUE(stats.individual_cluster_stats.empty()); + ASSERT_TRUE(stats.elapsed_time == 0.0); + ASSERT_TRUE(stats.global_timestep_bit_history.empty()); + ASSERT_TRUE(stats.bit_llrs.empty()); + ASSERT_TRUE(stats.error.empty()); + ASSERT_TRUE(stats.syndrome.empty()); + ASSERT_TRUE(stats.compare_recover.empty()); } TEST(LsdDecoder, test_reshuffle_same_wt_indices) { diff --git a/src_cpp/lsd.hpp b/src_cpp/lsd.hpp index f542f9c..a53a69e 100644 --- a/src_cpp/lsd.hpp +++ b/src_cpp/lsd.hpp @@ -443,6 +443,13 @@ namespace ldpc::lsd { void clear() { this->individual_cluster_stats.clear(); this->global_timestep_bit_history.clear(); + this->elapsed_time = 0.0; + this->lsd_method = osd::OsdMethod::COMBINATION_SWEEP; + this->lsd_order = 0; + this->bit_llrs = {}; + this->error = {}; + this->syndrome= {}; + this->compare_recover = {}; } [[nodiscard]] std::string toString() const { @@ -555,6 +562,10 @@ namespace ldpc::lsd { osd::OsdMethod lsd_method; int lsd_order; + void reset_cluster_stats(){ + this->statistics.clear(); + } + explicit LsdDecoder(ldpc::bp::BpSparse &parity_check_matrix, osd::OsdMethod lsdMethod = osd::OsdMethod::COMBINATION_SWEEP, int lsd_order = 0) : pcm(parity_check_matrix), @@ -618,6 +629,7 @@ namespace ldpc::lsd { const bool is_on_the_fly = true) { auto start_time = std::chrono::high_resolution_clock::now(); this->statistics.clear(); + this->statistics.syndrome = syndrome; fill(this->decoding.begin(), this->decoding.end(), 0); diff --git a/src_python/ldpc/bplsd_decoder/__init__.pyi b/src_python/ldpc/bplsd_decoder/__init__.pyi index ee5b7e7..aead045 100644 --- a/src_python/ldpc/bplsd_decoder/__init__.pyi +++ b/src_python/ldpc/bplsd_decoder/__init__.pyi @@ -166,3 +166,14 @@ class BpLsdDecoder(BpDecoderBase): """ + + def set_additional_stat_fields(self, error, syndrome, compare_recover) -> None: + """ + Sets additional fields to be collected in the statistics. + + Parameters + ---------- + fields : List[str] + A list of strings representing the additional fields to be collected in the statistics. + """ + diff --git a/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pxd b/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pxd index 947638f..2a3b69a 100644 --- a/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pxd +++ b/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pxd @@ -52,6 +52,7 @@ cdef extern from "lsd.hpp" namespace "ldpc::lsd": bool get_do_stats() void set_do_stats(bool do_stats) void set_additional_stat_fields(vector[int] error, vector[int] syndrome, vector[int] compare_recover) + void reset_cluster_stats() cdef class BpLsdDecoder(BpDecoderBase): cdef LsdDecoderCpp* lsd diff --git a/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pyx b/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pyx index 57c188d..c049688 100644 --- a/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pyx +++ b/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pyx @@ -140,7 +140,10 @@ cdef class BpLsdDecoder(BpDecoderBase): self.bpd.decoding = self.bpd.decode(self._syndrome) out = np.zeros(self.n,dtype=DTYPE) if self.bpd.converge: - for i in range(self.n): out[i] = self.bpd.decoding[i] + for i in range(self.n): + out[i] = self.bpd.decoding[i] + self.lsd.reset_cluster_stats() + if not self.bpd.converge: self.lsd.decoding = self.lsd.lsd_decode(self._syndrome, self.bpd.log_prob_ratios,self.bits_per_step, True) @@ -291,3 +294,10 @@ cdef class BpLsdDecoder(BpDecoderBase): self.lsd.statistics.error = error self.lsd.statistics.syndrome = syndrome self.lsd.statistics.compare_recover = compare_recover + + def reset_cluster_stats() -> None: + """ + Resets cluster statistics of the decoder. + Note that this also resets the additional stat fields, such as the error, and compare_recovery vectors + """ + self.lsd.reset_cluster_stats() \ No newline at end of file From 98575ecc1a044eacc52bf03e37c1f9c3d60ffeb4 Mon Sep 17 00:00:00 2001 From: lucas Date: Mon, 20 May 2024 21:20:16 +0200 Subject: [PATCH 2/3] some tests --- python_test/test_bplsd.py | 179 ++++++++++++++++++ src_python/ldpc/bplsd_decoder/__init__.pyi | 7 + .../ldpc/bplsd_decoder/_bplsd_decoder.pyx | 6 +- 3 files changed, 189 insertions(+), 3 deletions(-) create mode 100644 python_test/test_bplsd.py diff --git a/python_test/test_bplsd.py b/python_test/test_bplsd.py new file mode 100644 index 0000000..f9489e2 --- /dev/null +++ b/python_test/test_bplsd.py @@ -0,0 +1,179 @@ +import pytest +import numpy as np +import scipy.sparse +from ldpc.codes import rep_code, ring_code, hamming_code +from ldpc.bplsd_decoder import BpLsdDecoder + +# Define valid inputs for testing +valid_pcms = [ + np.array([[1, 0, 1], [0, 1, 1]]), + scipy.sparse.csr_matrix([[1, 0, 1], [0, 1, 1]]), +] + +valid_error_rates = [ + 0.1, +] + +valid_max_iters = [ + 0, + 10, +] + +valid_bp_methods = [ + "product_sum", + "minimum_sum", +] + +valid_ms_scaling_factors = [ + 1.0, + 0.5, +] + +valid_schedules = [ + "parallel", + "serial", +] + +valid_omp_thread_counts = [ + 1, + 4, +] + +valid_random_schedule_seeds = [ + 42, +] + +valid_serial_schedule_orders = [ + None, + [1, 0, 2], +] + +# Combine valid inputs for parameterized testing +valid_input_permutations = pytest.mark.parametrize( + "pcm, error_rate, max_iter, bp_method, ms_scaling_factor, schedule, omp_thread_count, random_schedule_seed, serial_schedule_order", + [ + (pcm, error, max_iter, bp_method, ms_factor, schedule, omp_count, seed, order) + for pcm in valid_pcms + for error in valid_error_rates + for max_iter in valid_max_iters + for bp_method in valid_bp_methods + for ms_factor in valid_ms_scaling_factors + for schedule in valid_schedules + for omp_count in valid_omp_thread_counts + for seed in valid_random_schedule_seeds + for order in valid_serial_schedule_orders + ], +) + +def test_BpLsdDecoder_init(): + + # test with numpy ndarray as pcm + pcm = np.array([[1, 0, 1], [0, 1, 1]]) + decoder = BpLsdDecoder(pcm, error_rate=0.1, max_iter=10, bp_method='prod_sum', ms_scaling_factor=0.5, schedule='parallel', omp_thread_count=4, random_schedule_seed=1, serial_schedule_order=[1,2,0],input_vector_type = "syndrome") + assert decoder is not None + assert decoder.check_count == 2 + assert decoder.bit_count == 3 + assert decoder.max_iter == 10 + assert decoder.bp_method == "product_sum" + assert decoder.ms_scaling_factor == 0.5 + assert decoder.schedule == "parallel" + assert decoder.omp_thread_count == 4 + assert decoder.random_schedule_seed == 1 + assert np.array_equal(decoder.serial_schedule_order, np.array([1, 2, 0])) + assert np.array_equal(decoder.input_vector_type, "syndrome") + + # test with scipy.sparse scipy.sparse.csr_matrix as pcm + pcm = scipy.sparse.csr_matrix([[1, 0, 1], [0, 1, 1]]) + decoder = BpLsdDecoder(pcm, error_channel=[0.1, 0.2, 0.3],input_vector_type = "syndrome") + assert decoder is not None + assert decoder.check_count == 2 + assert decoder.bit_count == 3 + assert decoder.max_iter == 3 + assert decoder.bp_method == "product_sum" + assert decoder.ms_scaling_factor == 1.0 + assert decoder.schedule == "parallel" + assert decoder.omp_thread_count == 1 + assert decoder.random_schedule_seed == 0 + assert np.array_equal(decoder.serial_schedule_order, np.array([0, 1, 2])) + assert np.array_equal(decoder.input_vector_type, "syndrome") + + + # test with invalid pcm type + with pytest.raises(TypeError): + decoder = BpLsdDecoder('invalid', error_rate=0.1) + + # test with invalid max_iter type + with pytest.raises(TypeError): + decoder = BpLsdDecoder(pcm, error_rate=0.1,max_iter='invalid') + + # test with invalid max_iter value + with pytest.raises(ValueError): + decoder = BpLsdDecoder(pcm, error_rate =0.1, max_iter=-1) + + # test with invalid bp_method value + with pytest.raises(ValueError): + decoder = BpLsdDecoder(pcm,error_rate=0.1, bp_method='invalid') + + # test with invalid schedule value + with pytest.raises(ValueError): + decoder = BpLsdDecoder(pcm,error_rate=0.1, schedule='invalid') + + # test with invalid ms_scaling_factor value + with pytest.raises(TypeError): + decoder = BpLsdDecoder(pcm,error_rate=0.1, ms_scaling_factor='invalid') + + # test with invalid omp_thread_count value + with pytest.raises(TypeError): + decoder = BpLsdDecoder(pcm, error_rate=0.1,omp_thread_count='invalid') + + # test with invalid random_schedule_seed value + with pytest.raises(TypeError): + decoder = BpLsdDecoder(pcm, error_rate=0.1, random_schedule_seed='invalid') + + # test with invalid serial_schedule_order value + with pytest.raises(Exception): + decoder = BpLsdDecoder(pcm, error_rate=0.1, serial_schedule_order=[1, 2]) + +def test_rep_code_ms(): + + H = rep_code(3) + + lsd = BpLsdDecoder(H,error_rate=0.1, bp_method='min_sum', ms_scaling_factor=1.0) + assert lsd is not None + assert lsd.bp_method == "minimum_sum" + assert lsd.schedule == "parallel" + assert np.array_equal(lsd.error_channel,np.array([0.1, 0.1, 0.1])) + + + decoding = lsd.decode(np.array([1, 1])) + assert(np.array_equal(decoding, np.array([0, 1,0]))) + + lsd.error_channel = np.array([0.1, 0, 0.1]) + assert np.array_equal(lsd.error_channel,np.array([0.1, 0, 0.1])) + + decoding=lsd.decode(np.array([1, 1])) + assert(np.array_equal(decoding, np.array([1, 0, 1]))) + +def test_stats_reset(): + + H = rep_code(5) + + lsd = BpLsdDecoder(H,max_iter=1,error_rate=0.1, bp_method='min_sum', ms_scaling_factor=1.0) + lsd.set_do_stats(True) + syndrome = np.array([1,1,0,1]) + lsd.decode(syndrome) + + stats = lsd.statistics + assert stats['lsd_order'] == 0 + assert stats["lsd_method"] == 1 + assert len(stats["bit_llrs"]) == H.shape[1] + assert len(stats["individual_cluster_stats"])>0 + assert len(stats["global_timestep_bit_history"])>0 + + syndrome = np.array([0,0,0,0]) + lsd.decode(syndrome) + + stats = lsd.statistics + assert len(stats["bit_llrs"]) == 0 + assert len(stats["individual_cluster_stats"])==0 + assert len(stats["global_timestep_bit_history"])==0 diff --git a/src_python/ldpc/bplsd_decoder/__init__.pyi b/src_python/ldpc/bplsd_decoder/__init__.pyi index aead045..6cb662d 100644 --- a/src_python/ldpc/bplsd_decoder/__init__.pyi +++ b/src_python/ldpc/bplsd_decoder/__init__.pyi @@ -177,3 +177,10 @@ class BpLsdDecoder(BpDecoderBase): A list of strings representing the additional fields to be collected in the statistics. """ + + def reset_cluster_stats(self) -> None: + """ + Resets cluster statistics of the decoder. + Note that this also resets the additional stat fields, such as the error, and compare_recovery vectors + """ + diff --git a/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pyx b/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pyx index c049688..7e86a25 100644 --- a/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pyx +++ b/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pyx @@ -295,9 +295,9 @@ cdef class BpLsdDecoder(BpDecoderBase): self.lsd.statistics.syndrome = syndrome self.lsd.statistics.compare_recover = compare_recover - def reset_cluster_stats() -> None: + def reset_cluster_stats(self) -> None: """ - Resets cluster statistics of the decoder. - Note that this also resets the additional stat fields, such as the error, and compare_recovery vectors + Resets cluster statistics of the decoder. + Note that this also resets the additional stat fields, such as the error, and compare_recovery vectors """ self.lsd.reset_cluster_stats() \ No newline at end of file From e953dd5cbefa75f82ba7265baa26e1405af56b00 Mon Sep 17 00:00:00 2001 From: lucas Date: Mon, 3 Jun 2024 15:02:56 +0200 Subject: [PATCH 3/3] fix test --- cpp_test/TestLsd.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp_test/TestLsd.cpp b/cpp_test/TestLsd.cpp index 4a902f9..04965f6 100644 --- a/cpp_test/TestLsd.cpp +++ b/cpp_test/TestLsd.cpp @@ -459,10 +459,10 @@ TEST(LsdDecoder, test_cluster_stats) { lsd.set_do_stats(true); auto syndrome = std::vector({1, 1, 0, 0, 0}); lsd.statistics.syndrome = std::vector(pcm.m, 1); - lsd.statistics.compare_recover = std::vector(pcm.n, 0); lsd.setLsdMethod(ldpc::osd::OsdMethod::EXHAUSTIVE); auto decoding = lsd.lsd_decode(syndrome, bp.log_prob_ratios, 1, true); lsd.statistics.error = std::vector(pcm.n, 1); + lsd.statistics.compare_recover = std::vector(pcm.n, 0); auto stats = lsd.statistics; std::cout << stats.toString() << std::endl;