diff --git a/examples/push_to_hf.py b/examples/push_to_hf.py new file mode 100644 index 0000000..43ec3ba --- /dev/null +++ b/examples/push_to_hf.py @@ -0,0 +1,44 @@ +from fastcore.utils import * +from fastdata.core import FastData + + +class Translation: + "Translation from an English phrase to a Spanish phrase" + + def __init__(self, english: str, spanish: str): + self.english = english + self.spanish = spanish + + def __repr__(self): + return f"{self.english} ➡ *{self.spanish}*" + + +prompt_template = """\ +Generate English and Spanish translations on the following topic: +{topic} +""" + +inputs = [ + {"topic": "I am going to the beach this weekend"}, + {"topic": "I am going to the gym after work"}, + {"topic": "I am going to the park with my kids"}, + {"topic": "I am going to the movies with my friends"}, + {"topic": "I am going to the store to buy some groceries"}, + {"topic": "I am going to the library to read some books"}, + {"topic": "I am going to the zoo to see the animals"}, + {"topic": "I am going to the museum to see the art"}, + {"topic": "I am going to the restaurant to eat some food"}, +] + +fast_data = FastData(model="claude-3-haiku-20240307") +dataset_name = "my_dataset" + +repo_id, translations = fast_data.generate_to_hf( + prompt_template=prompt_template, + inputs=inputs, + schema=Translation, + repo_id=dataset_name, + max_items_per_file=4, +) +print(f"A new repository has been create on {repo_id}") +print(translations) diff --git a/fastdata/_modidx.py b/fastdata/_modidx.py index d6b2231..d04a6c5 100644 --- a/fastdata/_modidx.py +++ b/fastdata/_modidx.py @@ -7,5 +7,8 @@ 'lib_path': 'fastdata'}, 'syms': { 'fastdata.core': { 'fastdata.core.FastData': ('core.html#fastdata', 'fastdata/core.py'), 'fastdata.core.FastData.__init__': ('core.html#fastdata.__init__', 'fastdata/core.py'), + 'fastdata.core.FastData._process_input': ('core.html#fastdata._process_input', 'fastdata/core.py'), + 'fastdata.core.FastData._save_results': ('core.html#fastdata._save_results', 'fastdata/core.py'), + 'fastdata.core.FastData._set_rate_limit': ('core.html#fastdata._set_rate_limit', 'fastdata/core.py'), 'fastdata.core.FastData.generate': ('core.html#fastdata.generate', 'fastdata/core.py'), - 'fastdata.core.FastData.set_rate_limit': ('core.html#fastdata.set_rate_limit', 'fastdata/core.py')}}} + 'fastdata.core.FastData.generate_to_hf': ('core.html#fastdata.generate_to_hf', 'fastdata/core.py')}}} diff --git a/fastdata/core.py b/fastdata/core.py index 53d0080..8eba273 100644 --- a/fastdata/core.py +++ b/fastdata/core.py @@ -3,26 +3,68 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_core.ipynb. # %% auto 0 -__all__ = ['FastData'] +__all__ = ['DATASET_CARD_TEMPLATE', 'FastData'] # %% ../nbs/00_core.ipynb 3 -from claudette import * +import concurrent.futures +import json +import shutil +from pathlib import Path +from uuid import uuid4 +from typing import Optional, Union + +from tqdm import tqdm from fastcore.utils import * from ratelimit import limits, sleep_and_retry -from tqdm import tqdm - -import concurrent.futures +from huggingface_hub import CommitScheduler, DatasetCard +from claudette import * # %% ../nbs/00_core.ipynb 4 +DATASET_CARD_TEMPLATE = """ +--- +tags: +- fastdata +- synthetic +--- + +# {title} + +_Note: This is an AI-generated dataset, so its content may be inaccurate or false._ + +**Source of the data:** + +The dataset was generated using [Fastdata](https://github.com/AnswerDotAI/fastdata) library and {model_id} with the following input: + +## System Prompt + +``` +{system_prompt} +``` + +## Prompt Template + +``` +{prompt_template} +``` + +## Sample Input + +```json +{sample_input} +``` + +""" + + class FastData: def __init__(self, model: str = "claude-3-haiku-20240307", calls: int = 100, period: int = 60): self.cli = Client(model) - self.set_rate_limit(calls, period) + self._set_rate_limit(calls, period) - def set_rate_limit(self, calls: int, period: int): + def _set_rate_limit(self, calls: int, period: int): """Set a new rate limit.""" @sleep_and_retry @limits(calls=calls, period=period) @@ -35,6 +77,22 @@ def rate_limited_call(prompt: str, schema, temp: float, sp: str): self._rate_limited_call = rate_limited_call + def _process_input(self, prompt_template, schema, temp, sp, input_data): + try: + prompt = prompt_template.format(**input_data) + return self._rate_limited_call( + prompt=prompt, schema=schema, temp=temp, sp=sp + ) + except Exception as e: + print(f"Error processing input {input_data}: {e}") + return None + + def _save_results(self, results: list[dict], save_path: Path) -> None: + with open(save_path, "w") as f: + for res in results: + obj_dict = getattr(res, "__stored_args__", res.__dict__) + f.write(json.dumps(obj_dict) + "\n") + def generate(self, prompt_template: str, inputs: list[dict], @@ -44,23 +102,122 @@ def generate(self, max_workers: int = 64) -> list[dict]: "For every input in INPUTS, fill PROMPT_TEMPLATE and generate a value fitting SCHEMA" - def process_input(input_data): - try: - prompt = prompt_template.format(**input_data) - return self._rate_limited_call( - prompt=prompt, - schema=schema, - temp=temp, - sp=sp - ) - except Exception as e: - print(f"Error processing input: {e}") - return None - - results = [] with tqdm(total=len(inputs)) as pbar: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(process_input, input_data) for input_data in inputs] + futures = [ + executor.submit( + self._process_input, + prompt_template, + schema, + temp, + sp, + input_data, + ) + for input_data in inputs + ] + for completed_future in concurrent.futures.as_completed(futures): pbar.update(1) return [f.result() for f in futures] + + def generate_to_hf( + self, + prompt_template: str, + inputs: list[dict], + schema, + repo_id: str, + temp: float = 1.0, + sp: str = "You are a helpful assistant.", + max_workers: int = 64, + max_items_per_file: int = 100, + commit_every: Union[int, float] = 5, + private: bool = False, + token: Optional[str] = None, + delete_files_after: bool = True, + ) -> tuple[str, list[dict]]: + """ + Generate data based on a prompt template and schema, and save it to Hugging Face dataset repository. + This function writes the generated records to multiple files, each containing a maximum of `max_items_per_file` records. + Due to the multi-threaded execution of the function, the order of the records in the files is not guaranteed to match the order of the input data. + + Args: + prompt_template (str): The template for generating prompts. + inputs (list[dict]): A list of input dictionaries to be processed. + schema: The schema to parse the generated data. + repo_id (str): The HuggingFace dataset name. + temp (float, optional): The temperature for generation. Defaults to 1.0. + sp (str, optional): The system prompt for the assistant. Defaults to "You are a helpful assistant.". + max_workers (int, optional): The maximum number of worker threads. Defaults to 64. + max_items_per_file (int, optional): The maximum number of items to save in each file. Defaults to 100. + commit_every (Union[int, float], optional): The number of minutes between each commit. Defaults to 5. + private (bool, optional): Whether the repository is private. Defaults to False. + token (Optional[str], optional): The token to use to commit to the repo. Defaults to the token saved on the machine. + delete_files_after (bool, optional): Whether to delete files after processing. Defaults to True. + + Returns: + tuple[str, list[dict]]: A tuple with the generated repo_id and the list of generated data dictionaries. + """ + dataset_dir = Path(repo_id.replace("/", "_")) + dataset_dir.mkdir(parents=True, exist_ok=True) + data_dir = dataset_dir / "data" + data_dir.mkdir(exist_ok=True) + + try: + scheduler = CommitScheduler( + repo_id=repo_id, + repo_type="dataset", + folder_path=dataset_dir, + every=commit_every, + private=private, + token=token, + ) + + readme_path = dataset_dir / "README.md" + + if not readme_path.exists(): + DatasetCard( + DATASET_CARD_TEMPLATE.format( + title=repo_id, + model_id=self.cli.model, + system_prompt=sp, + prompt_template=prompt_template, + sample_input=inputs[:2], + ) + ).save(readme_path) + + results = [] + total_inputs = len(inputs) + + with tqdm(total=total_inputs) as pbar: + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + futures = [ + executor.submit( + self._process_input, + prompt_template, + schema, + temp, + sp, + input_data, + ) + for input_data in inputs + ] + + current_file = data_dir / f"train-{uuid4()}.jsonl" + for completed_future in concurrent.futures.as_completed(futures): + result = completed_future.result() + if result is not None: + results.append(result) + with scheduler.lock: + self._save_results(results, current_file) + pbar.update(1) + if len(results) >= max_items_per_file: + current_file = data_dir / f"train-{uuid4()}.jsonl" + results.clear() + finally: + scheduler.trigger().result() # force upload last result + if delete_files_after: + shutil.rmtree(dataset_dir) + + return scheduler.repo_id, [f.result() for f in futures if f.done()] diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index e20301a..da4bf86 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -35,12 +35,18 @@ "outputs": [], "source": [ "#| export\n", - "from claudette import *\n", + "import concurrent.futures\n", + "import json\n", + "import shutil\n", + "from pathlib import Path\n", + "from uuid import uuid4\n", + "from typing import Optional, Union\n", + "\n", + "from tqdm import tqdm\n", "from fastcore.utils import *\n", "from ratelimit import limits, sleep_and_retry\n", - "from tqdm import tqdm\n", - "\n", - "import concurrent.futures" + "from huggingface_hub import CommitScheduler, DatasetCard\n", + "from claudette import *" ] }, { @@ -50,15 +56,51 @@ "outputs": [], "source": [ "#| export\n", + "DATASET_CARD_TEMPLATE = \"\"\"\n", + "---\n", + "tags:\n", + "- fastdata\n", + "- synthetic\n", + "---\n", + "\n", + "# {title}\n", + "\n", + "_Note: This is an AI-generated dataset, so its content may be inaccurate or false._\n", + "\n", + "**Source of the data:**\n", + "\n", + "The dataset was generated using [Fastdata](https://github.com/AnswerDotAI/fastdata) library and {model_id} with the following input:\n", + "\n", + "## System Prompt\n", + "\n", + "```\n", + "{system_prompt}\n", + "```\n", + "\n", + "## Prompt Template\n", + "\n", + "```\n", + "{prompt_template}\n", + "```\n", + "\n", + "## Sample Input\n", + "\n", + "```json\n", + "{sample_input}\n", + "```\n", + "\n", + "\"\"\"\n", + "\n", + "\n", "class FastData:\n", " def __init__(self,\n", " model: str = \"claude-3-haiku-20240307\",\n", " calls: int = 100,\n", " period: int = 60):\n", " self.cli = Client(model)\n", - " self.set_rate_limit(calls, period)\n", + " self._set_rate_limit(calls, period)\n", "\n", - " def set_rate_limit(self, calls: int, period: int):\n", + " def _set_rate_limit(self, calls: int, period: int):\n", " \"\"\"Set a new rate limit.\"\"\"\n", " @sleep_and_retry\n", " @limits(calls=calls, period=period)\n", @@ -71,6 +113,22 @@ " \n", " self._rate_limited_call = rate_limited_call\n", "\n", + " def _process_input(self, prompt_template, schema, temp, sp, input_data):\n", + " try:\n", + " prompt = prompt_template.format(**input_data)\n", + " return self._rate_limited_call(\n", + " prompt=prompt, schema=schema, temp=temp, sp=sp\n", + " )\n", + " except Exception as e:\n", + " print(f\"Error processing input {input_data}: {e}\")\n", + " return None\n", + "\n", + " def _save_results(self, results: list[dict], save_path: Path) -> None:\n", + " with open(save_path, \"w\") as f:\n", + " for res in results:\n", + " obj_dict = getattr(res, \"__stored_args__\", res.__dict__)\n", + " f.write(json.dumps(obj_dict) + \"\\n\")\n", + "\n", " def generate(self, \n", " prompt_template: str, \n", " inputs: list[dict], \n", @@ -80,26 +138,125 @@ " max_workers: int = 64) -> list[dict]:\n", " \"For every input in INPUTS, fill PROMPT_TEMPLATE and generate a value fitting SCHEMA\"\n", " \n", - " def process_input(input_data):\n", - " try:\n", - " prompt = prompt_template.format(**input_data)\n", - " return self._rate_limited_call(\n", - " prompt=prompt,\n", - " schema=schema,\n", - " temp=temp,\n", - " sp=sp\n", - " )\n", - " except Exception as e:\n", - " print(f\"Error processing input: {e}\")\n", - " return None\n", - "\n", - " results = []\n", " with tqdm(total=len(inputs)) as pbar:\n", " with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n", - " futures = [executor.submit(process_input, input_data) for input_data in inputs]\n", + " futures = [\n", + " executor.submit(\n", + " self._process_input,\n", + " prompt_template,\n", + " schema,\n", + " temp,\n", + " sp,\n", + " input_data,\n", + " )\n", + " for input_data in inputs\n", + " ]\n", + "\n", " for completed_future in concurrent.futures.as_completed(futures):\n", " pbar.update(1)\n", - " return [f.result() for f in futures]" + " return [f.result() for f in futures]\n", + "\n", + " def generate_to_hf(\n", + " self,\n", + " prompt_template: str,\n", + " inputs: list[dict],\n", + " schema,\n", + " repo_id: str,\n", + " temp: float = 1.0,\n", + " sp: str = \"You are a helpful assistant.\",\n", + " max_workers: int = 64,\n", + " max_items_per_file: int = 100,\n", + " commit_every: Union[int, float] = 5,\n", + " private: bool = False,\n", + " token: Optional[str] = None,\n", + " delete_files_after: bool = True,\n", + " ) -> tuple[str, list[dict]]:\n", + " \"\"\"\n", + " Generate data based on a prompt template and schema, and save it to Hugging Face dataset repository.\n", + " This function writes the generated records to multiple files, each containing a maximum of `max_items_per_file` records. \n", + " Due to the multi-threaded execution of the function, the order of the records in the files is not guaranteed to match the order of the input data. \n", + "\n", + " Args:\n", + " prompt_template (str): The template for generating prompts.\n", + " inputs (list[dict]): A list of input dictionaries to be processed.\n", + " schema: The schema to parse the generated data.\n", + " repo_id (str): The HuggingFace dataset name.\n", + " temp (float, optional): The temperature for generation. Defaults to 1.0.\n", + " sp (str, optional): The system prompt for the assistant. Defaults to \"You are a helpful assistant.\".\n", + " max_workers (int, optional): The maximum number of worker threads. Defaults to 64.\n", + " max_items_per_file (int, optional): The maximum number of items to save in each file. Defaults to 100.\n", + " commit_every (Union[int, float], optional): The number of minutes between each commit. Defaults to 5.\n", + " private (bool, optional): Whether the repository is private. Defaults to False.\n", + " token (Optional[str], optional): The token to use to commit to the repo. Defaults to the token saved on the machine.\n", + " delete_files_after (bool, optional): Whether to delete files after processing. Defaults to True.\n", + "\n", + " Returns:\n", + " tuple[str, list[dict]]: A tuple with the generated repo_id and the list of generated data dictionaries.\n", + " \"\"\"\n", + " dataset_dir = Path(repo_id.replace(\"/\", \"_\"))\n", + " dataset_dir.mkdir(parents=True, exist_ok=True)\n", + " data_dir = dataset_dir / \"data\"\n", + " data_dir.mkdir(exist_ok=True)\n", + "\n", + " try:\n", + " scheduler = CommitScheduler(\n", + " repo_id=repo_id,\n", + " repo_type=\"dataset\",\n", + " folder_path=dataset_dir,\n", + " every=commit_every,\n", + " private=private,\n", + " token=token,\n", + " )\n", + "\n", + " readme_path = dataset_dir / \"README.md\"\n", + "\n", + " if not readme_path.exists():\n", + " DatasetCard(\n", + " DATASET_CARD_TEMPLATE.format(\n", + " title=repo_id,\n", + " model_id=self.cli.model,\n", + " system_prompt=sp,\n", + " prompt_template=prompt_template,\n", + " sample_input=inputs[:2],\n", + " )\n", + " ).save(readme_path)\n", + "\n", + " results = []\n", + " total_inputs = len(inputs)\n", + "\n", + " with tqdm(total=total_inputs) as pbar:\n", + " with concurrent.futures.ThreadPoolExecutor(\n", + " max_workers=max_workers\n", + " ) as executor:\n", + " futures = [\n", + " executor.submit(\n", + " self._process_input,\n", + " prompt_template,\n", + " schema,\n", + " temp,\n", + " sp,\n", + " input_data,\n", + " )\n", + " for input_data in inputs\n", + " ]\n", + "\n", + " current_file = data_dir / f\"train-{uuid4()}.jsonl\"\n", + " for completed_future in concurrent.futures.as_completed(futures):\n", + " result = completed_future.result()\n", + " if result is not None:\n", + " results.append(result)\n", + " with scheduler.lock:\n", + " self._save_results(results, current_file)\n", + " pbar.update(1)\n", + " if len(results) >= max_items_per_file:\n", + " current_file = data_dir / f\"train-{uuid4()}.jsonl\"\n", + " results.clear()\n", + " finally:\n", + " scheduler.trigger().result() # force upload last result\n", + " if delete_files_after:\n", + " shutil.rmtree(dataset_dir)\n", + "\n", + " return scheduler.repo_id, [f.result() for f in futures if f.done()]" ] }, { @@ -327,6 +484,29 @@ "show(translations)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate translations and push results to Hugging Face Hub as a dataset\n", + "# Be sure to have the HF_TOKEN environment variable set to your Hugging Face API token\n", + "fast_data = FastData(model=\"claude-3-haiku-20240307\")\n", + "repo_id, translations = fast_data.generate_to_hf(\n", + " prompt_template=prompt_template,\n", + " inputs=[{\"persona\": persona, \"examples\": examples} for persona in personas],\n", + " schema=Translation,\n", + " sp=sp,\n", + " repo_id=f\"personas-translation-{uuid4()}\",\n", + " max_items_per_file=2, # It will create a local file each 2 translations \n", + ")\n", + "assert len(translations) == len(personas)\n", + "\n", + "new_dataset = load_dataset(repo_id)\n", + "assert len(translations) == len(personas)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -492,9 +672,7 @@ } ], "source": [ - "import uuid\n", - "str(uuid.uuid4())\n", - "print(prompt_template.format(**dict(datum=str(uuid.uuid4()))))" + "print(prompt_template.format(**dict(datum=str(uuid4()))))" ] }, { @@ -514,7 +692,7 @@ } ], "source": [ - "Datum(str(uuid.uuid4()))" + "Datum(str(uuid4()))" ] }, { @@ -538,7 +716,7 @@ } ], "source": [ - "in_vals = [{\"datum\":str(uuid.uuid4())} for _ in range(100)]\n", + "in_vals = [{\"datum\":str(uuid4())} for _ in range(100)]\n", "out_vals = fast_data.generate(\n", " prompt_template=prompt_template,\n", " inputs=in_vals,\n", diff --git a/nbs/sidebar.yml b/nbs/sidebar.yml new file mode 100644 index 0000000..45222d2 --- /dev/null +++ b/nbs/sidebar.yml @@ -0,0 +1,5 @@ +website: + sidebar: + contents: + - index.ipynb + - 00_core.ipynb diff --git a/settings.ini b/settings.ini index bdd1400..0751f10 100644 --- a/settings.ini +++ b/settings.ini @@ -26,7 +26,7 @@ keywords = nbdev jupyter notebook python language = English status = 3 user = AnswerDotAI -requirements = claudette fastcore ratelimit tqdm +requirements = claudette fastcore ratelimit tqdm huggingface_hub dev_requirements = black datasets ipykernel nbdev readme_nb = index.ipynb allowed_metadata_keys =