diff --git a/setup.cfg b/setup.cfg index 4d02276c6..ef448801f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -128,8 +128,7 @@ jax_cpu = # JAX GPU jax_gpu = %(jax_core_deps)s - jax[cuda]==0.4.10 - jaxlib==0.4.10+cuda11.cudnn86 + jax[cuda12_local]==0.4.10 # PyTorch CPU pytorch_cpu =