From 57083b2ebe6265ed3c5438e9b923cad602b9198c Mon Sep 17 00:00:00 2001 From: lucas Date: Wed, 31 Jan 2024 15:31:56 +0100 Subject: [PATCH] minor code issues --- cpp_test/TestUnionFind.cpp | 4 +- python_test/test_qcodes.py | 6 +- src_cpp/union_find.hpp | 884 ++++++++---------- .../_belief_find_decoder.pxd | 2 +- .../ldpc/bplsd_decoder/_bplsd_decoder.pxd | 4 +- .../_union_find_decoder.pxd | 2 +- .../_union_find_decoder.pyx | 4 +- 7 files changed, 412 insertions(+), 494 deletions(-) diff --git a/cpp_test/TestUnionFind.cpp b/cpp_test/TestUnionFind.cpp index 3c45686..26aad77 100644 --- a/cpp_test/TestUnionFind.cpp +++ b/cpp_test/TestUnionFind.cpp @@ -112,7 +112,7 @@ TEST(UfDecoder, ring_code3){ auto syndrome = vector{1, 0, 1}; - auto decoding = bpd.peel_decode(syndrome, ldpc::uf::NULL_DOUBLE_VECTOR,3); + auto decoding = bpd.peel_decode(syndrome, ldpc::uf::EMPTY_DOUBLE_VECTOR, 3); ASSERT_TRUE(bpd.pcm_max_bit_degree_2); tsl::robin_set boundary_bits = {}; @@ -133,7 +133,7 @@ TEST(UfDecoder, rep_code){ auto syndrome = vector(n,0); syndrome[0] = 1; syndrome[n-2] = 1; - auto decoding = bpd.peel_decode(syndrome, ldpc::uf::NULL_DOUBLE_VECTOR,1); + auto decoding = bpd.peel_decode(syndrome, ldpc::uf::EMPTY_DOUBLE_VECTOR, 1); auto expected_decoding = vector(n,0); expected_decoding[0] = 1; expected_decoding[n-1] = 1; diff --git a/python_test/test_qcodes.py b/python_test/test_qcodes.py index 2805a91..aa33bc2 100644 --- a/python_test/test_qcodes.py +++ b/python_test/test_qcodes.py @@ -61,8 +61,8 @@ def quantum_mc_sim(hx, lx, error_rate, run_count, seed, DECODER, run_label, DEBU def test_400_16_6_hgp(): - hx = scipy.sparse.load_npz("/home/luca/Documents/codeRepos/ldpc/python_test/pcms/hx_400_16_6.npz") - lx = scipy.sparse.load_npz("/home/luca/Documents/codeRepos/ldpc/python_test/pcms/lx_400_16_6.npz") + hx = scipy.sparse.load_npz("pcms/hx_400_16_6.npz") + lx = scipy.sparse.load_npz("pcms/lx_400_16_6.npz") error_rate = 0.03 run_count = 1000 @@ -99,7 +99,7 @@ def test_400_16_6_hgp(): decoder = BpLsdDecoder(hx, error_rate=error_rate, max_iter=max_iter, bp_method="ms", ms_scaling_factor=0.625, schedule="parallel", bits_per_step=1, lsd_order=0) ler, min_logical, speed, _ = quantum_mc_sim(hx, lx, error_rate, run_count, seed, decoder, - f"Min-sum LSD parallel schedule osd={osd_order}") + f"Min-sum LSD-0 parallel schedule") decoder = BpLsdDecoder(hx, error_rate=error_rate, max_iter=5, bp_method="ms", ms_scaling_factor=0.625, schedule="parallel", bits_per_step=1, lsd_order=osd_order) diff --git a/src_cpp/union_find.hpp b/src_cpp/union_find.hpp index 544cd33..d316c5b 100644 --- a/src_cpp/union_find.hpp +++ b/src_cpp/union_find.hpp @@ -18,455 +18,388 @@ #include "gf2sparse_linalg.hpp" #include "bp.hpp" -namespace ldpc::uf{ - -const std::vector NULL_DOUBLE_VECTOR = {}; -tsl::robin_set NULL_INT_ROBIN_SET = {}; - -std::vector sort_indices(std::vector& B){ - std::vector indices(B.size()); - std::iota(indices.begin(),indices.end(),0); - std::sort(indices.begin(), indices.end(), [&](int i, int j) { return B[i] < B[j];}); - return indices; -} - -struct Cluster{ - ldpc::bp::BpSparse& pcm; - tsl::robin_set& planar_code_boundary_bits; - int cluster_id; - bool contains_boundary_bits; - bool active; - bool valid; - tsl::robin_set bit_nodes; - tsl::robin_set check_nodes; - tsl::robin_set boundary_check_nodes; - std::vector candidate_bit_nodes; - tsl::robin_set enclosed_syndromes; - tsl::robin_map spanning_tree_check_roots; - tsl::robin_set spanning_tree_bits; - tsl::robin_set spanning_tree_leaf_nodes; - int spanning_tree_boundary_bit; - - Cluster** global_check_membership; - Cluster** global_bit_membership; - tsl::robin_set merge_list; - - std::vector cluster_decoding; - std::vector matrix_to_cluster_bit_map; - tsl::robin_map cluster_to_matrix_bit_map; - std::vector matrix_to_cluster_check_map; - tsl::robin_map cluster_to_matrix_check_map; - - Cluster() = default; - - Cluster(ldpc::bp::BpSparse& parity_check_matrix, int syndrome_index, Cluster** ccm, Cluster** bcm, tsl::robin_set& planar_code_boundary_bits = NULL_INT_ROBIN_SET): - pcm(parity_check_matrix), planar_code_boundary_bits(planar_code_boundary_bits){ - - this->active=true; - this->valid=false; - this->cluster_id = syndrome_index; - this->boundary_check_nodes.insert(syndrome_index); - this->check_nodes.insert(syndrome_index); - this->enclosed_syndromes.insert(syndrome_index); - this->global_check_membership = ccm; - this->global_bit_membership = bcm; - this->global_check_membership[syndrome_index]=this; - this->contains_boundary_bits = false; - this->spanning_tree_boundary_bit = -1; +namespace ldpc::uf { + const std::vector EMPTY_DOUBLE_VECTOR = {}; + tsl::robin_set EMPTY_INT_ROBIN_SET = {}; - - } - ~Cluster(){ - this->bit_nodes.clear(); - this->check_nodes.clear(); - this->boundary_check_nodes.clear(); - this->candidate_bit_nodes.clear(); - this->enclosed_syndromes.clear(); - this->merge_list.clear(); + std::vector sort_indices(std::vector &B) { + std::vector indices(B.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int i, int j) { return B[i] < B[j]; }); + return indices; } - int parity(){ - return this->enclosed_syndromes.size() % 2; - } - - void get_candidate_bit_nodes(){ - - std::vector erase_boundary_check; - this->candidate_bit_nodes.clear(); - - for(int check_index: boundary_check_nodes){ - bool erase = true; - for(auto& e: this->pcm.iterate_row(check_index)){ - if(this->global_bit_membership[e.col_index] != this ){ - candidate_bit_nodes.push_back(e.col_index); - erase = false; + struct Cluster { + ldpc::bp::BpSparse &pcm; + tsl::robin_set &planar_code_boundary_bits; + int cluster_id; + bool contains_boundary_bits; + bool active; + bool valid; + tsl::robin_set bit_nodes; + tsl::robin_set check_nodes; + tsl::robin_set boundary_check_nodes; + std::vector candidate_bit_nodes; + tsl::robin_set enclosed_syndromes; + tsl::robin_map spanning_tree_check_roots; + tsl::robin_set spanning_tree_bits; + tsl::robin_set spanning_tree_leaf_nodes; + int spanning_tree_boundary_bit; + Cluster **global_check_membership; + Cluster **global_bit_membership; + tsl::robin_set merge_list; + std::vector cluster_decoding; + std::vector matrix_to_cluster_bit_map; + tsl::robin_map cluster_to_matrix_bit_map; + std::vector matrix_to_cluster_check_map; + tsl::robin_map cluster_to_matrix_check_map; + + Cluster() = default; + + Cluster(ldpc::bp::BpSparse &parity_check_matrix, int syndrome_index, Cluster **ccm, Cluster **bcm, + tsl::robin_set &planar_code_boundary_bits = EMPTY_INT_ROBIN_SET) : + pcm(parity_check_matrix), planar_code_boundary_bits(planar_code_boundary_bits) { + this->active = true; + this->valid = false; + this->cluster_id = syndrome_index; + this->boundary_check_nodes.insert(syndrome_index); + this->check_nodes.insert(syndrome_index); + this->enclosed_syndromes.insert(syndrome_index); + this->global_check_membership = ccm; + this->global_bit_membership = bcm; + this->global_check_membership[syndrome_index] = this; + this->contains_boundary_bits = false; + this->spanning_tree_boundary_bit = -1; + } + + ~Cluster() { + this->bit_nodes.clear(); + this->check_nodes.clear(); + this->boundary_check_nodes.clear(); + this->candidate_bit_nodes.clear(); + this->enclosed_syndromes.clear(); + this->merge_list.clear(); + } + + int parity() { + return this->enclosed_syndromes.size() % 2; + } + + void get_candidate_bit_nodes() { + std::vector erase_boundary_check; + this->candidate_bit_nodes.clear(); + for (int check_index: boundary_check_nodes) { + bool erase = true; + for (auto &e: this->pcm.iterate_row(check_index)) { + if (this->global_bit_membership[e.col_index] != this) { + candidate_bit_nodes.push_back(e.col_index); + erase = false; + } + } + if (erase) { + erase_boundary_check.push_back(check_index); } } - if(erase) erase_boundary_check.push_back(check_index); - } - - - - for(int check_index: erase_boundary_check){ - this->boundary_check_nodes.erase(check_index); - } - - } - - int add_bit_node_to_cluster(int bit_index){ - - auto bit_membership = this->global_bit_membership[bit_index]; - if(bit_membership == this) return 0; //if the bit is already in the cluster terminate. - else if(bit_membership == NULL){ - //if the bit has not yet been assigned to a cluster we add it. - this->bit_nodes.insert(bit_index); - this->global_bit_membership[bit_index] = this; - } - else{ - //if the bit already exists in a cluster, we mark down that this cluster should be - //merged with the exisiting cluster. - this->merge_list.insert(bit_membership); - this->global_bit_membership[bit_index] = this; - } - - if(this->planar_code_boundary_bits.contains(bit_index)){ - this->contains_boundary_bits = true; + for (int check_index: erase_boundary_check) { + this->boundary_check_nodes.erase(check_index); + } + } + + int add_bit_node_to_cluster(int bit_index) { + auto bit_membership = this->global_bit_membership[bit_index]; + if (bit_membership == this) return 0; //if the bit is already in the cluster terminate. + else if (bit_membership == NULL) { + //if the bit has not yet been assigned to a cluster we add it. + this->bit_nodes.insert(bit_index); + this->global_bit_membership[bit_index] = this; + } else { + //if the bit already exists in a cluster, we mark down that this cluster should be + //merged with the exisiting cluster. + this->merge_list.insert(bit_membership); + this->global_bit_membership[bit_index] = this; + } + if (this->planar_code_boundary_bits.contains(bit_index)) { + this->contains_boundary_bits = true; + } + for (auto &e: this->pcm.iterate_column(bit_index)) { + int check_index = e.row_index; + auto check_membership = this->global_check_membership[check_index]; + if (check_membership == this) continue; + else if (check_membership == NULL) { + this->check_nodes.insert(check_index); + this->boundary_check_nodes.insert(check_index); + this->global_check_membership[check_index] = this; + } else { + this->check_nodes.insert(check_index); + this->boundary_check_nodes.insert(check_index); + this->merge_list.insert(check_membership); + this->global_check_membership[check_index] = this; + } + } + return 1; } - for(auto& e: this->pcm.iterate_column(bit_index)){ - int check_index = e.row_index; - auto check_membership = this->global_check_membership[check_index]; - if(check_membership == this) continue; - else if (check_membership == NULL){ + void merge_with_cluster(Cluster *cl2) { + for (auto bit_index: cl2->bit_nodes) { + this->bit_nodes.insert(bit_index); + this->global_bit_membership[bit_index] = this; + } + for (auto check_index: cl2->check_nodes) { this->check_nodes.insert(check_index); - this->boundary_check_nodes.insert(check_index); this->global_check_membership[check_index] = this; } - else{ - this->check_nodes.insert(check_index); + for (auto check_index: cl2->boundary_check_nodes) { this->boundary_check_nodes.insert(check_index); - this->merge_list.insert(check_membership); - this->global_check_membership[check_index] = this; + } + if (cl2->contains_boundary_bits) { + this->contains_boundary_bits = true; + } + cl2->active = false; + for (auto j: cl2->enclosed_syndromes) { + this->enclosed_syndromes.insert(j); } } - return 1; - - } - - void merge_with_cluster(Cluster* cl2){ - - for(auto bit_index: cl2->bit_nodes){ - this->bit_nodes.insert(bit_index); - this->global_bit_membership[bit_index] = this; - } - - for(auto check_index: cl2->check_nodes){ - this->check_nodes.insert(check_index); - this->global_check_membership[check_index] = this; - } - - for(auto check_index: cl2->boundary_check_nodes){ - this->boundary_check_nodes.insert(check_index); - } - - if(cl2->contains_boundary_bits == true){ - this->contains_boundary_bits = true; - } - - cl2->active = false; - for(auto j: cl2->enclosed_syndromes){ - this->enclosed_syndromes.insert(j); - } - } - - int grow_cluster(const std::vector& bit_weights = NULL_DOUBLE_VECTOR, int bits_per_step = 0){ - if(!this->active) return 0; - - this->get_candidate_bit_nodes(); - - - this->merge_list.clear(); - - if(bit_weights == NULL_DOUBLE_VECTOR){ - for(int bit_index: this->candidate_bit_nodes){ - this->add_bit_node_to_cluster(bit_index); + int grow_cluster(const std::vector &bit_weights = EMPTY_DOUBLE_VECTOR, int bits_per_step = 0) { + if (!this->active) { + return 0; } - } + this->get_candidate_bit_nodes(); + this->merge_list.clear(); - else{ - std::vector cluster_bit_weights; - for(int bit: this->candidate_bit_nodes){ - cluster_bit_weights.push_back(bit_weights[bit]); + if (bit_weights == EMPTY_DOUBLE_VECTOR) { + for (int bit_index: this->candidate_bit_nodes) { + this->add_bit_node_to_cluster(bit_index); + } + } else { + std::vector cluster_bit_weights; + for (auto bit: this->candidate_bit_nodes) { + cluster_bit_weights.push_back(bit_weights[bit]); + } + auto sorted_indices = sort_indices(cluster_bit_weights); + int count = 0; + for (auto i: sorted_indices) { + if (count == bits_per_step) break; + int bit_index = this->candidate_bit_nodes[i]; + this->add_bit_node_to_cluster(bit_index); + count++; + } } - auto sorted_indices = sort_indices(cluster_bit_weights); - int count = 0; - for(int i: sorted_indices){ - if(count == bits_per_step) break; - int bit_index = this->candidate_bit_nodes[i]; - this->add_bit_node_to_cluster(bit_index); - count++; + for (auto cl: merge_list) { + this->merge_with_cluster(cl); + cl->active = false; } - - } - - for(auto cl: merge_list){ - this->merge_with_cluster(cl); - cl->active = false; + return 1; } - return 1; - } - int find_spanning_tree_parent(int check_index){ + int find_spanning_tree_parent(const int check_index) { int parent = this->spanning_tree_check_roots[check_index]; - if(parent != check_index){ + if (parent != check_index) { return find_spanning_tree_parent(parent); + } else { + return parent; } - else return parent; } - void find_spanning_tree(){ - - this->spanning_tree_bits.clear(); - this->spanning_tree_check_roots.clear(); - this->spanning_tree_leaf_nodes.clear(); - - for(int bit_index: this->bit_nodes){ - this->spanning_tree_bits.insert(bit_index); - } - - // add the virtual boundary check - if(this->contains_boundary_bits == true){ - this->check_nodes.insert(-1); - } - - for(int check_index: this->check_nodes){ - this->spanning_tree_check_roots[check_index] = check_index; - } - - - - int check_neighbours[2]; - for(int bit_index: this->bit_nodes){ - check_neighbours[0] = this->pcm.column_heads[bit_index]->up->row_index; - check_neighbours[1] = this->pcm.column_heads[bit_index]->down->row_index; - - if(check_neighbours[0] == check_neighbours[1]){ - check_neighbours[1] = -1; //set the first check neighbour to the boundary check. - this->spanning_tree_boundary_bit = bit_index; + void find_spanning_tree() { + this->spanning_tree_bits.clear(); + this->spanning_tree_check_roots.clear(); + this->spanning_tree_leaf_nodes.clear(); + for (int bit_index: this->bit_nodes) { + this->spanning_tree_bits.insert(bit_index); } - - int root0 = this->find_spanning_tree_parent(check_neighbours[0]); - int root1 = this->find_spanning_tree_parent(check_neighbours[1]); - - if(root0!=root1){ - this->spanning_tree_check_roots[root1] = root0; + // add the virtual boundary check + if (this->contains_boundary_bits) { + this->check_nodes.insert(-1); } - else{ - this->spanning_tree_bits.erase(bit_index); + for (int check_index: this->check_nodes) { + this->spanning_tree_check_roots[check_index] = check_index; } - } + int check_neighbours[2]; - - for(int check_index: this->check_nodes){ - if(check_index == -1){ - this->spanning_tree_leaf_nodes.insert(check_index); + for (int bit_index: this->bit_nodes) { + check_neighbours[0] = this->pcm.column_heads[bit_index]->up->row_index; + check_neighbours[1] = this->pcm.column_heads[bit_index]->down->row_index; + if (check_neighbours[0] == check_neighbours[1]) { + check_neighbours[1] = -1; //set the first check neighbour to the boundary check. + this->spanning_tree_boundary_bit = bit_index; + } + int root0 = this->find_spanning_tree_parent(check_neighbours[0]); + int root1 = this->find_spanning_tree_parent(check_neighbours[1]); + if (root0 != root1) { + this->spanning_tree_check_roots[root1] = root0; + } else { + this->spanning_tree_bits.erase(bit_index); + } } - else{ - int spanning_tree_connectivity = 0; - for(auto& e: this->pcm.iterate_row(check_index)){ - if(this->spanning_tree_bits.contains(e.col_index)){ - spanning_tree_connectivity+=1; + for (int check_index: this->check_nodes) { + if (check_index == -1) { + this->spanning_tree_leaf_nodes.insert(check_index); + } else { + int spanning_tree_connectivity = 0; + for (auto &e: this->pcm.iterate_row(check_index)) { + if (this->spanning_tree_bits.contains(e.col_index)) { + spanning_tree_connectivity += 1; + } + } + if (spanning_tree_connectivity == 1) { + this->spanning_tree_leaf_nodes.insert(check_index); } - } - if(spanning_tree_connectivity == 1){ - this->spanning_tree_leaf_nodes.insert(check_index); } } } - } - - std::vector peel_decode(const std::vector& syndrome){ - - std::vector erasure; - tsl::robin_set synds; - for(auto check_index: check_nodes){ - if(syndrome[check_index] == 1) synds.insert(check_index); - } - if(this->contains_boundary_bits == true && this->parity() == 1 ){ - synds.insert(-1); - } - - this->find_spanning_tree(); - - while(synds.size()>0){ - - int leaf_node_index = *(this->spanning_tree_leaf_nodes.begin()); - int bit_index; - int check2 = -1; //we assume it is a boundary node at first. - - if(leaf_node_index == -1){ - bit_index = this->spanning_tree_boundary_bit; + std::vector peel_decode(const std::vector &syndrome) { + std::vector erasure; + tsl::robin_set synds; + for (auto check_index: check_nodes) { + if (syndrome[check_index] == 1) synds.insert(check_index); } - - else{ - for(auto& e: this->pcm.iterate_row(leaf_node_index)){ - bit_index = e.col_index; - if(this->spanning_tree_bits.contains(bit_index)) break; - } + if (this->contains_boundary_bits == true && this->parity() == 1) { + synds.insert(-1); } + this->find_spanning_tree(); + while (!synds.empty()) { + int leaf_node_index = *(this->spanning_tree_leaf_nodes.begin()); + int bit_index; + int check2 = -1; //we assume it is a boundary node at first. - for(auto& e: this->pcm.iterate_column(bit_index)){ - if(e.row_index!=leaf_node_index) { - check2 = e.row_index; + if (leaf_node_index == -1) { + bit_index = this->spanning_tree_boundary_bit; + } else { + for (auto &e: this->pcm.iterate_row(leaf_node_index)) { + bit_index = e.col_index; + if (this->spanning_tree_bits.contains(bit_index)) break; + } } - } - - - if(synds.contains(leaf_node_index)){ - this->spanning_tree_leaf_nodes.erase(leaf_node_index); - // this->spanning_tree_leaf_nodes.insert(check2); - erasure.push_back(bit_index); - this->spanning_tree_bits.erase(bit_index); - if(synds.contains(check2)) synds.erase(check2); - else synds.insert(check2); - synds.erase(leaf_node_index); - } - else{ - this->spanning_tree_leaf_nodes.erase(leaf_node_index); - // this->spanning_tree_leaf_nodes.insert(check2); - this->spanning_tree_bits.erase(bit_index); - } - - //check whether new check node is a leaf - if(check2 == -1){ - this->spanning_tree_leaf_nodes.insert(check2); - } - else{ - int spanning_tree_connectivity = 0; - for(auto& e: this->pcm.iterate_row(check2)){ - if(this->spanning_tree_bits.contains(e.col_index)){ - spanning_tree_connectivity+=1; + for (auto &e: this->pcm.iterate_column(bit_index)) { + if (e.row_index != leaf_node_index) { + check2 = e.row_index; } } - if(spanning_tree_connectivity == 1){ + + if (synds.contains(leaf_node_index)) { + this->spanning_tree_leaf_nodes.erase(leaf_node_index); + // this->spanning_tree_leaf_nodes.insert(check2); + erasure.push_back(bit_index); + this->spanning_tree_bits.erase(bit_index); + if (synds.contains(check2)) synds.erase(check2); + else synds.insert(check2); + synds.erase(leaf_node_index); + } else { + this->spanning_tree_leaf_nodes.erase(leaf_node_index); + // this->spanning_tree_leaf_nodes.insert(check2); + this->spanning_tree_bits.erase(bit_index); + } + //check whether new check node is a leaf + if (check2 == -1) { this->spanning_tree_leaf_nodes.insert(check2); + } else { + int spanning_tree_connectivity = 0; + for (auto &e: this->pcm.iterate_row(check2)) { + if (this->spanning_tree_bits.contains(e.col_index)) { + spanning_tree_connectivity += 1; + } + } + if (spanning_tree_connectivity == 1) { + this->spanning_tree_leaf_nodes.insert(check2); + } } } - + return erasure; } - return erasure; - } - - ldpc::bp::BpSparse convert_to_matrix(const std::vector& bit_weights = NULL_DOUBLE_VECTOR){ - - this->matrix_to_cluster_bit_map.clear(); - this->matrix_to_cluster_check_map.clear(); - this->cluster_to_matrix_bit_map.clear(); - this->cluster_to_matrix_check_map.clear(); + ldpc::bp::BpSparse convert_to_matrix(const std::vector &bit_weights = EMPTY_DOUBLE_VECTOR) { + this->matrix_to_cluster_bit_map.clear(); + this->matrix_to_cluster_check_map.clear(); + this->cluster_to_matrix_bit_map.clear(); + this->cluster_to_matrix_check_map.clear(); - - if(bit_weights!=NULL_DOUBLE_VECTOR){ - std::vector cluster_bit_weights; - std::vector bit_nodes_temp; - for(int bit: this->bit_nodes){ - cluster_bit_weights.push_back(bit_weights[bit]); - bit_nodes_temp.push_back(bit); + if (bit_weights != EMPTY_DOUBLE_VECTOR) { + std::vector cluster_bit_weights; + std::vector bit_nodes_temp; + for (int bit: this->bit_nodes) { + cluster_bit_weights.push_back(bit_weights[bit]); + bit_nodes_temp.push_back(bit); + } + auto sorted_indices = sort_indices(cluster_bit_weights); + int count = 0; + for (auto i: sorted_indices) { + int bit_index = bit_nodes_temp[i]; + this->matrix_to_cluster_bit_map.push_back(bit_index); + this->cluster_to_matrix_bit_map[bit_index] = count; + count++; + } + } else { + int count = 0; + for (auto bit_index: this->bit_nodes) { + this->matrix_to_cluster_bit_map.push_back(bit_index); + this->cluster_to_matrix_bit_map[bit_index] = count; + count++; + } } - auto sorted_indices = sort_indices(cluster_bit_weights); int count = 0; - for(int i: sorted_indices){ - int bit_index = bit_nodes_temp[i]; - this->matrix_to_cluster_bit_map.push_back(bit_index); - this->cluster_to_matrix_bit_map[bit_index] = count; - count++; - } - } - else{ - int count = 0; - for(int bit_index: this->bit_nodes){ - this->matrix_to_cluster_bit_map.push_back(bit_index); - this->cluster_to_matrix_bit_map[bit_index] = count; + for (auto check_index: this->check_nodes) { + this->matrix_to_cluster_check_map.push_back(check_index); + this->cluster_to_matrix_check_map[check_index] = count; count++; } + auto cluster_pcm = ldpc::bp::BpSparse(this->check_nodes.size(), this->bit_nodes.size()); - } - - int count = 0; - - for(int check_index: this->check_nodes){ - this->matrix_to_cluster_check_map.push_back(check_index); - this->cluster_to_matrix_check_map[check_index] = count; - count++; - } - - auto cluster_pcm = ldpc::bp::BpSparse(this->check_nodes.size(),this->bit_nodes.size()); - - for(int check_index: this->check_nodes){ - for(auto& e: this->pcm.iterate_row(check_index)){ - int bit_index = e.col_index; - if(this->bit_nodes.contains(bit_index)){ - int matrix_bit_index = cluster_to_matrix_bit_map[bit_index]; - int matrix_check_index = cluster_to_matrix_check_map[check_index]; - cluster_pcm.insert_entry(matrix_check_index,matrix_bit_index); + for (int check_index: this->check_nodes) { + for (auto &e: this->pcm.iterate_row(check_index)) { + int bit_index = e.col_index; + if (this->bit_nodes.contains(bit_index)) { + int matrix_bit_index = cluster_to_matrix_bit_map[bit_index]; + int matrix_check_index = cluster_to_matrix_check_map[check_index]; + cluster_pcm.insert_entry(matrix_check_index, matrix_bit_index); + } } } + return cluster_pcm; } - return cluster_pcm; - - } + std::vector invert_decode(const std::vector &syndrome, const std::vector &bit_weights) { + auto cluster_pcm = this->convert_to_matrix(bit_weights); + std::vector cluster_syndrome; + for (auto check_index: check_nodes) { + cluster_syndrome.push_back(syndrome[check_index]); + } + auto rr = ldpc::gf2sparse_linalg::RowReduce(cluster_pcm); + auto cluster_solution = rr.fast_solve(cluster_syndrome); + auto candidate_cluster_syndrome = cluster_pcm.mulvec(cluster_solution); + bool equal = true; - std::vector invert_decode(const std::vector& syndrome, const std::vector& bit_weights){ - - auto cluster_pcm = this->convert_to_matrix(bit_weights); - - std::vector cluster_syndrome; - for(int check_index: check_nodes){ - cluster_syndrome.push_back(syndrome[check_index]); - } - - auto rr = ldpc::gf2sparse_linalg::RowReduce(cluster_pcm); - auto cluster_solution = rr.fast_solve(cluster_syndrome); - - auto candidate_cluster_syndrome = cluster_pcm.mulvec(cluster_solution); - - bool equal = true; - for(int i =0; icluster_decoding.clear(); - this->valid = equal; - for(int i = 0; icluster_decoding.push_back(this->matrix_to_cluster_bit_map[i]); + this->cluster_decoding.clear(); + this->valid = equal; + for (auto i = 0; i < cluster_solution.size(); i++) { + if (cluster_solution[i] == 1) { + this->cluster_decoding.push_back(this->matrix_to_cluster_bit_map[i]); + } } + return this->cluster_decoding; } - return this->cluster_decoding; - - } - - void print(); + void print(); + }; -}; - -class UfDecoder{ + class UfDecoder { private: bool weighted; - ldpc::bp::BpSparse& pcm; + ldpc::bp::BpSparse &pcm; public: std::vector decoding; @@ -474,176 +407,161 @@ class UfDecoder{ int check_count; tsl::robin_set planar_code_boundary_bits; bool pcm_max_bit_degree_2; - UfDecoder(ldpc::bp::BpSparse& parity_check_matrix): pcm(parity_check_matrix){ + + UfDecoder(ldpc::bp::BpSparse &parity_check_matrix) : pcm(parity_check_matrix) { this->bit_count = pcm.n; this->check_count = pcm.m; this->decoding.resize(this->bit_count); this->weighted = false; this->pcm_max_bit_degree_2 = true; - for(int i = 0; ipcm.n; i++){ + for (auto i = 0; i < this->pcm.n; i++) { int col_weight = this->pcm.get_col_degree(i); - if(col_weight > 2){ + if (col_weight > 2) { this->pcm_max_bit_degree_2 = false; - } - else if(col_weight == 1){ + } else if (col_weight == 1) { this->planar_code_boundary_bits.insert(i); - } - else if(col_weight == 0){ - throw(std::runtime_error("Invalid parity check matrix. Column weight is zero.")); + } else if (col_weight == 0) { + throw (std::runtime_error("Invalid parity check matrix. Column weight is zero.")); } } - - } - std::vector& peel_decode(const std::vector& syndrome, const std::vector& bit_weights = NULL_DOUBLE_VECTOR, int bits_per_step = 1){ + std::vector & + peel_decode(const std::vector &syndrome, const std::vector &bit_weights = EMPTY_DOUBLE_VECTOR, + int bits_per_step = 1) { - if(!this->pcm_max_bit_degree_2){ - throw(std::runtime_error("Peel decoder only works for planar codes. Use the matrix_decode method for more general codes.")); + if (!this->pcm_max_bit_degree_2) { + throw (std::runtime_error( + "Peel decoder only works for planar codes. Use the matrix_decode method for more general codes.")); } - fill(this->decoding.begin(), this->decoding.end(), 0); - - std::vector clusters; - std::vector invalid_clusters; - Cluster** global_bit_membership = new Cluster*[pcm.n](); - Cluster** global_check_membership = new Cluster*[pcm.m](); - - for(int i =0; ipcm.m; i++){ - if(syndrome[i] == 1){ - Cluster* cl = new Cluster(this->pcm, i, global_check_membership, global_bit_membership, this->planar_code_boundary_bits); + std::vector clusters; + std::vector invalid_clusters; + auto **global_bit_membership = new Cluster *[pcm.n](); + auto **global_check_membership = new Cluster *[pcm.m](); + + for (int i = 0; i < this->pcm.m; i++) { + if (syndrome[i] == 1) { + auto *cl = new Cluster(this->pcm, i, global_check_membership, global_bit_membership, + this->planar_code_boundary_bits); clusters.push_back(cl); invalid_clusters.push_back(cl); } } - while(invalid_clusters.size()>0){ - - for(auto cl: invalid_clusters){ - if(cl->active){ - cl->grow_cluster(bit_weights,bits_per_step); + while (!invalid_clusters.empty()) { + for (auto cl: invalid_clusters) { + if (cl->active) { + cl->grow_cluster(bit_weights, bits_per_step); } } invalid_clusters.clear(); - for(auto cl: clusters){ - if(cl->active == true && cl->parity() == 1 && cl->contains_boundary_bits == false){ + for (auto cl: clusters) { + if (cl->active && cl->parity() == 1 && !cl->contains_boundary_bits) { invalid_clusters.push_back(cl); } } - - std::sort(invalid_clusters.begin(), invalid_clusters.end(), [](const Cluster* lhs, const Cluster* rhs){return lhs->bit_nodes.size() < rhs->bit_nodes.size();}); - + std::sort(invalid_clusters.begin(), invalid_clusters.end(), [](const Cluster *lhs, const Cluster *rhs) { + return lhs->bit_nodes.size() < rhs->bit_nodes.size(); + }); } - - - for(auto cl: clusters){ - if(cl->active){ + for (auto cl: clusters) { + if (cl->active) { auto erasure = cl->peel_decode(syndrome); - - for(int bit: erasure) this->decoding[bit] = 1; + for (int bit: erasure) this->decoding[bit] = 1; } delete cl; } - delete[] global_bit_membership; delete[] global_check_membership; - return this->decoding; - } - std::vector& matrix_decode(const std::vector& syndrome, const std::vector& bit_weights = NULL_DOUBLE_VECTOR, int bits_per_step = 1){ - + std::vector & + matrix_decode(const std::vector &syndrome, + const std::vector &bit_weights = EMPTY_DOUBLE_VECTOR, + int bits_per_step = 1) { fill(this->decoding.begin(), this->decoding.end(), 0); - std::vector clusters; - std::vector invalid_clusters; - Cluster** global_bit_membership = new Cluster*[pcm.n](); - Cluster** global_check_membership = new Cluster*[pcm.m](); + std::vector clusters; + std::vector invalid_clusters; + auto **global_bit_membership = new Cluster *[pcm.n](); + auto **global_check_membership = new Cluster *[pcm.m](); - for(int i =0; ipcm.m; i++){ - if(syndrome[i] == 1){ - Cluster* cl = new Cluster(this->pcm, i, global_check_membership, global_bit_membership); + for (auto i = 0; i < this->pcm.m; i++) { + if (syndrome[i] == 1) { + auto *cl = new Cluster(this->pcm, i, global_check_membership, global_bit_membership); clusters.push_back(cl); invalid_clusters.push_back(cl); } } - while(invalid_clusters.size()>0){ - - for(auto cl: invalid_clusters){ - if(cl->active){ - cl->grow_cluster(bit_weights,bits_per_step); - auto cluster_decoding = cl->invert_decode(syndrome,bit_weights); + while (!invalid_clusters.empty()) { + for (auto cl: invalid_clusters) { + if (cl->active) { + cl->grow_cluster(bit_weights, bits_per_step); + auto cluster_decoding = cl->invert_decode(syndrome, bit_weights); } } invalid_clusters.clear(); - for(auto cl: clusters){ - if(cl->active == true && cl->valid == false){ + for (auto cl: clusters) { + if (cl->active && !cl->valid) { invalid_clusters.push_back(cl); } } - - std::sort(invalid_clusters.begin(), invalid_clusters.end(), [](const Cluster* lhs, const Cluster* rhs){return lhs->bit_nodes.size() < rhs->bit_nodes.size();}); - + std::sort(invalid_clusters.begin(), invalid_clusters.end(), [](const Cluster *lhs, const Cluster *rhs) { + return lhs->bit_nodes.size() < rhs->bit_nodes.size(); + }); } - for(auto cl: clusters){ - if(cl->active){ - for(int bit: cl->cluster_decoding) this->decoding[bit] = 1; + for (auto cl: clusters) { + if (cl->active) { + for (int bit: cl->cluster_decoding) this->decoding[bit] = 1; } delete cl; } - delete[] global_bit_membership; delete[] global_check_membership; - return this->decoding; - } - - - -}; - -void Cluster::print(){ - std::cout<<"........."<cluster_id<active<enclosed_syndromes) std::cout<bit_nodes) std::cout<check_nodes) std::cout<candidate_bit_nodes) std::cout<boundary_check_nodes) std::cout<spanning_tree_bits) std::cout<spanning_tree_leaf_nodes) std::cout<contains_boundary_bits<planar_code_boundary_bits) std::cout<cluster_id << std::endl; + std::cout << "Active: " << this->active << std::endl; + std::cout << "Enclosed syndromes: "; + for (auto i: this->enclosed_syndromes) std::cout << i << " "; + std::cout << std::endl; + std::cout << "Cluster bits: "; + for (auto i: this->bit_nodes) std::cout << i << " "; + std::cout << std::endl; + std::cout << "Cluster checks: "; + for (auto i: this->check_nodes) std::cout << i << " "; + std::cout << std::endl; + std::cout << "Candidate bits: "; + for (auto i: this->candidate_bit_nodes) std::cout << i << " "; + std::cout << std::endl; + std::cout << "Boundary Checks: "; + for (auto i: this->boundary_check_nodes) std::cout << i << " "; + std::cout << std::endl; + std::cout << "Spanning tree bits: "; + for (auto i: this->spanning_tree_bits) std::cout << i << " "; + std::cout << std::endl; + std::cout << "Spanning tree leaf nodes: "; + for (auto i: this->spanning_tree_leaf_nodes) std::cout << i << " "; + std::cout << std::endl; + std::cout << "Contains boundary bits: " << this->contains_boundary_bits << std::endl; + std::cout << "Boundary bits: "; + for (auto i: this->planar_code_boundary_bits) std::cout << i << " "; + std::cout << std::endl; + std::cout << "........." << std::endl; } - - }//end namespace uf #endif \ No newline at end of file diff --git a/src_python/ldpc/belief_find_decoder/_belief_find_decoder.pxd b/src_python/ldpc/belief_find_decoder/_belief_find_decoder.pxd index 85e02fe..38c3347 100644 --- a/src_python/ldpc/belief_find_decoder/_belief_find_decoder.pxd +++ b/src_python/ldpc/belief_find_decoder/_belief_find_decoder.pxd @@ -15,7 +15,7 @@ cdef extern from "union_find.hpp" namespace "ldpc::uf": vector[uint8_t]& matrix_decode(vector[uint8_t]& syndrome, const vector[double]& bit_weights, int bits_per_step) vector[uint8_t] decoding - cdef const vector[double] NULL_DOUBLE_VECTOR + cdef const vector[double] EMPTY_DOUBLE_VECTOR cdef class BeliefFindDecoder(BpDecoderBase): cdef uf_decoder_cpp* ufd diff --git a/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pxd b/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pxd index 1fa32c5..ebbb9cd 100644 --- a/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pxd +++ b/src_python/ldpc/bplsd_decoder/_bplsd_decoder.pxd @@ -9,11 +9,11 @@ ctypedef np.uint8_t uint8_t cdef extern from "lsd.hpp" namespace "ldpc::lsd": - cdef const vector[double] NULL_DOUBLE_VECTOR "ldpc::lsd::NULL_DOUBLE_VECTOR" + cdef const vector[double] EMPTY_DOUBLE_VECTOR "ldpc::lsd::EMPTY_DOUBLE_VECTOR" cdef cppclass lsd_decoder_cpp "ldpc::lsd::LsdDecoder": lsd_decoder_cpp(BpSparse& pcm) except + - # vector[uint8_t]& on_the_fly_decode(vector[uint8_t]& syndrome, const vector[double]& bit_weights = NULL_DOUBLE_VECTOR) + # vector[uint8_t]& on_the_fly_decode(vector[uint8_t]& syndrome, const vector[double]& bit_weights = EMPTY_DOUBLE_VECTOR) vector[uint8_t]& lsd_decode(vector[uint8_t]& syndrome, const vector[double]& bit_weights, int bits_per_step, bool on_the_fly_decode, int lsd_order) vector[uint8_t] decoding vector[int] cluster_size_stats diff --git a/src_python/ldpc/union_find_decoder/_union_find_decoder.pxd b/src_python/ldpc/union_find_decoder/_union_find_decoder.pxd index 0bcce64..98bad54 100644 --- a/src_python/ldpc/union_find_decoder/_union_find_decoder.pxd +++ b/src_python/ldpc/union_find_decoder/_union_find_decoder.pxd @@ -15,7 +15,7 @@ cdef extern from "union_find.hpp" namespace "ldpc::uf": vector[uint8_t]& matrix_decode(vector[uint8_t]& syndrome, const vector[double]& bit_weights, int bits_per_step) vector[uint8_t] decoding - cdef const vector[double] NULL_DOUBLE_VECTOR "ldpc::uf::NULL_DOUBLE_VECTOR" + cdef const vector[double] EMPTY_DOUBLE_VECTOR "ldpc::uf::EMPTY_DOUBLE_VECTOR" cdef class UnionFindDecoder(): cdef int m diff --git a/src_python/ldpc/union_find_decoder/_union_find_decoder.pyx b/src_python/ldpc/union_find_decoder/_union_find_decoder.pyx index 304bb9f..8bb97b8 100644 --- a/src_python/ldpc/union_find_decoder/_union_find_decoder.pyx +++ b/src_python/ldpc/union_find_decoder/_union_find_decoder.pyx @@ -146,14 +146,14 @@ cdef class UnionFindDecoder: if llrs is not None: self.ufd.decoding = self.ufd.matrix_decode(self._syndrome, self.uf_llrs,self.bits_per_step) else: - self.ufd.decoding = self.ufd.matrix_decode(self._syndrome, NULL_DOUBLE_VECTOR,self.bits_per_step) + self.ufd.decoding = self.ufd.matrix_decode(self._syndrome, EMPTY_DOUBLE_VECTOR,self.bits_per_step) else: if llrs is not None: self.ufd.decoding = self.ufd.peel_decode(self._syndrome, self.uf_llrs,self.bits_per_step) else: - self.ufd.decoding = self.ufd.peel_decode(self._syndrome, NULL_DOUBLE_VECTOR,self.bits_per_step) + self.ufd.decoding = self.ufd.peel_decode(self._syndrome, EMPTY_DOUBLE_VECTOR,self.bits_per_step) out = np.zeros(self.n,dtype=DTYPE)