-
Notifications
You must be signed in to change notification settings - Fork 230
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* music gen init * musicgen mock fn * init class * llm -> tta todo * mv musicgen experiment to audio_experiments dir * MusicGen experiment * fix init * protect from librosa import
- Loading branch information
1 parent
5a80732
commit 7f47fd8
Showing
9 changed files
with
406 additions
and
1 deletion.
There are no files selected for viewing
182 changes: 182 additions & 0 deletions
182
examples/notebooks/audio_experiments/MusicGenExperiment.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/hashem/.local/lib/python3.10/site-packages/torch/cuda/__init__.py:546: UserWarning: Can't initialize NVML\n", | ||
" warnings.warn(\"Can't initialize NVML\")\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from prompttools.experiment import MusicGenExperiment\n", | ||
"from prompttools.utils.similarity import cos_similarity" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"compare_audio_paths = [\n", | ||
" \"sample_audio_files/80s_billy_joel.wav\",\n", | ||
" \"sample_audio_files/80s_billy_joel.wav\",\n", | ||
"]\n", | ||
"\n", | ||
"experiment = MusicGenExperiment(\n", | ||
" repo_id=[\"facebook/musicgen-small\"],\n", | ||
" prompt=[\"80s Rock n Roll\", \"90s R&B\"],\n", | ||
" duration=[5],\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'repo_id': ['facebook/musicgen-small'],\n", | ||
" 'duration': [5],\n", | ||
" 'prompt': ['80s Rock n Roll', '90s R&B']}" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"experiment.all_args" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.01226953137665987 maximum scale: 2.283313274383545\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 1.1718750101863407e-05 maximum scale: 1.1627463102340698\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 1.5625000742147677e-05 maximum scale: 1.009731411933899\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.009933593682944775 maximum scale: 1.5949103832244873\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.0016132812015712261 maximum scale: 1.474196434020996\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.006332031451165676 maximum scale: 1.7936652898788452\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.00017968750034924597 maximum scale: 1.1629440784454346\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.0023125000298023224 maximum scale: 1.9037144184112549\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.05552734434604645 maximum scale: 2.8524105548858643\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.0077851563692092896 maximum scale: 1.6531202793121338\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.005824218969792128 maximum scale: 1.2873204946517944\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.0003867187479045242 maximum scale: 1.2601758241653442\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.000714843743480742 maximum scale: 1.5760105848312378\n", | ||
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.006160156335681677 maximum scale: 1.741112470626831\n", | ||
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.009101562201976776 maximum scale: 2.25307035446167\n", | ||
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.0023593748919665813 maximum scale: 1.4188467264175415\n", | ||
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.004226562567055225 maximum scale: 1.790489912033081\n", | ||
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.012875000014901161 maximum scale: 2.996934652328491\n", | ||
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.0063593750819563866 maximum scale: 1.5094847679138184\n", | ||
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.0074609373696148396 maximum scale: 2.400330066680908\n", | ||
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.0032187500037252903 maximum scale: 1.4209964275360107\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"experiment.run()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<table border=\"1\" class=\"dataframe\">\n", | ||
" <thead>\n", | ||
" <tr style=\"text-align: right;\">\n", | ||
" <th></th>\n", | ||
" <th>prompt</th>\n", | ||
" <th>response</th>\n", | ||
" <th>latency</th>\n", | ||
" <th>cos_similarity</th>\n", | ||
" </tr>\n", | ||
" </thead>\n", | ||
" <tbody>\n", | ||
" <tr>\n", | ||
" <th>0</th>\n", | ||
" <td>80s Rock n Roll</td>\n", | ||
" <td>audio file generated</td>\n", | ||
" <td>1.795397</td>\n", | ||
" <td>0.653711</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>1</th>\n", | ||
" <td>90s R&B</td>\n", | ||
" <td>audio file generated</td>\n", | ||
" <td>0.029558</td>\n", | ||
" <td>0.698776</td>\n", | ||
" </tr>\n", | ||
" </tbody>\n", | ||
"</table>" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"experiment.evaluate(\n", | ||
" \"cos_similarity\",\n", | ||
" cos_similarity,\n", | ||
" expected=compare_audio_paths,\n", | ||
" audio_experiment=True,\n", | ||
")\n", | ||
"\n", | ||
"experiment.visualize()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3.10.12 64-bit", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
}, | ||
"orig_nbformat": 4, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Binary file added
BIN
+500 KB
examples/notebooks/audio_experiments/sample_audio_files/80s_billy_joel.wav
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
160 changes: 160 additions & 0 deletions
160
prompttools/experiment/experiments/musicgen_experiment.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright (c) Hegel AI, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code's license can be found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
from typing import Any, Dict, List, Union | ||
import itertools | ||
|
||
from time import perf_counter | ||
import logging | ||
|
||
try: | ||
import librosa | ||
except ImportError: | ||
librosa = None | ||
|
||
try: | ||
from audiocraft.models import MusicGen | ||
music_gen = MusicGen.get_pretrained | ||
from audiocraft.data.audio import audio_write | ||
except ImportError: | ||
music_gen = None | ||
|
||
from prompttools.selector.prompt_selector import PromptSelector | ||
from prompttools.mock.mock import mock_music_gen_completion_fn | ||
|
||
from .experiment import Experiment | ||
from .error import PromptExperimentException | ||
|
||
|
||
class MusicGenExperiment(Experiment): | ||
r""" | ||
Experiment for MusicGen's API. | ||
It accepts lists for each argument passed into MusicGen's API, | ||
then creates a cartesian product of those arguments, and gets results for each. | ||
Note: | ||
- All arguments here should be a ``list``, even if you want to keep the argument frozen | ||
(i.e. ``temperature=[1.0]``), because the experiment will try all possible combination | ||
of the input arguments. For example, ``kwargs`` should have string keys, | ||
with ``list``s being the values. | ||
Args: | ||
repo_id (List[str]): IDs of repository (e.g. [`facebook/musicgen-small`]). | ||
prompt (List[str] | List[PromptSelector]): list of prompts to test | ||
task (List[str]): List of tasks in strings. Determines whether to force a task instead of using task | ||
specified in the repository. | ||
**kwargs (Dict[str, list[object]]): Keyword parameters used in the call to ``MusicGen``. | ||
The values should be ``list``s. | ||
""" | ||
|
||
MODEL_PARAMETERS = ["repo_id", "task"] | ||
|
||
CALL_PARAMETERS = ["prompt"] | ||
|
||
def __init__( | ||
self, | ||
repo_id: List[str], | ||
prompt: Union[List[str], List[PromptSelector]], | ||
duration: List[int] = [5], | ||
**kwargs: Dict[str, list[object]], | ||
): | ||
if music_gen is None: | ||
raise ModuleNotFoundError( | ||
"Package `audiocraft` is required to be installed to use this experiment." | ||
"Please use `pip install audiocraft` to install the package" | ||
) | ||
if librosa is None: | ||
raise ModuleNotFoundError( | ||
"Package `librosa` is required to be installed to use this experiment." | ||
"Please use `pip install librosa` to install the package" | ||
) | ||
if "generated_audio_files" not in os.listdir(): | ||
os.mkdir("generated_audio_files") | ||
self.duration = duration | ||
self.completion_fn = self.music_gen_completion_fn | ||
if os.getenv("DEBUG", default=False): | ||
self.completion_fn = mock_music_gen_completion_fn | ||
self.model_params = dict(repo_id=repo_id, duration=self.duration) | ||
|
||
# If we are using a prompt selector, we need to render | ||
# messages, as well as create prompt_keys to map the messages | ||
# to corresponding prompts in other models. | ||
if isinstance(prompt[0], PromptSelector): | ||
self.prompt_keys = {selector.for_music_gen(): selector.for_music_gen() for selector in prompt} | ||
prompt = [selector.for_music_gen() for selector in prompt] | ||
else: | ||
self.prompt_keys = prompt | ||
|
||
self.call_params = dict(prompt=prompt) | ||
for k, v in kwargs.items(): | ||
self.CALL_PARAMETERS.append(k) | ||
self.call_params[k] = v | ||
|
||
self.all_args = self.model_params | self.call_params | ||
super().__init__() | ||
|
||
def prepare(self) -> None: | ||
r""" | ||
Creates argument combinations by taking the cartesian product of all inputs. | ||
""" | ||
self.model_argument_combos = [ | ||
dict(zip(self.model_params, val)) for val in itertools.product(*self.model_params.values()) | ||
] | ||
self.call_argument_combos = [ | ||
dict(zip(self.call_params, val)) for val in itertools.product(*self.call_params.values()) | ||
] | ||
|
||
def music_gen_completion_fn( | ||
self, | ||
**params: Dict[str, Any], | ||
): | ||
r""" | ||
Local model helper function to make request | ||
""" | ||
signal, sr = librosa.load(f'generated_audio_files/{params["prompt"]}.wav') | ||
# Extract relevant features, for example, Mel-frequency cepstral coefficients (MFCCs) | ||
mfccs = librosa.feature.mfcc(y=signal, sr=sr) | ||
return mfccs.flatten() | ||
|
||
def run( | ||
self, | ||
runs: int = 1, | ||
) -> None: | ||
r""" | ||
Create tuples of input and output for every possible combination of arguments. | ||
For each combination, it will execute `runs` times, default to 1. | ||
# TODO This can be done with an async queue | ||
""" | ||
if not self.argument_combos: | ||
logging.info("Preparing first...") | ||
self.prepare() | ||
results = [] | ||
latencies = [] | ||
for model_combo in self.model_argument_combos: | ||
client = music_gen( | ||
name=model_combo["repo_id"], | ||
) | ||
client.set_generation_params(duration=8) | ||
for call_combo in self.call_argument_combos: | ||
wav = client.generate(call_combo["prompt"]) | ||
for _, one_wav in enumerate(wav): | ||
audio_write(f'generated_audio_files/{call_combo["prompt"]}', one_wav.cpu(), client.sample_rate, strategy="loudness") | ||
for _ in range(runs): | ||
call_combo["client"] = client | ||
start = perf_counter() | ||
res = self.completion_fn(**call_combo) | ||
latencies.append(perf_counter() - start) | ||
results.append(res) | ||
self.argument_combos.append(model_combo | call_combo) | ||
if len(results) == 0: | ||
logging.error("No results. Something went wrong.") | ||
raise PromptExperimentException | ||
self._construct_result_dfs(self.argument_combos, results, latencies, extract_response_equal_full_result=True) | ||
|
||
@staticmethod | ||
def _extract_responses(output: List[Dict[str, object]]) -> List[float]: | ||
return output |
Oops, something went wrong.