Skip to content

Commit

Permalink
hip wip
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Aug 12, 2023
1 parent 7a7e4ae commit d0c613a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
5 changes: 5 additions & 0 deletions Src/Base/AMReX_GpuError.H
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ namespace Gpu {
+ " " + hipGetErrorString(amrex_i_err)); \
amrex::Abort(errStr); \
}}

#define AMREX_HIPRAND_SAFE_CALL(x) do { if((x)!=HIPRAND_STATUS_SUCCESS) { \
std::string errStr(std::string("HIPRAND error in file ") + __FILE__ \
+ " line " + std::to_string(__LINE__)); \
amrex::Abort(errStr); }} while(0)
#endif

#define AMREX_GPU_ERROR_CHECK() amrex::Gpu::ErrorCheck(__FILE__, __LINE__)
Expand Down
20 changes: 20 additions & 0 deletions Src/Base/AMReX_Random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ namespace
bool generator_initialized = false;
#if defined(AMREX_USE_CUDA)
curandGenerator_t cuda_rand_gen;
#elif defined(AMREX_USE_HIP)
hiprandGenerator_t hip_rand_gen;
#endif
}
#endif
Expand Down Expand Up @@ -231,17 +233,35 @@ void FillRandomNormal (Real* p, Long N, Real mean, Real stddev)
#ifdef AMREX_USE_GPU

#if defined(AMREX_USE_CUDA)

if (! generator_initialized) {
AMREX_CURAND_SAFE_CALL(curandCreateGenerator(&cuda_rand_gen,
CURAND_RNG_PSEUDO_DEFAULT));
AMREX_CURAND_SAFE_CALL(curandSetPseudoRandomGeneratorSeed(cuda_rand_gen,
1234ULL));
generator_initialized = true;
}
#ifdef BL_USE_FLOAT
AMREX_CURAND_SAFE_CALL(curandGenerateNormal(cuda_rand_gen, p, N, mean, stddev));
#else
AMREX_CURAND_SAFE_CALL(curandGenerateNormalDouble(cuda_rand_gen, p, N, mean, stddev));
#endif

#elif defined(AMREX_USE_HIP)

if (! generator_initialized) {
AMREX_HIPRAND_SAFE_CALL(hiprandCreateGenerator(&hip_rand_gen,
HIPRAND_RNG_PSEUDO_DEFAULT));
AMREX_HIPRAND_SAFE_CALL(hiprandSetPseudoRandomGeneratorSeed(hip_rand_gen,
1234ULL));
generator_initialized = true;
}
#ifdef BL_USE_FLOAT
AMREX_HIPRAND_SAFE_CALL(hiprandGenerateNormal(hip_rand_gen, p, N, mean, stddev));
#else
AMREX_HIPRAND_SAFE_CALL(hiprandGenerateNormalDouble(hip_rand_gen, p, N, mean, stddev));
#endif

#endif

Gpu::streamSynchronize();
Expand Down

0 comments on commit d0c613a

Please sign in to comment.