Skip to content

Commit

Permalink
Merge pull request #39 from quantumgizmos/fix-stats
Browse files Browse the repository at this point in the history
Fix stats
  • Loading branch information
lucasberent authored Jun 3, 2024
2 parents ba4ea7b + e953dd5 commit 4ad8af3
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 7 deletions.
26 changes: 20 additions & 6 deletions cpp_test/TestLsd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ TEST(LsdCluster, init1){
auto gcm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.m));
auto syndrome_index = 0;
auto cl = ldpc::lsd::LsdCluster(pcm, syndrome_index, gcm, gbm);

ASSERT_TRUE(cl.active);
ASSERT_FALSE(cl.valid);

Expand Down Expand Up @@ -55,7 +55,7 @@ TEST(LsdCluster, add_bitANDadd_check_add){
auto gcm = std::make_shared<std::vector<LsdCluster*>>(std::vector<LsdCluster*>(pcm.m));
auto syndrome_index = 1;
auto cl = ldpc::lsd::LsdCluster(pcm, syndrome_index, gcm, gbm);

cl.compute_growth_candidate_bit_nodes();
auto expected_candidate_bit_nodes = tsl::robin_set<int>{1, 2};
ASSERT_EQ(expected_candidate_bit_nodes, cl.candidate_bit_nodes);
Expand Down Expand Up @@ -458,11 +458,11 @@ TEST(LsdDecoder, test_cluster_stats) {
auto lsd = LsdDecoder(pcm, ldpc::osd::OsdMethod::EXHAUSTIVE, 0);
lsd.set_do_stats(true);
auto syndrome = std::vector<uint8_t>({1, 1, 0, 0, 0});
lsd.statistics.error = std::vector<uint8_t>(pcm.n, 1);
lsd.statistics.syndrome = std::vector<uint8_t>(pcm.m, 1);
lsd.statistics.compare_recover = std::vector<uint8_t>(pcm.n, 0);
auto decoding = lsd.lsd_decode(syndrome, bp.log_prob_ratios, 1, true);
lsd.setLsdMethod(ldpc::osd::OsdMethod::EXHAUSTIVE);
auto decoding = lsd.lsd_decode(syndrome, bp.log_prob_ratios, 1, true);
lsd.statistics.error = std::vector<uint8_t>(pcm.n, 1);
lsd.statistics.compare_recover = std::vector<uint8_t>(pcm.n, 0);

auto stats = lsd.statistics;
std::cout << stats.toString() << std::endl;
Expand All @@ -475,7 +475,7 @@ TEST(LsdDecoder, test_cluster_stats) {
ASSERT_TRUE(stats.global_timestep_bit_history[0].size() == 2);
ASSERT_TRUE(stats.global_timestep_bit_history[0][0].size() == 1);
ASSERT_TRUE(stats.global_timestep_bit_history[0][1].size() == 2);
ASSERT_TRUE(stats.global_timestep_bit_history[1].size() == 0);
ASSERT_TRUE(stats.global_timestep_bit_history[1].empty());
ASSERT_TRUE(stats.elapsed_time > 0.0);
ASSERT_TRUE(stats.individual_cluster_stats[0].active == false);
ASSERT_TRUE(stats.individual_cluster_stats[0].got_inactive_in_timestep == 0);
Expand All @@ -487,6 +487,20 @@ TEST(LsdDecoder, test_cluster_stats) {
ASSERT_TRUE(stats.error.size() == pcm.n);
ASSERT_TRUE(stats.syndrome.size() == pcm.n);
ASSERT_TRUE(stats.compare_recover.size() == pcm.n);

// now reset
lsd.reset_cluster_stats();
stats = lsd.statistics;
ASSERT_TRUE(lsd.get_do_stats());
ASSERT_TRUE(stats.lsd_method = ldpc::osd::OsdMethod::COMBINATION_SWEEP);
ASSERT_TRUE(stats.lsd_order == 0);
ASSERT_TRUE(stats.individual_cluster_stats.empty());
ASSERT_TRUE(stats.elapsed_time == 0.0);
ASSERT_TRUE(stats.global_timestep_bit_history.empty());
ASSERT_TRUE(stats.bit_llrs.empty());
ASSERT_TRUE(stats.error.empty());
ASSERT_TRUE(stats.syndrome.empty());
ASSERT_TRUE(stats.compare_recover.empty());
}

TEST(LsdDecoder, test_reshuffle_same_wt_indices) {
Expand Down
179 changes: 179 additions & 0 deletions python_test/test_bplsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import pytest
import numpy as np
import scipy.sparse
from ldpc.codes import rep_code, ring_code, hamming_code
from ldpc.bplsd_decoder import BpLsdDecoder

# Define valid inputs for testing
valid_pcms = [
np.array([[1, 0, 1], [0, 1, 1]]),
scipy.sparse.csr_matrix([[1, 0, 1], [0, 1, 1]]),
]

valid_error_rates = [
0.1,
]

valid_max_iters = [
0,
10,
]

valid_bp_methods = [
"product_sum",
"minimum_sum",
]

valid_ms_scaling_factors = [
1.0,
0.5,
]

valid_schedules = [
"parallel",
"serial",
]

valid_omp_thread_counts = [
1,
4,
]

valid_random_schedule_seeds = [
42,
]

valid_serial_schedule_orders = [
None,
[1, 0, 2],
]

# Combine valid inputs for parameterized testing
valid_input_permutations = pytest.mark.parametrize(
"pcm, error_rate, max_iter, bp_method, ms_scaling_factor, schedule, omp_thread_count, random_schedule_seed, serial_schedule_order",
[
(pcm, error, max_iter, bp_method, ms_factor, schedule, omp_count, seed, order)
for pcm in valid_pcms
for error in valid_error_rates
for max_iter in valid_max_iters
for bp_method in valid_bp_methods
for ms_factor in valid_ms_scaling_factors
for schedule in valid_schedules
for omp_count in valid_omp_thread_counts
for seed in valid_random_schedule_seeds
for order in valid_serial_schedule_orders
],
)

def test_BpLsdDecoder_init():

# test with numpy ndarray as pcm
pcm = np.array([[1, 0, 1], [0, 1, 1]])
decoder = BpLsdDecoder(pcm, error_rate=0.1, max_iter=10, bp_method='prod_sum', ms_scaling_factor=0.5, schedule='parallel', omp_thread_count=4, random_schedule_seed=1, serial_schedule_order=[1,2,0],input_vector_type = "syndrome")
assert decoder is not None
assert decoder.check_count == 2
assert decoder.bit_count == 3
assert decoder.max_iter == 10
assert decoder.bp_method == "product_sum"
assert decoder.ms_scaling_factor == 0.5
assert decoder.schedule == "parallel"
assert decoder.omp_thread_count == 4
assert decoder.random_schedule_seed == 1
assert np.array_equal(decoder.serial_schedule_order, np.array([1, 2, 0]))
assert np.array_equal(decoder.input_vector_type, "syndrome")

# test with scipy.sparse scipy.sparse.csr_matrix as pcm
pcm = scipy.sparse.csr_matrix([[1, 0, 1], [0, 1, 1]])
decoder = BpLsdDecoder(pcm, error_channel=[0.1, 0.2, 0.3],input_vector_type = "syndrome")
assert decoder is not None
assert decoder.check_count == 2
assert decoder.bit_count == 3
assert decoder.max_iter == 3
assert decoder.bp_method == "product_sum"
assert decoder.ms_scaling_factor == 1.0
assert decoder.schedule == "parallel"
assert decoder.omp_thread_count == 1
assert decoder.random_schedule_seed == 0
assert np.array_equal(decoder.serial_schedule_order, np.array([0, 1, 2]))
assert np.array_equal(decoder.input_vector_type, "syndrome")


# test with invalid pcm type
with pytest.raises(TypeError):
decoder = BpLsdDecoder('invalid', error_rate=0.1)

# test with invalid max_iter type
with pytest.raises(TypeError):
decoder = BpLsdDecoder(pcm, error_rate=0.1,max_iter='invalid')

# test with invalid max_iter value
with pytest.raises(ValueError):
decoder = BpLsdDecoder(pcm, error_rate =0.1, max_iter=-1)

# test with invalid bp_method value
with pytest.raises(ValueError):
decoder = BpLsdDecoder(pcm,error_rate=0.1, bp_method='invalid')

# test with invalid schedule value
with pytest.raises(ValueError):
decoder = BpLsdDecoder(pcm,error_rate=0.1, schedule='invalid')

# test with invalid ms_scaling_factor value
with pytest.raises(TypeError):
decoder = BpLsdDecoder(pcm,error_rate=0.1, ms_scaling_factor='invalid')

# test with invalid omp_thread_count value
with pytest.raises(TypeError):
decoder = BpLsdDecoder(pcm, error_rate=0.1,omp_thread_count='invalid')

# test with invalid random_schedule_seed value
with pytest.raises(TypeError):
decoder = BpLsdDecoder(pcm, error_rate=0.1, random_schedule_seed='invalid')

# test with invalid serial_schedule_order value
with pytest.raises(Exception):
decoder = BpLsdDecoder(pcm, error_rate=0.1, serial_schedule_order=[1, 2])

def test_rep_code_ms():

H = rep_code(3)

lsd = BpLsdDecoder(H,error_rate=0.1, bp_method='min_sum', ms_scaling_factor=1.0)
assert lsd is not None
assert lsd.bp_method == "minimum_sum"
assert lsd.schedule == "parallel"
assert np.array_equal(lsd.error_channel,np.array([0.1, 0.1, 0.1]))


decoding = lsd.decode(np.array([1, 1]))
assert(np.array_equal(decoding, np.array([0, 1,0])))

lsd.error_channel = np.array([0.1, 0, 0.1])
assert np.array_equal(lsd.error_channel,np.array([0.1, 0, 0.1]))

decoding=lsd.decode(np.array([1, 1]))
assert(np.array_equal(decoding, np.array([1, 0, 1])))

def test_stats_reset():

H = rep_code(5)

lsd = BpLsdDecoder(H,max_iter=1,error_rate=0.1, bp_method='min_sum', ms_scaling_factor=1.0)
lsd.set_do_stats(True)
syndrome = np.array([1,1,0,1])
lsd.decode(syndrome)

stats = lsd.statistics
assert stats['lsd_order'] == 0
assert stats["lsd_method"] == 1
assert len(stats["bit_llrs"]) == H.shape[1]
assert len(stats["individual_cluster_stats"])>0
assert len(stats["global_timestep_bit_history"])>0

syndrome = np.array([0,0,0,0])
lsd.decode(syndrome)

stats = lsd.statistics
assert len(stats["bit_llrs"]) == 0
assert len(stats["individual_cluster_stats"])==0
assert len(stats["global_timestep_bit_history"])==0
12 changes: 12 additions & 0 deletions src_cpp/lsd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,13 @@ namespace ldpc::lsd {
void clear() {
this->individual_cluster_stats.clear();
this->global_timestep_bit_history.clear();
this->elapsed_time = 0.0;
this->lsd_method = osd::OsdMethod::COMBINATION_SWEEP;
this->lsd_order = 0;
this->bit_llrs = {};
this->error = {};
this->syndrome= {};
this->compare_recover = {};
}

[[nodiscard]] std::string toString() const {
Expand Down Expand Up @@ -552,6 +559,10 @@ namespace ldpc::lsd {
osd::OsdMethod lsd_method;
int lsd_order;

void reset_cluster_stats(){
this->statistics.clear();
}

explicit LsdDecoder(ldpc::bp::BpSparse &parity_check_matrix,
osd::OsdMethod lsdMethod = osd::OsdMethod::COMBINATION_SWEEP,
int lsd_order = 0) : pcm(parity_check_matrix),
Expand Down Expand Up @@ -615,6 +626,7 @@ namespace ldpc::lsd {
const bool is_on_the_fly = true) {
auto start_time = std::chrono::high_resolution_clock::now();
this->statistics.clear();
this->statistics.syndrome = syndrome;

fill(this->decoding.begin(), this->decoding.end(), 0);

Expand Down
18 changes: 18 additions & 0 deletions src_python/ldpc/bplsd_decoder/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,21 @@ class BpLsdDecoder(BpDecoderBase):
"""


def set_additional_stat_fields(self, error, syndrome, compare_recover) -> None:
"""
Sets additional fields to be collected in the statistics.
Parameters
----------
fields : List[str]
A list of strings representing the additional fields to be collected in the statistics.
"""


def reset_cluster_stats(self) -> None:
"""
Resets cluster statistics of the decoder.
Note that this also resets the additional stat fields, such as the error, and compare_recovery vectors
"""

1 change: 1 addition & 0 deletions src_python/ldpc/bplsd_decoder/_bplsd_decoder.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ cdef extern from "lsd.hpp" namespace "ldpc::lsd":
bool get_do_stats()
void set_do_stats(bool do_stats)
void set_additional_stat_fields(vector[int] error, vector[int] syndrome, vector[int] compare_recover)
void reset_cluster_stats()

cdef class BpLsdDecoder(BpDecoderBase):
cdef LsdDecoderCpp* lsd
Expand Down
12 changes: 11 additions & 1 deletion src_python/ldpc/bplsd_decoder/_bplsd_decoder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ cdef class BpLsdDecoder(BpDecoderBase):
self.bpd.decoding = self.bpd.decode(self._syndrome)
out = np.zeros(self.n,dtype=DTYPE)
if self.bpd.converge:
for i in range(self.n): out[i] = self.bpd.decoding[i]
for i in range(self.n):
out[i] = self.bpd.decoding[i]
self.lsd.reset_cluster_stats()


if not self.bpd.converge:
self.lsd.decoding = self.lsd.lsd_decode(self._syndrome, self.bpd.log_prob_ratios,self.bits_per_step, True)
Expand Down Expand Up @@ -291,3 +294,10 @@ cdef class BpLsdDecoder(BpDecoderBase):
self.lsd.statistics.error = error
self.lsd.statistics.syndrome = syndrome
self.lsd.statistics.compare_recover = compare_recover

def reset_cluster_stats(self) -> None:
"""
Resets cluster statistics of the decoder.
Note that this also resets the additional stat fields, such as the error, and compare_recovery vectors
"""
self.lsd.reset_cluster_stats()

0 comments on commit 4ad8af3

Please sign in to comment.