Skip to content

Commit

Permalink
Add single precision support
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Oct 15, 2024
1 parent 4eb5f43 commit d7c313f
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 29 deletions.
111 changes: 83 additions & 28 deletions Src/FFT/AMReX_FFT.H
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ R2C<T>::R2C (Box const& domain)
domain.length(0)/2,
domain.bigEnd(1))))
{
AMREX_ALWAYS_ASSERT(domain.smallEnd() == 0);
AMREX_ALWAYS_ASSERT((std::is_same_v<double,Real>)); // xxxxx todo
AMREX_ALWAYS_ASSERT(domain.smallEnd() == 0 && domain.cellCentered());

int myproc = ParallelDescriptor::MyProc();
int nprocs = ParallelDescriptor::NProcs();
Expand All @@ -208,7 +207,7 @@ R2C<T>::R2C (Box const& domain)

// plans for x-direction
{
double* in = m_rx[myproc].dataPtr();
auto* in = m_rx[myproc].dataPtr();
auto* out = (FFTComplex*)(m_cx[myproc].dataPtr());

Box const local_box = m_rx.boxArray()[myproc];
Expand All @@ -228,15 +227,27 @@ R2C<T>::R2C (Box const& domain)
static_assert(false);

#else
m_fft_fwd_x = fftw_plan_many_dft_r2c
(1, &n, howmany, in, nullptr, 1, m_real_domain.length(0),
out, nullptr, 1, m_spectral_domain_x.length(0),
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);

m_fft_bwd_x = fftw_plan_many_dft_c2r
(1, &n, howmany, out, nullptr, 1, m_spectral_domain_x.length(0),
in, nullptr, 1, m_real_domain.length(0),
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
if constexpr (std::is_same_v<float,T>) {
m_fft_fwd_x = fftwf_plan_many_dft_r2c
(1, &n, howmany, in, nullptr, 1, m_real_domain.length(0),
out, nullptr, 1, m_spectral_domain_x.length(0),
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);

m_fft_bwd_x = fftwf_plan_many_dft_c2r
(1, &n, howmany, out, nullptr, 1, m_spectral_domain_x.length(0),
in, nullptr, 1, m_real_domain.length(0),
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
} else {
m_fft_fwd_x = fftw_plan_many_dft_r2c
(1, &n, howmany, in, nullptr, 1, m_real_domain.length(0),
out, nullptr, 1, m_spectral_domain_x.length(0),
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);

m_fft_bwd_x = fftw_plan_many_dft_c2r
(1, &n, howmany, out, nullptr, 1, m_spectral_domain_x.length(0),
in, nullptr, 1, m_real_domain.length(0),
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
}
#endif
}

Expand Down Expand Up @@ -285,40 +296,74 @@ R2C<T>::~R2C<T> ()
static_assert(false);

#else
fftw_destroy_plan(m_fft_fwd_x);
fftw_destroy_plan(m_fft_bwd_x);
fftw_destroy_plan(m_fft_fwd_y);
fftw_destroy_plan(m_fft_bwd_y);
fftw_destroy_plan(m_fft_fwd_z);
fftw_destroy_plan(m_fft_bwd_z);

if constexpr (std::is_same_v<float,T>) {
fftwf_destroy_plan(m_fft_fwd_x);
fftwf_destroy_plan(m_fft_bwd_x);
fftwf_destroy_plan(m_fft_fwd_y);
fftwf_destroy_plan(m_fft_bwd_y);
fftwf_destroy_plan(m_fft_fwd_z);
fftwf_destroy_plan(m_fft_bwd_z);
} else {
fftw_destroy_plan(m_fft_fwd_x);
fftw_destroy_plan(m_fft_bwd_x);
fftw_destroy_plan(m_fft_fwd_y);
fftw_destroy_plan(m_fft_bwd_y);
fftw_destroy_plan(m_fft_fwd_z);
fftw_destroy_plan(m_fft_bwd_z);
}
#endif
}

template <typename T>
void R2C<T>::forward_doit (MF const& inmf, Scaling scaling)
{
m_rx.ParallelCopy(inmf, 0, 0, 1);
fftw_execute(m_fft_fwd_x);
if constexpr (std::is_same_v<float,T>) {
fftwf_execute(m_fft_fwd_x);
} else {
fftw_execute(m_fft_fwd_x);
}

ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y);
fftw_execute(m_fft_fwd_y);
if constexpr (std::is_same_v<float,T>) {
fftwf_execute(m_fft_fwd_y);
} else {
fftw_execute(m_fft_fwd_y);
}

ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z);
fftw_execute(m_fft_fwd_z);
if constexpr (std::is_same_v<float,T>) {
fftwf_execute(m_fft_fwd_z);
} else {
fftw_execute(m_fft_fwd_z);
}
}

template <typename T>
void R2C<T>::backward_doit (MF& outmf, Scaling scaling)
{
// xxxxx todo: scaling

fftw_execute(m_fft_bwd_z);
if constexpr (std::is_same_v<float,T>) {
fftwf_execute(m_fft_bwd_z);
} else {
fftw_execute(m_fft_bwd_z);
}
ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y);

fftw_execute(m_fft_bwd_y);
if constexpr (std::is_same_v<float,T>) {
fftwf_execute(m_fft_bwd_y);
} else {
fftw_execute(m_fft_bwd_y);
}
ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x);

fftw_execute(m_fft_bwd_x);
if constexpr (std::is_same_v<float,T>) {
fftwf_execute(m_fft_bwd_x);
} else {
fftw_execute(m_fft_bwd_x);
}
outmf.ParallelCopy(m_rx, 0, 0, 1);
}

Expand All @@ -333,6 +378,9 @@ R2C<T>::make_c2c_plans (cMF& inout)
int n = local_box.length(0);
int howmany = local_box.length(1) * local_box.length(2);

FFTPlan fwd;
FFTPlan bwd;

#if defined(AMREX_USE_CUDA)

static_assert(false);
Expand All @@ -346,10 +394,17 @@ R2C<T>::make_c2c_plans (cMF& inout)
static_assert(false);

#else
auto* fwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n,
pinout, nullptr, 1, n, -1, FFTW_ESTIMATE);
auto* bwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n,
pinout, nullptr, 1, n, +1, FFTW_ESTIMATE);
if constexpr (std::is_same_v<float,T>) {
fwd = fftwf_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n,
pinout, nullptr, 1, n, -1, FFTW_ESTIMATE);
bwd = fftwf_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n,
pinout, nullptr, 1, n, +1, FFTW_ESTIMATE);
} else {
fwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n,
pinout, nullptr, 1, n, -1, FFTW_ESTIMATE);
bwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n,
pinout, nullptr, 1, n, +1, FFTW_ESTIMATE);
}
#endif

return {fwd,bwd};
Expand Down
2 changes: 1 addition & 1 deletion Tests/FFT/Poisson/GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ else ifeq ($(USE_HIP),TRUE)
LIBRARY_LOCATIONS += $(ROC_PATH)/rocfft/lib
LIBRARIES += -L$(ROC_PATH)/rocfft/lib -lrocfft
else
libraries += -lfftw3
libraries += -lfftw3f -lfftw3
endif

include $(AMREX_HOME)/Tools/GNUMake/Make.rules

0 comments on commit d7c313f

Please sign in to comment.