Skip to content

Commit

Permalink
Merge pull request #503 from m-a-d-n-e-s-s/pr-tree-operations
Browse files Browse the repository at this point in the history
unified the tree status: reconstruct, compress, nonstandard, redundan…
  • Loading branch information
fbischoff authored Oct 16, 2023
2 parents e4ee892 + a368591 commit d993694
Show file tree
Hide file tree
Showing 8 changed files with 510 additions and 80 deletions.
2 changes: 1 addition & 1 deletion src/madness/mra/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ if(BUILD_TESTING)

set(MRA_TEST_SOURCES testbsh.cc testproj.cc
testpdiff.cc testdiff1Db.cc testgconv.cc testopdir.cc testinnerext.cc
testgaxpyext.cc testvmra.cc, test_vectormacrotask.cc test_cloud.cc
testgaxpyext.cc testvmra.cc, test_vectormacrotask.cc test_cloud.cc test_tree_state.cc
test_macrotaskpartitioner.cc test_QCCalculationParametersBase.cc)
add_unittests(mra "${MRA_TEST_SOURCES}" "MADmra;MADgtest" "unittests;short")
set(MRA_SEPOP_TEST_SOURCES testsuite.cc
Expand Down
14 changes: 14 additions & 0 deletions src/madness/mra/funcdefaults.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,25 @@ namespace madness {
compressed, ///< d coeffs in internal nodes, s and d coeffs at the root
nonstandard, ///< s and d coeffs in internal nodes
nonstandard_with_leaves, ///< like nonstandard, with s coeffs at the leaves
nonstandard_after_apply, ///< s and d coeffs, state after operator application
redundant, ///< s coeffs everywhere
on_demand, ///< no coeffs anywhere, but a functor providing if necessary
unknown
};

template<std::size_t NDIM=1>
std::ostream& operator<<(std::ostream& os, const TreeState treestate) {
if (treestate==reconstructed) os << "reconstructed";
if (treestate==compressed) os << "compressed";
if (treestate==nonstandard) os << "nonstandard";
if (treestate==nonstandard_with_leaves) os << "nonstandard_with_leaves";
if (treestate==nonstandard_after_apply) os << "nonstandard_after_apply";
if (treestate==redundant) os << "redundant";
if (treestate==on_demand) os << "on_demand";
if (treestate==unknown) os << "unknown";
return os;
}

/*!
\brief This class is used to specify boundary conditions for all operators
\ingroup mrabcext
Expand Down
28 changes: 25 additions & 3 deletions src/madness/mra/funcimpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2151,7 +2151,6 @@ namespace madness {
};

/// remove all coefficients of internal nodes
/// presumably to switch from redundant to reconstructed state
struct remove_internal_coeffs {
typedef Range<typename dcT::iterator> rangeT;

Expand All @@ -2168,6 +2167,22 @@ namespace madness {

};

/// remove all coefficients of leaf nodes
struct remove_leaf_coeffs {
typedef Range<typename dcT::iterator> rangeT;

/// constructor need impl for cdata
remove_leaf_coeffs() {}

bool operator()(typename rangeT::iterator& it) const {
nodeT& node = it->second;
if (not node.has_children()) node.clear_coeff();
return true;
}
template <typename Archive> void serialize(const Archive& ar) {}

};


/// keep only the sum coefficients in each node
struct do_keep_sum_coeffs {
Expand Down Expand Up @@ -4438,11 +4453,12 @@ namespace madness {
/// cf reconstruct_op
void trickle_down_op(const keyT& key, const coeffT& s);

/// reconstruct this tree -- respects fence
void reconstruct(bool fence);

// Invoked on node where key is local
// void reconstruct_op(const keyT& key, const tensorT& s);
void reconstruct_op(const keyT& key, const coeffT& s);
void reconstruct_op(const keyT& key, const coeffT& s, const bool accumulate_NS=true);

/// compress the wave function

Expand All @@ -4464,6 +4480,9 @@ namespace madness {
/// convert this from redundant to standard reconstructed form
void undo_redundant(const bool fence);

void remove_internal_coefficients(const bool fence);
void remove_leaf_coefficients(const bool fence);


/// compute for each FunctionNode the norm of the function inside that node
void norm_tree(bool fence);
Expand Down Expand Up @@ -4768,7 +4787,7 @@ namespace madness {
if (fence)
world.gop.fence();

set_tree_state(nonstandard);
set_tree_state(nonstandard_after_apply);
// this->compressed=true;
// this->nonstandard=true;
// this->redundant=false;
Expand Down Expand Up @@ -4913,6 +4932,7 @@ namespace madness {
}
}
if (fence) world.gop.fence();
set_tree_state(TreeState::nonstandard_after_apply);
}

/// after apply we need to do some cleanup;
Expand Down Expand Up @@ -4950,6 +4970,7 @@ namespace madness {

}
if (fence) world.gop.fence();
set_tree_state(TreeState::nonstandard_after_apply);
}

/// recursive part of recursive_apply
Expand Down Expand Up @@ -5080,6 +5101,7 @@ namespace madness {

}
if (fence) world.gop.fence();
set_tree_state(TreeState::nonstandard_after_apply);
}

/// recursive part of recursive_apply
Expand Down
156 changes: 139 additions & 17 deletions src/madness/mra/mra.h
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ namespace madness {
PROFILE_MEMBER_FUNC(Function);
if (!impl || is_compressed()) return *this;
if (VERIFY_TREE) verify_tree();
if (impl->is_nonstandard()) {
if (impl->is_nonstandard() or impl->is_nonstandard_with_leaves()) {
impl->standard(fence);
} else {
const_cast<Function<T,NDIM>*>(this)->impl->compress(TreeState::compressed, fence);
Expand Down Expand Up @@ -767,6 +767,21 @@ namespace madness {
if (fence && VERIFY_TREE) verify_tree();
}

/// Converts the function to redundant form, i.e. sum coefficients on all levels

/// By default fence=true meaning that this operation completes before returning,
/// otherwise if fence=false it returns without fencing and the user must invoke
/// world.gop.fence() to assure global completion before using the function
/// for other purposes.
///
/// Must be already compressed.
void make_redundant(bool fence = true) {
PROFILE_MEMBER_FUNC(Function);
verify();
change_tree_state(redundant, fence);
if (fence && VERIFY_TREE) verify_tree();
}

/// Reconstructs the function, transforming into scaling function basis. Possible non-blocking comm.

/// By default fence=true meaning that this operation completes before returning,
Expand All @@ -781,7 +796,114 @@ namespace madness {
const Function<T,NDIM>& reconstruct(bool fence = true) const {
PROFILE_MEMBER_FUNC(Function);
if (!impl || impl->is_reconstructed()) return *this;
const_cast<Function<T,NDIM>*>(this)->impl->reconstruct(fence);
change_tree_state(reconstructed, fence);
if (fence && VERIFY_TREE) verify_tree(); // Must be after in case nonstandard
return *this;
}

/// changes tree state to given state

/// Since reconstruction/compression do not discard information we define them
/// as const ... "logical constness" not "bitwise constness".
/// @param[in] finalstate The final state of the tree
/// @param[in] fence Fence after the operation (might not be respected!!!)
const Function<T,NDIM>& change_tree_state(const TreeState finalstate, bool fence = true) const {
PROFILE_MEMBER_FUNC(Function);
if (not impl) return *this;
TreeState current_state=impl->get_tree_state();
if (finalstate==current_state) return *this;
MADNESS_CHECK_THROW(current_state!=TreeState::unknown,"unknown tree state");

// very special case
if (impl->get_tree_state()==nonstandard_after_apply) {
MADNESS_CHECK(finalstate==reconstructed);
impl->reconstruct(fence);
current_state=impl->get_tree_state();
}
MADNESS_CHECK_THROW(current_state!=TreeState::nonstandard_after_apply,"unknown tree state");
bool must_fence=false;

if (finalstate==reconstructed) {
if (current_state==reconstructed) return *this;
if (current_state==compressed) impl->reconstruct(fence);
if (current_state==nonstandard) impl->reconstruct(fence);
if (current_state==nonstandard_with_leaves) impl->remove_internal_coefficients(fence);
if (current_state==redundant) impl->remove_internal_coefficients(fence);
impl->set_tree_state(reconstructed);
} else if (finalstate==compressed) {
if (current_state==reconstructed) impl->compress(compressed,fence);
if (current_state==compressed) return *this;
if (current_state==nonstandard) impl->standard(fence);
if (current_state==nonstandard_with_leaves) impl->standard(fence);
if (current_state==redundant) {
impl->remove_internal_coefficients(true);
must_fence=true;
impl->set_tree_state(reconstructed);
impl->compress(compressed,fence);
}
impl->set_tree_state(compressed);
} else if (finalstate==nonstandard) {
if (current_state==reconstructed) impl->compress(nonstandard,fence);
if (current_state==compressed) {
impl->reconstruct(true);
must_fence=true;
impl->compress(nonstandard,fence);
}
if (current_state==nonstandard) return *this;
if (current_state==nonstandard_with_leaves) impl->remove_leaf_coefficients(fence);
if (current_state==redundant) {
impl->remove_internal_coefficients(true);
must_fence=true;
impl->set_tree_state(reconstructed);
impl->compress(nonstandard,fence);
}
impl->set_tree_state(nonstandard);
} else if (finalstate==nonstandard_with_leaves) {
if (current_state==reconstructed) impl->compress(nonstandard_with_leaves,fence);
if (current_state==compressed) {
impl->reconstruct(true);
must_fence=true;
impl->compress(nonstandard_with_leaves,fence);
}
if (current_state==nonstandard) {
impl->standard(true);
must_fence=true;
impl->reconstruct(true);
impl->compress(nonstandard_with_leaves,fence);
}
if (current_state==nonstandard_with_leaves) return *this;
if (current_state==redundant) {
impl->remove_internal_coefficients(true);
must_fence=true;
impl->set_tree_state(reconstructed);
impl->compress(nonstandard_with_leaves,fence);
}
impl->set_tree_state(nonstandard_with_leaves);
} else if (finalstate==redundant) {
if (current_state==reconstructed) impl->make_redundant(fence);
if (current_state==compressed) {
impl->reconstruct(true);
must_fence=true;
impl->make_redundant(fence);
}
if (current_state==nonstandard) {
impl->standard(true);
must_fence=true;
impl->reconstruct(true);
impl->make_redundant(fence);
}
if (current_state==nonstandard_with_leaves) {
impl->remove_internal_coefficients(true);
must_fence=true;
impl->set_tree_state(reconstructed);
impl->make_redundant(fence);
}
if (current_state==redundant) return *this;
impl->set_tree_state(redundant);
} else {
MADNESS_EXCEPTION("unknown/unsupported final tree state",1);
}
if (must_fence and world().rank()==0) print("could not respect fence in change_tree_state");
if (fence && VERIFY_TREE) verify_tree(); // Must be after in case nonstandard
return *this;
}
Expand Down Expand Up @@ -1401,15 +1523,16 @@ namespace madness {
.k(g.k()).thresh(g.thresh());
Function<resultT,KDIM> result=factory; // no empty() here!

FunctionImpl<R,LDIM>* gimpl = const_cast< FunctionImpl<R,LDIM>* >(g.get_impl().get());

this->reconstruct();
gimpl->make_redundant(true);
this->get_impl()->project_out(result.get_impl().get(),gimpl,dim,true);
change_tree_state(reconstructed,false);
g.change_tree_state(redundant,false);
world().gop.fence();
this->get_impl()->project_out(result.get_impl().get(),g.get_impl().get(),dim,true);
// result.get_impl()->project_out2(this->get_impl().get(),gimpl,dim);
result.world().gop.fence();
result.get_impl()->trickle_down(true);
gimpl->undo_redundant(true);
g.change_tree_state(reconstructed,false);
result.get_impl()->trickle_down(false);
result.get_impl()->set_tree_state(reconstructed);
result.world().gop.fence();
return result;
}

Expand Down Expand Up @@ -2076,7 +2199,7 @@ namespace madness {
if (op.modified()) {
result.get_impl()->trickle_down(true);
} else {
result.reconstruct();
result.get_impl()->reconstruct(true);
}
standard(world,ff1,false);
if (not same) standard(world,ff2,false);
Expand Down Expand Up @@ -2188,7 +2311,8 @@ namespace madness {
op.print_timer();
}

result.reconstruct();
result.get_impl()->reconstruct(true);

// fff.clear();
if (op.destructive()) {
ff.world().gop.fence();
Expand Down Expand Up @@ -2425,8 +2549,8 @@ namespace madness {
MADNESS_CHECK(world.size() == 1);

if (prepare) {
f.make_nonstandard(false, false);
g.make_nonstandard(false, false);
f.change_tree_state(nonstandard);
g.change_tree_state(nonstandard);
world.gop.fence();
f.get_impl()->compute_snorm_and_dnorm(false);
g.get_impl()->compute_snorm_and_dnorm(false);
Expand All @@ -2440,13 +2564,14 @@ namespace madness {
result=FunctionFactory<resultT,NDIM>(world)
.k(f.k()).thresh(f.thresh()).empty().nofence();
result.get_impl()->partial_inner(*f.get_impl(),*g.get_impl(),v1,v2);
result.get_impl()->set_tree_state(nonstandard);
result.get_impl()->set_tree_state(nonstandard_after_apply);
world.gop.set_forbid_fence(false);
}

if (finish) {

world.gop.fence();
result.get_impl()->reconstruct(true);
result.reconstruct();
FunctionImpl<T,LDIM>& f_nc=const_cast<FunctionImpl<T,LDIM>&>(*f.get_impl());
FunctionImpl<R,KDIM>& g_nc=const_cast<FunctionImpl<R,KDIM>&>(*g.get_impl());
Expand All @@ -2465,13 +2590,10 @@ namespace madness {

for (auto& key : erase_list(f_nc)) f_nc.get_coeffs().erase(key);
for (auto& key : erase_list(g_nc)) g_nc.get_coeffs().erase(key);
g_nc.standard(false);
f_nc.standard(false);
world.gop.fence();
g_nc.reconstruct(false);
f_nc.reconstruct(false);
world.gop.fence();
// print("timings: get_lists, recur, contract",wall_get_lists,wall_recur,wall_contract);

}

Expand Down
Loading

0 comments on commit d993694

Please sign in to comment.