diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index f1d29b11f..4c5529ae1 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -29,8 +29,8 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: # Only use CPU for Jax to avoid memory issues. # Setting the corresponding environment variable here has no effect; it has to # be done before jax and tensorflow (!) are imported for the first time. - os.environ['JAX_PLATFORMS'] = 'cpu' - jax.config.update('jax_platforms', 'cpu') + # os.environ['JAX_PLATFORMS'] = 'cpu' + jax.config.update('jax_platform_name', 'cpu') # From the docs: "(...) causes cuDNN to benchmark multiple convolution # algorithms and select the fastest." torch.backends.cudnn.benchmark = True