diff --git a/setup.py b/setup.py
index a9f2c42c..20f1dc7b 100644
--- a/setup.py
+++ b/setup.py
@@ -35,6 +35,9 @@
         "scikit-learn>=1.0.2",
         "scipy>=1.8.0",
     ],
+    extras_require={
+        "cuda12": ["jax[cuda12]>=0.4.16"],
+    },
     dependency_links=[
         "https://storage.googleapis.com/jax-releases/jax_releases.html",
     ],