Skip to content

Commit

Permalink
potential way of implementing discussion #587
Browse files Browse the repository at this point in the history
  • Loading branch information
mreineck committed Nov 20, 2024
1 parent 4d25c21 commit 64c963e
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 187 deletions.
58 changes: 26 additions & 32 deletions include/finufft/finufft_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ static inline void MY_OMP_SET_NUM_THREADS [[maybe_unused]] (int) {}

// group together a bunch of type 3 rescaling/centering/phasing parameters:
template<typename T> struct type3params {
T X1, C1, D1, h1, gam1; // x dim: X=halfwid C=center D=freqcen h,gam=rescale
T X2, C2, D2, h2, gam2; // y
T X3, C3, D3, h3, gam3; // z
std::array<T, 3> X, C, D, h, gam; // x dim: X=halfwid C=center D=freqcen h,gam=rescale
};

template<typename TF> struct FINUFFT_PLAN_T { // the main plan class, fully C++
Expand All @@ -151,30 +149,26 @@ template<typename TF> struct FINUFFT_PLAN_T { // the main plan class, fully C++
FINUFFT_PLAN_T &operator=(const FINUFFT_PLAN_T &) = delete;
~FINUFFT_PLAN_T();

int type; // transform type (Rokhlin naming): 1,2 or 3
int dim; // overall dimension: 1,2 or 3
int ntrans; // how many transforms to do at once (vector or "many" mode)
BIGINT nj; // num of NU pts in type 1,2 (for type 3, num input x pts)
BIGINT nk; // number of NU freq pts (type 3 only)
TF tol; // relative user tolerance
int batchSize; // # strength vectors to group together for FFTW, etc
int nbatch; // how many batches done to cover all ntrans vectors
int type; // transform type (Rokhlin naming): 1,2 or 3
int dim; // overall dimension: 1,2 or 3
int ntrans; // how many transforms to do at once (vector or "many" mode)
BIGINT nj; // num of NU pts in type 1,2 (for type 3, num input x pts)
BIGINT nk; // number of NU freq pts (type 3 only)
TF tol; // relative user tolerance
int batchSize; // # strength vectors to group together for FFTW, etc
int nbatch; // how many batches done to cover all ntrans vectors

BIGINT ms; // number of modes in x (1) dir (historical CMCL name) = N1
BIGINT mt; // number of modes in y (2) direction = N2
BIGINT mu; // number of modes in z (3) direction = N3
BIGINT N; // total # modes (prod of above three)
std::array<UBIGINT, 3> mstu; // number of modes in x/y/z dir (historical CMCL name) =
// N1/N2/N3
UBIGINT N; // total # modes (prod of above three)

BIGINT nf1 = 1; // size of internal fine grid in x (1) direction
BIGINT nf2 = 1; // " y (2)
BIGINT nf3 = 1; // " z (3)
BIGINT nf = 1; // total # fine grid points (product of the above three)
std::array<UBIGINT, 3> nf123 = {1, 1, 1}; // size of internal fine grid in x/y/z
// direction
UBIGINT nf = 1; // total # fine grid points (product of the above three)

int fftSign; // sign in exponential for NUFFT defn, guaranteed to be +-1
int fftSign; // sign in exponential for NUFFT defn, guaranteed to be +-1

std::vector<TF> phiHat1; // FT of kernel in t1,2, on x-axis mode grid
std::vector<TF> phiHat2; // " y-axis.
std::vector<TF> phiHat3; // " z-axis.
std::array<std::vector<TF>, 3> phiHat; // FT of kernel in t1,2, on x/y/z-axis mode grid

// fwBatch: (batches of) fine working grid(s) for the FFT to plan & act on.
// Usually the largest internal array. Its allocator is 64-byte (cache-line) aligned:
Expand All @@ -185,17 +179,17 @@ template<typename TF> struct FINUFFT_PLAN_T { // the main plan class, fully C++

// for t1,2: ptr to user-supplied NU pts (no new allocs).
// for t3: will become ptr to internally allocated "primed" (scaled) Xp, Yp, Zp vecs.
TF *X = nullptr, *Y = nullptr, *Z = nullptr;
std::array<TF *, 3> XYZ = {nullptr, nullptr, nullptr};

// type 3 specific
TF *S = nullptr, *T = nullptr, *U = nullptr; // ptrs to user's target NU-point arrays
// (no new allocs)
std::vector<TC> prephase; // pre-phase, for all input NU pts
std::vector<TC> deconv; // reciprocal of kernel FT, phase, all output NU pts
std::vector<TC> CpBatch; // working array of prephased strengths
std::vector<TF> Xp, Yp, Zp; // internal primed NU points (x'_j, etc)
std::vector<TF> Sp, Tp, Up; // internal primed targs (s'_k, etc)
type3params<TF> t3P; // groups together type 3 shift, scale, phase, parameters
std::array<TF *, 3> STU = {nullptr, nullptr, nullptr}; // ptrs to user's target NU-point
// arrays (no new allocs)
std::vector<TC> prephase; // pre-phase, for all input NU pts
std::vector<TC> deconv; // reciprocal of kernel FT, phase, all output NU pts
std::vector<TC> CpBatch; // working array of prephased strengths
std::array<std::vector<TF>, 3> XYZp; // internal primed NU points (x'_j, etc)
std::array<std::vector<TF>, 3> STUp; // internal primed targs (s'_k, etc)
type3params<TF> t3P; // groups together type 3 shift, scale, phase, parameters
std::unique_ptr<FINUFFT_PLAN_T<TF>> innerT2plan; // ptr used for type 2 in step 2 of
// type 3

Expand Down
24 changes: 12 additions & 12 deletions src/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ using namespace std;
template<typename TF> std::vector<int> gridsize_for_fft(FINUFFT_PLAN_T<TF> *p) {
// local helper func returns a new int array of length dim, extracted from
// the finufft plan, that fftw_plan_many_dft needs as its 2nd argument.
if (p->dim == 1) return {(int)p->nf1};
if (p->dim == 2) return {(int)p->nf2, (int)p->nf1};
if (p->dim == 1) return {(int)p->nf123[0]};
if (p->dim == 2) return {(int)p->nf123[1], (int)p->nf123[0]};
// if (p->dim == 3)
return {(int)p->nf3, (int)p->nf2, (int)p->nf1};
return {(int)p->nf123[2], (int)p->nf123[1], (int)p->nf123[0]};
}
template std::vector<int> gridsize_for_fft<float>(FINUFFT_PLAN_T<float> *p);
template std::vector<int> gridsize_for_fft<double>(FINUFFT_PLAN_T<double> *p);
Expand Down Expand Up @@ -49,11 +49,11 @@ template<typename TF> void do_fft(FINUFFT_PLAN_T<TF> *p) {
if (p->dim == 1) // 1D: no chance for FFT shortcuts
ducc0::c2c(data, data, axes, p->fftSign < 0, TF(1), nthreads);
else if (p->dim == 2) { // 2D: do partial FFTs
if (p->ms < 2) // something is weird, do standard FFT
if (p->mstu[0] < 2) // something is weird, do standard FFT
ducc0::c2c(data, data, axes, p->fftSign < 0, TF(1), nthreads);
else {
size_t y_lo = size_t((p->ms + 1) / 2);
size_t y_hi = size_t(ns[1] - p->ms / 2);
size_t y_lo = size_t((p->mstu[0] + 1) / 2);
size_t y_hi = size_t(ns[1] - p->mstu[0] / 2);
// the next line is analogous to the Python statement "sub1 = data[:, :, :y_lo]"
auto sub1 = ducc0::subarray(data, {{}, {}, {0, y_lo}});
// the next line is analogous to the Python statement "sub2 = data[:, :, y_hi:]"
Expand All @@ -68,14 +68,14 @@ template<typename TF> void do_fft(FINUFFT_PLAN_T<TF> *p) {
// do axis 2 in full
ducc0::c2c(data, data, {2}, p->fftSign < 0, TF(1), nthreads);
}
} else { // 3D
if ((p->ms < 2) || (p->mt < 2)) // something is weird, do standard FFT
} else { // 3D
if ((p->mstu[0] < 2) || (p->mstu[1] < 2)) // something is weird, do standard FFT
ducc0::c2c(data, data, axes, p->fftSign < 0, TF(1), nthreads);
else {
size_t z_lo = size_t((p->ms + 1) / 2);
size_t z_hi = size_t(ns[2] - p->ms / 2);
size_t y_lo = size_t((p->mt + 1) / 2);
size_t y_hi = size_t(ns[1] - p->mt / 2);
size_t z_lo = size_t((p->mstu[0] + 1) / 2);
size_t z_hi = size_t(ns[2] - p->mstu[0] / 2);
size_t y_lo = size_t((p->mstu[1] + 1) / 2);
size_t y_hi = size_t(ns[1] - p->mstu[1] / 2);
auto sub1 = ducc0::subarray(data, {{}, {}, {}, {0, z_lo}});
auto sub2 = ducc0::subarray(data, {{}, {}, {}, {z_hi, ducc0::MAXIDX}});
auto sub3 = ducc0::subarray(sub1, {{}, {}, {0, y_lo}, {}});
Expand Down
Loading

0 comments on commit 64c963e

Please sign in to comment.