Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STL random engine for MacOS #1099

Open
wants to merge 23 commits into
base: branch-24.03
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
dd05748
Branch-off distributions on STL. Group 1.
aschaffer Nov 14, 2023
5c344e7
Added an additional code insulation layer to simplify portability to …
aschaffer Nov 15, 2023
50b2f31
Fixes to STL insulation layer.
aschaffer Nov 15, 2023
a2492be
Distributions STL split. Group 2.
aschaffer Nov 15, 2023
f2198ce
Normal distribution refactored.
aschaffer Nov 16, 2023
3df844a
Distributions STL split. Group 3.
aschaffer Nov 16, 2023
e55921a
Better engine_uniform() function template.
aschaffer Nov 16, 2023
d34ecb5
Refactoring for curand abstraction replacements. Step 1.
aschaffer Nov 21, 2023
3b99848
Refactoring for curand abstraction replacements. Step 2.
aschaffer Nov 21, 2023
3331e33
Refactoring for curand abstraction replacements. Step 3.
aschaffer Nov 21, 2023
b343896
Refactoring for curand abstraction replacements. Step 4.
aschaffer Nov 22, 2023
231ea92
Refactoring for curand abstraction replacements. Step 5.
aschaffer Nov 22, 2023
9077427
Clean-up of some headers.
aschaffer Nov 23, 2023
c2f75e0
Branch-off STL code. Builds, but tests fail, including on device.
aschaffer Nov 28, 2023
72a5df0
Branch-off STL code. Builds, preliminary tests pass. Device fail.
aschaffer Nov 29, 2023
40b8092
Merge branch 'nv-legate:branch-24.01' into enh-ext-rng
aschaffer Dec 1, 2023
fe0c6b6
Merge branch 'branch-24.01' of github.com:nv-legate/cunumeric into en…
aschaffer Dec 4, 2023
4a4ca08
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2023
770de2f
MACOS macro guards.
aschaffer Dec 20, 2023
382d370
Merge branch 'nv-legate:branch-24.01' into enh-ext-rng
aschaffer Jan 4, 2024
8d85f33
Merge branch 'nv-legate:branch-24.01' into enh-ext-rng
aschaffer Feb 5, 2024
6039dfc
Merge branch 'nv-legate:branch-24.01' into enh-ext-rng
aschaffer Feb 14, 2024
51424ac
Merge branch 'nv-legate:branch-24.03' into enh-ext-rng
aschaffer Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions src/cunumeric/random/bitgenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@
*
*/

// MacOS host variant:
//
#if defined(__APPLE__) && defined(__MACH__)
#define USE_STL_RANDOM_ENGINE_
#endif

#include "cunumeric/random/bitgenerator.h"
#include "cunumeric/random/bitgenerator_template.inl"
#include "cunumeric/random/bitgenerator_util.h"

#include "cunumeric/random/curand_help.h"
#include "cunumeric/random/rnd_types.h"
#include "cunumeric/random/randutil/randutil.h"

#include "cunumeric/random/bitgenerator_curand.inl"
Expand All @@ -31,6 +37,16 @@ static Logger log_curand("cunumeric.random");

Logger& randutil_log() { return log_curand; }

#ifdef USE_STL_RANDOM_ENGINE_
void randutil_check_status(rnd_status_t error, const char* file, int line)
{
if (error) {
randutil_log().fatal() << "Internal random engine failure with error " << (int)error
<< " in file " << file << " at line " << line;
assert(false);
}
}
#else
void randutil_check_curand(curandStatus_t error, const char* file, int line)
{
if (error != CURAND_STATUS_SUCCESS) {
Expand All @@ -39,15 +55,16 @@ void randutil_check_curand(curandStatus_t error, const char* file, int line)
assert(false);
}
}
#endif

struct CPUGenerator : public CURANDGenerator {
CPUGenerator(BitGeneratorType gentype, uint64_t seed, uint64_t generatorId, uint32_t flags)
: CURANDGenerator(gentype, seed, generatorId)
{
CHECK_CURAND(::randutilCreateGeneratorHost(&gen_, type_, seed, generatorId, flags));
CHECK_RND_ENGINE(::randutilCreateGeneratorHost(&gen_, type_, seed, generatorId, flags));
}

virtual ~CPUGenerator() { CHECK_CURAND(::randutilDestroyGenerator(gen_)); }
virtual ~CPUGenerator() { CHECK_RND_ENGINE(::randutilDestroyGenerator(gen_)); }
};

template <>
Expand Down
15 changes: 13 additions & 2 deletions src/cunumeric/random/bitgenerator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,30 @@ namespace cunumeric {

using namespace legate;

// required by CHECK_CURAND_DEVICE:
//
void randutil_check_curand_device(curandStatus_t error, const char* file, int line)
{
if (error != CURAND_STATUS_SUCCESS) {
randutil_log().fatal() << "Internal CURAND failure with error " << (int)error << " in file "
<< file << " at line " << line;
assert(false);
}
}

struct GPUGenerator : public CURANDGenerator {
cudaStream_t stream_;
GPUGenerator(BitGeneratorType gentype, uint64_t seed, uint64_t generatorId, uint32_t flags)
: CURANDGenerator(gentype, seed, generatorId)
{
CHECK_CUDA(::cudaStreamCreate(&stream_));
CHECK_CURAND(::randutilCreateGenerator(&gen_, type_, seed, generatorId, flags, stream_));
CHECK_CURAND_DEVICE(::randutilCreateGenerator(&gen_, type_, seed, generatorId, flags, stream_));
}

virtual ~GPUGenerator()
{
CHECK_CUDA(::cudaStreamSynchronize(stream_));
CHECK_CURAND(::randutilDestroyGenerator(gen_));
CHECK_CURAND_DEVICE(::randutilDestroyGenerator(gen_));
}
};

Expand Down
113 changes: 57 additions & 56 deletions src/cunumeric/random/bitgenerator_curand.inl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include "cunumeric/random/bitgenerator_template.inl"
#include "cunumeric/random/bitgenerator_util.h"

#include "cunumeric/random/curand_help.h"
#include "cunumeric/random/rnd_types.h"
#include "cunumeric/random/randutil/randutil.h"

namespace cunumeric {
Expand All @@ -41,11 +41,11 @@ struct CURANDGenerator {
randutilGenerator_t gen_;
uint64_t seed_;
uint64_t generatorId_;
curandRngType type_;
randRngType type_;

protected:
CURANDGenerator(BitGeneratorType gentype, uint64_t seed, uint64_t generatorId)
: type_(get_curandRngType(gentype)), seed_(seed), generatorId_(generatorId)
: type_(get_rndRngType(gentype)), seed_(seed), generatorId_(generatorId)
{
randutil_log().debug() << "CURANDGenerator::create";
}
Expand All @@ -57,217 +57,218 @@ struct CURANDGenerator {

void generate_raw(uint64_t count, uint32_t* out)
{
CHECK_CURAND(::randutilGenerateRawUInt32(gen_, out, count));
CHECK_RND_ENGINE(::randutilGenerateRawUInt32(gen_, out, count));
}
void generate_integer_64(uint64_t count, int64_t* out, int64_t low, int64_t high)
{
CHECK_CURAND(::randutilGenerateIntegers64(gen_, out, count, low, high));
CHECK_RND_ENGINE(::randutilGenerateIntegers64(gen_, out, count, low, high));
}
void generate_integer_16(uint64_t count, int16_t* out, int16_t low, int16_t high)
{
CHECK_CURAND(::randutilGenerateIntegers16(gen_, out, count, low, high));
CHECK_RND_ENGINE(::randutilGenerateIntegers16(gen_, out, count, low, high));
}
void generate_integer_32(uint64_t count, int32_t* out, int32_t low, int32_t high)
{
CHECK_CURAND(::randutilGenerateIntegers32(gen_, out, count, low, high));
CHECK_RND_ENGINE(::randutilGenerateIntegers32(gen_, out, count, low, high));
}
void generate_uniform_64(uint64_t count, double* out, double low, double high)
{
CHECK_CURAND(::randutilGenerateUniformDoubleEx(gen_, out, count, low, high));
CHECK_RND_ENGINE(::randutilGenerateUniformDoubleEx(gen_, out, count, low, high));
}
void generate_uniform_32(uint64_t count, float* out, float low, float high)
{
CHECK_CURAND(::randutilGenerateUniformEx(gen_, out, count, low, high));
CHECK_RND_ENGINE(::randutilGenerateUniformEx(gen_, out, count, low, high));
}
void generate_lognormal_64(uint64_t count, double* out, double mean, double stdev)
{
CHECK_CURAND(::randutilGenerateLogNormalDoubleEx(gen_, out, count, mean, stdev));
CHECK_RND_ENGINE(::randutilGenerateLogNormalDoubleEx(gen_, out, count, mean, stdev));
}
void generate_lognormal_32(uint64_t count, float* out, float mean, float stdev)
{
CHECK_CURAND(::randutilGenerateLogNormalEx(gen_, out, count, mean, stdev));
CHECK_RND_ENGINE(::randutilGenerateLogNormalEx(gen_, out, count, mean, stdev));
}
void generate_normal_64(uint64_t count, double* out, double mean, double stdev)
{
CHECK_CURAND(::randutilGenerateNormalDoubleEx(gen_, out, count, mean, stdev));
CHECK_RND_ENGINE(::randutilGenerateNormalDoubleEx(gen_, out, count, mean, stdev));
}
void generate_normal_32(uint64_t count, float* out, float mean, float stdev)
{
CHECK_CURAND(::randutilGenerateNormalEx(gen_, out, count, mean, stdev));
CHECK_RND_ENGINE(::randutilGenerateNormalEx(gen_, out, count, mean, stdev));
}
void generate_poisson(uint64_t count, uint32_t* out, double lam)
{
CHECK_CURAND(::randutilGeneratePoissonEx(gen_, out, count, lam));
CHECK_RND_ENGINE(::randutilGeneratePoissonEx(gen_, out, count, lam));
}
void generate_exponential_64(uint64_t count, double* out, double scale)
{
CHECK_CURAND(::randutilGenerateExponentialDoubleEx(gen_, out, count, scale));
CHECK_RND_ENGINE(::randutilGenerateExponentialDoubleEx(gen_, out, count, scale));
}
void generate_exponential_32(uint64_t count, float* out, float scale)
{
CHECK_CURAND(::randutilGenerateExponentialEx(gen_, out, count, scale));
CHECK_RND_ENGINE(::randutilGenerateExponentialEx(gen_, out, count, scale));
}
void generate_gumbel_64(uint64_t count, double* out, double mu, double beta)
{
CHECK_CURAND(::randutilGenerateGumbelDoubleEx(gen_, out, count, mu, beta));
CHECK_RND_ENGINE(::randutilGenerateGumbelDoubleEx(gen_, out, count, mu, beta));
}
void generate_gumbel_32(uint64_t count, float* out, float mu, float beta)
{
CHECK_CURAND(::randutilGenerateGumbelEx(gen_, out, count, mu, beta));
CHECK_RND_ENGINE(::randutilGenerateGumbelEx(gen_, out, count, mu, beta));
}
void generate_laplace_64(uint64_t count, double* out, double mu, double beta)
{
CHECK_CURAND(::randutilGenerateLaplaceDoubleEx(gen_, out, count, mu, beta));
CHECK_RND_ENGINE(::randutilGenerateLaplaceDoubleEx(gen_, out, count, mu, beta));
}
void generate_laplace_32(uint64_t count, float* out, float mu, float beta)
{
CHECK_CURAND(::randutilGenerateLaplaceEx(gen_, out, count, mu, beta));
CHECK_RND_ENGINE(::randutilGenerateLaplaceEx(gen_, out, count, mu, beta));
}
void generate_logistic_64(uint64_t count, double* out, double mu, double beta)
{
CHECK_CURAND(::randutilGenerateLogisticDoubleEx(gen_, out, count, mu, beta));
CHECK_RND_ENGINE(::randutilGenerateLogisticDoubleEx(gen_, out, count, mu, beta));
}
void generate_logistic_32(uint64_t count, float* out, float mu, float beta)
{
CHECK_CURAND(::randutilGenerateLogisticEx(gen_, out, count, mu, beta));
CHECK_RND_ENGINE(::randutilGenerateLogisticEx(gen_, out, count, mu, beta));
}
void generate_pareto_64(uint64_t count, double* out, double alpha)
{
CHECK_CURAND(::randutilGenerateParetoDoubleEx(gen_, out, count, 1.0, alpha));
CHECK_RND_ENGINE(::randutilGenerateParetoDoubleEx(gen_, out, count, 1.0, alpha));
}
void generate_pareto_32(uint64_t count, float* out, float alpha)
{
CHECK_CURAND(::randutilGenerateParetoEx(gen_, out, count, 1.0f, alpha));
CHECK_RND_ENGINE(::randutilGenerateParetoEx(gen_, out, count, 1.0f, alpha));
}
void generate_power_64(uint64_t count, double* out, double alpha)
{
CHECK_CURAND(::randutilGeneratePowerDoubleEx(gen_, out, count, alpha));
CHECK_RND_ENGINE(::randutilGeneratePowerDoubleEx(gen_, out, count, alpha));
}
void generate_power_32(uint64_t count, float* out, float alpha)
{
CHECK_CURAND(::randutilGeneratePowerEx(gen_, out, count, alpha));
CHECK_RND_ENGINE(::randutilGeneratePowerEx(gen_, out, count, alpha));
}
void generate_rayleigh_64(uint64_t count, double* out, double sigma)
{
CHECK_CURAND(::randutilGenerateRayleighDoubleEx(gen_, out, count, sigma));
CHECK_RND_ENGINE(::randutilGenerateRayleighDoubleEx(gen_, out, count, sigma));
}
void generate_rayleigh_32(uint64_t count, float* out, float sigma)
{
CHECK_CURAND(::randutilGenerateRayleighEx(gen_, out, count, sigma));
CHECK_RND_ENGINE(::randutilGenerateRayleighEx(gen_, out, count, sigma));
}
void generate_cauchy_64(uint64_t count, double* out, double x0, double gamma)
{
CHECK_CURAND(::randutilGenerateCauchyDoubleEx(gen_, out, count, x0, gamma));
CHECK_RND_ENGINE(::randutilGenerateCauchyDoubleEx(gen_, out, count, x0, gamma));
}
void generate_cauchy_32(uint64_t count, float* out, float x0, float gamma)
{
CHECK_CURAND(::randutilGenerateCauchyEx(gen_, out, count, x0, gamma));
CHECK_RND_ENGINE(::randutilGenerateCauchyEx(gen_, out, count, x0, gamma));
}
void generate_triangular_64(uint64_t count, double* out, double a, double b, double c)
{
CHECK_CURAND(::randutilGenerateTriangularDoubleEx(gen_, out, count, a, b, c));
CHECK_RND_ENGINE(::randutilGenerateTriangularDoubleEx(gen_, out, count, a, b, c));
}
void generate_triangular_32(uint64_t count, float* out, float a, float b, float c)
{
CHECK_CURAND(::randutilGenerateTriangularEx(gen_, out, count, a, b, c));
CHECK_RND_ENGINE(::randutilGenerateTriangularEx(gen_, out, count, a, b, c));
}
void generate_weibull_64(uint64_t count, double* out, double lam, double k)
{
CHECK_CURAND(::randutilGenerateWeibullDoubleEx(gen_, out, count, lam, k));
CHECK_RND_ENGINE(::randutilGenerateWeibullDoubleEx(gen_, out, count, lam, k));
}
void generate_weibull_32(uint64_t count, float* out, float lam, float k)
{
CHECK_CURAND(::randutilGenerateWeibullEx(gen_, out, count, lam, k));
CHECK_RND_ENGINE(::randutilGenerateWeibullEx(gen_, out, count, lam, k));
}
void generate_beta_64(uint64_t count, double* out, double a, double b)
{
CHECK_CURAND(::randutilGenerateBetaDoubleEx(gen_, out, count, a, b));
CHECK_RND_ENGINE(::randutilGenerateBetaDoubleEx(gen_, out, count, a, b));
}
void generate_beta_32(uint64_t count, float* out, float a, float b)
{
CHECK_CURAND(::randutilGenerateBetaEx(gen_, out, count, a, b));
CHECK_RND_ENGINE(::randutilGenerateBetaEx(gen_, out, count, a, b));
}
void generate_f_64(uint64_t count, double* out, double dfnum, double dfden)
{
CHECK_CURAND(::randutilGenerateFisherSnedecorDoubleEx(gen_, out, count, dfnum, dfden));
CHECK_RND_ENGINE(::randutilGenerateFisherSnedecorDoubleEx(gen_, out, count, dfnum, dfden));
}
void generate_f_32(uint64_t count, float* out, float dfnum, float dfden)
{
CHECK_CURAND(::randutilGenerateFisherSnedecorEx(gen_, out, count, dfnum, dfden));
CHECK_RND_ENGINE(::randutilGenerateFisherSnedecorEx(gen_, out, count, dfnum, dfden));
}
void generate_logseries(uint64_t count, uint32_t* out, double p)
{
CHECK_CURAND(::randutilGenerateLogSeriesEx(gen_, out, count, p));
CHECK_RND_ENGINE(::randutilGenerateLogSeriesEx(gen_, out, count, p));
}
void generate_noncentral_f_64(
uint64_t count, double* out, double dfnum, double dfden, double nonc)
{
CHECK_CURAND(::randutilGenerateFisherSnedecorDoubleEx(gen_, out, count, dfnum, dfden, nonc));
CHECK_RND_ENGINE(
::randutilGenerateFisherSnedecorDoubleEx(gen_, out, count, dfnum, dfden, nonc));
}
void generate_noncentral_f_32(uint64_t count, float* out, float dfnum, float dfden, float nonc)
{
CHECK_CURAND(::randutilGenerateFisherSnedecorEx(gen_, out, count, dfnum, dfden, nonc));
CHECK_RND_ENGINE(::randutilGenerateFisherSnedecorEx(gen_, out, count, dfnum, dfden, nonc));
}
void generate_chisquare_64(uint64_t count, double* out, double df, double nonc)
{
CHECK_CURAND(::randutilGenerateChiSquareDoubleEx(gen_, out, count, df, nonc));
CHECK_RND_ENGINE(::randutilGenerateChiSquareDoubleEx(gen_, out, count, df, nonc));
}
void generate_chisquare_32(uint64_t count, float* out, float df, float nonc)
{
CHECK_CURAND(::randutilGenerateChiSquareEx(gen_, out, count, df, nonc));
CHECK_RND_ENGINE(::randutilGenerateChiSquareEx(gen_, out, count, df, nonc));
}
void generate_gamma_64(uint64_t count, double* out, double k, double theta)
{
CHECK_CURAND(::randutilGenerateGammaDoubleEx(gen_, out, count, k, theta));
CHECK_RND_ENGINE(::randutilGenerateGammaDoubleEx(gen_, out, count, k, theta));
}
void generate_gamma_32(uint64_t count, float* out, float k, float theta)
{
CHECK_CURAND(::randutilGenerateGammaEx(gen_, out, count, k, theta));
CHECK_RND_ENGINE(::randutilGenerateGammaEx(gen_, out, count, k, theta));
}
void generate_standard_t_64(uint64_t count, double* out, double df)
{
CHECK_CURAND(::randutilGenerateStandardTDoubleEx(gen_, out, count, df));
CHECK_RND_ENGINE(::randutilGenerateStandardTDoubleEx(gen_, out, count, df));
}
void generate_standard_t_32(uint64_t count, float* out, float df)
{
CHECK_CURAND(::randutilGenerateStandardTEx(gen_, out, count, df));
CHECK_RND_ENGINE(::randutilGenerateStandardTEx(gen_, out, count, df));
}
void generate_hypergeometric(
uint64_t count, uint32_t* out, int64_t ngood, int64_t nbad, int64_t nsample)
{
CHECK_CURAND(::randutilGenerateHyperGeometricEx(gen_, out, count, ngood, nbad, nsample));
CHECK_RND_ENGINE(::randutilGenerateHyperGeometricEx(gen_, out, count, ngood, nbad, nsample));
}
void generate_vonmises_64(uint64_t count, double* out, double mu, double kappa)
{
CHECK_CURAND(::randutilGenerateVonMisesDoubleEx(gen_, out, count, mu, kappa));
CHECK_RND_ENGINE(::randutilGenerateVonMisesDoubleEx(gen_, out, count, mu, kappa));
}
void generate_vonmises_32(uint64_t count, float* out, float mu, float kappa)
{
CHECK_CURAND(::randutilGenerateVonMisesEx(gen_, out, count, mu, kappa));
CHECK_RND_ENGINE(::randutilGenerateVonMisesEx(gen_, out, count, mu, kappa));
}
void generate_zipf(uint64_t count, uint32_t* out, double a)
{
CHECK_CURAND(::randutilGenerateZipfEx(gen_, out, count, a));
CHECK_RND_ENGINE(::randutilGenerateZipfEx(gen_, out, count, a));
}
void generate_geometric(uint64_t count, uint32_t* out, double p)
{
CHECK_CURAND(::randutilGenerateGeometricEx(gen_, out, count, p));
CHECK_RND_ENGINE(::randutilGenerateGeometricEx(gen_, out, count, p));
}
void generate_wald_64(uint64_t count, double* out, double mean, double scale)
{
CHECK_CURAND(::randutilGenerateWaldDoubleEx(gen_, out, count, mean, scale));
CHECK_RND_ENGINE(::randutilGenerateWaldDoubleEx(gen_, out, count, mean, scale));
}
void generate_wald_32(uint64_t count, float* out, float mean, float scale)
{
CHECK_CURAND(::randutilGenerateWaldEx(gen_, out, count, mean, scale));
CHECK_RND_ENGINE(::randutilGenerateWaldEx(gen_, out, count, mean, scale));
}
void generate_binomial(uint64_t count, uint32_t* out, uint32_t ntrials, double p)
{
CHECK_CURAND(::randutilGenerateBinomialEx(gen_, out, count, ntrials, p));
CHECK_RND_ENGINE(::randutilGenerateBinomialEx(gen_, out, count, ntrials, p));
}
void generate_negative_binomial(uint64_t count, uint32_t* out, uint32_t ntrials, double p)
{
CHECK_CURAND(::randutilGenerateNegativeBinomialEx(gen_, out, count, ntrials, p));
CHECK_RND_ENGINE(::randutilGenerateNegativeBinomialEx(gen_, out, count, ntrials, p));
}
};

Expand Down
Loading
Loading