diff --git a/mt3/colab/music_transcription_with_transformers.ipynb b/mt3/colab/music_transcription_with_transformers.ipynb index 7357b06..3df608a 100644 --- a/mt3/colab/music_transcription_with_transformers.ipynb +++ b/mt3/colab/music_transcription_with_transformers.ipynb @@ -67,7 +67,7 @@ "# install mt3\n", "!git clone --branch=main https://github.com/magenta/mt3\n", "!mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp\n", - "!python3 -m pip install jax[cuda12_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", + "!python3 -m pip install jax[cuda12] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", "\n", "# copy checkpoints\n", "!gsutil -q -m cp -r gs://mt3/checkpoints .\n",