Skip to content

Commit

Permalink
Merge pull request #1 from micvitc/dev/main
Browse files Browse the repository at this point in the history
added text + label generation
  • Loading branch information
ameen-91 authored Nov 16, 2024
2 parents ae538c5 + 437db9a commit 606f2a0
Show file tree
Hide file tree
Showing 13 changed files with 605 additions and 106 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ sad-utils-env/
notebooks/

*.ipynb

*.csv
main.py
test.py

*.prof
.ruff_cache/
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Mic Toolkit

![License](https://img.shields.io/badge/license-MIT-blue.svg)
![PyPI - Version](https://img.shields.io/pypi/v/mic-toolkit)
![Python](https://img.shields.io/badge/python-3.12-blue.svg)
![Poetry](https://img.shields.io/badge/poetry-1.1.0-blue.svg)
![PyPI - Downloads](https://img.shields.io/pypi/dm/mic-toolkit)
[![CI](https://github.com/micvitc/mic-toolkit/actions/workflows/ci.yaml/badge.svg)](https://github.com/micvitc/mic-toolkit/actions/workflows/ci.yaml)


Expand Down
106 changes: 86 additions & 20 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,40 +1,106 @@
# Welcome to the mic-toolkit
# Mic Toolkit

Utilities for internal MIC projects

## Installation
![PyPI - Version](https://img.shields.io/pypi/v/mic-toolkit)
![Python](https://img.shields.io/badge/python-3.12-blue.svg)
![PyPI - Downloads](https://img.shields.io/pypi/dm/mic-toolkit)
[![CI](https://github.com/micvitc/mic-toolkit/actions/workflows/ci.yaml/badge.svg)](https://github.com/micvitc/mic-toolkit/actions/workflows/ci.yaml)

Make sure you have a python version>=3.12

``` bash

pip install mic-toolkit
## Overview

Simple synthetic data generation using LLMs.

## Features

- **Text Generation**
- **Label Generation**

## Installation

To install Mic Toolkit, use the following command:

```sh
pip install mic-toolkit
```

## Current Utilities

### Simple Data Generation Pipeline
### Simple Text Generation

``` py title="text-gen-sample.py"


import pandas as pd
from mic_toolkit.synthetic.generation import Generator


generator = Generator(endpoint="http://localhost:11434", model="llama3.2:3b-instruct-q4_0")

``` py title="sample.py"
data = pd.read_csv("data.csv")

from mic_toolkit.synthetic.generation import TextGenerationPipeline
from datasets import Dataset
output = generator.generate_text(data, system_prompt="Translate the following text to French")

dataset = Dataset.from_dict(
{"instruction": ["Write a Python program to multiply two numbers."]}
)
print(output)

```
Output

```
text
0 He heard the crack echo in the late afternoon ...
1 There wasn't a bird in the sky, but that was n...
2 The choice was red, green, or blue. It didn't ...
3 What was beyond the bend in the stream was unk...
4 I guess we could discuss the implications of t...
5 There were about twenty people on the dam. Mos...
output
0 Il entendit le bruit de craquement échoer à la...
1 Il n'y avait pas d'oiseau dans le ciel, mais c...
2 La choix était rouge, vert ou bleu. Il n'a pas...
3 Ce qui se trouvait au-delà de la courbe du rui...
4 Quelles implications auraient la phrase "c'est...
5 Il y avait environ vingt personnes sur la barr...
```

### Label Generation

pipeline = TextGenerationPipeline(
model_name="model_name",
api_key="api_key",
base_url="base_url",
)

``` py title="text-gen-sample.py"

distiset = pipeline.run_pipeline(dataset=dataset)

print(distiset["default"]["train"][0]["generation"])
import pandas as pd
from mic_toolkit.synthetic.generation import Generator


generator = Generator(endpoint="http://localhost:11434", model="llama3.2:3b-instruct-q4_0")

data = pd.read_csv("data.csv")

output = generator.generate_labels(data=data, labels=["Europe", "Asia", "Africa", "America", "Oceania"], query="Which continent does the following country belong to?")

print(output)

```

Output

```
country label
0 France Europe
1 Argentina America
2 United States America
3 Canada America
4 Mexico America
5 Brazil America
6 United Kingdom Europe
7 Germany Europe
8 Italy Europe
9 Spain Europe
10 Australia Oceania
11 Japan Asia
```
Empty file removed mic_toolkit/dpo/__init__.py
Empty file.
2 changes: 0 additions & 2 deletions mic_toolkit/dpo/dpo_train.py

This file was deleted.

168 changes: 102 additions & 66 deletions mic_toolkit/synthetic/generation.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,120 @@
from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration
from datasets import Dataset
from ollama import Client
import pandas as pd
from tqdm import tqdm

tqdm.pandas()

class TextGenerationPipeline:
"""
Simple Text Generation Pipeline.
"""

def __init__(self, model_name: str, base_url: str, api_key: str):
"""Setup pipeline with LLM paramters.
class Generator:
"""Generator class for synthetic data generation."""

def __init__(
self,
endpoint: str,
model: str,
):
"""Initializes the LLM Client and model.
Args:
model_name (str): Name of the model.
base_url (str): URL endpoint for the model.
api_key (str): API key.
endpoint (str): Endpoint for the LLM API. For Ollama it is usually "http://localhost:11434".
model (str): Name of the model to use for generation. Find it using 'ollama list'.
"""
self.model_name = model_name
self.base_url = base_url
self.api_key = api_key
self.pipeline = self.create_pipeline()
self.client = Client(endpoint)
self.model = model

def generate_text(
self,
data: pd.DataFrame,
system_prompt: str = "You are a helpful AI assistant. Please provide a response to the following user query:",
max_tokens: int = None,
) -> pd.DataFrame:
"""_summary_
def create_pipeline(self) -> Pipeline:
"""Create the text generation pipeline.
Args:
data (pd.DataFrame): Dataframe with a single column of text data.
system_prompt (_type_, optional): optional System prompt. Defaults to "You are a helpful AI assistant. Please provide a response to the following user query:".
max_tokens (int, optional): max output tokens. Defaults to None.
Returns:
Pipeline: Text Generation Pipeline.
pd.DataFrame: Output dataframe with generated text.
"""
with Pipeline(
name="simple-text-generation-pipeline",
description="A simple text generation pipeline",
) as pipeline:
TextGeneration(
name="text_generation",
llm=OpenAILLM(
model=self.model_name,
base_url=self.base_url,
api_key=self.api_key,
),
)

return pipeline

def run_pipeline(
self, dataset: Dataset, temperature: float = 0.7, max_new_tokens: int = 512
) -> Dataset:
"""
Executes the text generation pipeline on the input dataset.

def generate_response(text):
options = {}
if max_tokens is not None:
options["num_predict"] = max_tokens
return self.client.chat(
model=self.model,
messages=[
{"role": "assistant", "content": system_prompt},
{"role": "user", "content": text},
],
options=options,
)["message"]["content"]

data["output"] = data[data.columns[0]].progress_apply(generate_response)

return data

def create_system_prompt(self, labels: list[str], query: str = "") -> str:
labels_str = ", ".join(labels)
if query:
return f"Classify the following text into one of the following categories: {labels_str} based on {query}. Just answer with the label. Absolutely no context is needed."
else:
return f"Classify the following text into one of the following categories: {labels_str}. Just answer with the label. Absolutely no context is needed."

def generate_labels(
self,
labels: list[str],
data: pd.DataFrame,
query: str = "",
max_tokens: int = None,
max_tries: int = 5,
) -> pd.DataFrame:
"""_summary_
Args:
dataset: The input dataset to process.
temperature: The temperature for text generation.
max_new_tokens: Maximum number of tokens to generate.
labels (list[str]): List of labels to classify the data into.
data (pd.DataFrame): Dataframe with a single column of text data.
query (str, optional): Classification query. Defaults to "".
max_tokens (int, optional): max output tokens. Defaults to None.
max_tries (int, optional): max tries to get the correct label. Defaults to 5.
Returns:
Dataset with generated text.
pd.DataFrame: _description_
"""
try:
distiset = self.pipeline.run(
dataset=dataset,
parameters={
"text_generation": {
"llm": {
"generation_kwargs": {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
}
}
},
},
)
return distiset

except Exception as e:
raise e


def sqr(x: int) -> int:
return x**2
system_prompt = self.create_system_prompt(labels, query)

def classify_text(text):
options = {}
if max_tokens is not None:
options["num_predict"] = max_tokens
response = self.client.chat(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
)["message"]["content"]
tries = max_tries
while response not in labels and tries > 0:
response = self.client.chat(
model=self.model,
messages=[
{
"role": "system",
"content": "You did not respond with just the label please respond again with the label only. Without any context or explanation"
+ system_prompt,
},
{"role": "user", "content": text},
],
options=options,
)["message"]["content"]
tries -= 1
return response

data["label"] = data[data.columns[0]].progress_apply(classify_text)
return data


if __name__ == "__main__":
Expand Down
14 changes: 14 additions & 0 deletions mic_toolkit/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import cProfile


def profiler(func):
def wrapper(*args, **kwargs):
profile = cProfile.Profile()
profile.enable()
result = func(*args, **kwargs)
profile.disable()
profile.print_stats(sort="cumtime")
profile.dump_stats(f"{func.__name__}.prof")
return result

return wrapper
13 changes: 13 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,25 @@ site_url: https://micvitc.github.io/mic-toolkit/
nav:
- Home: index.md
- Synthetic: synthetic.md
repo_url: https://github.com/micvitc/mic-toolkit
repo_name: micvitc/mic-toolkit

theme:
name: material
features:
- content.code.copy
- content.code.select
palette:
primary: black
accent: indigo
version:
provider: pip
package: mic-toolkit
plugins:
- mkdocstrings
- social:
enabled: !ENV [CI, false]
cards: true

markdown_extensions:
- pymdownx.highlight:
Expand All @@ -22,3 +33,5 @@ markdown_extensions:
- pymdownx.superfences




Loading

0 comments on commit 606f2a0

Please sign in to comment.