From bf1cfd4049357920cc7180463099483fc726179e Mon Sep 17 00:00:00 2001 From: Ian Simon Date: Fri, 20 Sep 2024 16:24:13 -0700 Subject: [PATCH] switch jax version "extra" to cuda12 PiperOrigin-RevId: 677009602 --- mt3/colab/music_transcription_with_transformers.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mt3/colab/music_transcription_with_transformers.ipynb b/mt3/colab/music_transcription_with_transformers.ipynb index 7357b06..3df608a 100644 --- a/mt3/colab/music_transcription_with_transformers.ipynb +++ b/mt3/colab/music_transcription_with_transformers.ipynb @@ -67,7 +67,7 @@ "# install mt3\n", "!git clone --branch=main https://github.com/magenta/mt3\n", "!mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp\n", - "!python3 -m pip install jax[cuda12_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", + "!python3 -m pip install jax[cuda12] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", "\n", "# copy checkpoints\n", "!gsutil -q -m cp -r gs://mt3/checkpoints .\n",