Skip to content

Commit

Permalink
Merge pull request #1150 from hschreiber/matrix_r_reduction_fix
Browse files Browse the repository at this point in the history
[Persistence Matrix] bug fix in R reduction
  • Loading branch information
VincentRouvreau authored Nov 7, 2024
2 parents 139e285 + 795c235 commit 29f4312
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -913,24 +913,28 @@ inline void Chain_matrix<Master_matrix>::print() const
{
std::cout << "Column Matrix:\n";
if constexpr (!Master_matrix::Option_list::has_map_column_container) {
for (ID_index i = 0; i < pivotToColumnIndex_.size() && pivotToColumnIndex_[i] != static_cast<Index>(-1); ++i) {
for (ID_index i = 0; i < pivotToColumnIndex_.size(); ++i) {
Index pos = pivotToColumnIndex_[i];
const Column& col = matrix_[pos];
for (const auto& entry : col) {
std::cout << entry.get_row_index() << " ";
if (pos != static_cast<Index>(-1)){
const Column& col = matrix_[pos];
for (const auto& entry : col) {
std::cout << entry.get_row_index() << " ";
}
std::cout << "(" << i << ", " << pos << ")\n";
}
std::cout << "(" << i << ", " << pos << ")\n";
}
if constexpr (Master_matrix::Option_list::has_row_access) {
std::cout << "\n";
std::cout << "Row Matrix:\n";
for (ID_index i = 0; i < pivotToColumnIndex_.size() && pivotToColumnIndex_[i] != static_cast<Index>(-1); ++i) {
for (ID_index i = 0; i < pivotToColumnIndex_.size(); ++i) {
Index pos = pivotToColumnIndex_[i];
const Row& row = RA_opt::get_row(pos);
for (const auto& entry : row) {
std::cout << entry.get_column_index() << " ";
if (pos != static_cast<Index>(-1)){
const Row& row = RA_opt::get_row(pos);
for (const auto& entry : row) {
std::cout << entry.get_column_index() << " ";
}
std::cout << "(" << i << ", " << pos << ")\n";
}
std::cout << "(" << i << ", " << pos << ")\n";
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,11 @@ inline const typename Base_pairing<Master_matrix>::Barcode& Base_pairing<Master_
template <class Master_matrix>
inline void Base_pairing<Master_matrix>::_reduce()
{
std::unordered_map<ID_index, Index> pivotsToColumn(_matrix()->get_number_of_columns());
std::unordered_map<Index, Index> negativeColumns(_matrix()->get_number_of_columns());

auto dim = _matrix()->get_max_dimension();
std::vector<std::vector<Index> > columnsByDim(dim + 1);
for (auto& v : columnsByDim) v.reserve(_matrix()->get_number_of_columns());
for (unsigned int i = 0; i < _matrix()->get_number_of_columns(); i++) {
columnsByDim[dim - _matrix()->get_column_dimension(i)].push_back(i);
}
Expand All @@ -144,17 +145,21 @@ inline void Base_pairing<Master_matrix>::_reduce()
for (Index i : cols) {
auto& curr = _matrix()->get_column(i);
if (curr.is_empty()) {
if (pivotsToColumn.find(i) == pivotsToColumn.end()) {
if (negativeColumns.find(i) == negativeColumns.end()) {
barcode_.emplace_back(i, -1, dim);
}
} else {
ID_index pivot = curr.get_pivot();
auto it = idToPosition_.find(pivot);
Index pivotColumnNumber = it == idToPosition_.end() ? pivot : it->second;
auto itNeg = negativeColumns.find(pivotColumnNumber);
Index pivotKiller = itNeg == negativeColumns.end() ? -1 : itNeg->second;

while (pivot != static_cast<ID_index>(-1) && pivotsToColumn.find(pivot) != pivotsToColumn.end()) {
while (pivot != static_cast<ID_index>(-1) && pivotKiller != static_cast<Index>(-1)) {
if constexpr (Master_matrix::Option_list::is_z2) {
curr += _matrix()->get_column(pivotsToColumn.at(pivot));
curr += _matrix()->get_column(pivotKiller);
} else {
auto& toadd = _matrix()->get_column(pivotsToColumn.at(pivot));
auto& toadd = _matrix()->get_column(pivotKiller);
typename Master_matrix::Element coef = toadd.get_pivot_value();
auto& operators = _matrix()->colSettings_->operators;
coef = operators.get_inverse(coef);
Expand All @@ -163,12 +168,14 @@ inline void Base_pairing<Master_matrix>::_reduce()
}

pivot = curr.get_pivot();
it = idToPosition_.find(pivot);
pivotColumnNumber = it == idToPosition_.end() ? pivot : it->second;
itNeg = negativeColumns.find(pivotColumnNumber);
pivotKiller = itNeg == negativeColumns.end() ? -1 : itNeg->second;
}

if (pivot != static_cast<ID_index>(-1)) {
pivotsToColumn.emplace(pivot, i);
auto it = idToPosition_.find(pivot);
auto pivotColumnNumber = it == idToPosition_.end() ? pivot : it->second;
negativeColumns.emplace(pivotColumnNumber, i);
_matrix()->get_column(pivotColumnNumber).clear();
barcode_.emplace_back(pivotColumnNumber, i, dim - 1);
} else {
Expand Down
225 changes: 220 additions & 5 deletions src/Persistence_matrix/test/pm_matrix_tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -1348,8 +1348,9 @@ template <class Matrix>
void test_barcode() {
struct BarComp {
bool operator()(const std::tuple<int, int, int>& c1, const std::tuple<int, int, int>& c2) const {
if (std::get<0>(c1) == std::get<0>(c2)) return std::get<1>(c1) < std::get<1>(c2);
return std::get<0>(c1) < std::get<0>(c2);
if (std::get<0>(c1) != std::get<0>(c2)) return std::get<0>(c1) < std::get<0>(c2);
if (std::get<1>(c1) != std::get<1>(c2)) return std::get<1>(c1) < std::get<1>(c2);
return std::get<2>(c1) < std::get<2>(c2);
}
};

Expand Down Expand Up @@ -1420,12 +1421,220 @@ void test_barcode() {
}

template <class Matrix>
void test_shifted_barcode() {
void test_shifted_barcode1() {
using C = typename Matrix::Column;
struct BarComp {
bool operator()(const std::tuple<int, int, int>& c1, const std::tuple<int, int, int>& c2) const {
if (std::get<0>(c1) != std::get<0>(c2)) return std::get<0>(c1) < std::get<0>(c2);
if (std::get<1>(c1) != std::get<1>(c2)) return std::get<1>(c1) < std::get<1>(c2);
return std::get<2>(c1) < std::get<2>(c2);
}
};

Matrix m(17, 2);
if constexpr (is_z2<C>()) {
m.insert_boundary(0, {}, 0);
m.insert_boundary(1, {}, 0);
m.insert_boundary(2, {}, 0);
m.insert_boundary(3, {}, 0);
m.insert_boundary(4, {}, 0);
m.insert_boundary(5, {}, 0);
m.insert_boundary(6, {}, 0);
m.insert_boundary(10, {0, 1}, 1);
m.insert_boundary(11, {1, 3}, 1);
m.insert_boundary(12, {2, 3}, 1);
m.insert_boundary(13, {2, 4}, 1);
m.insert_boundary(14, {3, 4}, 1);
m.insert_boundary(15, {2, 6}, 1);
m.insert_boundary(16, {4, 6}, 1);
m.insert_boundary(17, {5, 6}, 1);
m.insert_boundary(30, {12, 13, 14}, 2);
m.insert_boundary(31, {13, 15, 16}, 2);
} else {
m.insert_boundary(0, {}, 0);
m.insert_boundary(1, {}, 0);
m.insert_boundary(2, {}, 0);
m.insert_boundary(3, {}, 0);
m.insert_boundary(4, {}, 0);
m.insert_boundary(5, {}, 0);
m.insert_boundary(6, {}, 0);
m.insert_boundary(10, {{0, 1}, {1, 1}}, 1);
m.insert_boundary(11, {{1, 1}, {3, 1}}, 1);
m.insert_boundary(12, {{2, 1}, {3, 1}}, 1);
m.insert_boundary(13, {{2, 1}, {4, 1}}, 1);
m.insert_boundary(14, {{3, 1}, {4, 1}}, 1);
m.insert_boundary(15, {{2, 1}, {6, 1}}, 1);
m.insert_boundary(16, {{4, 1}, {6, 1}}, 1);
m.insert_boundary(17, {{5, 1}, {6, 1}}, 1);
m.insert_boundary(30, {{12, 1}, {13, 1}, {14, 1}}, 2);
m.insert_boundary(31, {{13, 1}, {15, 1}, {16, 1}}, 2);
}

const auto& barcode = m.get_current_barcode();

std::vector<witness_content<C> > reducedMatrix;
if constexpr (is_z2<C>()) {
if constexpr (Matrix::Option_list::is_of_boundary_type) {
reducedMatrix.emplace_back();
reducedMatrix.emplace_back();
reducedMatrix.emplace_back();
reducedMatrix.emplace_back();
reducedMatrix.emplace_back();
reducedMatrix.emplace_back();
reducedMatrix.emplace_back();
reducedMatrix.push_back({0, 1});
reducedMatrix.push_back({1, 3});
reducedMatrix.push_back({1, 2});
reducedMatrix.push_back({2, 4});
reducedMatrix.emplace_back();
reducedMatrix.push_back({2, 6});
reducedMatrix.emplace_back();
reducedMatrix.push_back({2, 5});
reducedMatrix.push_back({12, 13, 14});
reducedMatrix.push_back({13, 15, 16});
} else {
reducedMatrix.push_back({0});
reducedMatrix.push_back({0, 1});
reducedMatrix.push_back({0, 2});
reducedMatrix.push_back({0, 3});
reducedMatrix.push_back({0, 4});
reducedMatrix.push_back({0, 5});
reducedMatrix.push_back({0, 6});
reducedMatrix.push_back({10});
reducedMatrix.push_back({10, 11});
reducedMatrix.push_back({10, 11, 12});
reducedMatrix.push_back({10, 11, 12, 13});
reducedMatrix.push_back({12, 13, 14});
reducedMatrix.push_back({10, 11, 12, 15});
reducedMatrix.push_back({13, 15, 16});
reducedMatrix.push_back({10, 11, 12, 15, 17});
reducedMatrix.push_back({30});
reducedMatrix.push_back({31});
}
} else {
if constexpr (Matrix::Option_list::is_of_boundary_type) {
reducedMatrix.emplace_back();
reducedMatrix.emplace_back();
reducedMatrix.emplace_back();
reducedMatrix.emplace_back();
reducedMatrix.emplace_back();
reducedMatrix.emplace_back();
reducedMatrix.emplace_back();
reducedMatrix.push_back({{0, 1}, {1, 1}});
reducedMatrix.push_back({{1, 1}, {3, 1}});
reducedMatrix.push_back({{1, 1}, {2, 1}});
reducedMatrix.push_back({{2, 1}, {4, 1}});
reducedMatrix.emplace_back();
reducedMatrix.push_back({{2, 1}, {6, 1}});
reducedMatrix.emplace_back();
reducedMatrix.push_back({{2, 1}, {5, 1}});
reducedMatrix.push_back({{12, 1}, {13, 1}, {14, 1}});
reducedMatrix.push_back({{13, 1}, {15, 1}, {16, 1}});
} else {
reducedMatrix.push_back({{0, 1}});
reducedMatrix.push_back({{0, 1}, {1, 1}});
reducedMatrix.push_back({{0, 1}, {2, 1}});
reducedMatrix.push_back({{0, 1}, {3, 1}});
reducedMatrix.push_back({{0, 1}, {4, 1}});
reducedMatrix.push_back({{0, 1}, {5, 1}});
reducedMatrix.push_back({{0, 1}, {6, 1}});
reducedMatrix.push_back({{10, 1}});
reducedMatrix.push_back({{10, 1}, {11, 1}});
reducedMatrix.push_back({{10, 1}, {11, 1}, {12, 1}});
reducedMatrix.push_back({{10, 1}, {11, 1}, {12, 1}, {13, 1}});
reducedMatrix.push_back({{12, 1}, {13, 1}, {14, 1}});
reducedMatrix.push_back({{10, 1}, {11, 1}, {12, 1}, {15, 1}});
reducedMatrix.push_back({{13, 1}, {15, 1}, {16, 1}});
reducedMatrix.push_back({{10, 1}, {11, 1}, {12, 1}, {15, 1}, {17, 1}});
reducedMatrix.push_back({{30, 1}});
reducedMatrix.push_back({{31, 1}});
}
}

if constexpr (Matrix::Option_list::column_indexation_type == Column_indexation_types::IDENTIFIER){
test_column_equality<C>(reducedMatrix[0], get_column_content_via_iterators(m.get_column(0)));
test_column_equality<C>(reducedMatrix[1], get_column_content_via_iterators(m.get_column(1)));
test_column_equality<C>(reducedMatrix[2], get_column_content_via_iterators(m.get_column(2)));
test_column_equality<C>(reducedMatrix[3], get_column_content_via_iterators(m.get_column(3)));
test_column_equality<C>(reducedMatrix[4], get_column_content_via_iterators(m.get_column(4)));
test_column_equality<C>(reducedMatrix[5], get_column_content_via_iterators(m.get_column(5)));
test_column_equality<C>(reducedMatrix[6], get_column_content_via_iterators(m.get_column(6)));
test_column_equality<C>(reducedMatrix[7], get_column_content_via_iterators(m.get_column(10)));
test_column_equality<C>(reducedMatrix[8], get_column_content_via_iterators(m.get_column(11)));
test_column_equality<C>(reducedMatrix[9], get_column_content_via_iterators(m.get_column(12)));
test_column_equality<C>(reducedMatrix[10], get_column_content_via_iterators(m.get_column(13)));
test_column_equality<C>(reducedMatrix[11], get_column_content_via_iterators(m.get_column(14)));
test_column_equality<C>(reducedMatrix[12], get_column_content_via_iterators(m.get_column(15)));
test_column_equality<C>(reducedMatrix[13], get_column_content_via_iterators(m.get_column(16)));
test_column_equality<C>(reducedMatrix[14], get_column_content_via_iterators(m.get_column(17)));
test_column_equality<C>(reducedMatrix[15], get_column_content_via_iterators(m.get_column(30)));
test_column_equality<C>(reducedMatrix[16], get_column_content_via_iterators(m.get_column(31)));
} else {
test_content_equality(reducedMatrix, m);
}

std::set<std::tuple<int, int, int>, BarComp> bars1;
std::set<std::tuple<int, int, int>, BarComp> bars2;
std::set<std::tuple<int, int, int>, BarComp> bars3;
// bars are not ordered the same for all matrices
for (auto it = barcode.begin(); it != barcode.end(); ++it) {
//three access possibilities
bars1.emplace(it->dim, it->birth, it->death);
bars2.emplace(std::get<2>(*it), std::get<0>(*it), std::get<1>(*it));
auto [ x, y, z ] = *it;
bars3.emplace(z, x, y);
}
auto it = bars1.begin();
BOOST_CHECK_EQUAL(std::get<0>(*it), 0);
BOOST_CHECK_EQUAL(std::get<1>(*it), 0);
BOOST_CHECK_EQUAL(std::get<2>(*it), -1);
++it;
BOOST_CHECK_EQUAL(std::get<0>(*it), 0);
BOOST_CHECK_EQUAL(std::get<1>(*it), 1);
BOOST_CHECK_EQUAL(std::get<2>(*it), 7);
++it;
BOOST_CHECK_EQUAL(std::get<0>(*it), 0);
BOOST_CHECK_EQUAL(std::get<1>(*it), 2);
BOOST_CHECK_EQUAL(std::get<2>(*it), 9);
++it;
BOOST_CHECK_EQUAL(std::get<0>(*it), 0);
BOOST_CHECK_EQUAL(std::get<1>(*it), 3);
BOOST_CHECK_EQUAL(std::get<2>(*it), 8);
++it;
BOOST_CHECK_EQUAL(std::get<0>(*it), 0);
BOOST_CHECK_EQUAL(std::get<1>(*it), 4);
BOOST_CHECK_EQUAL(std::get<2>(*it), 10);
++it;
BOOST_CHECK_EQUAL(std::get<0>(*it), 0);
BOOST_CHECK_EQUAL(std::get<1>(*it), 5);
BOOST_CHECK_EQUAL(std::get<2>(*it), 14);
++it;
BOOST_CHECK_EQUAL(std::get<0>(*it), 0);
BOOST_CHECK_EQUAL(std::get<1>(*it), 6);
BOOST_CHECK_EQUAL(std::get<2>(*it), 12);
++it;
BOOST_CHECK_EQUAL(std::get<0>(*it), 1);
BOOST_CHECK_EQUAL(std::get<1>(*it), 11);
BOOST_CHECK_EQUAL(std::get<2>(*it), 15);
++it;
BOOST_CHECK_EQUAL(std::get<0>(*it), 1);
BOOST_CHECK_EQUAL(std::get<1>(*it), 13);
BOOST_CHECK_EQUAL(std::get<2>(*it), 16);
++it;
BOOST_CHECK(it == bars1.end());

BOOST_CHECK(bars1 == bars2);
BOOST_CHECK(bars1 == bars3);
}

template <class Matrix>
void test_shifted_barcode2() {
using C = typename Matrix::Column;
struct BarComp {
bool operator()(const std::tuple<int, int, int>& c1, const std::tuple<int, int, int>& c2) const {
if (std::get<0>(c1) == std::get<0>(c2)) return std::get<1>(c1) < std::get<1>(c2);
return std::get<0>(c1) < std::get<0>(c2);
if (std::get<0>(c1) != std::get<0>(c2)) return std::get<0>(c1) < std::get<0>(c2);
if (std::get<1>(c1) != std::get<1>(c2)) return std::get<1>(c1) < std::get<1>(c2);
return std::get<2>(c1) < std::get<2>(c2);
}
};

Expand Down Expand Up @@ -1553,6 +1762,12 @@ void test_shifted_barcode() {
BOOST_CHECK(bars1 == bars3);
}

template <class Matrix>
void test_shifted_barcode() {
test_shifted_barcode1<Matrix>();
test_shifted_barcode2<Matrix>();
}

template <class Matrix>
void test_base_swaps() {
auto columns = build_simple_boundary_matrix<typename Matrix::Column>();
Expand Down

0 comments on commit 29f4312

Please sign in to comment.