Skip to content

Commit

Permalink
refactor utilities for distributed_operator
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreMarchand20 committed Jul 16, 2024
1 parent 8de449c commit 59a2322
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
37 changes: 25 additions & 12 deletions include/htool/distributed_operator/utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
namespace htool {

template <typename CoefficientPrecision, typename CoordinatePrecision>
class DefaultApproximationBuilder {
class DistributedOperatorFromHMatrix {
private:
const PartitionFromCluster<CoefficientPrecision, CoordinatePrecision> target_partition, source_partition;
std::function<int(MPI_Comm)> get_rankWorld = [](MPI_Comm comm) {
Expand All @@ -26,36 +26,49 @@ class DefaultApproximationBuilder {
DistributedOperator<CoefficientPrecision> distributed_operator;
const HMatrix<CoefficientPrecision, CoordinatePrecision> *block_diagonal_hmatrix{nullptr};

DefaultApproximationBuilder(const VirtualGenerator<CoefficientPrecision> &generator, const Cluster<CoordinatePrecision> &target_cluster, const Cluster<CoordinatePrecision> &source_cluster, htool::underlying_type<CoefficientPrecision> epsilon, htool::underlying_type<CoefficientPrecision> eta, char symmetry, char UPLO, MPI_Comm communicator) : target_partition(target_cluster), source_partition(source_cluster), hmatrix(HMatrixTreeBuilder<CoefficientPrecision, CoordinatePrecision>(target_cluster, source_cluster, epsilon, eta, symmetry, UPLO, -1, get_rankWorld(communicator), get_rankWorld(communicator)).build(generator)), local_hmatrix(hmatrix, target_cluster.get_cluster_on_partition(get_rankWorld(communicator)), source_cluster, symmetry, UPLO, false, false), distributed_operator(target_partition, source_partition, symmetry, UPLO, communicator) {
DistributedOperatorFromHMatrix(const VirtualGenerator<CoefficientPrecision> &generator, const Cluster<CoordinatePrecision> &target_cluster, const Cluster<CoordinatePrecision> &source_cluster, HMatrixTreeBuilder<CoefficientPrecision, CoordinatePrecision> &hmatrix_builder, MPI_Comm communicator) : target_partition(target_cluster), source_partition(source_cluster), hmatrix(hmatrix_builder.build(generator)), local_hmatrix(hmatrix, hmatrix_builder.get_target_cluster().get_cluster_on_partition(get_rankWorld(communicator)), hmatrix_builder.get_source_cluster(), hmatrix_builder.get_symmetry(), hmatrix_builder.get_UPLO(), false, false), distributed_operator(target_partition, source_partition, hmatrix_builder.get_symmetry(), hmatrix_builder.get_UPLO(), communicator) {
distributed_operator.add_local_operator(&local_hmatrix);
block_diagonal_hmatrix = hmatrix.get_sub_hmatrix(target_cluster.get_cluster_on_partition(get_rankWorld(communicator)), source_cluster.get_cluster_on_partition(get_rankWorld(communicator)));
block_diagonal_hmatrix = hmatrix.get_sub_hmatrix(hmatrix_builder.get_target_cluster().get_cluster_on_partition(get_rankWorld(communicator)), hmatrix_builder.get_source_cluster().get_cluster_on_partition(get_rankWorld(communicator)));
}
};

template <typename CoefficientPrecision, typename CoordinatePrecision>
class DefaultLocalApproximationBuilder {
class DefaultApproximationBuilder {
private:
const PartitionFromCluster<CoefficientPrecision, CoordinatePrecision> target_partition, source_partition;
std::function<int(MPI_Comm)> get_rankWorld = [](MPI_Comm comm) {
int rankWorld;
MPI_Comm_rank(comm, &rankWorld);
return rankWorld; };
DistributedOperatorFromHMatrix<CoefficientPrecision, CoordinatePrecision> distributed_operator_builder;

public:
const HMatrix<CoefficientPrecision, CoordinatePrecision> hmatrix;
const HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix;

public:
DistributedOperator<CoefficientPrecision> &distributed_operator;
const HMatrix<CoefficientPrecision, CoordinatePrecision> *block_diagonal_hmatrix{nullptr};

DefaultApproximationBuilder(const VirtualGenerator<CoefficientPrecision> &generator, const Cluster<CoordinatePrecision> &target_cluster, const Cluster<CoordinatePrecision> &source_cluster, htool::underlying_type<CoefficientPrecision> epsilon, htool::underlying_type<CoefficientPrecision> eta, char symmetry, char UPLO, MPI_Comm communicator) : distributed_operator_builder(generator, target_cluster, source_cluster, HMatrixTreeBuilder<CoefficientPrecision, CoordinatePrecision>(target_cluster, source_cluster, epsilon, eta, symmetry, UPLO, -1, get_rankWorld(communicator), get_rankWorld(communicator)), communicator), hmatrix(distributed_operator_builder.hmatrix), distributed_operator(distributed_operator_builder.distributed_operator), block_diagonal_hmatrix(distributed_operator_builder.block_diagonal_hmatrix) {}
};

template <typename CoefficientPrecision, typename CoordinatePrecision>
class DefaultLocalApproximationBuilder {
private:
const LocalHMatrix<CoefficientPrecision, CoordinatePrecision> local_hmatrix;
std::function<int(MPI_Comm)> get_rankWorld = [](MPI_Comm comm) {
int rankWorld;
MPI_Comm_rank(comm, &rankWorld);
return rankWorld; };
DistributedOperatorFromHMatrix<CoefficientPrecision, CoordinatePrecision> distributed_operator_builder;

public:
DistributedOperator<CoefficientPrecision> distributed_operator;
const HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix;

public:
DistributedOperator<CoefficientPrecision> &distributed_operator;
const HMatrix<CoefficientPrecision, CoordinatePrecision> *block_diagonal_hmatrix{nullptr};

public:
DefaultLocalApproximationBuilder(const VirtualGenerator<CoefficientPrecision> &generator, const Cluster<CoordinatePrecision> &target_cluster, const Cluster<CoordinatePrecision> &source_cluster, htool::underlying_type<CoefficientPrecision> epsilon, htool::underlying_type<CoefficientPrecision> eta, char symmetry, char UPLO, MPI_Comm communicator) : target_partition(target_cluster), source_partition(source_cluster), hmatrix(HMatrixTreeBuilder<CoefficientPrecision, CoordinatePrecision>(target_cluster.get_cluster_on_partition(get_rankWorld(communicator)), source_cluster.get_cluster_on_partition(get_rankWorld(communicator)), epsilon, eta, symmetry, UPLO, -1, -1, -1).build(generator)), local_hmatrix(hmatrix, target_cluster.get_cluster_on_partition(get_rankWorld(communicator)), source_cluster.get_cluster_on_partition(get_rankWorld(communicator)), symmetry, UPLO, false, false), distributed_operator(target_partition, source_partition, symmetry, UPLO, communicator) {
distributed_operator.add_local_operator(&local_hmatrix);
block_diagonal_hmatrix = hmatrix.get_sub_hmatrix(target_cluster.get_cluster_on_partition(get_rankWorld(communicator)), source_cluster.get_cluster_on_partition(get_rankWorld(communicator)));
}
DefaultLocalApproximationBuilder(const VirtualGenerator<CoefficientPrecision> &generator, const Cluster<CoordinatePrecision> &target_cluster, const Cluster<CoordinatePrecision> &source_cluster, htool::underlying_type<CoefficientPrecision> epsilon, htool::underlying_type<CoefficientPrecision> eta, char symmetry, char UPLO, MPI_Comm communicator) : distributed_operator_builder(generator, target_cluster, source_cluster, HMatrixTreeBuilder<CoefficientPrecision, CoordinatePrecision>(target_cluster.get_cluster_on_partition(get_rankWorld(communicator)), source_cluster.get_cluster_on_partition(get_rankWorld(communicator)), epsilon, eta, symmetry, UPLO, -1, get_rankWorld(communicator), get_rankWorld(communicator)), communicator), hmatrix(distributed_operator_builder.hmatrix), distributed_operator(distributed_operator_builder.distributed_operator), block_diagonal_hmatrix(distributed_operator_builder.block_diagonal_hmatrix) {}
};

} // namespace htool
Expand Down
6 changes: 6 additions & 0 deletions include/htool/hmatrix/tree_builder/tree_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ class HMatrixTreeBuilder {
void set_minimal_source_depth(int minimal_source_depth) { m_minsourcedepth = minimal_source_depth; }
void set_minimal_target_depth(int minimal_target_depth) { m_mintargetdepth = minimal_target_depth; }
void set_dense_blocks_generator(std::shared_ptr<VirtualDenseBlocksGenerator<CoefficientPrecision>> dense_blocks_generator) { m_dense_blocks_generator = dense_blocks_generator; }

// Getters
char get_symmetry() { return m_symmetry_type; }
char get_UPLO() { return m_UPLO_type; }
const Cluster<CoordinatePrecision> &get_target_cluster() { return m_target_root_cluster; }
const Cluster<CoordinatePrecision> &get_source_cluster() { return m_source_root_cluster; }
};

template <typename CoefficientPrecision, typename CoordinatePrecision>
Expand Down

0 comments on commit 59a2322

Please sign in to comment.