Skip to content

Commit

Permalink
MusicGen experiment #82 (#106)
Browse files Browse the repository at this point in the history
* music gen init

* musicgen mock fn

* init class

* llm -> tta todo

* mv musicgen experiment to audio_experiments dir

* MusicGen experiment

* fix init

* protect from librosa import
  • Loading branch information
HashemAlsaket authored Mar 14, 2024
1 parent 5a80732 commit 7f47fd8
Show file tree
Hide file tree
Showing 9 changed files with 406 additions and 1 deletion.
182 changes: 182 additions & 0 deletions examples/notebooks/audio_experiments/MusicGenExperiment.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hashem/.local/lib/python3.10/site-packages/torch/cuda/__init__.py:546: UserWarning: Can't initialize NVML\n",
" warnings.warn(\"Can't initialize NVML\")\n"
]
}
],
"source": [
"from prompttools.experiment import MusicGenExperiment\n",
"from prompttools.utils.similarity import cos_similarity"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"compare_audio_paths = [\n",
" \"sample_audio_files/80s_billy_joel.wav\",\n",
" \"sample_audio_files/80s_billy_joel.wav\",\n",
"]\n",
"\n",
"experiment = MusicGenExperiment(\n",
" repo_id=[\"facebook/musicgen-small\"],\n",
" prompt=[\"80s Rock n Roll\", \"90s R&B\"],\n",
" duration=[5],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'repo_id': ['facebook/musicgen-small'],\n",
" 'duration': [5],\n",
" 'prompt': ['80s Rock n Roll', '90s R&B']}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"experiment.all_args"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.01226953137665987 maximum scale: 2.283313274383545\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 1.1718750101863407e-05 maximum scale: 1.1627463102340698\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 1.5625000742147677e-05 maximum scale: 1.009731411933899\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.009933593682944775 maximum scale: 1.5949103832244873\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.0016132812015712261 maximum scale: 1.474196434020996\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.006332031451165676 maximum scale: 1.7936652898788452\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.00017968750034924597 maximum scale: 1.1629440784454346\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.0023125000298023224 maximum scale: 1.9037144184112549\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.05552734434604645 maximum scale: 2.8524105548858643\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.0077851563692092896 maximum scale: 1.6531202793121338\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.005824218969792128 maximum scale: 1.2873204946517944\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.0003867187479045242 maximum scale: 1.2601758241653442\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.000714843743480742 maximum scale: 1.5760105848312378\n",
"CLIPPING generated_audio_files/80s Rock n Roll happening with proba (a bit of clipping is okay): 0.006160156335681677 maximum scale: 1.741112470626831\n",
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.009101562201976776 maximum scale: 2.25307035446167\n",
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.0023593748919665813 maximum scale: 1.4188467264175415\n",
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.004226562567055225 maximum scale: 1.790489912033081\n",
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.012875000014901161 maximum scale: 2.996934652328491\n",
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.0063593750819563866 maximum scale: 1.5094847679138184\n",
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.0074609373696148396 maximum scale: 2.400330066680908\n",
"CLIPPING generated_audio_files/90s R&B happening with proba (a bit of clipping is okay): 0.0032187500037252903 maximum scale: 1.4209964275360107\n"
]
}
],
"source": [
"experiment.run()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>prompt</th>\n",
" <th>response</th>\n",
" <th>latency</th>\n",
" <th>cos_similarity</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>80s Rock n Roll</td>\n",
" <td>audio file generated</td>\n",
" <td>1.795397</td>\n",
" <td>0.653711</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>90s R&B</td>\n",
" <td>audio file generated</td>\n",
" <td>0.029558</td>\n",
" <td>0.698776</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"experiment.evaluate(\n",
" \"cos_similarity\",\n",
" cos_similarity,\n",
" expected=compare_audio_paths,\n",
" audio_experiment=True,\n",
")\n",
"\n",
"experiment.visualize()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.12 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Binary file not shown.
2 changes: 2 additions & 0 deletions prompttools/experiment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .experiments.replicate_experiment import ReplicateExperiment
from .experiments.qdrant_experiment import QdrantExperiment
from .experiments.pinecone_experiment import PineconeExperiment
from .experiments.musicgen_experiment import MusicGenExperiment

__all__ = [
"AnthropicCompletionExperiment",
Expand All @@ -37,6 +38,7 @@
"HuggingFaceHubExperiment",
"MistralChatCompletionExperiment",
"MindsDBExperiment",
"MusicGenExperiment",
"OpenAIChatExperiment",
"OpenAICompletionExperiment",
"PineconeExperiment",
Expand Down
6 changes: 6 additions & 0 deletions prompttools/experiment/experiments/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self):
self.argument_combos: list[dict] = []
self.full_df, self.partial_df, self.score_df = None, None, None
self.image_experiment = False
self.audio_experiment = False
self._experiment_id = None
self._revision_id = None
try:
Expand Down Expand Up @@ -314,6 +315,9 @@ def visualize(self, get_all_cols: bool = False, pivot: bool = False, pivot_colum
table["response"] = table["response"].map(lambda x: self.cv2_image_to_base64(x))
table["response"] = table["response"].apply(self.display_image_html)
display.display(display.HTML(table.to_html(escape=False)))
elif is_interactive() and self.audio_experiment:
table["response"] = "audio file generated"
display.display(display.HTML(table.to_html(escape=False)))
elif is_interactive():
display.display(table)
else:
Expand All @@ -338,6 +342,7 @@ def evaluate(
eval_fn: Callable,
static_eval_fn_kwargs: dict = {},
image_experiment: bool = False,
audio_experiment: bool = False,
**eval_fn_kwargs,
) -> None:
"""
Expand All @@ -360,6 +365,7 @@ def evaluate(
>>> static_eval_fn_kwargs={"response_column_name": "response"})
"""
self.image_experiment = image_experiment
self.audio_experiment = audio_experiment
if metric_name in self.score_df.columns:
logging.warning(metric_name + " is already present, skipping.")
return
Expand Down
160 changes: 160 additions & 0 deletions prompttools/experiment/experiments/musicgen_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) Hegel AI, Inc.
# All rights reserved.
#
# This source code's license can be found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Any, Dict, List, Union
import itertools

from time import perf_counter
import logging

try:
import librosa
except ImportError:
librosa = None

try:
from audiocraft.models import MusicGen
music_gen = MusicGen.get_pretrained
from audiocraft.data.audio import audio_write
except ImportError:
music_gen = None

from prompttools.selector.prompt_selector import PromptSelector
from prompttools.mock.mock import mock_music_gen_completion_fn

from .experiment import Experiment
from .error import PromptExperimentException


class MusicGenExperiment(Experiment):
r"""
Experiment for MusicGen's API.
It accepts lists for each argument passed into MusicGen's API,
then creates a cartesian product of those arguments, and gets results for each.
Note:
- All arguments here should be a ``list``, even if you want to keep the argument frozen
(i.e. ``temperature=[1.0]``), because the experiment will try all possible combination
of the input arguments. For example, ``kwargs`` should have string keys,
with ``list``s being the values.
Args:
repo_id (List[str]): IDs of repository (e.g. [`facebook/musicgen-small`]).
prompt (List[str] | List[PromptSelector]): list of prompts to test
task (List[str]): List of tasks in strings. Determines whether to force a task instead of using task
specified in the repository.
**kwargs (Dict[str, list[object]]): Keyword parameters used in the call to ``MusicGen``.
The values should be ``list``s.
"""

MODEL_PARAMETERS = ["repo_id", "task"]

CALL_PARAMETERS = ["prompt"]

def __init__(
self,
repo_id: List[str],
prompt: Union[List[str], List[PromptSelector]],
duration: List[int] = [5],
**kwargs: Dict[str, list[object]],
):
if music_gen is None:
raise ModuleNotFoundError(
"Package `audiocraft` is required to be installed to use this experiment."
"Please use `pip install audiocraft` to install the package"
)
if librosa is None:
raise ModuleNotFoundError(
"Package `librosa` is required to be installed to use this experiment."
"Please use `pip install librosa` to install the package"
)
if "generated_audio_files" not in os.listdir():
os.mkdir("generated_audio_files")
self.duration = duration
self.completion_fn = self.music_gen_completion_fn
if os.getenv("DEBUG", default=False):
self.completion_fn = mock_music_gen_completion_fn
self.model_params = dict(repo_id=repo_id, duration=self.duration)

# If we are using a prompt selector, we need to render
# messages, as well as create prompt_keys to map the messages
# to corresponding prompts in other models.
if isinstance(prompt[0], PromptSelector):
self.prompt_keys = {selector.for_music_gen(): selector.for_music_gen() for selector in prompt}
prompt = [selector.for_music_gen() for selector in prompt]
else:
self.prompt_keys = prompt

self.call_params = dict(prompt=prompt)
for k, v in kwargs.items():
self.CALL_PARAMETERS.append(k)
self.call_params[k] = v

self.all_args = self.model_params | self.call_params
super().__init__()

def prepare(self) -> None:
r"""
Creates argument combinations by taking the cartesian product of all inputs.
"""
self.model_argument_combos = [
dict(zip(self.model_params, val)) for val in itertools.product(*self.model_params.values())
]
self.call_argument_combos = [
dict(zip(self.call_params, val)) for val in itertools.product(*self.call_params.values())
]

def music_gen_completion_fn(
self,
**params: Dict[str, Any],
):
r"""
Local model helper function to make request
"""
signal, sr = librosa.load(f'generated_audio_files/{params["prompt"]}.wav')
# Extract relevant features, for example, Mel-frequency cepstral coefficients (MFCCs)
mfccs = librosa.feature.mfcc(y=signal, sr=sr)
return mfccs.flatten()

def run(
self,
runs: int = 1,
) -> None:
r"""
Create tuples of input and output for every possible combination of arguments.
For each combination, it will execute `runs` times, default to 1.
# TODO This can be done with an async queue
"""
if not self.argument_combos:
logging.info("Preparing first...")
self.prepare()
results = []
latencies = []
for model_combo in self.model_argument_combos:
client = music_gen(
name=model_combo["repo_id"],
)
client.set_generation_params(duration=8)
for call_combo in self.call_argument_combos:
wav = client.generate(call_combo["prompt"])
for _, one_wav in enumerate(wav):
audio_write(f'generated_audio_files/{call_combo["prompt"]}', one_wav.cpu(), client.sample_rate, strategy="loudness")
for _ in range(runs):
call_combo["client"] = client
start = perf_counter()
res = self.completion_fn(**call_combo)
latencies.append(perf_counter() - start)
results.append(res)
self.argument_combos.append(model_combo | call_combo)
if len(results) == 0:
logging.error("No results. Something went wrong.")
raise PromptExperimentException
self._construct_result_dfs(self.argument_combos, results, latencies, extract_response_equal_full_result=True)

@staticmethod
def _extract_responses(output: List[Dict[str, object]]) -> List[float]:
return output
Loading

0 comments on commit 7f47fd8

Please sign in to comment.