diff --git a/CHANGELOG b/CHANGELOG index 6b34c4b46..fe852d418 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -2,9 +2,15 @@ List of features / changes made / release notes, in reverse chronological order. If not stated, FINUFFT is assumed (cuFINUFFT <=1.3 is listed separately). Master (9/10/24) - * reduced roundoff error in a[n] phase calc in CPU onedim_fseries_kernel(). #534 (Barnett). +* Support for type 3 in 1D, 2D, and 3D in the GPU library cufinufft (PR #517). + - Removed the CPU fseries computation (only used for benchmark no longer needed). + - Added complex arithmetic support for cuda_complex type + - Added tests for type 3 in 1D, 2D, and 3D and cuda_complex arithmetic + - Minor fixes on the GPU code: + a) removed memory leaks in case of errors + b) renamed maxbatchsize to batchsize V 2.3.0 (9/5/24) diff --git a/docs/devnotes.rst b/docs/devnotes.rst index adbaddc30..733d38302 100644 --- a/docs/devnotes.rst +++ b/docs/devnotes.rst @@ -54,6 +54,8 @@ Developer notes * CMake compiling on linux at Flatiron Institute (Rusty cluster): We have had a report that if you want to use LLVM, you need to ``module load llvm/16.0.3`` otherwise the default ``llvm/14.0.6`` does not find ``OpenMP_CXX``. +* Note to the nvcc developer. nvcc with debug symbols causes a stack overflow that is undetected at both compile and runtime. This goes undetected until ns>=10 and dim=3, for ns<10 or dim < 3, one can use -G and debug the code with cuda-gdb. The way to avoid is to not use Debug symbols, possibly using ``--generate-line-info`` might work (not tested). As a side note, compute-sanitizers do not detect the issue. + * Testing cufinufft (for FI, mostly): .. code-block:: sh diff --git a/include/cufinufft/common.h b/include/cufinufft/common.h index e1149df1f..5c11e8815 100644 --- a/include/cufinufft/common.h +++ b/include/cufinufft/common.h @@ -7,31 +7,37 @@ #include #include -#include +#include namespace cufinufft { namespace common { template -__global__ void fseries_kernel_compute(int nf1, int nf2, int nf3, T *f, - cuDoubleComplex *a, T *fwkerhalf1, T *fwkerhalf2, +__global__ void fseries_kernel_compute(int nf1, int nf2, int nf3, T *f, T *a, + T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3, + int ns); +template +__global__ void cu_nuft_kernel_compute(int nf1, int nf2, int nf3, T *f, T *z, T *kx, + T *ky, T *kz, T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3, int ns); template -int cufserieskernelcompute(int dim, int nf1, int nf2, int nf3, T *d_f, - cuDoubleComplex *d_a, T *d_fwkerhalf1, T *d_fwkerhalf2, - T *d_fwkerhalf3, int ns, cudaStream_t stream); +int fseries_kernel_compute(int dim, int nf1, int nf2, int nf3, T *d_f, T *d_phase, + T *d_fwkerhalf1, T *d_fwkerhalf2, T *d_fwkerhalf3, int ns, + cudaStream_t stream); +template +int nuft_kernel_compute(int dim, int nf1, int nf2, int nf3, T *d_f, T *d_z, T *d_kx, + T *d_ky, T *d_kz, T *d_fwkerhalf1, T *d_fwkerhalf2, + T *d_fwkerhalf3, int ns, cudaStream_t stream); template int setup_spreader_for_nufft(finufft_spread_opts &spopts, T eps, cufinufft_opts opts); void set_nf_type12(CUFINUFFT_BIGINT ms, cufinufft_opts opts, finufft_spread_opts spopts, CUFINUFFT_BIGINT *nf, CUFINUFFT_BIGINT b); + template -void onedim_fseries_kernel(CUFINUFFT_BIGINT nf, T *fwkerhalf, finufft_spread_opts opts); -template -void onedim_fseries_kernel_precomp(CUFINUFFT_BIGINT nf, T *f, std::complex *a, +void onedim_fseries_kernel_precomp(CUFINUFFT_BIGINT nf, T *f, T *a, finufft_spread_opts opts); template -void onedim_fseries_kernel_compute(CUFINUFFT_BIGINT nf, T *f, std::complex *a, - T *fwkerhalf, finufft_spread_opts opts); +void onedim_nuft_kernel_precomp(T *f, T *zout, finufft_spread_opts opts); template std::size_t shared_memory_required(int dim, int ns, int bin_size_x, int bin_size_y, diff --git a/include/cufinufft/contrib/helper_math.h b/include/cufinufft/contrib/helper_math.h new file mode 100644 index 000000000..119aca6b6 --- /dev/null +++ b/include/cufinufft/contrib/helper_math.h @@ -0,0 +1,173 @@ +#ifndef FINUFFT_INCLUDE_CUFINUFFT_CONTRIB_HELPER_MATH_H +#define FINUFFT_INCLUDE_CUFINUFFT_CONTRIB_HELPER_MATH_H + +#include + +// This header provides some helper functions for cuComplex types. +// It mainly wraps existing CUDA implementations to provide operator overloads +// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are all +// provided by CUDA + +// Addition for cuDoubleComplex (double) with cuDoubleComplex (double) +__host__ __device__ __forceinline__ cuDoubleComplex operator+( + const cuDoubleComplex &a, const cuDoubleComplex &b) noexcept { + return cuCadd(a, b); +} + +// Subtraction for cuDoubleComplex (double) with cuDoubleComplex (double) +__host__ __device__ __forceinline__ cuDoubleComplex operator-( + const cuDoubleComplex &a, const cuDoubleComplex &b) noexcept { + return cuCsub(a, b); +} + +// Multiplication for cuDoubleComplex (double) with cuDoubleComplex (double) +__host__ __device__ __forceinline__ cuDoubleComplex operator*( + const cuDoubleComplex &a, const cuDoubleComplex &b) noexcept { + return cuCmul(a, b); +} + +// Division for cuDoubleComplex (double) with cuDoubleComplex (double) +__host__ __device__ __forceinline__ cuDoubleComplex operator/( + const cuDoubleComplex &a, const cuDoubleComplex &b) noexcept { + return cuCdiv(a, b); +} + +// Equality for cuDoubleComplex (double) with cuDoubleComplex (double) +__host__ __device__ __forceinline__ bool operator==(const cuDoubleComplex &a, + const cuDoubleComplex &b) noexcept { + return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b); +} + +// Inequality for cuDoubleComplex (double) with cuDoubleComplex (double) +__host__ __device__ __forceinline__ bool operator!=(const cuDoubleComplex &a, + const cuDoubleComplex &b) noexcept { + return !(a == b); +} + +// Addition for cuDoubleComplex (double) with double +__host__ __device__ __forceinline__ cuDoubleComplex operator+(const cuDoubleComplex &a, + double b) noexcept { + return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a)); +} + +__host__ __device__ __forceinline__ cuDoubleComplex operator+( + double a, const cuDoubleComplex &b) noexcept { + return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b)); +} + +// Subtraction for cuDoubleComplex (double) with double +__host__ __device__ __forceinline__ cuDoubleComplex operator-(const cuDoubleComplex &a, + double b) noexcept { + return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a)); +} + +__host__ __device__ __forceinline__ cuDoubleComplex operator-( + double a, const cuDoubleComplex &b) noexcept { + return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b)); +} + +// Multiplication for cuDoubleComplex (double) with double +__host__ __device__ __forceinline__ cuDoubleComplex operator*(const cuDoubleComplex &a, + double b) noexcept { + return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b); +} + +__host__ __device__ __forceinline__ cuDoubleComplex operator*( + double a, const cuDoubleComplex &b) noexcept { + return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b)); +} + +// Division for cuDoubleComplex (double) with double +__host__ __device__ __forceinline__ cuDoubleComplex operator/(const cuDoubleComplex &a, + double b) noexcept { + return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b); +} + +__host__ __device__ __forceinline__ cuDoubleComplex operator/( + double a, const cuDoubleComplex &b) noexcept { + double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b); + return make_cuDoubleComplex((a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom); +} + +// Addition for cuFloatComplex (float) with cuFloatComplex (float) +__host__ __device__ __forceinline__ cuFloatComplex operator+( + const cuFloatComplex &a, const cuFloatComplex &b) noexcept { + return cuCaddf(a, b); +} + +// Subtraction for cuFloatComplex (float) with cuFloatComplex (float) +__host__ __device__ __forceinline__ cuFloatComplex operator-( + const cuFloatComplex &a, const cuFloatComplex &b) noexcept { + return cuCsubf(a, b); +} + +// Multiplication for cuFloatComplex (float) with cuFloatComplex (float) +__host__ __device__ __forceinline__ cuFloatComplex operator*( + const cuFloatComplex &a, const cuFloatComplex &b) noexcept { + return cuCmulf(a, b); +} + +// Division for cuFloatComplex (float) with cuFloatComplex (float) +__host__ __device__ __forceinline__ cuFloatComplex operator/( + const cuFloatComplex &a, const cuFloatComplex &b) noexcept { + return cuCdivf(a, b); +} + +// Equality for cuFloatComplex (float) with cuFloatComplex (float) +__host__ __device__ __forceinline__ bool operator==(const cuFloatComplex &a, + const cuFloatComplex &b) noexcept { + return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b); +} + +// Inequality for cuFloatComplex (float) with cuFloatComplex (float) +__host__ __device__ __forceinline__ bool operator!=(const cuFloatComplex &a, + const cuFloatComplex &b) noexcept { + return !(a == b); +} + +// Addition for cuFloatComplex (float) with float +__host__ __device__ __forceinline__ cuFloatComplex operator+(const cuFloatComplex &a, + float b) noexcept { + return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a)); +} + +__host__ __device__ __forceinline__ cuFloatComplex operator+( + float a, const cuFloatComplex &b) noexcept { + return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b)); +} + +// Subtraction for cuFloatComplex (float) with float +__host__ __device__ __forceinline__ cuFloatComplex operator-(const cuFloatComplex &a, + float b) noexcept { + return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a)); +} + +__host__ __device__ __forceinline__ cuFloatComplex operator-( + float a, const cuFloatComplex &b) noexcept { + return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b)); +} + +// Multiplication for cuFloatComplex (float) with float +__host__ __device__ __forceinline__ cuFloatComplex operator*(const cuFloatComplex &a, + float b) noexcept { + return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b); +} + +__host__ __device__ __forceinline__ cuFloatComplex operator*( + float a, const cuFloatComplex &b) noexcept { + return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b)); +} + +// Division for cuFloatComplex (float) with float +__host__ __device__ __forceinline__ cuFloatComplex operator/(const cuFloatComplex &a, + float b) noexcept { + return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b); +} + +__host__ __device__ __forceinline__ cuFloatComplex operator/( + float a, const cuFloatComplex &b) noexcept { + float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b); + return make_cuFloatComplex((a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom); +} + +#endif // FINUFFT_INCLUDE_CUFINUFFT_CONTRIB_HELPER_MATH_H diff --git a/include/cufinufft/defs.h b/include/cufinufft/defs.h index 6b2a075ea..630989a26 100644 --- a/include/cufinufft/defs.h +++ b/include/cufinufft/defs.h @@ -2,14 +2,16 @@ #define CUFINUFFT_DEFS_H #include - // constants needed within common // upper bound on w, ie nspread, even when padded (see evaluate_kernel_vector); also for // common -#define MAX_NSPREAD 16 +#define MAX_NSPREAD 16 // max number of positive quadr nodes -#define MAX_NQUAD 100 +#define MAX_NQUAD 100 + +// Fraction growth cut-off in utils:arraywidcen, sets when translate in type-3 +#define ARRAYWIDCEN_GROWFRAC 0.1 // FIXME: If cufft ever takes N > INT_MAX... constexpr int32_t MAX_NF = std::numeric_limits::max(); diff --git a/include/cufinufft/impl.h b/include/cufinufft/impl.h index 3a9fd6877..0913a404e 100644 --- a/include/cufinufft/impl.h +++ b/include/cufinufft/impl.h @@ -4,9 +4,9 @@ #include #include +#include #include -#include #include #include #include @@ -14,6 +14,7 @@ #include #include +#include // 1d template @@ -22,7 +23,9 @@ int cufinufft1d1_exec(cuda_complex *d_c, cuda_complex *d_fk, template int cufinufft1d2_exec(cuda_complex *d_c, cuda_complex *d_fk, cufinufft_plan_t *d_plan); - +template +int cufinufft1d3_exec(cuda_complex *d_c, cuda_complex *d_fk, + cufinufft_plan_t *d_plan); // 2d template int cufinufft2d1_exec(cuda_complex *d_c, cuda_complex *d_fk, @@ -31,6 +34,9 @@ template int cufinufft2d2_exec(cuda_complex *d_c, cuda_complex *d_fk, cufinufft_plan_t *d_plan); +template +int cufinufft2d3_exec(cuda_complex *d_c, cuda_complex *d_fk, + cufinufft_plan_t *d_plan); // 3d template int cufinufft3d1_exec(cuda_complex *d_c, cuda_complex *d_fk, @@ -38,6 +44,9 @@ int cufinufft3d1_exec(cuda_complex *d_c, cuda_complex *d_fk, template int cufinufft3d2_exec(cuda_complex *d_c, cuda_complex *d_fk, cufinufft_plan_t *d_plan); +template +int cufinufft3d3_exec(cuda_complex *d_c, cuda_complex *d_fk, + cufinufft_plan_t *d_plan); template int cufinufft_makeplan_impl(int type, int dim, int *nmodes, int iflag, int ntransf, T tol, @@ -62,11 +71,9 @@ int cufinufft_makeplan_impl(int type, int dim, int *nmodes, int iflag, int ntran Melody Shih 07/25/19. Use-facing moved to markdown, Barnett 2/16/21. Marco Barbone 07/26/24. Using SM when shared memory available is enough. */ + using namespace cufinufft::common; int ier; - cuDoubleComplex *d_a = nullptr; // fseries temp data - T *d_f = nullptr; // fseries temp data - - if (type < 1 || type > 2) { + if (type < 1 || type > 3) { fprintf(stderr, "[%s] Invalid type (%d): should be 1 or 2.\n", __func__, type); return FINUFFT_ERR_TYPE_NOTVALID; } @@ -76,205 +83,253 @@ int cufinufft_makeplan_impl(int type, int dim, int *nmodes, int iflag, int ntran return FINUFFT_ERR_NTRANS_NOTVALID; } - // Mult-GPU support: set the CUDA Device ID: - const int device_id = opts == nullptr ? 0 : opts->gpu_device_id; - cufinufft::utils::WithCudaDevice device_swapper(device_id); - /* allocate the plan structure, assign address to user pointer. */ auto *d_plan = new cufinufft_plan_t; - *d_plan_ptr = d_plan; - // Zero out your struct, (sets all pointers to NULL) memset(d_plan, 0, sizeof(*d_plan)); + *d_plan_ptr = d_plan; + + // Zero out your struct, (sets all pointers to NULL) + // set nf1, nf2, nf3 to 1 for type 3, type 1, type 2 will overwrite this + d_plan->nf1 = 1; + d_plan->nf2 = 1; + d_plan->nf3 = 1; + d_plan->tol = tol; /* If a user has not supplied their own options, assign defaults for them. */ if (opts == nullptr) { // use default opts cufinufft_default_opts(&(d_plan->opts)); } else { // or read from what's passed in d_plan->opts = *opts; // keep a deep copy; changing *opts now has no effect } - - // cudaMallocAsync isn't supported for all devices, regardless of cuda version. Check - // for support - cudaDeviceGetAttribute(&d_plan->supports_pools, cudaDevAttrMemoryPoolsSupported, - device_id); - static bool warned = false; - if (!warned && !d_plan->supports_pools && d_plan->opts.gpu_stream != nullptr) { - fprintf(stderr, - "[cufinufft] Warning: cudaMallocAsync not supported on this device. Use of " - "CUDA streams may not perform optimally.\n"); - warned = true; + d_plan->dim = dim; + d_plan->opts.gpu_maxbatchsize = std::max(d_plan->opts.gpu_maxbatchsize, 1); + + if (type != 3) { + d_plan->ms = nmodes[0]; + d_plan->mt = nmodes[1]; + d_plan->mu = nmodes[2]; + printf("[cufinufft] (ms,mt,mu): %d %d %d\n", d_plan->ms, d_plan->mt, d_plan->mu); + } else { + d_plan->opts.gpu_spreadinterponly = 1; } - auto &stream = d_plan->stream = (cudaStream_t)d_plan->opts.gpu_stream; - using namespace cufinufft::common; - /* Setup Spreader */ + int fftsign = (iflag >= 0) ? 1 : -1; + d_plan->iflag = fftsign; + d_plan->ntransf = ntransf; - // can return FINUFFT_WARN_EPS_TOO_SMALL=1, which is OK - if ((ier = setup_spreader_for_nufft(d_plan->spopts, tol, d_plan->opts)) > 1) { - delete *d_plan_ptr; - *d_plan_ptr = nullptr; - return ier; - } + int batchsize = (opts != nullptr) ? opts->gpu_maxbatchsize : 0; + // TODO: check if this is the right heuristic + if (batchsize == 0) // implies: use a heuristic. + batchsize = std::min(ntransf, 8); // heuristic from test codes + d_plan->batchsize = batchsize; - d_plan->dim = dim; - d_plan->ms = nmodes[0]; - d_plan->mt = nmodes[1]; - d_plan->mu = nmodes[2]; + const auto stream = d_plan->stream = (cudaStream_t)d_plan->opts.gpu_stream; - cufinufft_setup_binsize(type, d_plan->spopts.nspread, dim, &d_plan->opts); - RETURN_IF_CUDA_ERROR - - CUFINUFFT_BIGINT nf1 = 1, nf2 = 1, nf3 = 1; - set_nf_type12(d_plan->ms, d_plan->opts, d_plan->spopts, &nf1, - d_plan->opts.gpu_obinsizex); - if (dim > 1) - set_nf_type12(d_plan->mt, d_plan->opts, d_plan->spopts, &nf2, - d_plan->opts.gpu_obinsizey); - if (dim > 2) - set_nf_type12(d_plan->mu, d_plan->opts, d_plan->spopts, &nf3, - d_plan->opts.gpu_obinsizez); - - // dynamically request the maximum amount of shared memory available - // for the spreader - - /* Automatically set GPU method. */ - if (d_plan->opts.gpu_method == 0) { - /* For type 1, we default to method 2 (SM) since this is generally faster - * if there is enough shared memory available. Otherwise, we default to GM. - * - * For type 2, we always default to method 1 (GM). - */ - if (type == 2) { - d_plan->opts.gpu_method = 1; - } else { - // query the device for the amount of shared memory available - int shared_mem_per_block{}; - cudaDeviceGetAttribute(&shared_mem_per_block, - cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id); - RETURN_IF_CUDA_ERROR - // compute the amount of shared memory required for the method - const auto shared_mem_required = shared_memory_required( - dim, d_plan->spopts.nspread, d_plan->opts.gpu_binsizex, - d_plan->opts.gpu_binsizey, d_plan->opts.gpu_binsizez); - if ((shared_mem_required > shared_mem_per_block)) { - d_plan->opts.gpu_method = 1; - } else { - d_plan->opts.gpu_method = 2; - } + // Mult-GPU support: set the CUDA Device ID: + const int device_id = d_plan->opts.gpu_device_id; + const cufinufft::utils::WithCudaDevice FromID{device_id}; + + // cudaMallocAsync isn't supported for all devices, regardless of cuda version. Check + // for support + { + cudaDeviceGetAttribute(&d_plan->supports_pools, cudaDevAttrMemoryPoolsSupported, + device_id); + static bool warned = false; + if (!warned && !d_plan->supports_pools && d_plan->opts.gpu_stream != nullptr) { + fprintf(stderr, + "[cufinufft] Warning: cudaMallocAsync not supported on this device. Use of " + "CUDA streams may not perform optimally.\n"); + warned = true; } } - int fftsign = (iflag >= 0) ? 1 : -1; - - d_plan->nf1 = nf1; - d_plan->nf2 = nf2; - d_plan->nf3 = nf3; - d_plan->iflag = fftsign; - d_plan->ntransf = ntransf; - int maxbatchsize = opts ? opts->gpu_maxbatchsize : 0; - if (maxbatchsize == 0) // implies: use a heuristic. - maxbatchsize = std::min(ntransf, 8); // heuristic from test codes - d_plan->maxbatchsize = maxbatchsize; - d_plan->type = type; - - if (d_plan->type == 1) d_plan->spopts.spread_direction = 1; - if (d_plan->type == 2) d_plan->spopts.spread_direction = 2; - - using namespace cufinufft::memtransfer; - switch (d_plan->dim) { - case 1: { - if ((ier = allocgpumem1d_plan(d_plan))) goto finalize; - } break; - case 2: { - if ((ier = allocgpumem2d_plan(d_plan))) goto finalize; - } break; - case 3: { - if ((ier = allocgpumem3d_plan(d_plan))) goto finalize; - } break; + // simple check to use upsampfac=1.25 if tol is big + // FIXME: since cufft is really fast we should use 1.25 only if we run out of vram + if (d_plan->opts.upsampfac == 0.0) { // indicates auto-choose + d_plan->opts.upsampfac = 2.0; // default, and need for tol small + if (tol >= (T)1E-9 && type == 3) { // the tol sigma=5/4 can reach + d_plan->opts.upsampfac = 1.25; + } + if (d_plan->opts.debug) { + printf("[cufinufft] upsampfac automatically set to %.3g\n", d_plan->opts.upsampfac); + } } - cufftHandle fftplan; - cufftResult_t cufft_status; - switch (d_plan->dim) { - case 1: { - int n[] = {(int)nf1}; - int inembed[] = {(int)nf1}; - - cufft_status = cufftPlanMany(&fftplan, 1, n, inembed, 1, inembed[0], inembed, 1, - inembed[0], cufft_type(), maxbatchsize); - } break; - case 2: { - int n[] = {(int)nf2, (int)nf1}; - int inembed[] = {(int)nf2, (int)nf1}; + /* Setup Spreader */ + if ((ier = setup_spreader_for_nufft(d_plan->spopts, tol, d_plan->opts)) > 1) { + // can return FINUFFT_WARN_EPS_TOO_SMALL=1, which is OK + goto finalize; + } - cufft_status = - cufftPlanMany(&fftplan, 2, n, inembed, 1, inembed[0] * inembed[1], inembed, 1, - inembed[0] * inembed[1], cufft_type(), maxbatchsize); - } break; - case 3: { - int n[] = {(int)nf3, (int)nf2, (int)nf1}; - int inembed[] = {(int)nf3, (int)nf2, (int)nf1}; + d_plan->type = type; + d_plan->spopts.spread_direction = d_plan->type; - cufft_status = cufftPlanMany( - &fftplan, 3, n, inembed, 1, inembed[0] * inembed[1] * inembed[2], inembed, 1, - inembed[0] * inembed[1] * inembed[2], cufft_type(), maxbatchsize); - } break; + if (d_plan->opts.debug) { + // print the spreader options + printf("[cufinufft] spreader options:\n"); + printf("[cufinufft] nspread: %d\n", d_plan->spopts.nspread); } - if (cufft_status != CUFFT_SUCCESS) { - fprintf(stderr, "[%s] cufft makeplan error: %s", __func__, - cufftGetErrorString(cufft_status)); - ier = FINUFFT_ERR_CUDA_FAILURE; + cufinufft_setup_binsize(type, d_plan->spopts.nspread, dim, &d_plan->opts); + if (ier = cudaGetLastError(), ier != cudaSuccess) { goto finalize; } - cufftSetStream(fftplan, stream); - - d_plan->fftplan = fftplan; - { - std::complex *a = d_plan->fseries_precomp_a; - T *f = d_plan->fseries_precomp_f; + if (d_plan->opts.debug) { + printf("[cufinufft] bin size x: %d", d_plan->opts.gpu_binsizex); + if (dim > 1) printf(" bin size y: %d", d_plan->opts.gpu_binsizey); + if (dim > 2) printf(" bin size z: %d", d_plan->opts.gpu_binsizez); + printf("\n"); + // shared memory required for the spreader vs available shared memory + int shared_mem_per_block{}; + cudaDeviceGetAttribute(&shared_mem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, + device_id); + const auto mem_required = + shared_memory_required(dim, d_plan->spopts.nspread, d_plan->opts.gpu_binsizex, + d_plan->opts.gpu_binsizey, d_plan->opts.gpu_binsizez); + printf("[cufinufft] shared memory required for the spreader: %ld\n", mem_required); + } - onedim_fseries_kernel_precomp(nf1, f, a, d_plan->spopts); + if (type == 1 || type == 2) { + CUFINUFFT_BIGINT nf1 = 1, nf2 = 1, nf3 = 1; + set_nf_type12(d_plan->ms, d_plan->opts, d_plan->spopts, &nf1, + d_plan->opts.gpu_obinsizex); if (dim > 1) - onedim_fseries_kernel_precomp(nf2, f + MAX_NQUAD, a + MAX_NQUAD, d_plan->spopts); + set_nf_type12(d_plan->mt, d_plan->opts, d_plan->spopts, &nf2, + d_plan->opts.gpu_obinsizey); if (dim > 2) - onedim_fseries_kernel_precomp(nf3, f + 2 * MAX_NQUAD, a + 2 * MAX_NQUAD, - d_plan->spopts); + set_nf_type12(d_plan->mu, d_plan->opts, d_plan->spopts, &nf3, + d_plan->opts.gpu_obinsizez); + + // dynamically request the maximum amount of shared memory available + // for the spreader + + /* Automatically set GPU method. */ + if (d_plan->opts.gpu_method == 0) { + /* For type 1, we default to method 2 (SM) since this is generally faster + * if there is enough shared memory available. Otherwise, we default to GM. + * + * For type 2, we always default to method 1 (GM). + */ + if (type == 2) { + d_plan->opts.gpu_method = 1; + } else { + // query the device for the amount of shared memory available + int shared_mem_per_block{}; + cudaDeviceGetAttribute(&shared_mem_per_block, + cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id); + // compute the amount of shared memory required for the method + const auto shared_mem_required = shared_memory_required( + dim, d_plan->spopts.nspread, d_plan->opts.gpu_binsizex, + d_plan->opts.gpu_binsizey, d_plan->opts.gpu_binsizez); + if ((shared_mem_required > shared_mem_per_block)) { + d_plan->opts.gpu_method = 1; + } else { + d_plan->opts.gpu_method = 2; + } + } + } - if ((ier = checkCudaErrors( - cudaMallocWrapper(&d_a, dim * MAX_NQUAD * sizeof(cuDoubleComplex), stream, - d_plan->supports_pools)))) - goto finalize; - if ((ier = checkCudaErrors(cudaMallocWrapper(&d_f, dim * MAX_NQUAD * sizeof(T), - stream, d_plan->supports_pools)))) - goto finalize; - if ((ier = checkCudaErrors( - cudaMemcpyAsync(d_a, a, dim * MAX_NQUAD * sizeof(cuDoubleComplex), - cudaMemcpyHostToDevice, stream)))) + if ((ier = cudaGetLastError())) { goto finalize; - if ((ier = checkCudaErrors(cudaMemcpyAsync(d_f, f, dim * MAX_NQUAD * sizeof(T), - cudaMemcpyHostToDevice, stream)))) + } + + d_plan->nf1 = nf1; + d_plan->nf2 = nf2; + d_plan->nf3 = nf3; + d_plan->nf = nf1 * nf2 * nf3; + if (d_plan->opts.debug) { + printf("[cufinufft] (nf1,nf2,nf3) = (%d, %d, %d)\n", d_plan->nf1, d_plan->nf2, + d_plan->nf3); + } + + using namespace cufinufft::memtransfer; + switch (d_plan->dim) { + case 1: { + if ((ier = allocgpumem1d_plan(d_plan))) goto finalize; + } break; + case 2: { + if ((ier = allocgpumem2d_plan(d_plan))) goto finalize; + } break; + case 3: { + if ((ier = allocgpumem3d_plan(d_plan))) goto finalize; + } break; + } + + cufftHandle fftplan; + cufftResult_t cufft_status; + switch (d_plan->dim) { + case 1: { + int n[] = {(int)nf1}; + int inembed[] = {(int)nf1}; + + cufft_status = cufftPlanMany(&fftplan, 1, n, inembed, 1, inembed[0], inembed, 1, + inembed[0], cufft_type(), batchsize); + } break; + case 2: { + int n[] = {(int)nf2, (int)nf1}; + int inembed[] = {(int)nf2, (int)nf1}; + + cufft_status = + cufftPlanMany(&fftplan, 2, n, inembed, 1, inembed[0] * inembed[1], inembed, 1, + inembed[0] * inembed[1], cufft_type(), batchsize); + } break; + case 3: { + int n[] = {(int)nf3, (int)nf2, (int)nf1}; + int inembed[] = {(int)nf3, (int)nf2, (int)nf1}; + + cufft_status = cufftPlanMany( + &fftplan, 3, n, inembed, 1, inembed[0] * inembed[1] * inembed[2], inembed, 1, + inembed[0] * inembed[1] * inembed[2], cufft_type(), batchsize); + } break; + } + + if (cufft_status != CUFFT_SUCCESS) { + fprintf(stderr, "[%s] cufft makeplan error: %s", __func__, + cufftGetErrorString(cufft_status)); + ier = FINUFFT_ERR_CUDA_FAILURE; goto finalize; - if ((ier = cufserieskernelcompute( - d_plan->dim, nf1, nf2, nf3, d_f, d_a, d_plan->fwkerhalf1, d_plan->fwkerhalf2, - d_plan->fwkerhalf3, d_plan->spopts.nspread, stream))) + } + cufftSetStream(fftplan, stream); + + d_plan->fftplan = fftplan; + + // compute up to 3 * NQUAD precomputed values on CPU + T fseries_precomp_phase[3 * MAX_NQUAD]; + T fseries_precomp_f[3 * MAX_NQUAD]; + thrust::device_vector d_fseries_precomp_phase(3 * MAX_NQUAD); + thrust::device_vector d_fseries_precomp_f(3 * MAX_NQUAD); + onedim_fseries_kernel_precomp(d_plan->nf1, fseries_precomp_f, + fseries_precomp_phase, d_plan->spopts); + if (d_plan->dim > 1) + onedim_fseries_kernel_precomp(d_plan->nf2, fseries_precomp_f + MAX_NQUAD, + fseries_precomp_phase + MAX_NQUAD, d_plan->spopts); + if (d_plan->dim > 2) + onedim_fseries_kernel_precomp(d_plan->nf3, fseries_precomp_f + 2 * MAX_NQUAD, + fseries_precomp_phase + 2 * MAX_NQUAD, + d_plan->spopts); + // copy the precomputed data to the device using thrust + thrust::copy(fseries_precomp_phase, fseries_precomp_phase + 3 * MAX_NQUAD, + d_fseries_precomp_phase.begin()); + thrust::copy(fseries_precomp_f, fseries_precomp_f + 3 * MAX_NQUAD, + d_fseries_precomp_f.begin()); + // the full fseries is done on the GPU here + if ((ier = fseries_kernel_compute( + d_plan->dim, d_plan->nf1, d_plan->nf2, d_plan->nf3, + d_fseries_precomp_f.data().get(), d_fseries_precomp_phase.data().get(), + d_plan->fwkerhalf1, d_plan->fwkerhalf2, d_plan->fwkerhalf3, + d_plan->spopts.nspread, stream))) goto finalize; } - finalize: - cudaFreeWrapper(d_a, stream, d_plan->supports_pools); - cudaFreeWrapper(d_f, stream, d_plan->supports_pools); - if (ier > 1) { - delete *d_plan_ptr; + cufinufft_destroy_impl(*d_plan_ptr); *d_plan_ptr = nullptr; } - return ier; } template -int cufinufft_setpts_impl(int M, T *d_kx, T *d_ky, T *d_kz, int N, T *d_s, T *d_t, T *d_u, - cufinufft_plan_t *d_plan) +int cufinufft_setpts_12_impl(int M, T *d_kx, T *d_ky, T *d_kz, + cufinufft_plan_t *d_plan) /* "setNUpts" stage (in single or double precision). @@ -312,7 +367,7 @@ Notes: the type T means either single or double, matching the Melody Shih 07/25/19; Barnett 2/16/21 moved out docs. */ { - cufinufft::utils::WithCudaDevice device_swapper(d_plan->opts.gpu_device_id); + const cufinufft::utils::WithCudaDevice FromID(d_plan->opts.gpu_device_id); int nf1 = d_plan->nf1; int nf2 = d_plan->nf2; @@ -378,9 +433,381 @@ Notes: the type T means either single or double, matching the } break; } + if (d_plan->opts.debug) { + printf("[cufinufft] plan->M=%d\n", M); + } return ier; } +template +int cufinufft_setpts_impl(int M, T *d_kx, T *d_ky, T *d_kz, int N, T *d_s, T *d_t, T *d_u, + cufinufft_plan_t *d_plan) { + // type 1 and type 2 setpts + if (d_plan->type == 1 || d_plan->type == 2) { + return cufinufft_setpts_12_impl(M, d_kx, d_ky, d_kz, d_plan); + } + // type 3 setpts + + // This code follows the same implementation of the CPU code in finufft and uses similar + // variables names where possible. However, the use of GPU routines and paradigms make + // it harder to follow. To understand the code, it is recommended to read the CPU code + // first. + + if (d_plan->type != 3) { + fprintf(stderr, "[%s] Invalid type (%d): should be 1, 2, or 3.\n", __func__, + d_plan->type); + return FINUFFT_ERR_TYPE_NOTVALID; + } + if (N < 0) { + fprintf(stderr, "[cufinufft] Invalid N (%d): cannot be negative.\n", N); + return FINUFFT_ERR_NUM_NU_PTS_INVALID; + } + if (N > MAX_NF) { + fprintf(stderr, "[cufinufft] Invalid N (%d): cannot be greater than %d.\n", N, + MAX_NF); + return FINUFFT_ERR_NUM_NU_PTS_INVALID; + } + const auto stream = d_plan->stream; + d_plan->N = N; + if (d_plan->dim > 0 && d_s == nullptr) { + fprintf(stderr, "[%s] Error: d_s is nullptr but dim > 0.\n", __func__); + return FINUFFT_ERR_INVALID_ARGUMENT; + } + d_plan->d_Sp = d_plan->dim > 0 ? d_s : nullptr; + + if (d_plan->dim > 1 && d_t == nullptr) { + fprintf(stderr, "[%s] Error: d_t is nullptr but dim > 1.\n", __func__); + return FINUFFT_ERR_INVALID_ARGUMENT; + } + d_plan->d_Tp = d_plan->dim > 1 ? d_t : nullptr; + + if (d_plan->dim > 2 && d_u == nullptr) { + fprintf(stderr, "[%s] Error: d_u is nullptr but dim > 2.\n", __func__); + return FINUFFT_ERR_INVALID_ARGUMENT; + } + d_plan->d_Up = d_plan->dim > 2 ? d_u : nullptr; + + const auto dim = d_plan->dim; + // no need to set the params to zero, as they are already zeroed out in the plan + // memset(d_plan->type3_params, 0, sizeof(d_plan->type3_params)); + using namespace cufinufft::utils; + if (d_plan->dim > 0) { + const auto [x1, c1] = arraywidcen(M, d_kx, stream); + d_plan->type3_params.X1 = x1; + d_plan->type3_params.C1 = c1; + const auto [S1, D1] = arraywidcen(N, d_s, stream); + const auto [nf1, h1, gam1] = set_nhg_type3(S1, x1, d_plan->opts, d_plan->spopts); + d_plan->nf1 = nf1; + d_plan->type3_params.S1 = S1; + d_plan->type3_params.D1 = D1; + d_plan->type3_params.h1 = h1; + d_plan->type3_params.gam1 = gam1; + } + if (d_plan->dim > 1) { + const auto [x2, c2] = arraywidcen(M, d_ky, stream); + d_plan->type3_params.X2 = x2; + d_plan->type3_params.C2 = c2; + const auto [S2, D2] = arraywidcen(N, d_t, stream); + const auto [nf2, h2, gam2] = set_nhg_type3(S2, x2, d_plan->opts, d_plan->spopts); + d_plan->nf2 = nf2; + d_plan->type3_params.S2 = S2; + d_plan->type3_params.D2 = D2; + d_plan->type3_params.h2 = h2; + d_plan->type3_params.gam2 = gam2; + } + if (d_plan->dim > 2) { + const auto [x3, c3] = arraywidcen(M, d_kz, stream); + d_plan->type3_params.X3 = x3; + d_plan->type3_params.C3 = c3; + const auto [S3, D3] = arraywidcen(N, d_u, stream); + const auto [nf3, h3, gam3] = set_nhg_type3(S3, x3, d_plan->opts, d_plan->spopts); + d_plan->nf3 = nf3; + d_plan->type3_params.S3 = S3; + d_plan->type3_params.D3 = D3; + d_plan->type3_params.h3 = h3; + d_plan->type3_params.gam3 = gam3; + } + if (d_plan->opts.debug) { + printf("[%s]", __func__); + printf("\tM=%d N=%d\n", M, N); + printf("\tX1=%.3g C1=%.3g S1=%.3g D1=%.3g gam1=%g nf1=%d h1=%.3g\t\n", + d_plan->type3_params.X1, d_plan->type3_params.C1, d_plan->type3_params.S1, + d_plan->type3_params.D1, d_plan->type3_params.gam1, d_plan->nf1, + d_plan->type3_params.h1); + if (d_plan->dim > 1) { + printf("\tX2=%.3g C2=%.3g S2=%.3g D2=%.3g gam2=%g nf2=%d h2=%.3g\n", + d_plan->type3_params.X2, d_plan->type3_params.C2, d_plan->type3_params.S2, + d_plan->type3_params.D2, d_plan->type3_params.gam2, d_plan->nf2, + d_plan->type3_params.h2); + } + if (d_plan->dim > 2) { + printf("\tX3=%.3g C3=%.3g S3=%.3g D3=%.3g gam3=%g nf3=%d h3=%.3g\n", + d_plan->type3_params.X3, d_plan->type3_params.C3, d_plan->type3_params.S3, + d_plan->type3_params.D3, d_plan->type3_params.gam3, d_plan->nf3, + d_plan->type3_params.h3); + } + } + d_plan->nf = d_plan->nf1 * d_plan->nf2 * d_plan->nf3; + + // FIXME: MAX_NF might be too small... + if (d_plan->nf * d_plan->opts.gpu_maxbatchsize > MAX_NF) { + fprintf(stderr, + "[%s t3] fwBatch would be bigger than MAX_NF, not attempting malloc!\n", + __func__); + return FINUFFT_ERR_MAXNALLOC; + } + + // A macro might be better as it has access to __line__ and __func__ + const auto checked_free = [stream, pool = d_plan->supports_pools](auto x) constexpr { + if (!x) return cudaFreeWrapper(x, stream, pool); + return cudaSuccess; + }; + const auto checked_realloc = [checked_free, pool = d_plan->supports_pools, stream]( + auto &x, const auto size) constexpr { + if (auto ier = checked_free(x); ier != cudaSuccess) return ier; + return cudaMallocWrapper(&x, size, stream, pool); + }; + // FIXME: check the size of the allocs for the batch interface + if (checked_realloc(d_plan->fw, sizeof(cuda_complex) * d_plan->nf * + d_plan->batchsize) != cudaSuccess) + goto finalize; + if (checked_realloc(d_plan->CpBatch, sizeof(cuda_complex) * M * d_plan->batchsize) != + cudaSuccess) + goto finalize; + if (checked_realloc(d_plan->kx, sizeof(T) * M) != cudaSuccess) goto finalize; + if (checked_realloc(d_plan->d_Sp, sizeof(T) * N) != cudaSuccess) goto finalize; + if (d_plan->dim > 1) { + if (checked_realloc(d_plan->ky, sizeof(T) * M) != cudaSuccess) goto finalize; + if (checked_realloc(d_plan->d_Tp, sizeof(T) * N) != cudaSuccess) goto finalize; + } + if (d_plan->dim > 2) { + if (checked_realloc(d_plan->kz, sizeof(T) * M) != cudaSuccess) goto finalize; + if (checked_realloc(d_plan->d_Up, sizeof(T) * N) != cudaSuccess) goto finalize; + } + if (checked_realloc(d_plan->prephase, sizeof(cuda_complex) * M) != cudaSuccess) + goto finalize; + if (checked_realloc(d_plan->deconv, sizeof(cuda_complex) * N) != cudaSuccess) + goto finalize; + + // NOTE: init-captures are not allowed for extended __host__ __device__ lambdas + + if (d_plan->dim > 0) { + const auto ig1 = T(1) / d_plan->type3_params.gam1; + const auto C1 = -d_plan->type3_params.C1; + thrust::transform( + thrust::cuda::par.on(stream), d_kx, d_kx + M, d_plan->kx, + [ig1, C1] __host__ __device__(const T x) -> T { return (x + C1) * ig1; }); + } + if (d_plan->dim > 1) { + const auto ig2 = T(1) / d_plan->type3_params.gam2; + const auto C2 = -d_plan->type3_params.C2; + thrust::transform( + thrust::cuda::par.on(stream), d_ky, d_ky + M, d_plan->ky, + [ig2, C2] __host__ __device__(const T x) -> T { return (x + C2) * ig2; }); + } + if (d_plan->dim > 2) { + const auto ig3 = T(1) / d_plan->type3_params.gam3; + const auto C3 = -d_plan->type3_params.C3; + thrust::transform( + thrust::cuda::par.on(stream), d_kz, d_kz + M, d_plan->kz, + [ig3, C3] __host__ __device__(const T x) -> T { return (x + C3) * ig3; }); + } + if (d_plan->type3_params.D1 != 0 || d_plan->type3_params.D2 != 0 || + d_plan->type3_params.D3 != 0) { + // if ky is null, use kx for ky and kz + // this is not the most efficient implementation, but it is the most compact + const auto iterator = + thrust::make_zip_iterator(thrust::make_tuple(d_kx, + // to avoid out of bounds access, use + // kx if ky is null + (d_plan->dim > 1) ? d_ky : d_kx, + // same idea as above + (d_plan->dim > 2) ? d_kz : d_kx)); + const auto D1 = d_plan->type3_params.D1; + const auto D2 = d_plan->type3_params.D2; // this should be 0 if dim < 2 + const auto D3 = d_plan->type3_params.D3; // this should be 0 if dim < 3 + const auto realsign = d_plan->iflag >= 0 ? T(1) : T(-1); + thrust::transform( + thrust::cuda::par.on(stream), iterator, iterator + M, d_plan->prephase, + [D1, D2, D3, realsign] __host__ __device__( + const thrust::tuple &tuple) -> cuda_complex { + const auto x = thrust::get<0>(tuple); + const auto y = thrust::get<1>(tuple); + const auto z = thrust::get<2>(tuple); + // no branching because D2 and D3 are 0 if dim < 2 and dim < 3 + // this is generally faster on GPU + const auto phase = D1 * x + D2 * y + D3 * z; + // TODO: nvcc should have the sincos function + // check the cos + i*sin + // ref: https://en.wikipedia.org/wiki/Cis_(mathematics) + return cuda_complex{std::cos(phase), std::sin(phase) * realsign}; + }); + } else { + thrust::fill(thrust::cuda::par.on(stream), d_plan->prephase, d_plan->prephase + M, + cuda_complex{1, 0}); + } + + if (d_plan->dim > 0) { + const auto scale = d_plan->type3_params.h1 * d_plan->type3_params.gam1; + const auto D1 = -d_plan->type3_params.D1; + thrust::transform( + thrust::cuda::par.on(stream), d_s, d_s + N, d_plan->d_Sp, + [scale, D1] __host__ __device__(const T s) -> T { return scale * (s + D1); }); + } + if (d_plan->dim > 1) { + const auto scale = d_plan->type3_params.h2 * d_plan->type3_params.gam2; + const auto D2 = -d_plan->type3_params.D2; + thrust::transform( + thrust::cuda::par.on(stream), d_t, d_t + N, d_plan->d_Tp, + [scale, D2] __host__ __device__(const T t) -> T { return scale * (t + D2); }); + } + if (d_plan->dim > 2) { + const auto scale = d_plan->type3_params.h3 * d_plan->type3_params.gam3; + const auto D3 = -d_plan->type3_params.D3; + thrust::transform( + thrust::cuda::par.on(stream), d_u, d_u + N, d_plan->d_Up, + [scale, D3] __host__ __device__(const T u) -> T { return scale * (u + D3); }); + } + { // here we declare phi_hat1, phi_hat2, and phi_hat3 + // and the precomputed data for the fseries kernel + using namespace cufinufft::common; + + std::array nuft_precomp_z{}; + std::array nuft_precomp_f{}; + thrust::device_vector d_nuft_precomp_z(3 * MAX_NQUAD); + thrust::device_vector d_nuft_precomp_f(3 * MAX_NQUAD); + thrust::device_vector phi_hat1, phi_hat2, phi_hat3; + if (d_plan->dim > 0) { + phi_hat1.resize(N); + } + if (d_plan->dim > 1) { + phi_hat2.resize(N); + } + if (d_plan->dim > 2) { + phi_hat3.resize(N); + } + onedim_nuft_kernel_precomp(nuft_precomp_f.data(), nuft_precomp_z.data(), + d_plan->spopts); + if (d_plan->dim > 1) { + onedim_nuft_kernel_precomp(nuft_precomp_f.data() + MAX_NQUAD, + nuft_precomp_z.data() + MAX_NQUAD, + d_plan->spopts); + } + if (d_plan->dim > 2) { + onedim_nuft_kernel_precomp(nuft_precomp_f.data() + 2 * MAX_NQUAD, + nuft_precomp_z.data() + 2 * MAX_NQUAD, + d_plan->spopts); + } + // copy the precomputed data to the device using thrust + thrust::copy(nuft_precomp_z.begin(), nuft_precomp_z.end(), d_nuft_precomp_z.begin()); + thrust::copy(nuft_precomp_f.begin(), nuft_precomp_f.end(), d_nuft_precomp_f.begin()); + // sync the stream before calling the kernel might be needed + if (nuft_kernel_compute(d_plan->dim, N, N, N, d_nuft_precomp_f.data().get(), + d_nuft_precomp_z.data().get(), d_plan->d_Sp, d_plan->d_Tp, + d_plan->d_Up, phi_hat1.data().get(), phi_hat2.data().get(), + phi_hat3.data().get(), d_plan->spopts.nspread, stream)) + goto finalize; + + const auto is_c_finite = std::isfinite(d_plan->type3_params.C1) && + std::isfinite(d_plan->type3_params.C2) && + std::isfinite(d_plan->type3_params.C3); + const auto is_c_nonzero = d_plan->type3_params.C1 != 0 || + d_plan->type3_params.C2 != 0 || + d_plan->type3_params.C3 != 0; + + const auto phi_hat_iterator = thrust::make_zip_iterator( + thrust::make_tuple(phi_hat1.begin(), + // to avoid out of bounds access, use phi_hat1 if dim < 2 + dim > 1 ? phi_hat2.begin() : phi_hat1.begin(), + // to avoid out of bounds access, use phi_hat1 if dim < 3 + dim > 2 ? phi_hat3.begin() : phi_hat1.begin())); + thrust::transform( + thrust::cuda::par.on(stream), phi_hat_iterator, phi_hat_iterator + N, + d_plan->deconv, + [dim] __host__ __device__(const thrust::tuple tuple) -> cuda_complex { + auto phiHat = thrust::get<0>(tuple); + // in case dim < 2 or dim < 3, multiply by 1 + phiHat *= (dim > 1) ? thrust::get<1>(tuple) : T(1); + phiHat *= (dim > 2) ? thrust::get<2>(tuple) : T(1); + return {T(1) / phiHat, T(0)}; + }); + + if (is_c_finite && is_c_nonzero) { + const auto c1 = d_plan->type3_params.C1; + const auto c2 = d_plan->type3_params.C2; + const auto c3 = d_plan->type3_params.C3; + const auto d1 = -d_plan->type3_params.D1; + const auto d2 = -d_plan->type3_params.D2; + const auto d3 = -d_plan->type3_params.D3; + const auto realsign = d_plan->iflag >= 0 ? T(1) : T(-1); + // passing d_s three times if dim == 1 because d_t and d_u are not allocated + // passing d_s and d_t if dim == 2 because d_u is not allocated + const auto phase_iterator = thrust::make_zip_iterator( + thrust::make_tuple(d_s, dim > 1 ? d_t : d_s, dim > 2 ? d_u : d_s)); + thrust::transform( + thrust::cuda::par.on(stream), phase_iterator, phase_iterator + N, + d_plan->deconv, d_plan->deconv, + [c1, c2, c3, d1, d2, d3, realsign] __host__ __device__( + const thrust::tuple tuple, + cuda_complex deconv) -> cuda_complex { + // d2 and d3 are 0 if dim < 2 and dim < 3 + const auto phase = c1 * (thrust::get<0>(tuple) + d1) + + c2 * (thrust::get<1>(tuple) + d2) + + c3 * (thrust::get<2>(tuple) + d3); + return cuda_complex{std::cos(phase), realsign * std::sin(phase)} * deconv; + }); + } + // exiting the block frees the memory allocated for phi_hat1, phi_hat2, and phi_hat3 + // and the precomputed data for the fseries kernel + // since GPU memory is expensive, we should free it as soon as possible + } + + using namespace cufinufft::memtransfer; + switch (d_plan->dim) { + case 1: { + if ((allocgpumem1d_plan(d_plan))) goto finalize; + } break; + case 2: { + if ((allocgpumem2d_plan(d_plan))) goto finalize; + } break; + case 3: { + if ((allocgpumem3d_plan(d_plan))) goto finalize; + } break; + } + if (cufinufft_setpts_12_impl(M, d_plan->kx, d_plan->ky, d_plan->kz, d_plan)) { + fprintf(stderr, "[%s] cufinufft_setpts_12_impl failed\n", __func__); + goto finalize; + } + { + int t2modes[] = {d_plan->nf1, d_plan->nf2, d_plan->nf3}; + cufinufft_opts t2opts = d_plan->opts; + t2opts.gpu_spreadinterponly = 0; + t2opts.gpu_method = 1; + // Safe to ignore the return value here? + if (d_plan->t2_plan) cufinufft_destroy_impl(d_plan->t2_plan); + // check that maxbatchsize is correct + if (cufinufft_makeplan_impl(2, dim, t2modes, d_plan->iflag, d_plan->batchsize, + d_plan->tol, &d_plan->t2_plan, &t2opts)) { + fprintf(stderr, "[%s] inner t2 plan cufinufft_makeplan failed\n", __func__); + goto finalize; + } + if (cufinufft_setpts_12_impl(N, d_plan->d_Sp, d_plan->d_Tp, d_plan->d_Up, + d_plan->t2_plan)) { + fprintf(stderr, "[%s] inner t2 plan cufinufft_setpts_12 failed\n", __func__); + goto finalize; + } + if (d_plan->t2_plan->spopts.spread_direction != 2) { + fprintf(stderr, "[%s] inner t2 plan cufinufft_setpts_12 wrong direction\n", + __func__); + goto finalize; + } + } + return 0; +finalize: + cufinufft_destroy_impl(d_plan); + return FINUFFT_ERR_CUDA_FAILURE; +} + template int cufinufft_execute_impl(cuda_complex *d_c, cuda_complex *d_fk, cufinufft_plan_t *d_plan) @@ -413,26 +840,17 @@ int cufinufft_execute_impl(cuda_complex *d_c, cuda_complex *d_fk, case 1: { if (type == 1) ier = cufinufft1d1_exec(d_c, d_fk, d_plan); if (type == 2) ier = cufinufft1d2_exec(d_c, d_fk, d_plan); - if (type == 3) { - std::cerr << "Not Implemented yet" << std::endl; - ier = FINUFFT_ERR_TYPE_NOTVALID; - } + if (type == 3) ier = cufinufft1d3_exec(d_c, d_fk, d_plan); } break; case 2: { if (type == 1) ier = cufinufft2d1_exec(d_c, d_fk, d_plan); if (type == 2) ier = cufinufft2d2_exec(d_c, d_fk, d_plan); - if (type == 3) { - std::cerr << "Not Implemented yet" << std::endl; - ier = FINUFFT_ERR_TYPE_NOTVALID; - } + if (type == 3) ier = cufinufft2d3_exec(d_c, d_fk, d_plan); } break; case 3: { if (type == 1) ier = cufinufft3d1_exec(d_c, d_fk, d_plan); if (type == 2) ier = cufinufft3d2_exec(d_c, d_fk, d_plan); - if (type == 3) { - std::cerr << "Not Implemented yet" << std::endl; - ier = FINUFFT_ERR_TYPE_NOTVALID; - } + if (type == 3) ier = cufinufft3d3_exec(d_c, d_fk, d_plan); } break; } @@ -447,20 +865,21 @@ int cufinufft_destroy_impl(cufinufft_plan_t *d_plan) In this stage, we (1) free all the memories that have been allocated on gpu (2) delete the cuFFT plan - - Also see ../docs/cppdoc.md for main user-facing documentation. */ { - cufinufft::utils::WithCudaDevice device_swapper(d_plan->opts.gpu_device_id); // Can't destroy a null pointer. if (!d_plan) return FINUFFT_ERR_PLAN_NOTVALID; + cufinufft::utils::WithCudaDevice device_swapper(d_plan->opts.gpu_device_id); + using namespace cufinufft::memtransfer; freegpumemory(d_plan); if (d_plan->fftplan) cufftDestroy(d_plan->fftplan); + if (d_plan->t2_plan) cufinufft_destroy_impl(d_plan->t2_plan); + /* free/destruct the plan */ delete d_plan; diff --git a/include/cufinufft/precision_independent.h b/include/cufinufft/precision_independent.h index 9fa48a07e..e8ef209c3 100644 --- a/include/cufinufft/precision_independent.h +++ b/include/cufinufft/precision_independent.h @@ -41,8 +41,6 @@ __global__ void calc_subprob_2d(int *bin_size, int *num_subprob, int maxsubprobs __global__ void map_b_into_subprob_2d(int *d_subprob_to_bin, int *d_subprobstartpts, int *d_numsubprob, int numbins); -__global__ void trivial_global_sort_index_2d(int M, int *index); - /* spreadinterp3d */ __global__ void calc_subprob_3d_v2(int *bin_size, int *num_subprob, int maxsubprobsize, int numbins); @@ -57,8 +55,6 @@ __global__ void calc_subprob_3d_v1(int binsperobinx, int binsperobiny, int binsp __global__ void map_b_into_subprob_3d_v1(int *d_subprob_to_obin, int *d_subprobstartpts, int *d_numsubprob, int numbins); -__global__ void trivial_global_sort_index_3d(int M, int *index); - __global__ void fill_ghost_bins(int binsperobinx, int binsperobiny, int binsperobinz, int nobinx, int nobiny, int nobinz, int *binsize); diff --git a/include/cufinufft/spreadinterp.h b/include/cufinufft/spreadinterp.h index 2963d381d..9efd094c8 100644 --- a/include/cufinufft/spreadinterp.h +++ b/include/cufinufft/spreadinterp.h @@ -10,7 +10,7 @@ namespace cufinufft { namespace spreadinterp { template -static __forceinline__ __device__ constexpr T fma(const T a, const T b, const T c) { +static __forceinline__ __device__ constexpr T cudaFMA(const T a, const T b, const T c) { if constexpr (std::is_same_v) { // fused multiply-add, round to nearest even return __fmaf_rn(a, b, c); @@ -20,9 +20,13 @@ static __forceinline__ __device__ constexpr T fma(const T a, const T b, const T } static_assert(std::is_same_v || std::is_same_v, "Only float and double are supported."); - return T{0}; + return std::fma(a, b, c); } +/** + * local NU coord fold+rescale macro: does the following affine transform to x: + * (x+PI) mod PI each to [0,N) + */ template constexpr __forceinline__ __host__ __device__ T fold_rescale(T x, int N) { constexpr auto x2pi = T(0.159154943091895345554011992339482617); @@ -30,14 +34,14 @@ constexpr __forceinline__ __host__ __device__ T fold_rescale(T x, int N) { #if defined(__CUDA_ARCH__) if constexpr (std::is_same_v) { // fused multiply-add, round to nearest even - auto result = __fmaf_rn(x, x2pi, half); + auto result = cudaFMA(x, x2pi, half); // subtract, round down result = __fsub_rd(result, floorf(result)); // multiply, round down return __fmul_rd(result, static_cast(N)); } else if constexpr (std::is_same_v) { // fused multiply-add, round to nearest even - auto result = __fma_rn(x, x2pi, half); + auto result = cudaFMA(x, x2pi, half); // subtract, round down result = __dsub_rd(result, floor(result)); // multiply, round down @@ -85,15 +89,15 @@ static __forceinline__ __device__ T evaluate_kernel(T x, T es_c, T es_beta, int } template -static __inline__ __device__ void eval_kernel_vec_horner(T *ker, const T x, const int w, - const double upsampfac) +static __device__ void eval_kernel_vec_horner(T *ker, const T x, const int w, + const double upsampfac) /* Fill ker[] with Horner piecewise poly approx to [-w/2,w/2] ES kernel eval at x_j = x + j, for j=0,..,w-1. Thus x in [-w/2,-w/2+1]. w is aka ns. This is the current evaluation method, since it's faster (except i7 w=16). Two upsampfacs implemented. Params must match ref formula. Barnett 4/24/18 */ { - const auto z = fma(T(2), x, T(w - 1)); // scale so local grid offset z in [-1,1] - // T z = 2 * x + w - 1.0; + // const T z = T(2) * x + T(w - 1); + const auto z = cudaFMA(T(2), x, T(w - 1)); // scale so local grid offset z in [-1,1] // insert the auto-generated code which expects z, w args, writes to ker... if (upsampfac == 2.0) { // floating point equality is fine here using FLT = T; diff --git a/include/cufinufft/types.h b/include/cufinufft/types.h index 16046c8ef..5b2fba790 100644 --- a/include/cufinufft/types.h +++ b/include/cufinufft/types.h @@ -8,20 +8,24 @@ #include #include -#include +#include #define CUFINUFFT_BIGINT int -// Ugly trick to map a template to a fixed type, here cuda_complex -template struct cuda_complex_impl; -template<> struct cuda_complex_impl { - using type = cuFloatComplex; -}; -template<> struct cuda_complex_impl { - using type = cuDoubleComplex; -}; - -template using cuda_complex = typename cuda_complex_impl::type; +// Marco Barbone 8/5/2924, replaced the ugly trick with std::conditional +// to define cuda_complex +// by using std::conditional and std::is_same, we can define cuda_complex +// if T is float, cuda_complex is cuFloatComplex +// if T is double, cuda_complex is cuDoubleComplex +// where cuFloatComplex and cuDoubleComplex are defined in cuComplex.h +// TODO: migrate to cuda/std/complex and remove this +// Issue: cufft seems not to support cuda::std::complex +// A reinterpret_cast should be enough +template +using cuda_complex = typename std::conditional< + std::is_same::value, cuFloatComplex, + typename std::conditional::value, cuDoubleComplex, + void>::type>::type; template struct cufinufft_plan_t { cufinufft_opts opts; @@ -37,7 +41,7 @@ template struct cufinufft_plan_t { CUFINUFFT_BIGINT mt; CUFINUFFT_BIGINT mu; int ntransf; - int maxbatchsize; + int batchsize; int iflag; int supports_pools; @@ -46,13 +50,39 @@ template struct cufinufft_plan_t { T *fwkerhalf2; T *fwkerhalf3; + // for type 1,2 it is a pointer to kx, ky, kz (no new allocs), for type 3 it + // for t3: allocated as "primed" (scaled) src pts x'_j, etc T *kx; T *ky; T *kz; + cuda_complex *CpBatch; // working array of prephased strengths + cuda_complex *fwbatch; + + // no allocs here cuda_complex *c; cuda_complex *fw; cuda_complex *fk; + // Type 3 specific + struct { + T X1, C1, S1, D1, h1, gam1; // x dim: X=halfwid C=center D=freqcen h,gam=rescale, + // s=interval + T X2, C2, S2, D2, h2, gam2; // y + T X3, C3, S3, D3, h3, gam3; // z + } type3_params; + int N; // number of NU freq pts (type 3 only) + CUFINUFFT_BIGINT nf; + T *d_Sp; + T *d_Tp; + T *d_Up; + T tol; + // inner type 2 plan for type 3 + cufinufft_plan_t *t2_plan; + // new allocs. + // FIXME: convert to device vectors to use resize + cuda_complex *prephase; // pre-phase, for all input NU pts + cuda_complex *deconv; // reciprocal of kernel FT, phase, all output NU pts + // Arrays that used in subprob method int *idxnupts; // length: #nupts, index of the nupts in the bin-sorted order int *sortidx; // length: #nupts, order inside the bin the nupt belongs to @@ -66,18 +96,14 @@ template struct cufinufft_plan_t { int *numnupts; int *subprob_to_nupts; - // Temporary variables to do fseries precomputation - std::complex fseries_precomp_a[3 * MAX_NQUAD]; - T fseries_precomp_f[3 * MAX_NQUAD]; - cufftHandle fftplan; cudaStream_t stream; }; -template static cufftType_t cufft_type(); -template<> inline cufftType_t cufft_type() { return CUFFT_C2C; } +template constexpr static inline cufftType_t cufft_type(); +template<> constexpr inline cufftType_t cufft_type() { return CUFFT_C2C; } -template<> inline cufftType_t cufft_type() { return CUFFT_Z2Z; } +template<> constexpr inline cufftType_t cufft_type() { return CUFFT_Z2Z; } static inline cufftResult cufft_ex(cufftHandle plan, cufftComplex *idata, cufftComplex *odata, int direction) { diff --git a/include/cufinufft/utils.h b/include/cufinufft/utils.h index 4bfaa801d..432711aae 100644 --- a/include/cufinufft/utils.h +++ b/include/cufinufft/utils.h @@ -3,7 +3,6 @@ // octave (mkoctfile) needs this otherwise it doesn't know what int64_t is! #include -#include #include #include @@ -15,6 +14,8 @@ #include #include +#include + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__) #else __inline__ __device__ double atomicAdd(double *address, double val) { @@ -34,19 +35,56 @@ __inline__ __device__ double atomicAdd(double *address, double val) { } #endif +/** + * It computes the stard and end point of the spreading window given the center x and the + * width ns. + */ +template __forceinline__ __device__ auto interval(const int ns, const T x) { + const auto xstart = int(std::ceil(x - T(ns) * T(.5))); + const auto xend = int(std::floor(x + T(ns) * T(.5))); + return int2{xstart, xend}; +} + +// Define a macro to check if NVCC version is >= 11.3 +#if defined(__CUDACC_VER_MAJOR__) && defined(__CUDACC_VER_MINOR__) +#if (__CUDACC_VER_MAJOR__ > 11) || \ + (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 3 && __CUDA_ARCH__ >= 600) +#define ALLOCA_SUPPORTED 1 +// windows compatibility +#if __has_include() +#include +#endif +#endif +#endif + +#if defined(__CUDA_ARCH__) +#if __CUDA_ARCH__ >= 900 +#define COMPUTE_CAPABILITY_90_OR_HIGHER 1 +#else +#define COMPUTE_CAPABILITY_90_OR_HIGHER 0 +#endif +#else +#define COMPUTE_CAPABILITY_90_OR_HIGHER 0 +#endif + namespace cufinufft { namespace utils { class WithCudaDevice { public: - WithCudaDevice(int device) { - cudaGetDevice(&orig_device_); + explicit WithCudaDevice(const int device) : orig_device_{get_orig_device()} { cudaSetDevice(device); } ~WithCudaDevice() { cudaSetDevice(orig_device_); } private: - int orig_device_; + const int orig_device_; + + static int get_orig_device() noexcept { + int device{}; + cudaGetDevice(&device); + return device; + } }; // jfm timer class @@ -72,49 +110,6 @@ template T infnorm(int n, std::complex *a) { return sqrt(nrm); } -#ifdef __CUDA_ARCH__ -__forceinline__ __device__ auto interval(const int ns, const float x) { - // float to int round up and fused multiply-add to round up - const auto xstart = __float2int_ru(__fmaf_ru(ns, -.5f, x)); - // float to int round down and fused multiply-add to round down - const auto xend = __float2int_rd(__fmaf_rd(ns, .5f, x)); - return int2{xstart, xend}; -} -__forceinline__ __device__ auto interval(const int ns, const double x) { - // same as above - const auto xstart = __double2int_ru(__fma_ru(ns, -.5, x)); - const auto xend = __double2int_rd(__fma_rd(ns, .5, x)); - return int2{xstart, xend}; -} -#endif - -// Define a macro to check if NVCC version is >= 11.3 -#if defined(__CUDACC_VER_MAJOR__) && defined(__CUDACC_VER_MINOR__) -#if (__CUDACC_VER_MAJOR__ > 11) || \ - (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 3 && __CUDA_ARCH__ >= 600) - -#define ALLOCA_SUPPORTED 1 -// windows compatibility -#if __has_include() -#include -#endif -#else -#define ALLOCA_SUPPORTED 0 -#endif -#else -#define ALLOCA_SUPPORTED 0 -#endif - -#if defined(__CUDA_ARCH__) -#if __CUDA_ARCH__ >= 900 -#define COMPUTE_CAPABILITY_90_OR_HIGHER 1 -#else -#define COMPUTE_CAPABILITY_90_OR_HIGHER 0 -#endif -#else -#define COMPUTE_CAPABILITY_90_OR_HIGHER 0 -#endif - /** * does a complex atomic add on a shared memory address * it adds the real and imaginary parts separately @@ -123,8 +118,8 @@ __forceinline__ __device__ auto interval(const int ns, const double x) { */ template -static __forceinline__ __device__ void atomicAddComplexShared( - cuda_complex *address, cuda_complex res) { +static __forceinline__ __device__ void atomicAddComplexShared(cuda_complex *address, + cuda_complex res) { const auto raw_address = reinterpret_cast(address); atomicAdd(raw_address, res.x); atomicAdd(raw_address + 1, res.y); @@ -136,8 +131,8 @@ static __forceinline__ __device__ void atomicAddComplexShared( * on shared memory are supported so we leverage them */ template -static __forceinline__ __device__ void atomicAddComplexGlobal( - cuda_complex *address, cuda_complex res) { +static __forceinline__ __device__ void atomicAddComplexGlobal(cuda_complex *address, + cuda_complex res) { if constexpr ( std::is_same_v, float2> && COMPUTE_CAPABILITY_90_OR_HIGHER) { atomicAdd(address, res); @@ -146,6 +141,68 @@ static __forceinline__ __device__ void atomicAddComplexGlobal( } } +template auto arrayrange(int n, T *a, cudaStream_t stream) { + const auto d_min_max = thrust::minmax_element(thrust::cuda::par.on(stream), a, a + n); + + // copy d_min and d_max to host + T min{}, max{}; + checkCudaErrors(cudaMemcpyAsync(&min, thrust::raw_pointer_cast(d_min_max.first), + sizeof(T), cudaMemcpyDeviceToHost, stream)); + checkCudaErrors(cudaMemcpyAsync(&max, thrust::raw_pointer_cast(d_min_max.second), + sizeof(T), cudaMemcpyDeviceToHost, stream)); + return std::make_tuple(min, max); +} + +// Writes out w = half-width and c = center of an interval enclosing all a[n]'s +// Only chooses a nonzero center if this increases w by less than fraction +// ARRAYWIDCEN_GROWFRAC defined in defs.h. +// This prevents rephasings which don't grow nf by much. 6/8/17 +// If n==0, w and c are not finite. +template auto arraywidcen(int n, T *a, cudaStream_t stream) { + const auto [lo, hi] = arrayrange(n, a, stream); + auto w = (hi - lo) / 2; + auto c = (hi + lo) / 2; + if (std::abs(c) < ARRAYWIDCEN_GROWFRAC * w) { + w += std::abs(c); + c = 0.0; + } + return std::make_tuple(w, c); +} + +template +auto set_nhg_type3(T S, T X, const cufinufft_opts &opts, + const finufft_spread_opts &spopts) +/* + * It implements the same function in finufft.cpp + * set_nhg_type3 in finufft.cpp for documentation + */ +{ + int nss = spopts.nspread + 1; // since ns may be odd + T Xsafe = X, Ssafe = S; // may be tweaked locally + if (X == 0.0) // logic ensures XS>=1, handle X=0 a/o S=0 + if (S == 0.0) { + Xsafe = 1.0; + Ssafe = 1.0; + } else + Xsafe = std::max(Xsafe, T(1) / S); + else + Ssafe = std::max(Ssafe, T(1) / X); + // use the safe X and S... + T nfd = 2.0 * opts.upsampfac * Ssafe * Xsafe / M_PI + nss; + if (!std::isfinite(nfd)) nfd = 0.0; // use FLT to catch inf + auto nf = (int)nfd; + // printf("initial nf=%lld, ns=%d\n",*nf,spopts.nspread); + // catch too small nf, and nan or +-inf, otherwise spread fails... + if (nf < 2 * spopts.nspread) nf = 2 * spopts.nspread; + if (nf < MAX_NF) // otherwise will fail anyway + nf = utils::next235beven(nf, 1); // expensive at huge nf + // Note: b is 1 because type 3 uses a type 2 plan, so it should not need the extra + // condition that seems to be used by Block Gather as type 2 are only GM-sort + auto h = 2 * T(M_PI) / nf; // upsampled grid spacing + auto gam = T(nf) / (2.0 * opts.upsampfac * Ssafe); // x scale fac to x' + return std::make_tuple(nf, h, gam); +} + } // namespace utils } // namespace cufinufft diff --git a/include/cufinufft_opts.h b/include/cufinufft_opts.h index c9898f3b7..743b3cf5c 100644 --- a/include/cufinufft_opts.h +++ b/include/cufinufft_opts.h @@ -29,6 +29,8 @@ typedef struct cufinufft_opts { // see cufinufft_default_opts() for defaults int modeord; // (type 1,2 only): 0 CMCL-style increasing mode order // 1 FFT-style mode order + + int debug; // 0: no debug, 1: debug } cufinufft_opts; #endif diff --git a/include/finufft_errors.h b/include/finufft_errors.h index 245e6d339..fb827557c 100644 --- a/include/finufft_errors.h +++ b/include/finufft_errors.h @@ -23,6 +23,7 @@ enum { FINUFFT_ERR_METHOD_NOTVALID = 17, FINUFFT_ERR_BINSIZE_NOTVALID = 18, FINUFFT_ERR_INSUFFICIENT_SHMEM = 19, - FINUFFT_ERR_NUM_NU_PTS_INVALID = 20 + FINUFFT_ERR_NUM_NU_PTS_INVALID = 20, + FINUFFT_ERR_INVALID_ARGUMENT = 21 }; #endif diff --git a/matlab/finufft.mw b/matlab/finufft.mw index 4758157a1..73c362e6a 100644 --- a/matlab/finufft.mw +++ b/matlab/finufft.mw @@ -150,7 +150,7 @@ classdef finufft_plan < handle plan.floatprec = opts.floatprec; end end - + n_modes = ones(3,1); % is dummy for type 3 if type==3 if length(n_modes_or_dim)~=1 @@ -179,7 +179,7 @@ classdef finufft_plan < handle plan.n_trans = n_trans; % Note the peculiarity that mwrap only accepts a double for n_trans, even % though it's declared int. It complains, also with int64 for nj, etc :( - + % replace in finufft_opts struct whichever fields are in incoming opts... # copy_finufft_opts(mxArray opts, finufft_opts* o); if strcmp(plan.floatprec,'double') diff --git a/perftest/cuda/CMakeLists.txt b/perftest/cuda/CMakeLists.txt index 8b1ad1c9b..92c870cc5 100644 --- a/perftest/cuda/CMakeLists.txt +++ b/perftest/cuda/CMakeLists.txt @@ -2,6 +2,8 @@ add_executable(cuperftest cuperftest.cu) target_include_directories(cuperftest PUBLIC ${CUFINUFFT_INCLUDE_DIRS}) target_link_libraries(cuperftest PUBLIC cufinufft) target_compile_features(cuperftest PRIVATE cxx_std_17) +target_compile_options(cuperftest + PRIVATE $<$:--extended-lambda>) set_target_properties( cuperftest PROPERTIES LINKER_LANGUAGE CUDA diff --git a/src/cuda/1d/cufinufft1d.cu b/src/cuda/1d/cufinufft1d.cu index a17b6f044..06389ef75 100644 --- a/src/cuda/1d/cufinufft1d.cu +++ b/src/cuda/1d/cufinufft1d.cu @@ -1,6 +1,9 @@ #include #include +#include + #include +#include #include #include @@ -34,17 +37,16 @@ int cufinufft1d1_exec(cuda_complex *d_c, cuda_complex *d_fk, int ier; cuda_complex *d_fkstart; cuda_complex *d_cstart; - for (int i = 0; i * d_plan->maxbatchsize < d_plan->ntransf; i++) { - int blksize = - std::min(d_plan->ntransf - i * d_plan->maxbatchsize, d_plan->maxbatchsize); - d_cstart = d_c + i * d_plan->maxbatchsize * d_plan->M; - d_fkstart = d_fk + i * d_plan->maxbatchsize * d_plan->ms; - d_plan->c = d_cstart; - d_plan->fk = d_fkstart; + for (int i = 0; i * d_plan->batchsize < d_plan->ntransf; i++) { + int blksize = std::min(d_plan->ntransf - i * d_plan->batchsize, d_plan->batchsize); + d_cstart = d_c + i * d_plan->batchsize * d_plan->M; + d_fkstart = d_fk + i * d_plan->batchsize * d_plan->ms; + d_plan->c = d_cstart; + d_plan->fk = d_fkstart; // this is needed if ((ier = checkCudaErrors(cudaMemsetAsync( - d_plan->fw, 0, d_plan->maxbatchsize * d_plan->nf1 * sizeof(cuda_complex), + d_plan->fw, 0, d_plan->batchsize * d_plan->nf1 * sizeof(cuda_complex), stream)))) return ier; @@ -88,11 +90,10 @@ int cufinufft1d2_exec(cuda_complex *d_c, cuda_complex *d_fk, int ier; cuda_complex *d_fkstart; cuda_complex *d_cstart; - for (int i = 0; i * d_plan->maxbatchsize < d_plan->ntransf; i++) { - int blksize = - std::min(d_plan->ntransf - i * d_plan->maxbatchsize, d_plan->maxbatchsize); - d_cstart = d_c + i * d_plan->maxbatchsize * d_plan->M; - d_fkstart = d_fk + i * d_plan->maxbatchsize * d_plan->ms; + for (int i = 0; i * d_plan->batchsize < d_plan->ntransf; i++) { + int blksize = std::min(d_plan->ntransf - i * d_plan->batchsize, d_plan->batchsize); + d_cstart = d_c + i * d_plan->batchsize * d_plan->M; + d_fkstart = d_fk + i * d_plan->batchsize * d_plan->ms; d_plan->c = d_cstart; d_plan->fk = d_fkstart; @@ -116,6 +117,65 @@ int cufinufft1d2_exec(cuda_complex *d_c, cuda_complex *d_fk, return 0; } +template +int cufinufft1d3_exec(cuda_complex *d_c, cuda_complex *d_fk, + cufinufft_plan_t *d_plan) { + /* + 1D Type-3 NUFFT + + This function is called in "exec" stage (See ../cufinufft.cu). + It includes (copied from doc in finufft library) + Step 0: pre-phase the input strengths + Step 1: spread data + Step 2: Type 2 NUFFT + Step 3: deconvolve (amplify) each Fourier mode, using kernel Fourier coeff + + Marco Barbone 08/14/2024 + */ + int ier; + cuda_complex *d_cstart; + cuda_complex *d_fkstart; + const auto stream = d_plan->stream; + for (int i = 0; i * d_plan->batchsize < d_plan->ntransf; i++) { + int blksize = min(d_plan->ntransf - i * d_plan->batchsize, d_plan->batchsize); + d_cstart = d_c + i * d_plan->batchsize * d_plan->M; + d_fkstart = d_fk + i * d_plan->batchsize * d_plan->N; + // setting input for spreader + d_plan->c = d_plan->CpBatch; + // setting output for spreader + d_plan->fk = d_plan->fw; + if ((ier = checkCudaErrors(cudaMemsetAsync( + d_plan->fw, 0, d_plan->batchsize * d_plan->nf * sizeof(cuda_complex), + stream)))) + return ier; + // NOTE: fw might need to be set to 0 + // Step 0: pre-phase the input strengths + for (int block = 0; block < blksize; block++) { + thrust::transform(thrust::cuda::par.on(stream), d_plan->prephase, + d_plan->prephase + d_plan->M, d_cstart + block * d_plan->M, + d_plan->c + block * d_plan->M, + thrust::multiplies>()); + } + // Step 1: Spread + if ((ier = cuspread1d(d_plan, blksize))) return ier; + // now d_plan->fk = d_plan->fw contains the spread values + // Step 2: Type 2 NUFFT + // type 2 goes from fk to c + // saving the results directly in the user output array d_fk + // it needs to do blksize transforms + d_plan->t2_plan->ntransf = blksize; + if ((ier = cufinufft1d2_exec(d_fkstart, d_plan->fw, d_plan->t2_plan))) return ier; + // Step 3: deconvolve + // now we need to d_fk = d_fk*d_plan->deconv + for (int i = 0; i < blksize; i++) { + thrust::transform(thrust::cuda::par.on(stream), d_plan->deconv, + d_plan->deconv + d_plan->N, d_fkstart + i * d_plan->N, + d_fkstart + i * d_plan->N, thrust::multiplies>()); + } + } + return 0; +} + template int cufinufft1d1_exec(cuda_complex *d_c, cuda_complex *d_fk, cufinufft_plan_t *d_plan); template int cufinufft1d1_exec(cuda_complex *d_c, @@ -126,3 +186,8 @@ template int cufinufft1d2_exec(cuda_complex *d_c, cuda_complex(cuda_complex *d_c, cuda_complex *d_fk, cufinufft_plan_t *d_plan); +template int cufinufft1d3_exec(cuda_complex *d_c, cuda_complex *d_fk, + cufinufft_plan_t *d_plan); +template int cufinufft1d3_exec(cuda_complex *d_c, + cuda_complex *d_fk, + cufinufft_plan_t *d_plan); diff --git a/src/cuda/2d/cufinufft2d.cu b/src/cuda/2d/cufinufft2d.cu index f7f7b1559..8c165edbf 100644 --- a/src/cuda/2d/cufinufft2d.cu +++ b/src/cuda/2d/cufinufft2d.cu @@ -2,7 +2,11 @@ #include #include #include + +#include + #include +#include #include #include @@ -34,17 +38,17 @@ int cufinufft2d1_exec(cuda_complex *d_c, cuda_complex *d_fk, cuda_complex *d_cstart; auto &stream = d_plan->stream; - for (int i = 0; i * d_plan->maxbatchsize < d_plan->ntransf; i++) { - int blksize = min(d_plan->ntransf - i * d_plan->maxbatchsize, d_plan->maxbatchsize); - d_cstart = d_c + i * d_plan->maxbatchsize * d_plan->M; - d_fkstart = d_fk + i * d_plan->maxbatchsize * d_plan->ms * d_plan->mt; + for (int i = 0; i * d_plan->batchsize < d_plan->ntransf; i++) { + int blksize = min(d_plan->ntransf - i * d_plan->batchsize, d_plan->batchsize); + d_cstart = d_c + i * d_plan->batchsize * d_plan->M; + d_fkstart = d_fk + i * d_plan->batchsize * d_plan->ms * d_plan->mt; d_plan->c = d_cstart; d_plan->fk = d_fkstart; // this is needed if ((ier = checkCudaErrors(cudaMemsetAsync( d_plan->fw, 0, - d_plan->maxbatchsize * d_plan->nf1 * d_plan->nf2 * sizeof(cuda_complex), + d_plan->batchsize * d_plan->nf1 * d_plan->nf2 * sizeof(cuda_complex), stream)))) return ier; @@ -88,10 +92,10 @@ int cufinufft2d2_exec(cuda_complex *d_c, cuda_complex *d_fk, int ier; cuda_complex *d_fkstart; cuda_complex *d_cstart; - for (int i = 0; i * d_plan->maxbatchsize < d_plan->ntransf; i++) { - int blksize = min(d_plan->ntransf - i * d_plan->maxbatchsize, d_plan->maxbatchsize); - d_cstart = d_c + i * d_plan->maxbatchsize * d_plan->M; - d_fkstart = d_fk + i * d_plan->maxbatchsize * d_plan->ms * d_plan->mt; + for (int i = 0; i * d_plan->batchsize < d_plan->ntransf; i++) { + int blksize = min(d_plan->ntransf - i * d_plan->batchsize, d_plan->batchsize); + d_cstart = d_c + i * d_plan->batchsize * d_plan->M; + d_fkstart = d_fk + i * d_plan->batchsize * d_plan->ms * d_plan->mt; d_plan->c = d_cstart; d_plan->fk = d_fkstart; @@ -115,6 +119,64 @@ int cufinufft2d2_exec(cuda_complex *d_c, cuda_complex *d_fk, return 0; } +template +int cufinufft2d3_exec(cuda_complex *d_c, cuda_complex *d_fk, + cufinufft_plan_t *d_plan) { + /* + 2D Type-3 NUFFT + + This function is called in "exec" stage (See ../cufinufft.cu). + It includes (copied from doc in finufft library) + Step 0: pre-phase the input strengths + Step 1: spread data + Step 2: Type 2 NUFFT + Step 3: deconvolve (amplify) each Fourier mode, using kernel Fourier coeff + + Marco Barbone 08/14/2024 + */ + int ier; + cuda_complex *d_cstart; + cuda_complex *d_fkstart; + const auto stream = d_plan->stream; + for (int i = 0; i * d_plan->batchsize < d_plan->ntransf; i++) { + int blksize = min(d_plan->ntransf - i * d_plan->batchsize, d_plan->batchsize); + d_cstart = d_c + i * d_plan->batchsize * d_plan->M; + d_fkstart = d_fk + i * d_plan->batchsize * d_plan->N; + // setting input for spreader + d_plan->c = d_plan->CpBatch; + // setting output for spreader + d_plan->fk = d_plan->fw; + if ((ier = checkCudaErrors(cudaMemsetAsync( + d_plan->fw, 0, d_plan->batchsize * d_plan->nf * sizeof(cuda_complex), + stream)))) + return ier; + // NOTE: fw might need to be set to 0 + // Step 0: pre-phase the input strengths + for (int i = 0; i < blksize; i++) { + thrust::transform(thrust::cuda::par.on(stream), d_plan->prephase, + d_plan->prephase + d_plan->M, d_cstart + i * d_plan->M, + d_plan->c + i * d_plan->M, thrust::multiplies>()); + } + // Step 1: Spread + if ((ier = cuspread2d(d_plan, blksize))) return ier; + // now d_plan->fk = d_plan->fw contains the spread values + // Step 2: Type 2 NUFFT + // type 2 goes from fk to c + // saving the results directly in the user output array d_fk + // it needs to do blksize transforms + d_plan->t2_plan->ntransf = blksize; + if ((ier = cufinufft2d2_exec(d_fkstart, d_plan->fw, d_plan->t2_plan))) return ier; + // Step 3: deconvolve + // now we need to d_fk = d_fk*d_plan->deconv + for (int i = 0; i < blksize; i++) { + thrust::transform(thrust::cuda::par.on(stream), d_plan->deconv, + d_plan->deconv + d_plan->N, d_fkstart + i * d_plan->N, + d_fkstart + i * d_plan->N, thrust::multiplies>()); + } + } + return 0; +} + template int cufinufft2d1_exec(cuda_complex *d_c, cuda_complex *d_fk, cufinufft_plan_t *d_plan); template int cufinufft2d1_exec(cuda_complex *d_c, @@ -125,3 +187,8 @@ template int cufinufft2d2_exec(cuda_complex *d_c, cuda_complex(cuda_complex *d_c, cuda_complex *d_fk, cufinufft_plan_t *d_plan); +template int cufinufft2d3_exec(cuda_complex *d_c, cuda_complex *d_fk, + cufinufft_plan_t *d_plan); +template int cufinufft2d3_exec(cuda_complex *d_c, + cuda_complex *d_fk, + cufinufft_plan_t *d_plan); diff --git a/src/cuda/2d/spread2d_wrapper.cu b/src/cuda/2d/spread2d_wrapper.cu index 80cf9f8e9..490c8eed1 100644 --- a/src/cuda/2d/spread2d_wrapper.cu +++ b/src/cuda/2d/spread2d_wrapper.cu @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -96,9 +97,7 @@ int cuspread2d_nuptsdriven_prop(int nf1, int nf2, int M, cufinufft_plan_t *d_ RETURN_IF_CUDA_ERROR } else { int *d_idxnupts = d_plan->idxnupts; - - trivial_global_sort_index_2d<<<(M + 1024 - 1) / 1024, 1024, 0, stream>>>(M, - d_idxnupts); + thrust::sequence(thrust::cuda::par.on(stream), d_idxnupts, d_idxnupts + M); RETURN_IF_CUDA_ERROR } diff --git a/src/cuda/2d/spreadinterp2d.cuh b/src/cuda/2d/spreadinterp2d.cuh index 53a243e7e..805e921aa 100644 --- a/src/cuda/2d/spreadinterp2d.cuh +++ b/src/cuda/2d/spreadinterp2d.cuh @@ -47,9 +47,9 @@ __global__ void spread_2d_nupts_driven( } for (auto yy = ystart; yy <= yend; yy++) { + const auto iy = yy < 0 ? yy + nf2 : (yy > nf2 - 1 ? yy - nf2 : yy); for (auto xx = xstart; xx <= xend; xx++) { const auto ix = xx < 0 ? xx + nf1 : (xx > nf1 - 1 ? xx - nf1 : xx); - const auto iy = yy < 0 ? yy + nf2 : (yy > nf2 - 1 ? yy - nf2 : yy); const auto outidx = ix + iy * nf1; const auto kervalue1 = ker1[xx - xstart]; const auto kervalue2 = ker2[yy - ystart]; @@ -240,9 +240,9 @@ __global__ void interp_2d_nupts_driven( cuda_complex cnow{0, 0}; for (int yy = ystart; yy <= yend; yy++) { const T kervalue2 = ker2[yy - ystart]; + const auto iy = yy < 0 ? yy + nf2 : (yy > nf2 - 1 ? yy - nf2 : yy); for (int xx = xstart; xx <= xend; xx++) { const auto ix = xx < 0 ? xx + nf1 : (xx > nf1 - 1 ? xx - nf1 : xx); - const auto iy = yy < 0 ? yy + nf2 : (yy > nf2 - 1 ? yy - nf2 : yy); const auto inidx = ix + iy * nf1; const auto kervalue1 = ker1[xx - xstart]; cnow.x += fw[inidx].x * kervalue1 * kervalue2; diff --git a/src/cuda/3d/cufinufft3d.cu b/src/cuda/3d/cufinufft3d.cu index 5977e6d5f..ce1f37c0e 100644 --- a/src/cuda/3d/cufinufft3d.cu +++ b/src/cuda/3d/cufinufft3d.cu @@ -3,11 +3,14 @@ #include #include +#include #include #include #include +#include + using namespace cufinufft::deconvolve; using namespace cufinufft::spreadinterp; using std::min; @@ -32,19 +35,17 @@ int cufinufft3d1_exec(cuda_complex *d_c, cuda_complex *d_fk, int ier; cuda_complex *d_fkstart; cuda_complex *d_cstart; - for (int i = 0; i * d_plan->maxbatchsize < d_plan->ntransf; i++) { - int blksize = min(d_plan->ntransf - i * d_plan->maxbatchsize, d_plan->maxbatchsize); - d_cstart = d_c + i * d_plan->maxbatchsize * d_plan->M; - d_fkstart = d_fk + i * d_plan->maxbatchsize * d_plan->ms * d_plan->mt * d_plan->mu; + for (int i = 0; i * d_plan->batchsize < d_plan->ntransf; i++) { + int blksize = min(d_plan->ntransf - i * d_plan->batchsize, d_plan->batchsize); + d_cstart = d_c + i * d_plan->batchsize * d_plan->M; + d_fkstart = d_fk + i * d_plan->batchsize * d_plan->ms * d_plan->mt * d_plan->mu; d_plan->c = d_cstart; d_plan->fk = d_fkstart; - if ((ier = checkCudaErrors( - cudaMemsetAsync(d_plan->fw, 0, - d_plan->maxbatchsize * d_plan->nf1 * d_plan->nf2 * - d_plan->nf3 * sizeof(cuda_complex), - stream)))) + if ((ier = checkCudaErrors(cudaMemsetAsync( + d_plan->fw, 0, d_plan->batchsize * d_plan->nf * sizeof(cuda_complex), + stream)))) return ier; // Step 1: Spread @@ -85,10 +86,10 @@ int cufinufft3d2_exec(cuda_complex *d_c, cuda_complex *d_fk, int ier; cuda_complex *d_fkstart; cuda_complex *d_cstart; - for (int i = 0; i * d_plan->maxbatchsize < d_plan->ntransf; i++) { - int blksize = min(d_plan->ntransf - i * d_plan->maxbatchsize, d_plan->maxbatchsize); - d_cstart = d_c + i * d_plan->maxbatchsize * d_plan->M; - d_fkstart = d_fk + i * d_plan->maxbatchsize * d_plan->ms * d_plan->mt * d_plan->mu; + for (int i = 0; i * d_plan->batchsize < d_plan->ntransf; i++) { + int blksize = min(d_plan->ntransf - i * d_plan->batchsize, d_plan->batchsize); + d_cstart = d_c + i * d_plan->batchsize * d_plan->M; + d_fkstart = d_fk + i * d_plan->batchsize * d_plan->ms * d_plan->mt * d_plan->mu; d_plan->c = d_cstart; d_plan->fk = d_fkstart; @@ -113,6 +114,66 @@ int cufinufft3d2_exec(cuda_complex *d_c, cuda_complex *d_fk, return 0; } +// TODO: in case data is centered, we could save GPU memory +template +int cufinufft3d3_exec(cuda_complex *d_c, cuda_complex *d_fk, + cufinufft_plan_t *d_plan) { + /* + 3D Type-3 NUFFT + + This function is called in "exec" stage (See ../cufinufft.cu). + It includes (copied from doc in finufft library) + Step 0: pre-phase the input strengths + Step 1: spread data + Step 2: Type 2 NUFFT + Step 3: deconvolve (amplify) each Fourier mode, using kernel Fourier coeff + + Marco Barbone 08/14/2024 + */ + int ier; + cuda_complex *d_cstart; + cuda_complex *d_fkstart; + const auto stream = d_plan->stream; + printf("[cufinufft] d_plan->ntransf = %d\n", d_plan->ntransf); + for (int i = 0; i * d_plan->batchsize < d_plan->ntransf; i++) { + int blksize = min(d_plan->ntransf - i * d_plan->batchsize, d_plan->batchsize); + d_cstart = d_c + i * d_plan->batchsize * d_plan->M; + d_fkstart = d_fk + i * d_plan->batchsize * d_plan->N; + // setting input for spreader + d_plan->c = d_plan->CpBatch; + // setting output for spreader + d_plan->fk = d_plan->fw; + // NOTE: fw might need to be set to 0 + if ((ier = checkCudaErrors(cudaMemsetAsync( + d_plan->fw, 0, d_plan->batchsize * d_plan->nf * sizeof(cuda_complex), + stream)))) + return ier; + // Step 0: pre-phase the input strengths + for (int i = 0; i < blksize; i++) { + thrust::transform(thrust::cuda::par.on(stream), d_plan->prephase, + d_plan->prephase + d_plan->M, d_cstart + i * d_plan->M, + d_plan->c + i * d_plan->M, thrust::multiplies>()); + } + // Step 1: Spread + if ((ier = cuspread3d(d_plan, blksize))) return ier; + // now d_plan->fk = d_plan->fw contains the spread values + // Step 2: Type 2 NUFFT + // type 2 goes from fk to c + // saving the results directly in the user output array d_fk + // it needs to do blksize transforms + d_plan->t2_plan->ntransf = blksize; + if ((ier = cufinufft3d2_exec(d_fkstart, d_plan->fw, d_plan->t2_plan))) return ier; + // Step 3: deconvolve + // now we need to d_fk = d_fk*d_plan->deconv + for (int i = 0; i < blksize; i++) { + thrust::transform(thrust::cuda::par.on(stream), d_plan->deconv, + d_plan->deconv + d_plan->N, d_fkstart + i * d_plan->N, + d_fkstart + i * d_plan->N, thrust::multiplies>()); + } + } + return 0; +} + template int cufinufft3d1_exec(cuda_complex *d_c, cuda_complex *d_fk, cufinufft_plan_t *d_plan); template int cufinufft3d1_exec(cuda_complex *d_c, @@ -124,3 +185,8 @@ template int cufinufft3d2_exec(cuda_complex *d_c, cuda_complex(cuda_complex *d_c, cuda_complex *d_fk, cufinufft_plan_t *d_plan); + +template int cufinufft3d3_exec(cuda_complex *d_c, cuda_complex *d_fk, + cufinufft_plan_t *d_plan); +template int cufinufft3d3_exec(cuda_complex *d_c, cuda_complex *d_fk, + cufinufft_plan_t *d_plan); diff --git a/src/cuda/3d/interp3d_wrapper.cu b/src/cuda/3d/interp3d_wrapper.cu index 91379d3ae..51c620756 100644 --- a/src/cuda/3d/interp3d_wrapper.cu +++ b/src/cuda/3d/interp3d_wrapper.cu @@ -50,10 +50,7 @@ int cuinterp3d(cufinufft_plan_t *d_plan, int blksize) template int cuinterp3d_nuptsdriven(int nf1, int nf2, int nf3, int M, cufinufft_plan_t *d_plan, int blksize) { - auto &stream = d_plan->stream; - - dim3 threadsPerBlock; - dim3 blocks; + const auto stream = d_plan->stream; int ns = d_plan->spopts.nspread; // psi's support in terms of number of cells T es_c = d_plan->spopts.ES_c; @@ -68,10 +65,8 @@ int cuinterp3d_nuptsdriven(int nf1, int nf2, int nf3, int M, cufinufft_plan_t cuda_complex *d_c = d_plan->c; cuda_complex *d_fw = d_plan->fw; - threadsPerBlock.x = 16; - threadsPerBlock.y = 1; - blocks.x = (M + threadsPerBlock.x - 1) / threadsPerBlock.x; - blocks.y = 1; + const dim3 threadsPerBlock{16, 1, 1}; + const dim3 blocks{(M + threadsPerBlock.x - 1) / threadsPerBlock.x, 1, 1}; if (d_plan->opts.gpu_kerevalmeth) { for (int t = 0; t < blksize; t++) { diff --git a/src/cuda/3d/spread3d_wrapper.cu b/src/cuda/3d/spread3d_wrapper.cu index 475a888ac..a0411c2b1 100644 --- a/src/cuda/3d/spread3d_wrapper.cu +++ b/src/cuda/3d/spread3d_wrapper.cu @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -72,9 +73,9 @@ int cuspread3d_nuptsdriven_prop(int nf1, int nf2, int nf3, int M, } int numbins[3]; - numbins[0] = ceil((T)nf1 / bin_size_x); - numbins[1] = ceil((T)nf2 / bin_size_y); - numbins[2] = ceil((T)nf3 / bin_size_z); + numbins[0] = (nf1 + bin_size_x - 1) / bin_size_x; + numbins[1] = (nf2 + bin_size_y - 1) / bin_size_y; + numbins[2] = (nf3 + bin_size_z - 1) / bin_size_z; T *d_kx = d_plan->kx; T *d_ky = d_plan->ky; @@ -105,9 +106,7 @@ int cuspread3d_nuptsdriven_prop(int nf1, int nf2, int nf3, int M, RETURN_IF_CUDA_ERROR } else { int *d_idxnupts = d_plan->idxnupts; - - trivial_global_sort_index_3d<<<(M + 1024 - 1) / 1024, 1024, 0, stream>>>(M, - d_idxnupts); + thrust::sequence(thrust::cuda::par.on(stream), d_idxnupts, d_idxnupts + M); RETURN_IF_CUDA_ERROR } diff --git a/src/cuda/3d/spreadinterp3d.cuh b/src/cuda/3d/spreadinterp3d.cuh index 59b4661ff..298ae4a43 100644 --- a/src/cuda/3d/spreadinterp3d.cuh +++ b/src/cuda/3d/spreadinterp3d.cuh @@ -445,6 +445,7 @@ __global__ void interp_3d_nupts_driven( T ker2[MAX_NSPREAD]; T ker3[MAX_NSPREAD]; #endif + cuda_complex cnow{}; for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < M; i += blockDim.x * gridDim.x) { const auto x_rescaled = fold_rescale(x[idxnupts[i]], nf1); @@ -455,11 +456,14 @@ __global__ void interp_3d_nupts_driven( const auto [ystart, yend] = interval(ns, y_rescaled); const auto [zstart, zend] = interval(ns, z_rescaled); - const auto x1 = T(xstart) - x_rescaled; - const auto y1 = T(ystart) - y_rescaled; - const auto z1 = T(zstart) - z_rescaled; + const T x1 = T(xstart) - x_rescaled; + const T y1 = T(ystart) - y_rescaled; + const T z1 = T(zstart) - z_rescaled; - cuda_complex cnow{0, 0}; + // having cnow allocated to 0 inside the loop breaks type 3 spread + // are we doing a buffer overflow somewhere? + cnow.x = T(0); + cnow.y = T(0); if constexpr (KEREVALMETH == 1) { eval_kernel_vec_horner(ker1, x1, ns, sigma); @@ -473,13 +477,13 @@ __global__ void interp_3d_nupts_driven( for (int zz = zstart; zz <= zend; zz++) { const auto kervalue3 = ker3[zz - zstart]; - int iz = zz < 0 ? zz + nf3 : (zz > nf3 - 1 ? zz - nf3 : zz); + const auto iz = zz < 0 ? zz + nf3 : (zz > nf3 - 1 ? zz - nf3 : zz); for (int yy = ystart; yy <= yend; yy++) { const auto kervalue2 = ker2[yy - ystart]; int iy = yy < 0 ? yy + nf2 : (yy > nf2 - 1 ? yy - nf2 : yy); for (int xx = xstart; xx <= xend; xx++) { - const int ix = xx < 0 ? xx + nf1 : (xx > nf1 - 1 ? xx - nf1 : xx); - const int inidx = ix + iy * nf1 + iz * nf2 * nf1; + const auto ix = xx < 0 ? xx + nf1 : (xx > nf1 - 1 ? xx - nf1 : xx); + const auto inidx = ix + iy * nf1 + iz * nf2 * nf1; const auto kervalue1 = ker1[xx - xstart]; cnow.x += fw[inidx].x * kervalue1 * kervalue2 * kervalue3; cnow.y += fw[inidx].y * kervalue1 * kervalue2 * kervalue3; diff --git a/src/cuda/CMakeLists.txt b/src/cuda/CMakeLists.txt index 6c78bab09..9f8d1344c 100644 --- a/src/cuda/CMakeLists.txt +++ b/src/cuda/CMakeLists.txt @@ -1,6 +1,7 @@ set(PRECISION_INDEPENDENT_SRC precision_independent.cu utils.cpp ${PROJECT_SOURCE_DIR}/contrib/legendre_rule_fast.cpp) + set(PRECISION_DEPENDENT_SRC spreadinterp.cpp 1d/cufinufft1d.cu @@ -23,19 +24,24 @@ set(CUFINUFFT_INCLUDE_DIRS $ $ $) + set(CUFINUFFT_INCLUDE_DIRS ${CUFINUFFT_INCLUDE_DIRS} PARENT_SCOPE) # flush denormals to zero and enable verbose PTXAS output set(FINUFFT_CUDA_FLAGS + $<$: + --extended-lambda -ftz=true -fmad=true -restrict --extra-device-vectorization $<$:-G -maxrregcount - 32>) + 64 + > + >) add_library(cufinufft_common_objects OBJECT ${PRECISION_INDEPENDENT_SRC}) target_include_directories(cufinufft_common_objects @@ -48,12 +54,7 @@ set_target_properties( CUDA_STANDARD 17 CUDA_STANDARD_REQUIRED ON) target_compile_features(cufinufft_common_objects PRIVATE cxx_std_17) -target_compile_options( - cufinufft_common_objects - PRIVATE $<$:${FINUFFT_CUDA_FLAGS}>) -target_compile_options( - cufinufft_common_objects - PRIVATE $<$:${FINUFFT_CUDA_FLAGS}>) +target_compile_options(cufinufft_common_objects PRIVATE ${FINUFFT_CUDA_FLAGS}) add_library(cufinufft_objects OBJECT ${PRECISION_DEPENDENT_SRC}) target_include_directories(cufinufft_objects PUBLIC ${CUFINUFFT_INCLUDE_DIRS}) @@ -65,8 +66,7 @@ set_target_properties( CUDA_STANDARD 17 CUDA_STANDARD_REQUIRED ON) target_compile_features(cufinufft_objects PRIVATE cxx_std_17) -target_compile_options( - cufinufft_objects PRIVATE $<$:${FINUFFT_CUDA_FLAGS}>) +target_compile_options(cufinufft_objects PRIVATE ${FINUFFT_CUDA_FLAGS}) if(FINUFFT_SHARED_LINKING) add_library(cufinufft SHARED $ @@ -87,7 +87,7 @@ set_target_properties( CUDA_STANDARD_REQUIRED ON ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") target_compile_features(cufinufft PRIVATE cxx_std_17) - +target_compile_options(cufinufft PRIVATE ${FINUFFT_CUDA_FLAGS}) if(WIN32) target_link_libraries(cufinufft PUBLIC CUDA::cudart CUDA::cufft CUDA::nvToolsExt) diff --git a/src/cuda/common.cu b/src/cuda/common.cu index b19986520..31f0418e2 100644 --- a/src/cuda/common.cu +++ b/src/cuda/common.cu @@ -21,20 +21,26 @@ namespace common { using namespace cufinufft::spreadinterp; using std::max; -/* Kernel for computing approximations of exact Fourier series coeffs of - cnufftspread's real symmetric kernel. */ -// a , f are intermediate results from function onedim_fseries_kernel_precomp() -// (see cufinufft/contrib/common.cpp for description) +/** Kernel for computing approximations of exact Fourier series coeffs of + * cnufftspread's real symmetric kernel. + * phase, f are intermediate results from function onedim_fseries_kernel_precomp() + * (see cufinufft/contrib/common.cpp for description) + * this is the equispaced frequency case, used by type 1 & 2, matching + * onedim_fseries_kernel in CPU code + */ template -__global__ void fseries_kernel_compute(int nf1, int nf2, int nf3, T *f, - cuDoubleComplex *a, T *fwkerhalf1, T *fwkerhalf2, - T *fwkerhalf3, int ns) { +__global__ void cu_fseries_kernel_compute(int nf1, int nf2, int nf3, T *f, T *phase, + T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3, + int ns) { T J2 = ns / 2.0; int q = (int)(2 + 3.0 * J2); int nf; - cuDoubleComplex *at = a + threadIdx.y * MAX_NQUAD; - T *ft = f + threadIdx.y * MAX_NQUAD; + T *phaset = phase + threadIdx.y * MAX_NQUAD; + T *ft = f + threadIdx.y * MAX_NQUAD; T *oarr; + // standard parallelism pattern in cuda. using a 2D grid, this allows to leverage more + // threads as the parallelism is x*y*z + // each thread check the y index to determine which array to use if (threadIdx.y == 0) { oarr = fwkerhalf1; nf = nf1; @@ -48,19 +54,62 @@ __global__ void fseries_kernel_compute(int nf1, int nf2, int nf3, T *f, for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < nf / 2 + 1; i += blockDim.x * gridDim.x) { - int brk = 0.5 + i; - T x = 0.0; + T x = 0.0; + for (int n = 0; n < q; n++) { + // in type 1/2 2*PI/nf -> k[i] + x += ft[n] * T(2) * std::cos(T(i) * phaset[n]); + } + oarr[i] = x * T(i % 2 ? -1 : 1); // signflip for the kernel origin being at PI + } +} + +/** Kernel for computing approximations of exact Fourier series coeffs of + * cnufftspread's real symmetric kernel. + * a , f are intermediate results from function onedim_fseries_kernel_precomp() + * (see cufinufft/contrib/common.cpp for description) + * this is the arbitrary frequency case (hence the extra kx, ky, kx arguments), used by + * type 3, matching onedim_nuft_kernel in CPU code + */ +template +__global__ void cu_nuft_kernel_compute(int nf1, int nf2, int nf3, T *f, T *z, T *kx, + T *ky, T *kz, T *fwkerhalf1, T *fwkerhalf2, + T *fwkerhalf3, int ns) { + T J2 = ns / 2.0; + int q = (int)(2 + 2.0 * J2); + int nf; + T *at = z + threadIdx.y * MAX_NQUAD; + T *ft = f + threadIdx.y * MAX_NQUAD; + T *oarr, *k; + // standard parallelism pattern in cuda. using a 2D grid, this allows to leverage more + // threads as the parallelism is x*y*z + // each thread check the y index to determine which array to use + if (threadIdx.y == 0) { + k = kx; + oarr = fwkerhalf1; + nf = nf1; + } else if (threadIdx.y == 1) { + k = ky; + oarr = fwkerhalf2; + nf = nf2; + } else { + k = kz; + oarr = fwkerhalf3; + nf = nf3; + } + for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < nf; + i += blockDim.x * gridDim.x) { + T x = 0.0; for (int n = 0; n < q; n++) { - x += ft[n] * 2 * (pow(cabs(at[n]), brk) * cos(brk * carg(at[n]))); + x += ft[n] * T(2) * std::cos(k[i] * at[n]); } oarr[i] = x; } } template -int cufserieskernelcompute(int dim, int nf1, int nf2, int nf3, T *d_f, - cuDoubleComplex *d_a, T *d_fwkerhalf1, T *d_fwkerhalf2, - T *d_fwkerhalf3, int ns, cudaStream_t stream) +int fseries_kernel_compute(int dim, int nf1, int nf2, int nf3, T *d_f, T *d_phase, + T *d_fwkerhalf1, T *d_fwkerhalf2, T *d_fwkerhalf3, int ns, + cudaStream_t stream) /* wrapper for approximation of Fourier series of real symmetric spreading kernel. @@ -73,8 +122,37 @@ int cufserieskernelcompute(int dim, int nf1, int nf2, int nf3, T *d_f, dim3 threadsPerBlock(16, dim); dim3 numBlocks((nout + 16 - 1) / 16, 1); - fseries_kernel_compute<<>>( - nf1, nf2, nf3, d_f, d_a, d_fwkerhalf1, d_fwkerhalf2, d_fwkerhalf3, ns); + cu_fseries_kernel_compute<<>>( + nf1, nf2, nf3, d_f, d_phase, d_fwkerhalf1, d_fwkerhalf2, d_fwkerhalf3, ns); + RETURN_IF_CUDA_ERROR + + return 0; +} + +template +int nuft_kernel_compute(int dim, int nf1, int nf2, int nf3, T *d_f, T *d_z, T *d_kx, + T *d_ky, T *d_kz, T *d_fwkerhalf1, T *d_fwkerhalf2, + T *d_fwkerhalf3, int ns, cudaStream_t stream) +/* + Approximates exact Fourier transform of cnufftspread's real symmetric + kernel, directly via q-node quadrature on Euler-Fourier formula, exploiting + narrowness of kernel. Evaluates at set of arbitrary freqs k in [-pi, pi), + for a kernel with x measured in grid-spacings. (See previous routine for + FT definition). + It implements onedim_nuft_kernel in CPU code. Except it combines up to three + onedimensional kernel evaluations at once (for efficiency). + + Marco Barbone 08/28/2024 +*/ +{ + int nout = max(max(nf1, nf2), nf3); + + dim3 threadsPerBlock(16, dim); + dim3 numBlocks((nout + 16 - 1) / 16, 1); + + cu_nuft_kernel_compute<<>>( + nf1, nf2, nf3, d_f, d_z, d_kx, d_ky, d_kz, d_fwkerhalf1, d_fwkerhalf2, d_fwkerhalf3, + ns); RETURN_IF_CUDA_ERROR return 0; @@ -104,39 +182,6 @@ void set_nf_type12(CUFINUFFT_BIGINT ms, cufinufft_opts opts, finufft_spread_opts } } -template -void onedim_fseries_kernel(CUFINUFFT_BIGINT nf, T *fwkerhalf, finufft_spread_opts opts) -/* - Approximates exact Fourier series coeffs of cnufftspread's real symmetric - kernel, directly via q-node quadrature on Euler-Fourier formula, exploiting - narrowness of kernel. Uses phase winding for cheap eval on the regular freq - grid. Note that this is also the Fourier transform of the non-periodized - kernel. The FT definition is f(k) = int e^{-ikx} f(x) dx. The output has an - overall prefactor of 1/h, which is needed anyway for the correction, and - arises because the quadrature weights are scaled for grid units not x units. - - Inputs: - nf - size of 1d uniform spread grid, must be even. - opts - spreading opts object, needed to eval kernel (must be already set up) - - Outputs: - fwkerhalf - real Fourier series coeffs from indices 0 to nf/2 inclusive, - divided by h = 2pi/n. - (should be allocated for at least nf/2+1 Ts) - - Compare onedim_dct_kernel which has same interface, but computes DFT of - sampled kernel, not quite the same object. - - Barnett 2/7/17. openmp (since slow vs fftw in 1D large-N case) 3/3/18 - Melody 2/20/22 separate into precomp & comp functions defined below. - */ -{ - T f[MAX_NQUAD]; - std::complex a[MAX_NQUAD]; - onedim_fseries_kernel_precomp(nf, f, a, opts); - onedim_fseries_kernel_compute(nf, f, a, fwkerhalf, opts); -} - /* Precomputation of approximations of exact Fourier series coeffs of cnufftspread's real symmetric kernel. @@ -144,58 +189,46 @@ void onedim_fseries_kernel(CUFINUFFT_BIGINT nf, T *fwkerhalf, finufft_spread_opt Inputs: nf - size of 1d uniform spread grid, must be even. opts - spreading opts object, needed to eval kernel (must be already set up) + phase_winding - if true (type 1-2), scaling for the equispaced case else (type 3) + scaling for the general kx,ky,kz case Outputs: - a - phase winding rates - f - funciton values at quadrature nodes multiplied with quadrature weights - (a, f are provided as the inputs of onedim_fseries_kernel_compute() defined below) + a - vector of phases to be used for cosines on the GPU; + f - function values at quadrature nodes multiplied with quadrature weights (a, f are + provided as the inputs of onedim_fseries_kernel_compute() defined below) */ + template -void onedim_fseries_kernel_precomp(CUFINUFFT_BIGINT nf, T *f, std::complex *a, +void onedim_fseries_kernel_precomp(CUFINUFFT_BIGINT nf, T *f, T *phase, finufft_spread_opts opts) { T J2 = opts.nspread / 2.0; // J/2, half-width of ker z-support // # quadr nodes in z (from 0 to J/2; reflections will be added)... - int q = (int)(2 + 3.0 * J2); // not sure why so large? cannot exceed MAX_NQUAD + const auto q = (int)(2 + 3.0 * J2); // matches CPU code double z[2 * MAX_NQUAD]; double w[2 * MAX_NQUAD]; - finufft::quadrature::legendre_compute_glr(2 * q, z, w); // only half the nodes used, - // eg on (0,1) - for (int n = 0; n < q; ++n) { // set up nodes z_n and vals f_n - z[n] *= J2; // rescale nodes - f[n] = J2 * w[n] * evaluate_kernel((T)z[n], opts); // vals & quadr wei - a[n] = exp((T)(2.0 * M_PI) * std::complex(0.0, 1.0) * (T)(nf / 2 - z[n]) / - (T)nf); // phase winding rates + // eg on (0,1) + for (int n = 0; n < q; ++n) { // set up nodes z_n and vals f_n + z[n] *= J2; // rescale nodes + f[n] = J2 * w[n] * evaluate_kernel((T)z[n], opts); // vals & quadr wei + phase[n] = T(2.0 * M_PI * z[n] / T(nf)); // phase winding rates } } template -void onedim_fseries_kernel_compute(CUFINUFFT_BIGINT nf, T *f, std::complex *a, - T *fwkerhalf, finufft_spread_opts opts) { - T J2 = opts.nspread / 2.0; // J/2, half-width of ker z-support - int q = (int)(2 + 3.0 * J2); // not sure why so large? cannot exceed MAX_NQUAD - CUFINUFFT_BIGINT nout = nf / 2 + 1; // how many values we're writing to - int nt = std::min(nout, MY_OMP_GET_MAX_THREADS()); // how many chunks - std::vector brk(nt + 1); // start indices for each thread - for (int t = 0; t <= nt; ++t) // split nout mode indices btw threads - brk[t] = (CUFINUFFT_BIGINT)(0.5 + nout * t / (double)nt); -#pragma omp parallel - { - int t = MY_OMP_GET_THREAD_NUM(); - if (t < nt) { // could be nt < actual # threads - std::complex aj[MAX_NQUAD]; // phase rotator for this thread - for (int n = 0; n < q; ++n) - aj[n] = pow(a[n], (T)brk[t]); // init phase factors for chunk - for (CUFINUFFT_BIGINT j = brk[t]; j < brk[t + 1]; ++j) { // loop along output - // array - T x = 0.0; // accumulator for answer at this j - for (int n = 0; n < q; ++n) { - x += f[n] * 2 * real(aj[n]); // include the negative freq - aj[n] *= a[n]; // wind the phases - } - fwkerhalf[j] = x; - } - } +void onedim_nuft_kernel_precomp(T *f, T *z, finufft_spread_opts opts) { + // it implements the first half of onedim_nuft_kernel in CPU code + T J2 = opts.nspread / 2.0; // J/2, half-width of ker z-support + // # quadr nodes in z (from 0 to J/2; reflections will be added)... + int q = (int)(2 + 2.0 * J2); // matches CPU code + double z_local[2 * MAX_NQUAD]; + double w_local[2 * MAX_NQUAD]; + finufft::quadrature::legendre_compute_glr(2 * q, z_local, w_local); // only half the + // nodes used, eg on + // (0,1) + for (int n = 0; n < q; ++n) { // set up nodes z_n and vals f_n + z[n] = J2 * T(z_local[n]); // rescale nodes + f[n] = J2 * w_local[n] * evaluate_kernel(z[n], opts); // vals & quadr wei } } @@ -312,34 +345,32 @@ void cufinufft_setup_binsize(int type, int ns, int dim, cufinufft_opts *opts) { } } -template void onedim_fseries_kernel_compute(CUFINUFFT_BIGINT nf, float *f, - std::complex *a, float *fwkerhalf, - finufft_spread_opts opts); -template void onedim_fseries_kernel_compute(CUFINUFFT_BIGINT nf, double *f, - std::complex *a, double *fwkerhalf, - finufft_spread_opts opts); - template int setup_spreader_for_nufft(finufft_spread_opts &spopts, float eps, cufinufft_opts opts); template int setup_spreader_for_nufft(finufft_spread_opts &spopts, double eps, cufinufft_opts opts); -template void onedim_fseries_kernel_precomp( - CUFINUFFT_BIGINT nf, float *f, std::complex *a, finufft_spread_opts opts); -template void onedim_fseries_kernel_precomp( - CUFINUFFT_BIGINT nf, double *f, std::complex *a, finufft_spread_opts opts); -template int cufserieskernelcompute(int dim, int nf1, int nf2, int nf3, float *d_f, - cuDoubleComplex *d_a, float *d_fwkerhalf1, - float *d_fwkerhalf2, float *d_fwkerhalf3, int ns, - cudaStream_t stream); -template int cufserieskernelcompute(int dim, int nf1, int nf2, int nf3, double *d_f, - cuDoubleComplex *d_a, double *d_fwkerhalf1, - double *d_fwkerhalf2, double *d_fwkerhalf3, int ns, - cudaStream_t stream); - -template void onedim_fseries_kernel(CUFINUFFT_BIGINT nf, float *fwkerhalf, - finufft_spread_opts opts); -template void onedim_fseries_kernel(CUFINUFFT_BIGINT nf, double *fwkerhalf, - finufft_spread_opts opts); +template void onedim_fseries_kernel_precomp(CUFINUFFT_BIGINT nf, float *f, + float *a, finufft_spread_opts opts); +template void onedim_fseries_kernel_precomp(CUFINUFFT_BIGINT nf, double *f, + double *a, finufft_spread_opts opts); +template void onedim_nuft_kernel_precomp(float *f, float *a, + finufft_spread_opts opts); +template void onedim_nuft_kernel_precomp(double *f, double *a, + finufft_spread_opts opts); +template int fseries_kernel_compute(int dim, int nf1, int nf2, int nf3, float *d_f, + float *d_a, float *d_fwkerhalf1, float *d_fwkerhalf2, + float *d_fwkerhalf3, int ns, cudaStream_t stream); +template int fseries_kernel_compute( + int dim, int nf1, int nf2, int nf3, double *d_f, double *d_a, double *d_fwkerhalf1, + double *d_fwkerhalf2, double *d_fwkerhalf3, int ns, cudaStream_t stream); +template int nuft_kernel_compute(int dim, int nf1, int nf2, int nf3, float *d_f, + float *d_a, float *d_kx, float *d_ky, float *d_kz, + float *d_fwkerhalf1, float *d_fwkerhalf2, + float *d_fwkerhalf3, int ns, cudaStream_t stream); +template int nuft_kernel_compute( + int dim, int nf1, int nf2, int nf3, double *d_f, double *d_a, double *d_kx, + double *d_ky, double *d_kz, double *d_fwkerhalf1, double *d_fwkerhalf2, + double *d_fwkerhalf3, int ns, cudaStream_t stream); template std::size_t shared_memory_required(int dim, int ns, int bin_size_x, int bin_size_y, int bin_size_z); diff --git a/src/cuda/cufinufft.cu b/src/cuda/cufinufft.cu index c00bf8eba..534fa5358 100644 --- a/src/cuda/cufinufft.cu +++ b/src/cuda/cufinufft.cu @@ -112,7 +112,7 @@ void cufinufft_default_opts(cufinufft_opts *opts) opts->gpu_method = 0; opts->gpu_sort = 1; opts->gpu_kerevalmeth = 1; - opts->upsampfac = 2.0; + opts->upsampfac = 0; opts->gpu_maxsubprobsize = 1024; opts->gpu_obinsizex = 0; opts->gpu_obinsizey = 0; @@ -121,6 +121,7 @@ void cufinufft_default_opts(cufinufft_opts *opts) opts->gpu_binsizey = 0; opts->gpu_binsizez = 0; opts->gpu_maxbatchsize = 0; + opts->debug = 0; opts->gpu_stream = cudaStreamDefault; // sphinx tag (don't remove): @gpu_defopts_end } diff --git a/src/cuda/deconvolve_wrapper.cu b/src/cuda/deconvolve_wrapper.cu index 94eb6b4c8..38a4f0da9 100644 --- a/src/cuda/deconvolve_wrapper.cu +++ b/src/cuda/deconvolve_wrapper.cu @@ -235,7 +235,7 @@ int cudeconvolve1d(cufinufft_plan_t *d_plan, int blksize) int ms = d_plan->ms; int nf1 = d_plan->nf1; int nmodes = ms; - int maxbatchsize = d_plan->maxbatchsize; + int maxbatchsize = d_plan->batchsize; if (d_plan->spopts.spread_direction == 1) { for (int t = 0; t < blksize; t++) { @@ -268,7 +268,7 @@ int cudeconvolve2d(cufinufft_plan_t *d_plan, int blksize) int nf1 = d_plan->nf1; int nf2 = d_plan->nf2; int nmodes = ms * mt; - int maxbatchsize = d_plan->maxbatchsize; + int maxbatchsize = d_plan->batchsize; if (d_plan->spopts.spread_direction == 1) { for (int t = 0; t < blksize; t++) { @@ -305,7 +305,7 @@ int cudeconvolve3d(cufinufft_plan_t *d_plan, int blksize) int nf2 = d_plan->nf2; int nf3 = d_plan->nf3; int nmodes = ms * mt * mu; - int maxbatchsize = d_plan->maxbatchsize; + int maxbatchsize = d_plan->batchsize; if (d_plan->spopts.spread_direction == 1) { for (int t = 0; t < blksize; t++) { deconvolve_3d<<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>( diff --git a/src/cuda/memtransfer_wrapper.cu b/src/cuda/memtransfer_wrapper.cu index b83b9d2d5..e5308e8bc 100644 --- a/src/cuda/memtransfer_wrapper.cu +++ b/src/cuda/memtransfer_wrapper.cu @@ -20,11 +20,11 @@ int allocgpumem1d_plan(cufinufft_plan_t *d_plan) */ { utils::WithCudaDevice device_swapper(d_plan->opts.gpu_device_id); - auto &stream = d_plan->stream; + const auto stream = d_plan->stream; - int ier; + int ier{0}; int nf1 = d_plan->nf1; - int maxbatchsize = d_plan->maxbatchsize; + int maxbatchsize = d_plan->batchsize; switch (d_plan->opts.gpu_method) { case 1: { @@ -59,6 +59,7 @@ int allocgpumem1d_plan(cufinufft_plan_t *d_plan) goto finalize; } break; default: + ier = FINUFFT_ERR_METHOD_NOTVALID; std::cerr << "err: invalid method " << std::endl; } @@ -90,8 +91,8 @@ int allocgpumem1d_nupts(cufinufft_plan_t *d_plan) */ { utils::WithCudaDevice device_swapper(d_plan->opts.gpu_device_id); - auto &stream = d_plan->stream; - int ier; + const auto stream = d_plan->stream; + int ier{0}; int M = d_plan->M; CUDA_FREE_AND_NULL(d_plan->sortidx, stream, d_plan->supports_pools); @@ -135,12 +136,12 @@ int allocgpumem2d_plan(cufinufft_plan_t *d_plan) */ { utils::WithCudaDevice device_swapper(d_plan->opts.gpu_device_id); - auto &stream = d_plan->stream; - int ier; + const auto stream = d_plan->stream; + int ier{0}; int nf1 = d_plan->nf1; int nf2 = d_plan->nf2; - int maxbatchsize = d_plan->maxbatchsize; + int maxbatchsize = d_plan->batchsize; switch (d_plan->opts.gpu_method) { case 1: { @@ -180,6 +181,7 @@ int allocgpumem2d_plan(cufinufft_plan_t *d_plan) goto finalize; } break; default: + ier = FINUFFT_ERR_METHOD_NOTVALID; std::cerr << "[allocgpumem2d_plan] error: invalid method\n"; } @@ -213,8 +215,8 @@ int allocgpumem2d_nupts(cufinufft_plan_t *d_plan) */ { utils::WithCudaDevice device_swapper(d_plan->opts.gpu_device_id); - auto &stream = d_plan->stream; - int ier; + const auto stream = d_plan->stream; + int ier{0}; const int M = d_plan->M; @@ -240,6 +242,7 @@ int allocgpumem2d_nupts(cufinufft_plan_t *d_plan) goto finalize; } break; default: + ier = FINUFFT_ERR_METHOD_NOTVALID; std::cerr << "[allocgpumem2d_nupts] error: invalid method\n"; } @@ -258,13 +261,13 @@ int allocgpumem3d_plan(cufinufft_plan_t *d_plan) */ { utils::WithCudaDevice device_swapper(d_plan->opts.gpu_device_id); - auto &stream = d_plan->stream; - int ier; + const auto stream = d_plan->stream; + int ier{0}; int nf1 = d_plan->nf1; int nf2 = d_plan->nf2; int nf3 = d_plan->nf3; - int maxbatchsize = d_plan->maxbatchsize; + int maxbatchsize = d_plan->batchsize; switch (d_plan->opts.gpu_method) { case 1: { @@ -337,6 +340,7 @@ int allocgpumem3d_plan(cufinufft_plan_t *d_plan) goto finalize; } break; default: + ier = FINUFFT_ERR_METHOD_NOTVALID; std::cerr << "[allocgpumem3d_plan] error: invalid method\n"; } @@ -360,7 +364,11 @@ int allocgpumem3d_plan(cufinufft_plan_t *d_plan) } finalize: - if (ier) freegpumemory(d_plan); + if (ier) { + std::cerr << "[allocgpumem3d_plan] error:" + << cudaGetErrorString(static_cast(ier)) << std::endl; + freegpumemory(d_plan); + } return ier; } @@ -374,8 +382,8 @@ int allocgpumem3d_nupts(cufinufft_plan_t *d_plan) */ { utils::WithCudaDevice device_swapper(d_plan->opts.gpu_device_id); - auto &stream = d_plan->stream; - int ier; + const auto stream = d_plan->stream; + int ier{0}; int M = d_plan->M; CUDA_FREE_AND_NULL(d_plan->sortidx, stream, d_plan->supports_pools); @@ -405,6 +413,7 @@ int allocgpumem3d_nupts(cufinufft_plan_t *d_plan) goto finalize; } break; default: + ier = FINUFFT_ERR_METHOD_NOTVALID; std::cerr << "[allocgpumem3d_nupts] error: invalid method\n"; } @@ -441,6 +450,21 @@ void freegpumemory(cufinufft_plan_t *d_plan) CUDA_FREE_AND_NULL(d_plan->numnupts, stream, d_plan->supports_pools); CUDA_FREE_AND_NULL(d_plan->numsubprob, stream, d_plan->supports_pools); + + if (d_plan->type != 3) { + return; + } + + CUDA_FREE_AND_NULL(d_plan->kx, stream, d_plan->supports_pools); + CUDA_FREE_AND_NULL(d_plan->d_Sp, stream, d_plan->supports_pools); + CUDA_FREE_AND_NULL(d_plan->ky, stream, d_plan->supports_pools); + CUDA_FREE_AND_NULL(d_plan->d_Tp, stream, d_plan->supports_pools); + CUDA_FREE_AND_NULL(d_plan->kz, stream, d_plan->supports_pools); + CUDA_FREE_AND_NULL(d_plan->d_Up, stream, d_plan->supports_pools); + CUDA_FREE_AND_NULL(d_plan->prephase, stream, d_plan->supports_pools); + CUDA_FREE_AND_NULL(d_plan->deconv, stream, d_plan->supports_pools); + CUDA_FREE_AND_NULL(d_plan->fwbatch, stream, d_plan->supports_pools); + CUDA_FREE_AND_NULL(d_plan->CpBatch, stream, d_plan->supports_pools); } template int allocgpumem1d_plan(cufinufft_plan_t *d_plan); diff --git a/src/cuda/precision_independent.cu b/src/cuda/precision_independent.cu index b2c0c292f..7b199220a 100644 --- a/src/cuda/precision_independent.cu +++ b/src/cuda/precision_independent.cu @@ -71,13 +71,6 @@ __global__ void map_b_into_subprob_2d(int *d_subprob_to_bin, int *d_subprobstart } } -__global__ void trivial_global_sort_index_2d(int M, int *index) { - for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < M; - i += gridDim.x * blockDim.x) { - index[i] = i; - } -} - /* spreadinterp3d */ __global__ void calc_subprob_3d_v2(int *bin_size, int *num_subprob, int maxsubprobsize, int numbins) { @@ -121,13 +114,6 @@ __global__ void map_b_into_subprob_3d_v1(int *d_subprob_to_obin, int *d_subprobs } } -__global__ void trivial_global_sort_index_3d(int M, int *index) { - for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < M; - i += gridDim.x * blockDim.x) { - index[i] = i; - } -} - __global__ void fill_ghost_bins(int binsperobinx, int binsperobiny, int binsperobinz, int nobinx, int nobiny, int nobinz, int *binsize) { int binx = threadIdx.x + blockIdx.x * blockDim.x; diff --git a/src/cuda/spreadinterp.cpp b/src/cuda/spreadinterp.cpp index 98b5382bc..646ffa434 100644 --- a/src/cuda/spreadinterp.cpp +++ b/src/cuda/spreadinterp.cpp @@ -1,10 +1,7 @@ #include #include -#include -#include #include -#include #include #include @@ -44,8 +41,7 @@ int setup_spreader(finufft_spread_opts &opts, T eps, T upsampfac, int kerevalmet opts.upsampfac = upsampfac; // as in FINUFFT v2.0, allow too-small-eps by truncating to eps_mach... - int ier = 0; - + int ier = 0; constexpr T EPSILON = std::numeric_limits::epsilon(); if (eps < EPSILON) { fprintf(stderr, "setup_spreader: warning, increasing tol=%.3g to eps_mach=%.3g.\n", diff --git a/src/finufft.cpp b/src/finufft.cpp index ce4ada36d..6c592396f 100644 --- a/src/finufft.cpp +++ b/src/finufft.cpp @@ -267,7 +267,6 @@ void onedim_nuft_kernel(BIGINT nk, FLT *k, FLT *phihat, finufft_spread_opts opts for (int n = 0; n < q; ++n) { z[n] *= (FLT)J2; // quadr nodes for [0,J/2] f[n] = J2 * (FLT)w[n] * evaluate_kernel((FLT)z[n], opts); // w/ quadr weights - // printf("f[%d] = %.3g\n",n,f[n]); } #pragma omp parallel for num_threads(opts.nthreads) for (BIGINT j = 0; j < nk; ++j) { // loop along output array @@ -723,6 +722,7 @@ int FINUFFT_MAKEPLAN(int type, int dim, BIGINT *n_modes, int iflag, int ntrans, fprintf(stderr, "[%s] fwBatch would be bigger than MAX_NF, not attempting malloc!\n", __func__); + // FIXME: this error causes memory leaks. We should free phiHat1, phiHat2, phiHat3 return FINUFFT_ERR_MAXNALLOC; } @@ -751,12 +751,12 @@ int FINUFFT_MAKEPLAN(int type, int dim, BIGINT *n_modes, int iflag, int ntrans, // idist, ot, onembed, ostride, odist, sign, flags { std::lock_guard lock(fftw_lock); - // FFTW_PLAN_TH sets all future fftw_plan calls to use nthr_fft threads. // FIXME: Since this might override what the user wants for fftw, we'd like to // set it just for our one plan and then revert to the user value. // Unfortunately fftw_planner_nthreads wasn't introduced until fftw 3.3.9, and // there isn't a convenient mechanism to probe the version + // there is fftw_version which returns a string, but that's not compile time FFTW_PLAN_TH(nthr_fft); p->fftwPlan = FFTW_PLAN_MANY_DFT(dim, ns, p->batchSize, (FFTW_CPX *)p->fwBatch, NULL, 1, p->nf, (FFTW_CPX *)p->fwBatch, NULL, 1, @@ -874,14 +874,14 @@ int FINUFFT_SETPTS(FINUFFT_PLAN p, BIGINT nj, FLT *xj, FLT *yj, FLT *zj, BIGINT if (p->opts.debug) { // report on choices of shifts, centers, etc... printf("\tM=%lld N=%lld\n", (long long)nj, (long long)nk); - printf("\tX1=%.3g C1=%.3g S1=%.3g D1=%.3g gam1=%g nf1=%lld\t\n", p->t3P.X1, - p->t3P.C1, S1, p->t3P.D1, p->t3P.gam1, (long long)p->nf1); + printf("\tX1=%.3g C1=%.3g S1=%.3g D1=%.3g gam1=%g nf1=%lld h1=%.3g\t\n", p->t3P.X1, + p->t3P.C1, S1, p->t3P.D1, p->t3P.gam1, (long long)p->nf1, p->t3P.h1); if (d > 1) - printf("\tX2=%.3g C2=%.3g S2=%.3g D2=%.3g gam2=%g nf2=%lld\n", p->t3P.X2, - p->t3P.C2, S2, p->t3P.D2, p->t3P.gam2, (long long)p->nf2); + printf("\tX2=%.3g C2=%.3g S2=%.3g D2=%.3g gam2=%g nf2=%lld h2=%.3g\n", p->t3P.X2, + p->t3P.C2, S2, p->t3P.D2, p->t3P.gam2, (long long)p->nf2, p->t3P.h2); if (d > 2) - printf("\tX3=%.3g C3=%.3g S3=%.3g D3=%.3g gam3=%g nf3=%lld\n", p->t3P.X3, - p->t3P.C3, S3, p->t3P.D3, p->t3P.gam3, (long long)p->nf3); + printf("\tX3=%.3g C3=%.3g S3=%.3g D3=%.3g gam3=%g nf3=%lld h3=%.3g\n", p->t3P.X3, + p->t3P.C3, S3, p->t3P.D3, p->t3P.gam3, (long long)p->nf3, p->t3P.h3); } p->nf = p->nf1 * p->nf2 * p->nf3; // fine grid total number of points if (p->nf * p->batchSize > MAX_NF) { @@ -913,6 +913,7 @@ int FINUFFT_SETPTS(FINUFFT_PLAN p, BIGINT nj, FLT *xj, FLT *yj, FLT *zj, BIGINT // printf("fwbatch, cpbatch ptrs: %llx %llx\n",p->fwBatch,p->CpBatch); // alloc rescaled NU src pts x'_j (in X etc), rescaled NU targ pts s'_k ... + // FIXME: should use realloc if (p->X) free(p->X); if (p->Sp) free(p->Sp); p->X = (FLT *)malloc(sizeof(FLT) * nj); @@ -970,7 +971,6 @@ int FINUFFT_SETPTS(FINUFFT_PLAN p, BIGINT nj, FLT *xj, FLT *yj, FLT *zj, BIGINT p->Up[k] = p->t3P.h3 * p->t3P.gam3 * (u[k] - p->t3P.D3); // so |u'_k| < // pi/R } - // (old STEP 3a) Compute deconvolution post-factors array (per targ pt)... // (exploits that FT separates because kernel is prod of 1D funcs) if (p->deconv) free(p->deconv); @@ -1158,8 +1158,9 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) { #pragma omp parallel for num_threads(p->opts.nthreads) // or p->batchSize? for (int i = 0; i < thisBatchSize; i++) { BIGINT ioff = i * p->nj; - for (BIGINT j = 0; j < p->nj; ++j) + for (BIGINT j = 0; j < p->nj; ++j) { p->CpBatch[ioff + j] = p->prephase[j] * cjb[ioff + j]; + } } t_pre += timer.elapsedsec(); @@ -1169,10 +1170,6 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) { spreadinterpSortedBatch(thisBatchSize, p, p->CpBatch); // p->X are primed t_spr += timer.elapsedsec(); - // for (int j=0;jnf1;++j) - // printf("fw[%d]=%.3g+%.3gi\n",j,p->fwBatch[j][0],p->fwBatch[j][1]); // - // debug - // STEP 2: type 2 NUFFT from fw batch to user output fk array batch... timer.restart(); // illegal possible shrink of ntrans *after* plan for smaller last batch: @@ -1181,7 +1178,6 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) { still the same size, as Andrea explained; just wastes a few flops) */ FINUFFT_EXECUTE(p->innerT2plan, fkb, p->fwBatch); t_t2 += timer.elapsedsec(); - // STEP 3: apply deconvolve (precomputed 1/phiHat(targ_k), phasing too)... timer.restart(); #pragma omp parallel for num_threads(p->opts.nthreads) diff --git a/test/cuda/CMakeLists.txt b/test/cuda/CMakeLists.txt index 87cbe8205..1cadb7569 100644 --- a/test/cuda/CMakeLists.txt +++ b/test/cuda/CMakeLists.txt @@ -5,14 +5,16 @@ foreach(srcfile ${test_src}) get_filename_component(executable ${executable} NAME) add_executable(${executable} ${srcfile}) target_include_directories(${executable} PUBLIC ${CUFINUFFT_INCLUDE_DIRS}) + target_compile_options(${executable} + PUBLIC $<$:--extended-lambda>) find_library(MathLib m) if(MathLib) target_link_libraries(${executable} PUBLIC cufinufft ${MathLib}) endif() target_compile_features(${executable} PUBLIC cxx_std_17) set_target_properties( - ${executable} PROPERTIES LINKER_LANGUAGE CUDA CUDA_ARCHITECTURES - "${FINUFFT_CUDA_ARCHITECTURES}") + ${executable} PROPERTIES LINKER_LANGUAGE CUDA + CUDA_ARCHITECTURES "${FINUFFT_CUDA_ARCHITECTURES}") message(STATUS "Adding test ${executable}" " with CUDA_ARCHITECTURES=${FINUFFT_CUDA_ARCHITECTURES}" " and INCLUDE=${CUFINUFFT_INCLUDE_DIRS}") @@ -30,6 +32,9 @@ function(add_tests PREC REQ_TOL CHECK_TOL UPSAMP) add_test(NAME cufinufft1d2_test_GM_${PREC}_${UPSAMP} COMMAND cufinufft1d_test 1 2 1e2 2e2 ${REQ_TOL} ${CHECK_TOL} ${PREC} ${UPSAMP}) + add_test(NAME cufinufft1d3_test_GM_${PREC}_${UPSAMP} + COMMAND cufinufft1d_test 1 3 1e2 2e2 ${REQ_TOL} ${CHECK_TOL} ${PREC} + ${UPSAMP}) add_test(NAME cufinufft2d1_test_GM_${PREC}_${UPSAMP} COMMAND cufinufft2d_test 1 1 1e2 2e2 2e4 ${REQ_TOL} ${CHECK_TOL} @@ -39,6 +44,14 @@ function(add_tests PREC REQ_TOL CHECK_TOL UPSAMP) COMMAND cufinufft2d_test 2 1 1e2 2e2 2e4 ${REQ_TOL} ${CHECK_TOL} ${PREC} ${UPSAMP}) + add_test(NAME cufinufft2d2_test_SM_${PREC}_${UPSAMP} + COMMAND cufinufft2d_test 2 2 1e2 2e2 2e4 ${REQ_TOL} ${CHECK_TOL} + ${PREC} ${UPSAMP}) + + add_test(NAME cufinufft2d3_test_SM_${PREC}_${UPSAMP} + COMMAND cufinufft2d_test 2 3 1e2 2e2 2e4 ${REQ_TOL} ${CHECK_TOL} + ${PREC} ${UPSAMP}) + add_test(NAME cufinufft2d1many_test_GM_${PREC}_${UPSAMP} COMMAND cufinufft2dmany_test 1 1 1e2 2e2 5 0 2e4 ${REQ_TOL} ${CHECK_TOL} ${PREC} ${UPSAMP}) @@ -55,6 +68,14 @@ function(add_tests PREC REQ_TOL CHECK_TOL UPSAMP) COMMAND cufinufft2dmany_test 2 2 1e2 2e2 5 0 2e4 ${REQ_TOL} ${CHECK_TOL} ${PREC} ${UPSAMP}) + add_test(NAME cufinufft2d3many_test_GM_${PREC}_${UPSAMP} + COMMAND cufinufft2dmany_test 1 3 1e2 2e2 5 0 2e4 ${REQ_TOL} + ${CHECK_TOL} ${PREC} ${UPSAMP}) + + add_test(NAME cufinufft2d3many_test_SM_${PREC}_${UPSAMP} + COMMAND cufinufft2dmany_test 2 3 1e2 2e2 5 0 2e4 ${REQ_TOL} + ${CHECK_TOL} ${PREC} ${UPSAMP}) + add_test(NAME cufinufft3d1_test_GM_${PREC}_${UPSAMP} COMMAND cufinufft3d_test 1 1 2 5 10 20 ${REQ_TOL} ${CHECK_TOL} ${PREC} ${UPSAMP}) @@ -71,17 +92,34 @@ function(add_tests PREC REQ_TOL CHECK_TOL UPSAMP) add_test(NAME cufinufft3d2_test_SM_${PREC}_${UPSAMP} COMMAND cufinufft3d_test 2 2 2 5 10 20 ${REQ_TOL} ${CHECK_TOL} ${PREC} ${UPSAMP}) + + add_test(NAME cufinufft3d3_test_SM_${PREC}_${UPSAMP} + COMMAND cufinufft3d_test 2 3 2 5 10 30 ${REQ_TOL} ${CHECK_TOL} + ${PREC} ${UPSAMP}) endif() add_test(NAME cufinufft3d2_test_GM_${PREC}_${UPSAMP} COMMAND cufinufft3d_test 1 2 2 5 10 20 ${REQ_TOL} ${CHECK_TOL} ${PREC} ${UPSAMP}) + + add_test(NAME cufinufft3d3_test_GM_${PREC}_${UPSAMP} + COMMAND cufinufft3d_test 1 3 2 3 7 20 ${REQ_TOL} ${CHECK_TOL}*100 + ${PREC} ${UPSAMP}) endfunction() +add_test(NAME cufinufft_public_api COMMAND public_api_test) +add_test(NAME cufinufft_makeplan COMMAND test_makeplan) +add_test(NAME cufinufft_math_test COMMAND cufinufft_math_test) + add_tests(float 1e-5 2e-4 2.0) add_tests(double 1e-12 1e-11 2.0) add_tests(float 1e-5 2e-4 1.25) add_tests(double 1e-8 1e-7 1.25) - -add_test(NAME cufinufft_public_api COMMAND public_api_test) -add_test(NAME cufinufft_makeplan COMMAND test_makeplan) +# the upsamp is appended to the testname, ctest does not allows multiple tests +# to share the same testname hence we use the trick 0. and 0.f to differentiate +# the tests and allow them to run in the future we should add the precision to +# the test (f +add_tests(float 1e-5 2e-4 0.f) +add_tests(double 1e-12 1e-11 0.f) +add_tests(float 1e-5 2e-4 0.) +add_tests(double 1e-8 1e-7 0.) diff --git a/test/cuda/cufinufft1d_test.cu b/test/cuda/cufinufft1d_test.cu index dbd6260ac..52d40ca0e 100644 --- a/test/cuda/cufinufft1d_test.cu +++ b/test/cuda/cufinufft1d_test.cu @@ -1,12 +1,12 @@ #include #include -#include #include #include #include #include +#include #include #include @@ -22,11 +22,11 @@ int run_test(int method, int type, int N1, int M, T tol, T checktol, int iflag, std::cout << std::scientific << std::setprecision(3); int ier; - thrust::host_vector x(M); + thrust::host_vector x(M), s{}; thrust::host_vector> c(M); thrust::host_vector> fk(N1); - thrust::device_vector d_x(M); + thrust::device_vector d_x(M), d_s{}; thrust::device_vector> d_c(M); thrust::device_vector> d_fk(N1); @@ -40,6 +40,7 @@ int run_test(int method, int type, int N1, int M, T tol, T checktol, int iflag, for (int i = 0; i < M; i++) { x[i] = M_PI * randm11(); // x in [-pi,pi) } + if (type == 1) { for (int i = 0; i < M; i++) { c[i].real(randm11()); @@ -50,6 +51,16 @@ int run_test(int method, int type, int N1, int M, T tol, T checktol, int iflag, fk[i].real(randm11()); fk[i].imag(randm11()); } + } else if (type == 3) { + for (int i = 0; i < M; i++) { + c[i].real(randm11()); + c[i].imag(randm11()); + } + s.resize(N1); + for (int i = 0; i < N1; i++) { + s[i] = N1 / 2 * randm11(); + } + d_s = s; } else { std::cerr << "Invalid type " << type << " supplied\n"; return 1; @@ -60,6 +71,8 @@ int run_test(int method, int type, int N1, int M, T tol, T checktol, int iflag, d_c = c; else if (type == 2) d_fk = fk; + else if (type == 3) + d_c = c; cudaEvent_t start, stop; float milliseconds = 0; @@ -107,8 +120,8 @@ int run_test(int method, int type, int N1, int M, T tol, T checktol, int iflag, printf("[time ] cufinufft plan:\t\t %.3g s\n", milliseconds / 1000); cudaEventRecord(start); - ier = cufinufft_setpts_impl(M, d_x.data().get(), NULL, NULL, 0, NULL, NULL, NULL, - dplan); + ier = cufinufft_setpts_impl(M, d_x.data().get(), NULL, NULL, N1, d_s.data().get(), + NULL, NULL, dplan); if (ier != 0) { printf("err: cufinufft_setpts\n"); @@ -153,9 +166,15 @@ int run_test(int method, int type, int N1, int M, T tol, T checktol, int iflag, opts.gpu_method, N1, M, totaltime / 1000, M / totaltime * 1000); printf("\t\t\t\t\t(exec-only thoughput: %.3g NU pts/s)\n", M / exec_ms * 1000); + if (type == 1) + fk = d_fk; + else if (type == 2) + c = d_c; + else if (type == 3) + fk = d_fk; + T rel_error = std::numeric_limits::max(); if (type == 1) { - fk = d_fk; int nt1 = 0.37 * N1; // choose some mode index to check thrust::complex Ft = thrust::complex(0, 0), J = thrust::complex(0.0, iflag); for (int j = 0; j < M; ++j) Ft += c[j] * exp(J * (nt1 * x[j])); // crude direct @@ -164,8 +183,6 @@ int run_test(int method, int type, int N1, int M, T tol, T checktol, int iflag, rel_error = abs(Ft - fk[it]) / infnorm(N1, (std::complex *)fk.data()); printf("[gpu ] one mode: rel err in F[%d] is %.3g\n", nt1, rel_error); } else if (type == 2) { - c = d_c; - int jt = M / 2; // check arbitrary choice of one targ pt thrust::complex J = thrust::complex(0, iflag); thrust::complex ct = thrust::complex(0, 0); @@ -174,6 +191,16 @@ int run_test(int method, int type, int N1, int M, T tol, T checktol, int iflag, ct += fk[m++] * exp(J * (m1 * x[jt])); // crude direct rel_error = abs(c[jt] - ct) / infnorm(M, (std::complex *)c.data()); printf("[gpu ] one targ: rel err in c[%d] is %.3g\n", jt, rel_error); + } else if (type == 3) { + int jt = (N1) / 2; // check arbitrary choice of one targ pt + thrust::complex J = thrust::complex(0, iflag); + thrust::complex Ft = thrust::complex(0, 0); + + for (int j = 0; j < M; ++j) { + Ft += c[j] * exp(J * (x[j] * s[jt])); + } + rel_error = abs(Ft - fk[jt]) / infnorm(N1, (std::complex *)fk.data()); + printf("[gpu ] one mode: rel err in F[%d] is %.3g\n", jt, rel_error); } return std::isnan(rel_error) || rel_error > checktol; @@ -185,7 +212,7 @@ int main(int argc, char *argv[]) { "Arguments:\n" " method: One of\n" " 1: nupts driven\n" - " type: Type of transform (1, 2)\n" + " type: Type of transform (1, 2, 3)\n" " N1: Number of fourier modes\n" " M: The number of non-uniform points\n" " tol: NUFFT tolerance\n" diff --git a/test/cuda/cufinufft2d_test.cu b/test/cuda/cufinufft2d_test.cu index f3b767f2e..549508d26 100644 --- a/test/cuda/cufinufft2d_test.cu +++ b/test/cuda/cufinufft2d_test.cu @@ -22,10 +22,10 @@ int run_test(int method, int type, int N1, int N2, int M, T tol, T checktol, int double upsampfac) { std::cout << std::scientific << std::setprecision(3); - thrust::host_vector x(M), y(M); + thrust::host_vector x(M), y(M), s{}, t{}; thrust::host_vector> c(M), fk(N1 * N2); - thrust::device_vector d_x(M), d_y(M); + thrust::device_vector d_x(M), d_y(M), d_s{}, d_t{}; thrust::device_vector> d_c(M), d_fk(N1 * N2); std::default_random_engine eng(1); @@ -49,6 +49,19 @@ int run_test(int method, int type, int N1, int N2, int M, T tol, T checktol, int fk[i].real(randm11()); fk[i].imag(randm11()); } + } else if (type == 3) { + for (int i = 0; i < M; i++) { + c[i].real(randm11()); + c[i].imag(randm11()); + } + s.resize(N1 * N2); + t.resize(N1 * N2); + for (int i = 0; i < N1 * N2; i++) { + s[i] = (N1 / 2) * randm11(); + t[i] = (N2 / 2) * randm11(); + } + d_s = s; + d_t = t; } else { std::cerr << "Invalid type " << type << " supplied\n"; return 1; @@ -60,6 +73,8 @@ int run_test(int method, int type, int N1, int N2, int M, T tol, T checktol, int d_c = c; else if (type == 2) d_fk = fk; + else if (type == 3) + d_c = c; cudaEvent_t start, stop; float milliseconds = 0; @@ -106,8 +121,8 @@ int run_test(int method, int type, int N1, int N2, int M, T tol, T checktol, int printf("[time ] cufinufft plan:\t\t %.3g s\n", milliseconds / 1000); cudaEventRecord(start); - ier = cufinufft_setpts_impl(M, d_x.data().get(), d_y.data().get(), nullptr, 0, - nullptr, nullptr, nullptr, dplan); + ier = cufinufft_setpts_impl(M, d_x.data().get(), d_y.data().get(), nullptr, N1 * N2, + d_s.data().get(), d_t.data().get(), nullptr, dplan); if (ier != 0) { printf("err: cufinufft_setpts\n"); return ier; @@ -144,6 +159,8 @@ int run_test(int method, int type, int N1, int N2, int M, T tol, T checktol, int fk = d_fk; else if (type == 2) c = d_c; + else if (type == 3) + fk = d_fk; printf("[Method %d] %d NU pts to %d U pts in %.3g s: %.3g NU pts/s\n", opts.gpu_method, M, N1 * N2, totaltime / 1000, M / totaltime * 1000); @@ -173,8 +190,17 @@ int run_test(int method, int type, int N1, int N2, int M, T tol, T checktol, int rel_error = abs(c[jt] - ct) / infnorm(M, (std::complex *)c.data()); printf("[gpu ] one targ: rel err in c[%d] is %.3g\n", jt, rel_error); - } + } else if (type == 3) { + int jt = (N1 * N2) / 2; // check arbitrary choice of one targ pt + thrust::complex J = thrust::complex(0, iflag); + thrust::complex Ft = thrust::complex(0, 0); + for (int j = 0; j < M; ++j) { + Ft += c[j] * exp(J * (x[j] * s[jt] + y[j] * t[jt])); + } + rel_error = abs(Ft - fk[jt]) / infnorm(N1 * N2, (std::complex *)fk.data()); + printf("[gpu ] one mode: rel err in F[%d] is %.3g\n", jt, rel_error); + } return std::isnan(rel_error) || rel_error > checktol; } @@ -185,7 +211,7 @@ int main(int argc, char *argv[]) { " method: One of\n" " 1: nupts driven,\n" " 2: sub-problem, or\n" - " type: Type of transform (1, 2)" + " type: Type of transform (1, 2, 3)" " N1, N2: The size of the 2D array\n" " M: The number of non-uniform points\n" " tol: NUFFT tolerance\n" diff --git a/test/cuda/cufinufft2dmany_test.cu b/test/cuda/cufinufft2dmany_test.cu index 4afcd97dd..02658b671 100644 --- a/test/cuda/cufinufft2dmany_test.cu +++ b/test/cuda/cufinufft2dmany_test.cu @@ -26,10 +26,10 @@ int run_test(int method, int type, int N1, int N2, int ntransf, int maxbatchsize const int N = N1 * N2; printf("#modes = %d, #inputs = %d, #NUpts = %d\n", N, ntransf, M); - thrust::host_vector x(M), y(M); + thrust::host_vector x(M), y(M), s{}, t{}; thrust::host_vector> c(M * ntransf), fk(ntransf * N1 * N2); - thrust::device_vector d_x(M), d_y(M); + thrust::device_vector d_x(M), d_y(M), d_s{}, d_t{}; thrust::device_vector> d_c(M * ntransf), d_fk(ntransf * N1 * N2); std::default_random_engine eng(1); @@ -53,6 +53,19 @@ int run_test(int method, int type, int N1, int N2, int ntransf, int maxbatchsize fk[i].real(randm11()); fk[i].imag(randm11()); } + } else if (type == 3) { + for (int i = 0; i < ntransf * M; i++) { + c[i].real(randm11()); + c[i].imag(randm11()); + } + s.resize(N1 * N2); + t.resize(N1 * N2); + for (int i = 0; i < N1 * N2; i++) { + s[i] = M_PI * randm11(); + t[i] = M_PI * randm11(); + } + d_s = s; + d_t = t; } else { std::cerr << "Invalid type " << type << " supplied\n"; return 1; @@ -64,6 +77,8 @@ int run_test(int method, int type, int N1, int N2, int ntransf, int maxbatchsize d_c = c; else if (type == 2) d_fk = fk; + else if (type == 3) + d_c = c; cudaEvent_t start, stop; float milliseconds = 0; @@ -109,8 +124,8 @@ int run_test(int method, int type, int N1, int N2, int ntransf, int maxbatchsize printf("[time ] cufinufft plan:\t\t %.3g s\n", milliseconds / 1000); cudaEventRecord(start); - ier = cufinufft_setpts_impl(M, d_x.data().get(), d_y.data().get(), NULL, 0, NULL, - NULL, NULL, dplan); + ier = cufinufft_setpts_impl(M, d_x.data().get(), d_y.data().get(), nullptr, N1 * N2, + d_s.data().get(), d_t.data().get(), nullptr, dplan); if (ier != 0) { printf("err: cufinufft_setpts\n"); return ier; @@ -137,6 +152,10 @@ int run_test(int method, int type, int N1, int N2, int ntransf, int maxbatchsize cudaEventRecord(start); ier = cufinufft_destroy_impl(dplan); + if (ier != 0) { + printf("err: cufinufft3d_destroy\n"); + return ier; + } cudaEventRecord(stop); cudaEventSynchronize(stop); cudaEventElapsedTime(&milliseconds, start, stop); @@ -147,6 +166,8 @@ int run_test(int method, int type, int N1, int N2, int ntransf, int maxbatchsize fk = d_fk; else if (type == 2) c = d_c; + else if (type == 3) + fk = d_fk; T rel_error = std::numeric_limits::max(); if (type == 1) { @@ -175,6 +196,18 @@ int run_test(int method, int type, int N1, int N2, int ntransf, int maxbatchsize rel_error = abs(cstart[jt] - ct) / infnorm(M, (std::complex *)c.data()); printf("[gpu ] %dth data one targ: rel err in c[%d] is %.3g\n", t, jt, rel_error); + } else if (type == 3) { + int jt = (N1 * N2) / 2; // check arbitrary choice of one targ pt + thrust::complex J = thrust::complex(0, iflag); + thrust::complex Ft = thrust::complex(0, 0); + thrust::complex *fkstart = fk.data() + (ntransf - 1) * N1 * N2; + const thrust::complex *cstart = c.data() + (ntransf - 1) * M; + + for (int j = 0; j < M; ++j) { + Ft += cstart[j] * exp(J * (x[j] * s[jt] + y[j] * t[jt])); + } + rel_error = abs(Ft - fkstart[jt]) / infnorm(N1 * N2, (std::complex *)fk.data()); + printf("[gpu ] one mode: rel err in F[%d] is %.3g\n", jt, rel_error); } printf("[totaltime] %.3g us, speed %.3g NUpts/s\n", totaltime * 1000, @@ -193,7 +226,7 @@ int main(int argc, char *argv[]) { " method: One of\n" " 1: nupts driven,\n" " 2: sub-problem, or\n" - " type: Type of transform (1, 2)\n" + " type: Type of transform (1, 2, 3)\n" " N1, N2: The size of the 2D array\n" " ntransf: Number of inputs\n" " maxbatchsize: Number of simultaneous transforms (or 0 for default)\n" diff --git a/test/cuda/cufinufft3d_test.cu b/test/cuda/cufinufft3d_test.cu index 67818c2b2..65b0d7a0c 100644 --- a/test/cuda/cufinufft3d_test.cu +++ b/test/cuda/cufinufft3d_test.cu @@ -23,10 +23,10 @@ int run_test(int method, int type, int N1, int N2, int N3, int M, T tol, T check std::cout << std::scientific << std::setprecision(3); int ier; - thrust::host_vector x(M), y(M), z(M); + thrust::host_vector x(M), y(M), z(M), s{}, t{}, u{}; thrust::host_vector> c(M), fk(N1 * N2 * N3); - thrust::device_vector d_x(M), d_y(M), d_z(M); + thrust::device_vector d_x(M), d_y(M), d_z(M), d_s{}, d_t{}, d_u{}; thrust::device_vector> d_c(M), d_fk(N1 * N2 * N3); std::default_random_engine eng(1); @@ -51,6 +51,22 @@ int run_test(int method, int type, int N1, int N2, int N3, int M, T tol, T check fk[i].real(randm11()); fk[i].imag(randm11()); } + } else if (type == 3) { + for (int i = 0; i < M; i++) { + c[i].real(randm11()); + c[i].imag(randm11()); + } + s.resize(N1 * N2 * N3); + t.resize(N1 * N2 * N3); + u.resize(N1 * N2 * N3); + for (int i = 0; i < N1 * N2 * N3; i++) { + s[i] = (N1 / 2) * randm11(); + t[i] = (N2 / 2) * randm11(); + u[i] = (N3 / 2) * randm11(); + } + d_s = s; + d_t = t; + d_u = u; } else { std::cerr << "Invalid type " << type << " supplied\n"; return 1; @@ -64,6 +80,8 @@ int run_test(int method, int type, int N1, int N2, int N3, int M, T tol, T check d_c = c; else if (type == 2) d_fk = fk; + else if (type == 3) + d_c = c; cudaEvent_t start, stop; float milliseconds = 0; @@ -112,7 +130,8 @@ int run_test(int method, int type, int N1, int N2, int N3, int M, T tol, T check cudaEventRecord(start); ier = cufinufft_setpts_impl(M, d_x.data().get(), d_y.data().get(), d_z.data().get(), - 0, nullptr, nullptr, nullptr, dplan); + N1 * N2 * N3, d_s.data().get(), d_t.data().get(), + d_u.data().get(), dplan); if (ier != 0) { printf("err: cufinufft_setpts\n"); return ier; @@ -149,6 +168,8 @@ int run_test(int method, int type, int N1, int N2, int N3, int M, T tol, T check fk = d_fk; else if (type == 2) c = d_c; + else if (type == 3) + fk = d_fk; printf("[Method %d] %d NU pts to %d U pts in %.3g s:\t%.3g NU pts/s\n", opts.gpu_method, M, N1 * N2 * N3, totaltime / 1000, M / totaltime * 1000); @@ -184,6 +205,17 @@ int run_test(int method, int type, int N1, int N2, int N3, int M, T tol, T check rel_error = abs(c[jt] - ct) / infnorm(M, (std::complex *)c.data()); printf("[gpu ] one targ: rel err in c[%ld] is %.3g\n", (int64_t)jt, rel_error); + } else if (type == 3) { + + int jt = (N1 * N2 * N3) / 2; // check arbitrary choice of one targ pt + thrust::complex J = thrust::complex(0, iflag); + thrust::complex Ft = thrust::complex(0, 0); + + for (int j = 0; j < M; ++j) { + Ft += c[j] * exp(J * (x[j] * s[jt] + y[j] * t[jt] + z[j] * u[jt])); + } + rel_error = abs(Ft - fk[jt]) / infnorm(N1 * N2 * N3, (std::complex *)fk.data()); + printf("[gpu ] one mode: rel err in F[%d] is %.3g\n", jt, rel_error); } return std::isnan(rel_error) || rel_error > checktol; @@ -198,7 +230,7 @@ int main(int argc, char *argv[]) { " 1: nupts driven,\n" " 2: sub-problem, or\n" " 4: block gather.\n" - " type: Type of transform (1, 2)" + " type: Type of transform (1, 2, 3)" " N1, N2, N3: The size of the 3D array\n" " M: The number of non-uniform points\n" " tol: NUFFT tolerance\n" diff --git a/test/cuda/cufinufft_math_test.cu b/test/cuda/cufinufft_math_test.cu new file mode 100644 index 000000000..1588abb23 --- /dev/null +++ b/test/cuda/cufinufft_math_test.cu @@ -0,0 +1,137 @@ +#include +#include +#include +#include + +// Include the custom operators for cuComplex +#include +#include + +// Helper function to create cuComplex +template cuda_complex make_cuda_complex(T real, T imag) { + return cuda_complex{real, imag}; +} + +// Helper function to compare cuComplex with std::complex using 1 - ratio as error +template +bool compareComplexRel(const cuda_complex a, const std::complex b, + const std::string &operation, + T epsilon = std::numeric_limits::epsilon()) { + const auto std_a = std::complex(a.x, a.y); + const auto err = std::abs(std_a - b) / std::abs(std_a); + const auto tol = epsilon * T(10); // factor to allow for rounding error + if (err > tol) { + std::cout << "Comparison failed in operation: " << operation << "\n"; + std::cout << "cuComplex: (" << a.x << ", " << a.y << ")\n"; + std::cout << "std::complex: (" << b.real() << ", " << b.imag() << ")\n"; + std::cout << "RelError: " << err << "\n"; + } + return err <= tol; +} + +template +bool compareComplexAbs(const cuda_complex a, const std::complex b, + const std::string &operation, + T epsilon = std::numeric_limits::epsilon()) { + const auto std_a = std::complex(a.x, a.y); + const auto err = std::abs(std_a - b); + const auto tol = epsilon * T(10); // factor to allow for rounding error + if (err > tol) { + std::cout << "Comparison failed in operation: " << operation << "\n"; + std::cout << "cuComplex: (" << a.x << ", " << a.y << ")\n"; + std::cout << "std::complex: (" << b.real() << ", " << b.imag() << ")\n"; + std::cout << "AbsError: " << err << "\n"; + } + return err <= tol; +} + +template int testRandomOperations() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dis(-1.0, 1.0); + + for (int i = 0; i < 1000; ++i) { + T real1 = dis(gen); + T imag1 = dis(gen); + T real2 = dis(gen); + T imag2 = dis(gen); + T scalar = dis(gen); + + cuda_complex a = make_cuda_complex(real1, imag1); + cuda_complex b = make_cuda_complex(real2, imag2); + std::complex std_a(real1, imag1); + std::complex std_b(real2, imag2); + + // Test addition + cuda_complex result_add = a + b; + std::complex expected_add = std_a + std_b; + if (!compareComplexAbs(result_add, expected_add, + "add complex<" + std::string(typeid(T).name()) + "> complex<" + + std::string(typeid(T).name()) + ">")) + return 1; + + // Test subtraction + cuda_complex result_sub = a - b; + std::complex expected_sub = std_a - std_b; + if (!compareComplexAbs(result_sub, expected_sub, + "sub complex<" + std::string(typeid(T).name()) + "> complex<" + + std::string(typeid(T).name()) + ">")) + return 1; + + // Test multiplication + cuda_complex result_mul = a * b; + std::complex expected_mul = std_a * std_b; + if (!compareComplexRel(result_mul, expected_mul, + "mul complex<" + std::string(typeid(T).name()) + "> complex<" + + std::string(typeid(T).name()) + ">")) + return 1; + + // Test division + cuda_complex result_div = a / b; + std::complex expected_div = std_a / std_b; + if (!compareComplexRel(result_div, expected_div, + "div complex<" + std::string(typeid(T).name()) + "> complex<" + + std::string(typeid(T).name()) + ">")) + return 1; + + // Test addition with scalar + cuda_complex result_add_scalar = a + scalar; + std::complex expected_add_scalar = std_a + scalar; + if (!compareComplexRel(result_add_scalar, expected_add_scalar, + "add complex<" + std::string(typeid(T).name()) + "> scalar<" + + std::string(typeid(T).name()) + ">")) + return 1; + + // Test subtraction with scalar + cuda_complex result_sub_scalar = a - scalar; + std::complex expected_sub_scalar = std_a - scalar; + if (!compareComplexRel(result_sub_scalar, expected_sub_scalar, + "sub complex<" + std::string(typeid(T).name()) + "> scalar<" + + std::string(typeid(T).name()) + ">")) + return 1; + + // Test multiplication with scalar + cuda_complex result_mul_scalar = a * scalar; + std::complex expected_mul_scalar = std_a * scalar; + if (!compareComplexRel(result_mul_scalar, expected_mul_scalar, + "mul complex<" + std::string(typeid(T).name()) + "> scalar<" + + std::string(typeid(T).name()) + ">")) + return 1; + + cuda_complex result_div_scalar = a / scalar; + std::complex expected_div_scalar = std_a / scalar; + if (!compareComplexRel(result_div_scalar, expected_div_scalar, + "div complex<" + std::string(typeid(T).name()) + "> scalar<" + + std::string(typeid(T).name()) + ">")) + return 1; + } + return 0; +} + +int main() { + if (testRandomOperations()) return 1; + if (testRandomOperations()) return 1; + + std::cout << "All tests passed!" << std::endl; + return 0; +} diff --git a/test/cuda/fseries_kernel_test.cu b/test/cuda/fseries_kernel_test.cu deleted file mode 100644 index 7f18ee21c..000000000 --- a/test/cuda/fseries_kernel_test.cu +++ /dev/null @@ -1,158 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -using namespace cufinufft::common; -using namespace cufinufft::spreadinterp; -using namespace cufinufft::utils; - -template int run_test(int nf1, int dim, T eps, int gpu, int nf2, int nf3) { - - finufft_spread_opts opts; - T *fwkerhalf1, *fwkerhalf2, *fwkerhalf3; - T *d_fwkerhalf1, *d_fwkerhalf2, *d_fwkerhalf3; - checkCudaErrors(cudaMalloc(&d_fwkerhalf1, sizeof(T) * (nf1 / 2 + 1))); - if (dim > 1) checkCudaErrors(cudaMalloc(&d_fwkerhalf2, sizeof(T) * (nf2 / 2 + 1))); - if (dim > 2) checkCudaErrors(cudaMalloc(&d_fwkerhalf3, sizeof(T) * (nf3 / 2 + 1))); - - int ier = setup_spreader(opts, (T)eps, (T)2.0, 0); - - cudaEvent_t start, stop; - cudaEventCreate(&start); - cudaEventCreate(&stop); - - float milliseconds = 0; - float gputime = 0; - float cputime = 0; - - CNTime timer; - if (!gpu) { - timer.start(); - fwkerhalf1 = (T *)malloc(sizeof(T) * (nf1 / 2 + 1)); - if (dim > 1) fwkerhalf2 = (T *)malloc(sizeof(T) * (nf2 / 2 + 1)); - if (dim > 2) fwkerhalf3 = (T *)malloc(sizeof(T) * (nf3 / 2 + 1)); - - onedim_fseries_kernel(nf1, fwkerhalf1, opts); - if (dim > 1) onedim_fseries_kernel(nf2, fwkerhalf2, opts); - if (dim > 2) onedim_fseries_kernel(nf3, fwkerhalf3, opts); - cputime = timer.elapsedsec(); - cudaEventRecord(start); - { - checkCudaErrors(cudaMemcpy(d_fwkerhalf1, fwkerhalf1, sizeof(T) * (nf1 / 2 + 1), - cudaMemcpyHostToDevice)); - if (dim > 1) - checkCudaErrors(cudaMemcpy(d_fwkerhalf2, fwkerhalf2, sizeof(T) * (nf2 / 2 + 1), - cudaMemcpyHostToDevice)); - if (dim > 2) - checkCudaErrors(cudaMemcpy(d_fwkerhalf3, fwkerhalf3, sizeof(T) * (nf3 / 2 + 1), - cudaMemcpyHostToDevice)); - } - cudaEventRecord(stop); - cudaEventSynchronize(stop); - cudaEventElapsedTime(&milliseconds, start, stop); - gputime = milliseconds; - printf("[time ] dim=%d, nf1=%8d, ns=%2d, CPU: %6.2f ms\n", dim, nf1, opts.nspread, - gputime + cputime * 1000); - free(fwkerhalf1); - if (dim > 1) free(fwkerhalf2); - if (dim > 2) free(fwkerhalf3); - } else { - timer.start(); - std::complex a[dim * MAX_NQUAD]; - T f[dim * MAX_NQUAD]; - onedim_fseries_kernel_precomp(nf1, f, a, opts); - if (dim > 1) onedim_fseries_kernel_precomp(nf2, f + MAX_NQUAD, a + MAX_NQUAD, opts); - if (dim > 2) - onedim_fseries_kernel_precomp(nf3, f + 2 * MAX_NQUAD, a + 2 * MAX_NQUAD, opts); - cputime = timer.elapsedsec(); - - cuDoubleComplex *d_a; - T *d_f; - cudaEventRecord(start); - { - checkCudaErrors(cudaMalloc(&d_a, dim * MAX_NQUAD * sizeof(cuDoubleComplex))); - checkCudaErrors(cudaMalloc(&d_f, dim * MAX_NQUAD * sizeof(T))); - checkCudaErrors(cudaMemcpy(d_a, a, dim * MAX_NQUAD * sizeof(cuDoubleComplex), - cudaMemcpyHostToDevice)); - checkCudaErrors( - cudaMemcpy(d_f, f, dim * MAX_NQUAD * sizeof(T), cudaMemcpyHostToDevice)); - ier = - cufserieskernelcompute(dim, nf1, nf2, nf3, d_f, d_a, d_fwkerhalf1, d_fwkerhalf2, - d_fwkerhalf3, opts.nspread, cudaStreamDefault); - } - cudaEventRecord(stop); - cudaEventSynchronize(stop); - cudaEventElapsedTime(&milliseconds, start, stop); - gputime = milliseconds; - printf("[time ] dim=%d, nf1=%8d, ns=%2d, GPU: %6.2f ms\n", dim, nf1, opts.nspread, - gputime + cputime * 1000); - cudaFree(d_a); - cudaFree(d_f); - } - - fwkerhalf1 = (T *)malloc(sizeof(T) * (nf1 / 2 + 1)); - if (dim > 1) fwkerhalf2 = (T *)malloc(sizeof(T) * (nf2 / 2 + 1)); - if (dim > 2) fwkerhalf3 = (T *)malloc(sizeof(T) * (nf3 / 2 + 1)); - - checkCudaErrors(cudaMemcpy(fwkerhalf1, d_fwkerhalf1, sizeof(T) * (nf1 / 2 + 1), - cudaMemcpyDeviceToHost)); - if (dim > 1) - checkCudaErrors(cudaMemcpy(fwkerhalf2, d_fwkerhalf2, sizeof(T) * (nf2 / 2 + 1), - cudaMemcpyDeviceToHost)); - if (dim > 2) - checkCudaErrors(cudaMemcpy(fwkerhalf3, d_fwkerhalf3, sizeof(T) * (nf3 / 2 + 1), - cudaMemcpyDeviceToHost)); - for (int i = 0; i < nf1 / 2 + 1; i++) printf("%10.8e ", fwkerhalf1[i]); - printf("\n"); - if (dim > 1) - for (int i = 0; i < nf2 / 2 + 1; i++) printf("%10.8e ", fwkerhalf2[i]); - printf("\n"); - if (dim > 2) - for (int i = 0; i < nf3 / 2 + 1; i++) printf("%10.8e ", fwkerhalf3[i]); - printf("\n"); - - return 0; -} - -int main(int argc, char *argv[]) { - if (argc < 3) { - fprintf(stderr, - "Usage: onedim_fseries_kernel_test prec nf1 [dim [tol [gpuversion [nf2 " - "[nf3]]]]]\n" - "Arguments:\n" - " prec: 'f' or 'd' (float/double)\n" - " nf1: The size of the upsampled fine grid size in x.\n" - " dim: Dimension of the nuFFT.\n" - " tol: NUFFT tolerance (default 1e-6).\n" - " gpuversion: Use gpu version or not (default True).\n" - " nf2: The size of the upsampled fine grid size in y. (default nf1)\n" - " nf3: The size of the upsampled fine grid size in z. (default nf3)\n"); - return 1; - } - char prec = argv[1][0]; - int nf1 = std::atof(argv[2]); - int dim = 1; - double eps = 1e-6; - int gpu = 1; - int nf2 = nf1; - int nf3 = nf1; - if (argc > 3) dim = std::atoi(argv[3]); - if (argc > 4) eps = std::atof(argv[4]); - if (argc > 5) gpu = std::atoi(argv[5]); - if (argc > 6) nf2 = std::atoi(argv[6]); - if (argc > 7) nf3 = std::atoi(argv[7]); - - if (prec == 'f') - return run_test(nf1, dim, eps, gpu, nf2, nf3); - else if (prec == 'd') - return run_test(nf1, dim, eps, gpu, nf2, nf3); - else - return -1; -} diff --git a/test/cuda/fseriesperf.sh b/test/cuda/fseriesperf.sh deleted file mode 100755 index 36af42276..000000000 --- a/test/cuda/fseriesperf.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash -# basic perf test of compute fseries for 1d, single/double -# Melody 02/20/22 - -BINDIR=./ - -BIN=$BINDIR/fseries_kernel_test -DIM=1 - -echo "Double.............................................." -for N in 1e2 5e2 1e3 2e3 5e3 1e4 5e4 1e5 5e5 -do - for TOL in 1e-8 - do - $BIN $N $DIM $TOL 0 - $BIN $N $DIM $TOL 1 - done -done - -BIN=$BINDIR/fseries_kernel_testf -echo "Single.............................................." -for N in 1e2 5e2 1e3 2e3 5e3 1e4 5e4 1e5 5e5 -do - for TOL in 1e-6 - do - $BIN $N $DIM $TOL 0 - $BIN $N $DIM $TOL 1 - done -done