From 59a2322e3572ecb132bc328f62c15866d40c5d43 Mon Sep 17 00:00:00 2001 From: Pierre Marchand Date: Tue, 16 Jul 2024 19:32:52 +0200 Subject: [PATCH] refactor utilities for distributed_operator --- .../htool/distributed_operator/utility.hpp | 37 +++++++++++++------ .../hmatrix/tree_builder/tree_builder.hpp | 6 +++ 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/include/htool/distributed_operator/utility.hpp b/include/htool/distributed_operator/utility.hpp index 3ffdd2c6..eb8cac7a 100644 --- a/include/htool/distributed_operator/utility.hpp +++ b/include/htool/distributed_operator/utility.hpp @@ -8,7 +8,7 @@ namespace htool { template -class DefaultApproximationBuilder { +class DistributedOperatorFromHMatrix { private: const PartitionFromCluster target_partition, source_partition; std::function get_rankWorld = [](MPI_Comm comm) { @@ -26,36 +26,49 @@ class DefaultApproximationBuilder { DistributedOperator distributed_operator; const HMatrix *block_diagonal_hmatrix{nullptr}; - DefaultApproximationBuilder(const VirtualGenerator &generator, const Cluster &target_cluster, const Cluster &source_cluster, htool::underlying_type epsilon, htool::underlying_type eta, char symmetry, char UPLO, MPI_Comm communicator) : target_partition(target_cluster), source_partition(source_cluster), hmatrix(HMatrixTreeBuilder(target_cluster, source_cluster, epsilon, eta, symmetry, UPLO, -1, get_rankWorld(communicator), get_rankWorld(communicator)).build(generator)), local_hmatrix(hmatrix, target_cluster.get_cluster_on_partition(get_rankWorld(communicator)), source_cluster, symmetry, UPLO, false, false), distributed_operator(target_partition, source_partition, symmetry, UPLO, communicator) { + DistributedOperatorFromHMatrix(const VirtualGenerator &generator, const Cluster &target_cluster, const Cluster &source_cluster, HMatrixTreeBuilder &hmatrix_builder, MPI_Comm communicator) : target_partition(target_cluster), source_partition(source_cluster), hmatrix(hmatrix_builder.build(generator)), local_hmatrix(hmatrix, hmatrix_builder.get_target_cluster().get_cluster_on_partition(get_rankWorld(communicator)), hmatrix_builder.get_source_cluster(), hmatrix_builder.get_symmetry(), hmatrix_builder.get_UPLO(), false, false), distributed_operator(target_partition, source_partition, hmatrix_builder.get_symmetry(), hmatrix_builder.get_UPLO(), communicator) { distributed_operator.add_local_operator(&local_hmatrix); - block_diagonal_hmatrix = hmatrix.get_sub_hmatrix(target_cluster.get_cluster_on_partition(get_rankWorld(communicator)), source_cluster.get_cluster_on_partition(get_rankWorld(communicator))); + block_diagonal_hmatrix = hmatrix.get_sub_hmatrix(hmatrix_builder.get_target_cluster().get_cluster_on_partition(get_rankWorld(communicator)), hmatrix_builder.get_source_cluster().get_cluster_on_partition(get_rankWorld(communicator))); } }; template -class DefaultLocalApproximationBuilder { +class DefaultApproximationBuilder { private: - const PartitionFromCluster target_partition, source_partition; std::function get_rankWorld = [](MPI_Comm comm) { int rankWorld; MPI_Comm_rank(comm, &rankWorld); return rankWorld; }; + DistributedOperatorFromHMatrix distributed_operator_builder; public: - const HMatrix hmatrix; + const HMatrix &hmatrix; + public: + DistributedOperator &distributed_operator; + const HMatrix *block_diagonal_hmatrix{nullptr}; + + DefaultApproximationBuilder(const VirtualGenerator &generator, const Cluster &target_cluster, const Cluster &source_cluster, htool::underlying_type epsilon, htool::underlying_type eta, char symmetry, char UPLO, MPI_Comm communicator) : distributed_operator_builder(generator, target_cluster, source_cluster, HMatrixTreeBuilder(target_cluster, source_cluster, epsilon, eta, symmetry, UPLO, -1, get_rankWorld(communicator), get_rankWorld(communicator)), communicator), hmatrix(distributed_operator_builder.hmatrix), distributed_operator(distributed_operator_builder.distributed_operator), block_diagonal_hmatrix(distributed_operator_builder.block_diagonal_hmatrix) {} +}; + +template +class DefaultLocalApproximationBuilder { private: - const LocalHMatrix local_hmatrix; + std::function get_rankWorld = [](MPI_Comm comm) { + int rankWorld; + MPI_Comm_rank(comm, &rankWorld); + return rankWorld; }; + DistributedOperatorFromHMatrix distributed_operator_builder; public: - DistributedOperator distributed_operator; + const HMatrix &hmatrix; + + public: + DistributedOperator &distributed_operator; const HMatrix *block_diagonal_hmatrix{nullptr}; public: - DefaultLocalApproximationBuilder(const VirtualGenerator &generator, const Cluster &target_cluster, const Cluster &source_cluster, htool::underlying_type epsilon, htool::underlying_type eta, char symmetry, char UPLO, MPI_Comm communicator) : target_partition(target_cluster), source_partition(source_cluster), hmatrix(HMatrixTreeBuilder(target_cluster.get_cluster_on_partition(get_rankWorld(communicator)), source_cluster.get_cluster_on_partition(get_rankWorld(communicator)), epsilon, eta, symmetry, UPLO, -1, -1, -1).build(generator)), local_hmatrix(hmatrix, target_cluster.get_cluster_on_partition(get_rankWorld(communicator)), source_cluster.get_cluster_on_partition(get_rankWorld(communicator)), symmetry, UPLO, false, false), distributed_operator(target_partition, source_partition, symmetry, UPLO, communicator) { - distributed_operator.add_local_operator(&local_hmatrix); - block_diagonal_hmatrix = hmatrix.get_sub_hmatrix(target_cluster.get_cluster_on_partition(get_rankWorld(communicator)), source_cluster.get_cluster_on_partition(get_rankWorld(communicator))); - } + DefaultLocalApproximationBuilder(const VirtualGenerator &generator, const Cluster &target_cluster, const Cluster &source_cluster, htool::underlying_type epsilon, htool::underlying_type eta, char symmetry, char UPLO, MPI_Comm communicator) : distributed_operator_builder(generator, target_cluster, source_cluster, HMatrixTreeBuilder(target_cluster.get_cluster_on_partition(get_rankWorld(communicator)), source_cluster.get_cluster_on_partition(get_rankWorld(communicator)), epsilon, eta, symmetry, UPLO, -1, get_rankWorld(communicator), get_rankWorld(communicator)), communicator), hmatrix(distributed_operator_builder.hmatrix), distributed_operator(distributed_operator_builder.distributed_operator), block_diagonal_hmatrix(distributed_operator_builder.block_diagonal_hmatrix) {} }; } // namespace htool diff --git a/include/htool/hmatrix/tree_builder/tree_builder.hpp b/include/htool/hmatrix/tree_builder/tree_builder.hpp index d236f289..d82fd90b 100644 --- a/include/htool/hmatrix/tree_builder/tree_builder.hpp +++ b/include/htool/hmatrix/tree_builder/tree_builder.hpp @@ -131,6 +131,12 @@ class HMatrixTreeBuilder { void set_minimal_source_depth(int minimal_source_depth) { m_minsourcedepth = minimal_source_depth; } void set_minimal_target_depth(int minimal_target_depth) { m_mintargetdepth = minimal_target_depth; } void set_dense_blocks_generator(std::shared_ptr> dense_blocks_generator) { m_dense_blocks_generator = dense_blocks_generator; } + + // Getters + char get_symmetry() { return m_symmetry_type; } + char get_UPLO() { return m_UPLO_type; } + const Cluster &get_target_cluster() { return m_target_root_cluster; } + const Cluster &get_source_cluster() { return m_source_root_cluster; } }; template