Skip to content

Commit

Permalink
Only do kain on states whose residual is greater than 25% target resi…
Browse files Browse the repository at this point in the history
…dual
  • Loading branch information
ahurta92 committed Oct 26, 2023
1 parent 4a350d3 commit 3efccf6
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/apps/molresponse/ExcitedResponse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2161,7 +2161,7 @@ auto ExcitedResponse::update_response(World &world, X_space &Chi, XCOperator<dou
// kain if iteration >0 or first run where there should not be a problem
// computed new_chi and res
if (r_params.kain() && (iter > 0) && true) {
new_chi = kain_x_space_update(world, rotated_chi, new_res, kain_x_space);
new_chi = kain_x_space_update(world, rotated_chi, new_res, kain_x_space, 1e-8);
}
if (false) { x_space_step_restriction(world, rotated_chi, new_chi, compute_y, maxrotn); }

Expand Down
13 changes: 2 additions & 11 deletions src/apps/molresponse/FrequencyResponse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,7 @@ void FrequencyResponse::iterate(World &world) {
auto bsh_x_ops = make_bsh_operators_response(world, x_shifts, omega);
std::vector<poperatorT> bsh_y_ops;
bsh_y_ops = (compute_y) ? make_bsh_operators_response(world, y_shifts, -omega) : bsh_x_ops;
auto max_rotation = .5;
if (thresh >= 1e-2) {
max_rotation = 2;
} else if (thresh >= 1e-4) {
max_rotation = 2 * x_residual_target;
} else if (thresh >= 1e-6) {
max_rotation = 2 * x_residual_target;
} else if (thresh >= 1e-7) {
max_rotation = .01;
}
auto max_rotation = .25 * x_residual_target + x_residual_target;
PQ = generator(world, *this);
PQ.truncate();

Expand Down Expand Up @@ -280,7 +271,7 @@ auto FrequencyResponse::update_response(World &world, X_space &chi, XCOperator<d
auto [new_res, bsh] = update_residual(world, chi, new_chi, r_params.calc_type(), old_residuals, xres_old);
inner_to_json(world, "r_x", response_context.inner(new_res, new_res), iter_function_data);
if (iteration >= 0) {// & (iteration % 3 == 0)) {
new_chi = kain_x_space_update(world, chi, new_res, kain_x_space);
new_chi = kain_x_space_update(world, chi, new_res, kain_x_space, max_rotation);
}
inner_to_json(world, "x_update", response_context.inner(new_chi, new_chi), iter_function_data);

Expand Down
28 changes: 23 additions & 5 deletions src/apps/molresponse/ResponseBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1174,24 +1174,42 @@ auto ResponseBase::update_residual(World &world, const X_space &chi, const X_spa
}

auto ResponseBase::kain_x_space_update(World &world, const X_space &chi, const X_space &residual_chi,
response_solver &kain_x_space) -> X_space {
response_solver &kain_x_space, double max_rotation) -> X_space {
if (r_params.print_level() >= 1) { molresponse::start_timer(world); }
size_t m = chi.num_states();
size_t n = chi.num_orbitals();
X_space kain_update = chi.copy();
response_matrix update(m);

// compute the norm of the residuals


Tensor<double> residual_norms(m);

bool compute_y = r_params.omega() != 0.0;
if (compute_y) {
auto x_vectors = to_response_matrix(chi);
auto x_residuals = to_response_matrix(residual_chi);
// compute the norm of the residuals
for (const auto &b: chi.active) { residual_norms(b) = norm2(world, x_residuals[b]); }// / norm2(world, gx[b]); }

for (const auto &i: Chi.active) {
auto temp = kain_x_space[i].update(x_vectors[i], x_residuals[i]);
std::copy(temp.begin(), temp.begin() + n, kain_update.x[i].begin());
std::copy(temp.begin() + n, temp.end(), kain_update.y[i].begin());
if (residual_norms(i) > max_rotation) {
auto temp = kain_x_space[i].update(x_vectors[i], x_residuals[i]);
std::copy(temp.begin(), temp.begin() + n, kain_update.x[i].begin());
std::copy(temp.begin() + n, temp.end(), kain_update.y[i].begin());
}
};
} else {
for (const auto &i: Chi.active) { kain_update.x[i] = kain_x_space[i].update(chi.x[i], residual_chi.x[i]); }
// first compute the residuals
for (const auto &b: chi.active) {
residual_norms(b) = norm2(world, residual_chi.x[b]);
}// / norm2(world, g_chi.x[b]); }
for (const auto &i: Chi.active) {
if (residual_norms(i) > max_rotation) {
kain_update.x[i] = kain_x_space[i].update(chi.x[i], residual_chi.x[i]);
}
}
}
if (r_params.print_level() >= 1) { molresponse::end_timer(world, "kain_x_update", "kain_x_update", iter_timing); }
return kain_update;
Expand Down
2 changes: 1 addition & 1 deletion src/apps/molresponse/ResponseBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ class ResponseBase {


auto kain_x_space_update(World &world, const X_space &chi, const X_space &residual_chi,
response_solver &kain_x_space) -> X_space;
response_solver &kain_x_space, double max_rotation) -> X_space;

void x_space_step_restriction(World &world, const X_space &old_Chi, X_space &temp, bool restrict_y,
const double &max_bsh_rotation);
Expand Down

0 comments on commit 3efccf6

Please sign in to comment.