Skip to content

Commit

Permalink
improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreMarchand20 committed Aug 8, 2024
1 parent 738da88 commit 140adb5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 61 deletions.
63 changes: 3 additions & 60 deletions src/htool/hmatrix/hmatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,66 +40,9 @@ void declare_HMatrix(py::module &m, const std::string &className) {

py_class.def("__deepcopy__", [](const Class &self, py::dict) { return Class(self); }, "memo"_a);

// // Getters
// py_class.def_property_readonly("shape", [](const Class &self) {
// return std::array<int, 2>{self.nb_rows(), self.nb_cols()};
// });
// py_class.def("get_perm_t", overload_cast_<>()(&Class::get_permt, py::const_));
// py_class.def("get_perm_s", overload_cast_<>()(&Class::get_perms, py::const_));
// py_class.def("get_MasterOffset_t", overload_cast_<>()(&Class::get_MasterOffset_t, py::const_));
// py_class.def("get_MasterOffset_s", overload_cast_<>()(&Class::get_MasterOffset_s, py::const_));

// // Linear algebra
// py_class.def("__mul__", [](const Class &self, std::vector<T> b) {
// return self * b;
// });
// py_class.def("matvec", [](const Class &self, std::vector<T> b) {
// return self * b;
// });
// py_class.def("__matmul__", [](const Class &self, py::array_t<T, py::array::f_style | py::array::forcecast> B) {
// int mu;

// if (B.ndim() == 1) {
// mu = 1;
// } else if (B.ndim() == 2) {
// mu = B.shape()[1];
// } else {
// throw std::runtime_error("Wrong dimension for HMatrix-matrix product"); // LCOV_EXCL_LINE
// }
// if (B.shape()[0] != self.nb_cols()) {
// throw std::runtime_error("Wrong size for HMatrix-matrix product"); // LCOV_EXCL_LINE
// }

// std::vector<T> result(self.nb_rows() * mu, 0);

// self.mvprod_global_to_global(B.data(), result.data(), mu);

// if (B.ndim() == 1) {
// std::array<long int, 1> shape{self.nb_rows()};
// return py::array_t<T, py::array::f_style>(shape, result.data());
// } else {
// std::array<long int, 2> shape{self.nb_rows(), mu};
// return py::array_t<T, py::array::f_style>(shape, result.data());
// }
// });

py_class.def(
"get_sub_hmatrix", [](const HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix, const Cluster<CoordinatePrecision> &target_cluster, const Cluster<CoordinatePrecision> &source_cluster) {
return &*hmatrix.get_sub_hmatrix(target_cluster, source_cluster);
},
py::return_value_policy::reference_internal);
py_class.def("get_tree_parameters", [](const HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix) {
auto tree_parameters = htool::get_tree_parameters(hmatrix);
return tree_parameters;
});
py_class.def("get_local_information", [](const HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix) {
auto information = htool::get_hmatrix_information(hmatrix);
return information;
});
py_class.def("get_distributed_information", [](const HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix, MPI_Comm_wrapper comm) {
auto information = htool::get_distributed_hmatrix_information(hmatrix, comm);
return information;
});
py_class.def("get_tree_parameters", [](const HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix) { return htool::get_tree_parameters(hmatrix); });
py_class.def("get_local_information", [](const HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix) { return htool::get_hmatrix_information(hmatrix); });
py_class.def("get_distributed_information", [](const HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix, MPI_Comm_wrapper comm) { return htool::get_distributed_hmatrix_information(hmatrix, comm); });
}

#endif
19 changes: 18 additions & 1 deletion tests/test_distributed_operator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import Htool
import matplotlib.pyplot as plt
import mpi4py
import numpy as np
import pytest

Expand Down Expand Up @@ -41,7 +44,21 @@ def test_distributed_operator(
default_distributed_operator_holder = default_distributed_operator
distributed_operator = default_distributed_operator_holder.distributed_operator
local_hmatrix = default_distributed_operator_holder.hmatrix
print(local_hmatrix.get_local_information())

hmatrix_distributed_information = local_hmatrix.get_distributed_information(
mpi4py.MPI.COMM_WORLD
)
hmatrix_tree_parameter = local_hmatrix.get_tree_parameters()
hmatrix_local_information = local_hmatrix.get_local_information()
if mpi4py.MPI.COMM_WORLD.rank == 0:
print(hmatrix_distributed_information)
print(hmatrix_local_information)
print(hmatrix_tree_parameter)

fig = plt.figure()
ax1 = fig.add_subplot(1, 1, 1)
Htool.plot(ax1, local_hmatrix)

else:
distributed_operator = custom_distributed_operator

Expand Down

0 comments on commit 140adb5

Please sign in to comment.