From 4b7b737c7f000e382cbc016b07168485726e2a4a Mon Sep 17 00:00:00 2001 From: ameen-91 Date: Sat, 16 Nov 2024 18:17:40 +0530 Subject: [PATCH] added text + label generation --- .gitignore | 4 +- mic_toolkit/dpo/__init__.py | 0 mic_toolkit/dpo/dpo_train.py | 2 - mic_toolkit/synthetic/generation.py | 168 +++++++++++++++++----------- mic_toolkit/utils.py | 14 +++ poetry.lock | 125 ++++++++++++++++++++- pyproject.toml | 3 + tests/dpo_test.py | 0 tests/synthetic_test.py | 5 - tests/test_generation.py | 43 +++++++ 10 files changed, 285 insertions(+), 79 deletions(-) delete mode 100644 mic_toolkit/dpo/__init__.py delete mode 100644 mic_toolkit/dpo/dpo_train.py create mode 100644 mic_toolkit/utils.py delete mode 100644 tests/dpo_test.py delete mode 100644 tests/synthetic_test.py create mode 100644 tests/test_generation.py diff --git a/.gitignore b/.gitignore index 22cf49c..11f21e8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,8 +165,8 @@ sad-utils-env/ notebooks/ *.ipynb - +*.csv main.py test.py - +*.prof .ruff_cache/ \ No newline at end of file diff --git a/mic_toolkit/dpo/__init__.py b/mic_toolkit/dpo/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mic_toolkit/dpo/dpo_train.py b/mic_toolkit/dpo/dpo_train.py deleted file mode 100644 index ed444e3..0000000 --- a/mic_toolkit/dpo/dpo_train.py +++ /dev/null @@ -1,2 +0,0 @@ -if __name__ == "__main__": - NotImplementedError() diff --git a/mic_toolkit/synthetic/generation.py b/mic_toolkit/synthetic/generation.py index d4c2c2c..e84983f 100644 --- a/mic_toolkit/synthetic/generation.py +++ b/mic_toolkit/synthetic/generation.py @@ -1,84 +1,120 @@ -from distilabel.llms import OpenAILLM -from distilabel.pipeline import Pipeline -from distilabel.steps.tasks import TextGeneration -from datasets import Dataset +from ollama import Client +import pandas as pd +from tqdm import tqdm +tqdm.pandas() -class TextGenerationPipeline: - """ - Simple Text Generation Pipeline. - """ - def __init__(self, model_name: str, base_url: str, api_key: str): - """Setup pipeline with LLM paramters. +class Generator: + """Generator class for synthetic data generation.""" + + def __init__( + self, + endpoint: str, + model: str, + ): + """Initializes the LLM Client and model. Args: - model_name (str): Name of the model. - base_url (str): URL endpoint for the model. - api_key (str): API key. + endpoint (str): Endpoint for the LLM API. For Ollama it is usually "http://localhost:11434". + model (str): Name of the model to use for generation. Find it using 'ollama list'. """ - self.model_name = model_name - self.base_url = base_url - self.api_key = api_key - self.pipeline = self.create_pipeline() + self.client = Client(endpoint) + self.model = model + + def generate_text( + self, + data: pd.DataFrame, + system_prompt: str = "You are a helpful AI assistant. Please provide a response to the following user query:", + max_tokens: int = None, + ) -> pd.DataFrame: + """_summary_ - def create_pipeline(self) -> Pipeline: - """Create the text generation pipeline. + Args: + data (pd.DataFrame): Dataframe with a single column of text data. + system_prompt (_type_, optional): optional System prompt. Defaults to "You are a helpful AI assistant. Please provide a response to the following user query:". + max_tokens (int, optional): max output tokens. Defaults to None. Returns: - Pipeline: Text Generation Pipeline. + pd.DataFrame: Output dataframe with generated text. """ - with Pipeline( - name="simple-text-generation-pipeline", - description="A simple text generation pipeline", - ) as pipeline: - TextGeneration( - name="text_generation", - llm=OpenAILLM( - model=self.model_name, - base_url=self.base_url, - api_key=self.api_key, - ), - ) - - return pipeline - - def run_pipeline( - self, dataset: Dataset, temperature: float = 0.7, max_new_tokens: int = 512 - ) -> Dataset: - """ - Executes the text generation pipeline on the input dataset. + + def generate_response(text): + options = {} + if max_tokens is not None: + options["num_predict"] = max_tokens + return self.client.chat( + model=self.model, + messages=[ + {"system": system_prompt}, + {"role": "user", "content": text}, + ], + options=options, + )["message"]["content"] + + data["output"] = data[data.columns[0]].progress_apply(generate_response) + + return data + + def create_system_prompt(self, labels: list[str], query: str = "") -> str: + labels_str = ", ".join(labels) + if query: + return f"Classify the following text into one of the following categories: {labels_str} based on {query}. Just answer with the label. Absolutely no context is needed." + else: + return f"Classify the following text into one of the following categories: {labels_str}. Just answer with the label. Absolutely no context is needed." + + def generate_labels( + self, + labels: list[str], + data: pd.DataFrame, + query: str = "", + max_tokens: int = None, + max_tries: int = 5, + ) -> pd.DataFrame: + """_summary_ Args: - dataset: The input dataset to process. - temperature: The temperature for text generation. - max_new_tokens: Maximum number of tokens to generate. + labels (list[str]): List of labels to classify the data into. + data (pd.DataFrame): Dataframe with a single column of text data. + query (str, optional): Classification query. Defaults to "". + max_tokens (int, optional): max output tokens. Defaults to None. + max_tries (int, optional): max tries to get the correct label. Defaults to 5. Returns: - Dataset with generated text. + pd.DataFrame: _description_ """ - try: - distiset = self.pipeline.run( - dataset=dataset, - parameters={ - "text_generation": { - "llm": { - "generation_kwargs": { - "temperature": temperature, - "max_new_tokens": max_new_tokens, - } - } - }, - }, - ) - return distiset - - except Exception as e: - raise e - - -def sqr(x: int) -> int: - return x**2 + system_prompt = self.create_system_prompt(labels, query) + + def classify_text(text): + options = {} + if max_tokens is not None: + options["num_predict"] = max_tokens + response = self.client.chat( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": text}, + ], + )["message"]["content"] + tries = max_tries + while response not in labels and tries > 0: + response = self.client.chat( + model=self.model, + messages=[ + { + "role": "system", + "content": "You did not respond with just the label please respond again with the label only. Without any context or explanation" + + system_prompt, + }, + {"role": "user", "content": text}, + ], + options=options, + )["message"]["content"] + tries -= 1 + return response + + data["label"] = data[data.columns[0]].progress_apply(classify_text) + return data if __name__ == "__main__": diff --git a/mic_toolkit/utils.py b/mic_toolkit/utils.py new file mode 100644 index 0000000..a1be851 --- /dev/null +++ b/mic_toolkit/utils.py @@ -0,0 +1,14 @@ +import cProfile + + +def profiler(func): + def wrapper(*args, **kwargs): + profile = cProfile.Profile() + profile.enable() + result = func(*args, **kwargs) + profile.disable() + profile.print_stats(sort="cumtime") + profile.dump_stats(f"{func.__name__}.prof") + return result + + return wrapper diff --git a/poetry.lock b/poetry.lock index d5f74a5..3793d80 100644 --- a/poetry.lock +++ b/poetry.lock @@ -136,6 +136,26 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "anyio" +version = "4.6.2.post1" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.9" +files = [ + {file = "anyio-4.6.2.post1-py3-none-any.whl", hash = "sha256:6d170c36fba3bdd840c73d3868c1e777e33676a69c3a72cf0a0d5d6d8009b61d"}, + {file = "anyio-4.6.2.post1.tar.gz", hash = "sha256:4c8bc31ccdb51c7f7bd251f51c609e038d63e34219b44aa86e47576389880b4c"}, +] + +[package.dependencies] +idna = ">=2.8" +sniffio = ">=1.1" + +[package.extras] +doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] +trio = ["trio (>=0.26.1)"] + [[package]] name = "appnope" version = "0.1.4" @@ -739,6 +759,63 @@ files = [ [package.dependencies] colorama = ">=0.4" +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.7" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd"}, + {file = "httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<1.0)"] + +[[package]] +name = "httpx" +version = "0.27.2" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, + {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] + [[package]] name = "huggingface-hub" version = "0.24.7" @@ -1405,6 +1482,20 @@ files = [ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] +[[package]] +name = "ollama" +version = "0.3.3" +description = "The official Python client for Ollama." +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "ollama-0.3.3-py3-none-any.whl", hash = "sha256:ca6242ce78ab34758082b7392df3f9f6c2cb1d070a9dede1a4c545c929e16dba"}, + {file = "ollama-0.3.3.tar.gz", hash = "sha256:f90a6d61803117f40b0e8ff17465cab5e1eb24758a473cfe8101aff38bc13b51"}, +] + +[package.dependencies] +httpx = ">=0.27.0,<0.28.0" + [[package]] name = "packaging" version = "24.1" @@ -2151,6 +2242,31 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "snakeviz" +version = "2.2.2" +description = "A web-based viewer for Python profiler output" +optional = false +python-versions = ">=3.9" +files = [ + {file = "snakeviz-2.2.2-py3-none-any.whl", hash = "sha256:77e7b9c82f6152edc330040319b97612351cd9b48c706434c535c2df31d10ac5"}, + {file = "snakeviz-2.2.2.tar.gz", hash = "sha256:08028c6f8e34a032ff14757a38424770abb8662fb2818985aeea0d9bc13a7d83"}, +] + +[package.dependencies] +tornado = ">=2.0" + +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + [[package]] name = "stack-data" version = "0.6.3" @@ -2192,13 +2308,13 @@ files = [ [[package]] name = "tqdm" -version = "4.66.5" +version = "4.67.0" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd"}, - {file = "tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad"}, + {file = "tqdm-4.67.0-py3-none-any.whl", hash = "sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be"}, + {file = "tqdm-4.67.0.tar.gz", hash = "sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a"}, ] [package.dependencies] @@ -2206,6 +2322,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [package.extras] dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +discord = ["requests"] notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] @@ -2577,4 +2694,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "2e004ec01ae176d6269c2765e6c50c06a5f03a3132dd58aad699d0883785d835" +content-hash = "8c33af7a2fdeb9f3e65aa8c4fff3569f995f79faf3d4a201ecb8c4a176e98119" diff --git a/pyproject.toml b/pyproject.toml index 1ac2e28..e459e4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,8 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.12" datasets = "^3.0.0" +ollama = "^0.3.3" +tqdm = "^4.67.0" [tool.poetry.group.dev.dependencies] @@ -18,6 +20,7 @@ mkdocs = "^1.6.1" pre-commit = "^3.8.0" ipykernel = "^6.29.5" pytest = "^8.3.3" +snakeviz = "^2.2.2" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/dpo_test.py b/tests/dpo_test.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/synthetic_test.py b/tests/synthetic_test.py deleted file mode 100644 index 3e92131..0000000 --- a/tests/synthetic_test.py +++ /dev/null @@ -1,5 +0,0 @@ -from mic_toolkit.synthetic.generation import sqr - - -def test_sqr(): - assert sqr(5) == 25 diff --git a/tests/test_generation.py b/tests/test_generation.py new file mode 100644 index 0000000..ddda5ad --- /dev/null +++ b/tests/test_generation.py @@ -0,0 +1,43 @@ +import pytest +import pandas as pd +from unittest.mock import MagicMock +from mic_toolkit.synthetic.generation import Generator + + +@pytest.fixture +def generator(): + endpoint = "http://localhost:11434" + model = "model" + return Generator(endpoint, model) + + +def test_classify_text_correct_label(generator): + labels = ["label1", "label2"] + text = "sample text" + generator.create_system_prompt = MagicMock(return_value="system prompt") + generator.client.chat = MagicMock(return_value={"message": {"content": "label1"}}) + result = generator.generate_labels(labels, pd.DataFrame([text]), max_tries=5) + assert result["label"].iloc[0] == "label1" + + +def test_classify_text_exceeds_max_tries(generator): + labels = ["label1", "label2"] + text = "sample text" + generator.create_system_prompt = MagicMock(return_value="system prompt") + generator.client.chat = MagicMock( + side_effect=[{"message": {"content": "wrong label"}}] * 5 + + [{"message": {"content": "wrong label"}}] + ) + result = generator.generate_labels( + labels, pd.DataFrame({"text": [text]}), max_tries=5 + ) + assert result["label"].iloc[0] == "wrong label" + + +def test_generate_text(generator): + text = "sample text" + generator.client.chat = MagicMock( + return_value={"message": {"content": "generated response"}} + ) + result = generator.generate_text(pd.DataFrame({"text": [text]})) + assert result["output"].iloc[0] == "generated response"