diff --git a/mt3/colab/music_transcription_with_transformers.ipynb b/mt3/colab/music_transcription_with_transformers.ipynb index a13f4d9..7357b06 100644 --- a/mt3/colab/music_transcription_with_transformers.ipynb +++ b/mt3/colab/music_transcription_with_transformers.ipynb @@ -67,7 +67,7 @@ "# install mt3\n", "!git clone --branch=main https://github.com/magenta/mt3\n", "!mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp\n", - "!python3 -m pip install jax[cuda11_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", + "!python3 -m pip install jax[cuda12_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", "\n", "# copy checkpoints\n", "!gsutil -q -m cp -r gs://mt3/checkpoints .\n",