From d0c613ae8489e8b120c22ede0a83ddde5c7ba3fc Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Sat, 12 Aug 2023 16:48:42 -0400 Subject: [PATCH] hip wip --- Src/Base/AMReX_GpuError.H | 5 +++++ Src/Base/AMReX_Random.cpp | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/Src/Base/AMReX_GpuError.H b/Src/Base/AMReX_GpuError.H index 5242f5dc161..ce3ac188a85 100644 --- a/Src/Base/AMReX_GpuError.H +++ b/Src/Base/AMReX_GpuError.H @@ -95,6 +95,11 @@ namespace Gpu { + " " + hipGetErrorString(amrex_i_err)); \ amrex::Abort(errStr); \ }} + +#define AMREX_HIPRAND_SAFE_CALL(x) do { if((x)!=HIPRAND_STATUS_SUCCESS) { \ + std::string errStr(std::string("HIPRAND error in file ") + __FILE__ \ + + " line " + std::to_string(__LINE__)); \ + amrex::Abort(errStr); }} while(0) #endif #define AMREX_GPU_ERROR_CHECK() amrex::Gpu::ErrorCheck(__FILE__, __LINE__) diff --git a/Src/Base/AMReX_Random.cpp b/Src/Base/AMReX_Random.cpp index fadc8486a5e..f9a15f34239 100644 --- a/Src/Base/AMReX_Random.cpp +++ b/Src/Base/AMReX_Random.cpp @@ -31,6 +31,8 @@ namespace bool generator_initialized = false; #if defined(AMREX_USE_CUDA) curandGenerator_t cuda_rand_gen; +#elif defined(AMREX_USE_HIP) + hiprandGenerator_t hip_rand_gen; #endif } #endif @@ -231,17 +233,35 @@ void FillRandomNormal (Real* p, Long N, Real mean, Real stddev) #ifdef AMREX_USE_GPU #if defined(AMREX_USE_CUDA) + if (! generator_initialized) { AMREX_CURAND_SAFE_CALL(curandCreateGenerator(&cuda_rand_gen, CURAND_RNG_PSEUDO_DEFAULT)); AMREX_CURAND_SAFE_CALL(curandSetPseudoRandomGeneratorSeed(cuda_rand_gen, 1234ULL)); + generator_initialized = true; } #ifdef BL_USE_FLOAT AMREX_CURAND_SAFE_CALL(curandGenerateNormal(cuda_rand_gen, p, N, mean, stddev)); #else AMREX_CURAND_SAFE_CALL(curandGenerateNormalDouble(cuda_rand_gen, p, N, mean, stddev)); #endif + +#elif defined(AMREX_USE_HIP) + + if (! generator_initialized) { + AMREX_HIPRAND_SAFE_CALL(hiprandCreateGenerator(&hip_rand_gen, + HIPRAND_RNG_PSEUDO_DEFAULT)); + AMREX_HIPRAND_SAFE_CALL(hiprandSetPseudoRandomGeneratorSeed(hip_rand_gen, + 1234ULL)); + generator_initialized = true; + } +#ifdef BL_USE_FLOAT + AMREX_HIPRAND_SAFE_CALL(hiprandGenerateNormal(hip_rand_gen, p, N, mean, stddev)); +#else + AMREX_HIPRAND_SAFE_CALL(hiprandGenerateNormalDouble(hip_rand_gen, p, N, mean, stddev)); +#endif + #endif Gpu::streamSynchronize();