diff --git a/setup.py b/setup.py index a9f2c42c..20f1dc7b 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,9 @@ "scikit-learn>=1.0.2", "scipy>=1.8.0", ], + extras_require={ + "cuda12": ["jax[cuda12]>=0.4.16"], + }, dependency_links=[ "https://storage.googleapis.com/jax-releases/jax_releases.html", ],