diff --git a/src/htool/hmatrix/hmatrix.hpp b/src/htool/hmatrix/hmatrix.hpp index 15c124f..899f134 100644 --- a/src/htool/hmatrix/hmatrix.hpp +++ b/src/htool/hmatrix/hmatrix.hpp @@ -40,66 +40,9 @@ void declare_HMatrix(py::module &m, const std::string &className) { py_class.def("__deepcopy__", [](const Class &self, py::dict) { return Class(self); }, "memo"_a); - // // Getters - // py_class.def_property_readonly("shape", [](const Class &self) { - // return std::array{self.nb_rows(), self.nb_cols()}; - // }); - // py_class.def("get_perm_t", overload_cast_<>()(&Class::get_permt, py::const_)); - // py_class.def("get_perm_s", overload_cast_<>()(&Class::get_perms, py::const_)); - // py_class.def("get_MasterOffset_t", overload_cast_<>()(&Class::get_MasterOffset_t, py::const_)); - // py_class.def("get_MasterOffset_s", overload_cast_<>()(&Class::get_MasterOffset_s, py::const_)); - - // // Linear algebra - // py_class.def("__mul__", [](const Class &self, std::vector b) { - // return self * b; - // }); - // py_class.def("matvec", [](const Class &self, std::vector b) { - // return self * b; - // }); - // py_class.def("__matmul__", [](const Class &self, py::array_t B) { - // int mu; - - // if (B.ndim() == 1) { - // mu = 1; - // } else if (B.ndim() == 2) { - // mu = B.shape()[1]; - // } else { - // throw std::runtime_error("Wrong dimension for HMatrix-matrix product"); // LCOV_EXCL_LINE - // } - // if (B.shape()[0] != self.nb_cols()) { - // throw std::runtime_error("Wrong size for HMatrix-matrix product"); // LCOV_EXCL_LINE - // } - - // std::vector result(self.nb_rows() * mu, 0); - - // self.mvprod_global_to_global(B.data(), result.data(), mu); - - // if (B.ndim() == 1) { - // std::array shape{self.nb_rows()}; - // return py::array_t(shape, result.data()); - // } else { - // std::array shape{self.nb_rows(), mu}; - // return py::array_t(shape, result.data()); - // } - // }); - - py_class.def( - "get_sub_hmatrix", [](const HMatrix &hmatrix, const Cluster &target_cluster, const Cluster &source_cluster) { - return &*hmatrix.get_sub_hmatrix(target_cluster, source_cluster); - }, - py::return_value_policy::reference_internal); - py_class.def("get_tree_parameters", [](const HMatrix &hmatrix) { - auto tree_parameters = htool::get_tree_parameters(hmatrix); - return tree_parameters; - }); - py_class.def("get_local_information", [](const HMatrix &hmatrix) { - auto information = htool::get_hmatrix_information(hmatrix); - return information; - }); - py_class.def("get_distributed_information", [](const HMatrix &hmatrix, MPI_Comm_wrapper comm) { - auto information = htool::get_distributed_hmatrix_information(hmatrix, comm); - return information; - }); + py_class.def("get_tree_parameters", [](const HMatrix &hmatrix) { return htool::get_tree_parameters(hmatrix); }); + py_class.def("get_local_information", [](const HMatrix &hmatrix) { return htool::get_hmatrix_information(hmatrix); }); + py_class.def("get_distributed_information", [](const HMatrix &hmatrix, MPI_Comm_wrapper comm) { return htool::get_distributed_hmatrix_information(hmatrix, comm); }); } #endif diff --git a/tests/test_distributed_operator.py b/tests/test_distributed_operator.py index dbabcd0..28b11a9 100644 --- a/tests/test_distributed_operator.py +++ b/tests/test_distributed_operator.py @@ -1,3 +1,6 @@ +import Htool +import matplotlib.pyplot as plt +import mpi4py import numpy as np import pytest @@ -41,7 +44,21 @@ def test_distributed_operator( default_distributed_operator_holder = default_distributed_operator distributed_operator = default_distributed_operator_holder.distributed_operator local_hmatrix = default_distributed_operator_holder.hmatrix - print(local_hmatrix.get_local_information()) + + hmatrix_distributed_information = local_hmatrix.get_distributed_information( + mpi4py.MPI.COMM_WORLD + ) + hmatrix_tree_parameter = local_hmatrix.get_tree_parameters() + hmatrix_local_information = local_hmatrix.get_local_information() + if mpi4py.MPI.COMM_WORLD.rank == 0: + print(hmatrix_distributed_information) + print(hmatrix_local_information) + print(hmatrix_tree_parameter) + + fig = plt.figure() + ax1 = fig.add_subplot(1, 1, 1) + Htool.plot(ax1, local_hmatrix) + else: distributed_operator = custom_distributed_operator