Skip to content

Commit

Permalink
Keep positions in ST.
Browse files Browse the repository at this point in the history
  • Loading branch information
ye-luo committed Oct 17, 2024
1 parent 47218c2 commit 105a3b3
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2ROMPTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ void SplineC2ROMPTarget<ST>::mw_evaluateDetRatios(const RefVectorWithLeader<SPOS
for (const VirtualParticleSet& VP : vp_list)
mw_nVP += VP.getTotalNum();

const size_t packed_size = nw * sizeof(ValueType*) + mw_nVP * (6 * sizeof(TT) + sizeof(int));
const size_t packed_size = nw * sizeof(ValueType*) + mw_nVP * (6 * sizeof(ST) + sizeof(int));
det_ratios_buffer_H2D.resize(packed_size);

// pack invRow_ptr_list to det_ratios_buffer_H2D
Expand All @@ -297,9 +297,9 @@ void SplineC2ROMPTarget<ST>::mw_evaluateDetRatios(const RefVectorWithLeader<SPOS
ptr_buffer[iw] = invRow_ptr_list[iw];

// pack particle positions
auto* pos_ptr = reinterpret_cast<TT*>(det_ratios_buffer_H2D.data() + nw * sizeof(ValueType*));
auto* pos_ptr = reinterpret_cast<ST*>(det_ratios_buffer_H2D.data() + nw * sizeof(ValueType*));
auto* ref_id_ptr =
reinterpret_cast<int*>(det_ratios_buffer_H2D.data() + nw * sizeof(ValueType*) + mw_nVP * 6 * sizeof(TT));
reinterpret_cast<int*>(det_ratios_buffer_H2D.data() + nw * sizeof(ValueType*) + mw_nVP * 6 * sizeof(ST));
size_t iVP = 0;
for (size_t iw = 0; iw < nw; iw++)
{
Expand Down Expand Up @@ -353,14 +353,14 @@ void SplineC2ROMPTarget<ST>::mw_evaluateDetRatios(const RefVectorWithLeader<SPOS

auto* restrict offload_scratch_iat_ptr = offload_scratch_ptr + spline_padded_size * iat;
auto* restrict psi_iat_ptr = results_scratch_ptr + sposet_padded_size * iat;
auto* ref_id_ptr = reinterpret_cast<int*>(buffer_H2D_ptr + nw * sizeof(ValueType*) + mw_nVP * 6 * sizeof(TT));
auto* ref_id_ptr = reinterpret_cast<int*>(buffer_H2D_ptr + nw * sizeof(ValueType*) + mw_nVP * 6 * sizeof(ST));
auto* restrict psiinv_ptr = reinterpret_cast<const ValueType**>(buffer_H2D_ptr)[ref_id_ptr[iat]];
auto* restrict pos_scratch = reinterpret_cast<TT*>(buffer_H2D_ptr + nw * sizeof(ValueType*));
auto* restrict pos_scratch = reinterpret_cast<ST*>(buffer_H2D_ptr + nw * sizeof(ValueType*));

int ix, iy, iz;
ST a[4], b[4], c[4];
spline2::computeLocationAndFractional(spline_ptr, ST(pos_scratch[iat * 6 + 3]), ST(pos_scratch[iat * 6 + 4]),
ST(pos_scratch[iat * 6 + 5]), ix, iy, iz, a, b, c);
spline2::computeLocationAndFractional(spline_ptr, pos_scratch[iat * 6 + 3], pos_scratch[iat * 6 + 4],
pos_scratch[iat * 6 + 5], ix, iy, iz, a, b, c);

PRAGMA_OFFLOAD("omp parallel for")
for (int index = 0; index < last - first; index++)
Expand All @@ -370,7 +370,7 @@ void SplineC2ROMPTarget<ST>::mw_evaluateDetRatios(const RefVectorWithLeader<SPOS
const size_t last_cplx = omptarget::min(last / 2, num_complex_splines);
PRAGMA_OFFLOAD("omp parallel for")
for (int index = first_cplx; index < last_cplx; index++)
C2R::assign_v(ST(pos_scratch[iat * 6]), ST(pos_scratch[iat * 6 + 1]), ST(pos_scratch[iat * 6 + 2]),
C2R::assign_v(pos_scratch[iat * 6], pos_scratch[iat * 6 + 1], pos_scratch[iat * 6 + 2],
psi_iat_ptr, offload_scratch_iat_ptr, myKcart_ptr, myKcart_padded_size, first_spo_local,
nComplexBands_local, index);

Expand Down

0 comments on commit 105a3b3

Please sign in to comment.