Skip to content

Commit

Permalink
added text + label generation
Browse files Browse the repository at this point in the history
  • Loading branch information
ameen-91 committed Nov 16, 2024
1 parent ae538c5 commit 4b7b737
Show file tree
Hide file tree
Showing 10 changed files with 285 additions and 79 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ sad-utils-env/
notebooks/

*.ipynb

*.csv
main.py
test.py

*.prof
.ruff_cache/
Empty file removed mic_toolkit/dpo/__init__.py
Empty file.
2 changes: 0 additions & 2 deletions mic_toolkit/dpo/dpo_train.py

This file was deleted.

168 changes: 102 additions & 66 deletions mic_toolkit/synthetic/generation.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,120 @@
from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration
from datasets import Dataset
from ollama import Client
import pandas as pd
from tqdm import tqdm

tqdm.pandas()

class TextGenerationPipeline:
"""
Simple Text Generation Pipeline.
"""

def __init__(self, model_name: str, base_url: str, api_key: str):
"""Setup pipeline with LLM paramters.
class Generator:
"""Generator class for synthetic data generation."""

def __init__(
self,
endpoint: str,
model: str,
):
"""Initializes the LLM Client and model.
Args:
model_name (str): Name of the model.
base_url (str): URL endpoint for the model.
api_key (str): API key.
endpoint (str): Endpoint for the LLM API. For Ollama it is usually "http://localhost:11434".
model (str): Name of the model to use for generation. Find it using 'ollama list'.
"""
self.model_name = model_name
self.base_url = base_url
self.api_key = api_key
self.pipeline = self.create_pipeline()
self.client = Client(endpoint)
self.model = model

def generate_text(
self,
data: pd.DataFrame,
system_prompt: str = "You are a helpful AI assistant. Please provide a response to the following user query:",
max_tokens: int = None,
) -> pd.DataFrame:
"""_summary_
def create_pipeline(self) -> Pipeline:
"""Create the text generation pipeline.
Args:
data (pd.DataFrame): Dataframe with a single column of text data.
system_prompt (_type_, optional): optional System prompt. Defaults to "You are a helpful AI assistant. Please provide a response to the following user query:".
max_tokens (int, optional): max output tokens. Defaults to None.
Returns:
Pipeline: Text Generation Pipeline.
pd.DataFrame: Output dataframe with generated text.
"""
with Pipeline(
name="simple-text-generation-pipeline",
description="A simple text generation pipeline",
) as pipeline:
TextGeneration(
name="text_generation",
llm=OpenAILLM(
model=self.model_name,
base_url=self.base_url,
api_key=self.api_key,
),
)

return pipeline

def run_pipeline(
self, dataset: Dataset, temperature: float = 0.7, max_new_tokens: int = 512
) -> Dataset:
"""
Executes the text generation pipeline on the input dataset.

def generate_response(text):
options = {}
if max_tokens is not None:
options["num_predict"] = max_tokens
return self.client.chat(
model=self.model,
messages=[
{"system": system_prompt},
{"role": "user", "content": text},
],
options=options,
)["message"]["content"]

data["output"] = data[data.columns[0]].progress_apply(generate_response)

return data

def create_system_prompt(self, labels: list[str], query: str = "") -> str:
labels_str = ", ".join(labels)
if query:
return f"Classify the following text into one of the following categories: {labels_str} based on {query}. Just answer with the label. Absolutely no context is needed."
else:
return f"Classify the following text into one of the following categories: {labels_str}. Just answer with the label. Absolutely no context is needed."

def generate_labels(
self,
labels: list[str],
data: pd.DataFrame,
query: str = "",
max_tokens: int = None,
max_tries: int = 5,
) -> pd.DataFrame:
"""_summary_
Args:
dataset: The input dataset to process.
temperature: The temperature for text generation.
max_new_tokens: Maximum number of tokens to generate.
labels (list[str]): List of labels to classify the data into.
data (pd.DataFrame): Dataframe with a single column of text data.
query (str, optional): Classification query. Defaults to "".
max_tokens (int, optional): max output tokens. Defaults to None.
max_tries (int, optional): max tries to get the correct label. Defaults to 5.
Returns:
Dataset with generated text.
pd.DataFrame: _description_
"""
try:
distiset = self.pipeline.run(
dataset=dataset,
parameters={
"text_generation": {
"llm": {
"generation_kwargs": {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
}
}
},
},
)
return distiset

except Exception as e:
raise e


def sqr(x: int) -> int:
return x**2
system_prompt = self.create_system_prompt(labels, query)

def classify_text(text):
options = {}
if max_tokens is not None:
options["num_predict"] = max_tokens
response = self.client.chat(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
)["message"]["content"]
tries = max_tries
while response not in labels and tries > 0:
response = self.client.chat(
model=self.model,
messages=[
{
"role": "system",
"content": "You did not respond with just the label please respond again with the label only. Without any context or explanation"
+ system_prompt,
},
{"role": "user", "content": text},
],
options=options,
)["message"]["content"]
tries -= 1
return response

data["label"] = data[data.columns[0]].progress_apply(classify_text)
return data


if __name__ == "__main__":
Expand Down
14 changes: 14 additions & 0 deletions mic_toolkit/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import cProfile


def profiler(func):
def wrapper(*args, **kwargs):
profile = cProfile.Profile()
profile.enable()
result = func(*args, **kwargs)
profile.disable()
profile.print_stats(sort="cumtime")
profile.dump_stats(f"{func.__name__}.prof")
return result

return wrapper
125 changes: 121 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 4b7b737

Please sign in to comment.