From 31ef92e08a4e8106596954dcc2ad43ee6ab8589e Mon Sep 17 00:00:00 2001 From: Andrea Soria Date: Fri, 8 Nov 2024 21:47:30 -0400 Subject: [PATCH 1/7] Adding generate_to_hf method on FastData class --- examples/push_to_hf.py | 44 ++++++++ fastdata/_modidx.py | 5 +- fastdata/core.py | 202 +++++++++++++++++++++++++++++++---- nbs/00_core.ipynb | 233 ++++++++++++++++++++++++++++++++++++----- nbs/sidebar.yml | 5 + settings.ini | 2 +- 6 files changed, 440 insertions(+), 51 deletions(-) create mode 100644 examples/push_to_hf.py create mode 100644 nbs/sidebar.yml diff --git a/examples/push_to_hf.py b/examples/push_to_hf.py new file mode 100644 index 0000000..9d6b8a1 --- /dev/null +++ b/examples/push_to_hf.py @@ -0,0 +1,44 @@ +from fastcore.utils import * +from fastdata.core import FastData + + +class Translation: + "Translation from an English phrase to a Spanish phrase" + + def __init__(self, english: str, spanish: str): + self.english = english + self.spanish = spanish + + def __repr__(self): + return f"{self.english} ➡ *{self.spanish}*" + + +prompt_template = """\ +Generate English and Spanish translations on the following topic: +{topic} +""" + +inputs = [ + {"topic": "I am going to the beach this weekend"}, + {"topic": "I am going to the gym after work"}, + {"topic": "I am going to the park with my kids"}, + {"topic": "I am going to the movies with my friends"}, + {"topic": "I am going to the store to buy some groceries"}, + {"topic": "I am going to the library to read some books"}, + {"topic": "I am going to the zoo to see the animals"}, + {"topic": "I am going to the museum to see the art"}, + {"topic": "I am going to the restaurant to eat some food"}, +] + +fast_data = FastData(model="claude-3-haiku-20240307") +dataset_name = "my_dataset" + +repo_id, translations = fast_data.generate_to_hf( + prompt_template=prompt_template, + inputs=inputs, + schema=Translation, + repo_id=dataset_name, + save_interval=4, +) +print(f"A new repository has been create on {repo_id}") +print(translations) diff --git a/fastdata/_modidx.py b/fastdata/_modidx.py index d6b2231..d04a6c5 100644 --- a/fastdata/_modidx.py +++ b/fastdata/_modidx.py @@ -7,5 +7,8 @@ 'lib_path': 'fastdata'}, 'syms': { 'fastdata.core': { 'fastdata.core.FastData': ('core.html#fastdata', 'fastdata/core.py'), 'fastdata.core.FastData.__init__': ('core.html#fastdata.__init__', 'fastdata/core.py'), + 'fastdata.core.FastData._process_input': ('core.html#fastdata._process_input', 'fastdata/core.py'), + 'fastdata.core.FastData._save_results': ('core.html#fastdata._save_results', 'fastdata/core.py'), + 'fastdata.core.FastData._set_rate_limit': ('core.html#fastdata._set_rate_limit', 'fastdata/core.py'), 'fastdata.core.FastData.generate': ('core.html#fastdata.generate', 'fastdata/core.py'), - 'fastdata.core.FastData.set_rate_limit': ('core.html#fastdata.set_rate_limit', 'fastdata/core.py')}}} + 'fastdata.core.FastData.generate_to_hf': ('core.html#fastdata.generate_to_hf', 'fastdata/core.py')}}} diff --git a/fastdata/core.py b/fastdata/core.py index 53d0080..f25fc0c 100644 --- a/fastdata/core.py +++ b/fastdata/core.py @@ -3,26 +3,68 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_core.ipynb. # %% auto 0 -__all__ = ['FastData'] +__all__ = ['DATASET_CARD_TEMPLATE', 'FastData'] # %% ../nbs/00_core.ipynb 3 -from claudette import * +import concurrent.futures +import json +import shutil +from pathlib import Path +from uuid import uuid4 +from typing import Optional, Union + +from tqdm import tqdm from fastcore.utils import * from ratelimit import limits, sleep_and_retry -from tqdm import tqdm - -import concurrent.futures +from huggingface_hub import CommitScheduler, DatasetCard +from claudette import * # %% ../nbs/00_core.ipynb 4 +DATASET_CARD_TEMPLATE = """ +--- +tags: +- fastdata +- synthetic +--- + +# {title} + +_Note: This is an AI-generated dataset, so its content may be inaccurate or false._ + +**Source of the data:** + +The dataset was generated using [Fastdata](https://github.com/AnswerDotAI/fastdata) library and {model_id} with the following input: + +## System Prompt + +``` +{system_prompt} +``` + +## Prompt Template + +``` +{prompt_template} +``` + +## Sample Input + +```json +{sample_input} +``` + +""" + + class FastData: def __init__(self, model: str = "claude-3-haiku-20240307", calls: int = 100, period: int = 60): self.cli = Client(model) - self.set_rate_limit(calls, period) + self._set_rate_limit(calls, period) - def set_rate_limit(self, calls: int, period: int): + def _set_rate_limit(self, calls: int, period: int): """Set a new rate limit.""" @sleep_and_retry @limits(calls=calls, period=period) @@ -35,6 +77,22 @@ def rate_limited_call(prompt: str, schema, temp: float, sp: str): self._rate_limited_call = rate_limited_call + def _process_input(self, prompt_template, schema, temp, sp, input_data): + try: + prompt = prompt_template.format(**input_data) + return self._rate_limited_call( + prompt=prompt, schema=schema, temp=temp, sp=sp + ) + except Exception as e: + print(f"Error processing input {input_data}: {e}") + return None + + def _save_results(self, results: list[dict], save_path: Path) -> None: + with open(save_path, "w") as f: + for res in results: + obj_dict = getattr(res, "__stored_args__", res.__dict__) + f.write(json.dumps(obj_dict) + "\n") + def generate(self, prompt_template: str, inputs: list[dict], @@ -44,23 +102,123 @@ def generate(self, max_workers: int = 64) -> list[dict]: "For every input in INPUTS, fill PROMPT_TEMPLATE and generate a value fitting SCHEMA" - def process_input(input_data): - try: - prompt = prompt_template.format(**input_data) - return self._rate_limited_call( - prompt=prompt, - schema=schema, - temp=temp, - sp=sp - ) - except Exception as e: - print(f"Error processing input: {e}") - return None - - results = [] with tqdm(total=len(inputs)) as pbar: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(process_input, input_data) for input_data in inputs] + futures = [ + executor.submit( + self._process_input, + prompt_template, + schema, + temp, + sp, + input_data, + ) + for input_data in inputs + ] + for completed_future in concurrent.futures.as_completed(futures): pbar.update(1) return [f.result() for f in futures] + + def generate_to_hf( + self, + prompt_template: str, + inputs: list[dict], + schema, + repo_id: str, + temp: float = 1.0, + sp: str = "You are a helpful assistant.", + max_workers: int = 64, + save_interval: int = 100, + commit_every: Union[int, float] = 5, + private: bool = False, + token: Optional[str] = None, + delete_files_after: bool = True, + ) -> tuple[str, list[dict]]: + """ + Generate data based on a prompt template and schema, and save it to Hugging Face dataset repository. + + Args: + prompt_template (str): The template for generating prompts. + inputs (list[dict]): A list of input dictionaries to be processed. + schema: The schema to parse the generated data. + repo_id (str): The HuggingFace dataset name. + temp (float, optional): The temperature for generation. Defaults to 1.0. + sp (str, optional): The system prompt for the assistant. Defaults to "You are a helpful assistant.". + max_workers (int, optional): The maximum number of worker threads. Defaults to 64. + save_interval (int, optional): The batch size at which to save the results. Defaults to 100. + commit_every (Union[int, float], optional): The number of minutes between each commit. Defaults to 5. + private (bool, optional): Whether the repository is private. Defaults to False. + token (Optional[str], optional): The token to use to commit to the repo. Defaults to the token saved on the machine. + delete_files_after (bool, optional): Whether to delete files after processing. Defaults to True. + + Returns: + tuple[str, list[dict]]: A tuple with the generated repo_id and the list of generated data dictionaries. + """ + dataset_dir = Path(repo_id.replace("/", "_")) + dataset_dir.mkdir(parents=True, exist_ok=True) + data_dir = dataset_dir / "data" + data_dir.mkdir(exist_ok=True) + + try: + scheduler = CommitScheduler( + repo_id=repo_id, + repo_type="dataset", + folder_path=dataset_dir, + every=commit_every, + private=private, + token=token, + ) + DatasetCard( + DATASET_CARD_TEMPLATE.format( + title=repo_id, + model_id=self.cli.model, + system_prompt=sp, + prompt_template=prompt_template, + sample_input=inputs[:2], + ) + ).save(dataset_dir / "README.md") + + results = [] + total_inputs = len(inputs) + + with tqdm(total=total_inputs) as pbar: + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + futures = [ + executor.submit( + self._process_input, + prompt_template, + schema, + temp, + sp, + input_data, + ) + for input_data in inputs + ] + completed = 0 + + for completed_future in concurrent.futures.as_completed(futures): + result = completed_future.result() + if result is not None: + results.append(result) + completed += 1 + pbar.update(1) + + if completed % save_interval == 0 or completed == total_inputs: + if results: + with scheduler.lock: + self._save_results( + results, data_dir / f"train-{uuid4()}.jsonl" + ) + results.clear() + scheduler._push_to_hub() + scheduler.stop() + except Exception as e: + raise e + finally: + if delete_files_after: + shutil.rmtree(dataset_dir) + + return scheduler.repo_id, [f.result() for f in futures if f.done()] diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index e20301a..f9be864 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -35,12 +35,18 @@ "outputs": [], "source": [ "#| export\n", - "from claudette import *\n", + "import concurrent.futures\n", + "import json\n", + "import shutil\n", + "from pathlib import Path\n", + "from uuid import uuid4\n", + "from typing import Optional, Union\n", + "\n", + "from tqdm import tqdm\n", "from fastcore.utils import *\n", "from ratelimit import limits, sleep_and_retry\n", - "from tqdm import tqdm\n", - "\n", - "import concurrent.futures" + "from huggingface_hub import CommitScheduler, DatasetCard\n", + "from claudette import *" ] }, { @@ -50,15 +56,51 @@ "outputs": [], "source": [ "#| export\n", + "DATASET_CARD_TEMPLATE = \"\"\"\n", + "---\n", + "tags:\n", + "- fastdata\n", + "- synthetic\n", + "---\n", + "\n", + "# {title}\n", + "\n", + "_Note: This is an AI-generated dataset, so its content may be inaccurate or false._\n", + "\n", + "**Source of the data:**\n", + "\n", + "The dataset was generated using [Fastdata](https://github.com/AnswerDotAI/fastdata) library and {model_id} with the following input:\n", + "\n", + "## System Prompt\n", + "\n", + "```\n", + "{system_prompt}\n", + "```\n", + "\n", + "## Prompt Template\n", + "\n", + "```\n", + "{prompt_template}\n", + "```\n", + "\n", + "## Sample Input\n", + "\n", + "```json\n", + "{sample_input}\n", + "```\n", + "\n", + "\"\"\"\n", + "\n", + "\n", "class FastData:\n", " def __init__(self,\n", " model: str = \"claude-3-haiku-20240307\",\n", " calls: int = 100,\n", " period: int = 60):\n", " self.cli = Client(model)\n", - " self.set_rate_limit(calls, period)\n", + " self._set_rate_limit(calls, period)\n", "\n", - " def set_rate_limit(self, calls: int, period: int):\n", + " def _set_rate_limit(self, calls: int, period: int):\n", " \"\"\"Set a new rate limit.\"\"\"\n", " @sleep_and_retry\n", " @limits(calls=calls, period=period)\n", @@ -71,6 +113,22 @@ " \n", " self._rate_limited_call = rate_limited_call\n", "\n", + " def _process_input(self, prompt_template, schema, temp, sp, input_data):\n", + " try:\n", + " prompt = prompt_template.format(**input_data)\n", + " return self._rate_limited_call(\n", + " prompt=prompt, schema=schema, temp=temp, sp=sp\n", + " )\n", + " except Exception as e:\n", + " print(f\"Error processing input {input_data}: {e}\")\n", + " return None\n", + "\n", + " def _save_results(self, results: list[dict], save_path: Path) -> None:\n", + " with open(save_path, \"w\") as f:\n", + " for res in results:\n", + " obj_dict = getattr(res, \"__stored_args__\", res.__dict__)\n", + " f.write(json.dumps(obj_dict) + \"\\n\")\n", + "\n", " def generate(self, \n", " prompt_template: str, \n", " inputs: list[dict], \n", @@ -80,26 +138,126 @@ " max_workers: int = 64) -> list[dict]:\n", " \"For every input in INPUTS, fill PROMPT_TEMPLATE and generate a value fitting SCHEMA\"\n", " \n", - " def process_input(input_data):\n", - " try:\n", - " prompt = prompt_template.format(**input_data)\n", - " return self._rate_limited_call(\n", - " prompt=prompt,\n", - " schema=schema,\n", - " temp=temp,\n", - " sp=sp\n", - " )\n", - " except Exception as e:\n", - " print(f\"Error processing input: {e}\")\n", - " return None\n", - "\n", - " results = []\n", " with tqdm(total=len(inputs)) as pbar:\n", " with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n", - " futures = [executor.submit(process_input, input_data) for input_data in inputs]\n", + " futures = [\n", + " executor.submit(\n", + " self._process_input,\n", + " prompt_template,\n", + " schema,\n", + " temp,\n", + " sp,\n", + " input_data,\n", + " )\n", + " for input_data in inputs\n", + " ]\n", + "\n", " for completed_future in concurrent.futures.as_completed(futures):\n", " pbar.update(1)\n", - " return [f.result() for f in futures]" + " return [f.result() for f in futures]\n", + "\n", + " def generate_to_hf(\n", + " self,\n", + " prompt_template: str,\n", + " inputs: list[dict],\n", + " schema,\n", + " repo_id: str,\n", + " temp: float = 1.0,\n", + " sp: str = \"You are a helpful assistant.\",\n", + " max_workers: int = 64,\n", + " save_interval: int = 100,\n", + " commit_every: Union[int, float] = 5,\n", + " private: bool = False,\n", + " token: Optional[str] = None,\n", + " delete_files_after: bool = True,\n", + " ) -> tuple[str, list[dict]]:\n", + " \"\"\"\n", + " Generate data based on a prompt template and schema, and save it to Hugging Face dataset repository.\n", + "\n", + " Args:\n", + " prompt_template (str): The template for generating prompts.\n", + " inputs (list[dict]): A list of input dictionaries to be processed.\n", + " schema: The schema to parse the generated data.\n", + " repo_id (str): The HuggingFace dataset name.\n", + " temp (float, optional): The temperature for generation. Defaults to 1.0.\n", + " sp (str, optional): The system prompt for the assistant. Defaults to \"You are a helpful assistant.\".\n", + " max_workers (int, optional): The maximum number of worker threads. Defaults to 64.\n", + " save_interval (int, optional): The batch size at which to save the results. Defaults to 100.\n", + " commit_every (Union[int, float], optional): The number of minutes between each commit. Defaults to 5.\n", + " private (bool, optional): Whether the repository is private. Defaults to False.\n", + " token (Optional[str], optional): The token to use to commit to the repo. Defaults to the token saved on the machine.\n", + " delete_files_after (bool, optional): Whether to delete files after processing. Defaults to True.\n", + "\n", + " Returns:\n", + " tuple[str, list[dict]]: A tuple with the generated repo_id and the list of generated data dictionaries.\n", + " \"\"\"\n", + " dataset_dir = Path(repo_id.replace(\"/\", \"_\"))\n", + " dataset_dir.mkdir(parents=True, exist_ok=True)\n", + " data_dir = dataset_dir / \"data\"\n", + " data_dir.mkdir(exist_ok=True)\n", + "\n", + " try:\n", + " scheduler = CommitScheduler(\n", + " repo_id=repo_id,\n", + " repo_type=\"dataset\",\n", + " folder_path=dataset_dir,\n", + " every=commit_every,\n", + " private=private,\n", + " token=token,\n", + " )\n", + " DatasetCard(\n", + " DATASET_CARD_TEMPLATE.format(\n", + " title=repo_id,\n", + " model_id=self.cli.model,\n", + " system_prompt=sp,\n", + " prompt_template=prompt_template,\n", + " sample_input=inputs[:2],\n", + " )\n", + " ).save(dataset_dir / \"README.md\")\n", + "\n", + " results = []\n", + " total_inputs = len(inputs)\n", + "\n", + " with tqdm(total=total_inputs) as pbar:\n", + " with concurrent.futures.ThreadPoolExecutor(\n", + " max_workers=max_workers\n", + " ) as executor:\n", + " futures = [\n", + " executor.submit(\n", + " self._process_input,\n", + " prompt_template,\n", + " schema,\n", + " temp,\n", + " sp,\n", + " input_data,\n", + " )\n", + " for input_data in inputs\n", + " ]\n", + " completed = 0\n", + "\n", + " for completed_future in concurrent.futures.as_completed(futures):\n", + " result = completed_future.result()\n", + " if result is not None:\n", + " results.append(result)\n", + " completed += 1\n", + " pbar.update(1)\n", + "\n", + " if completed % save_interval == 0 or completed == total_inputs:\n", + " if results:\n", + " with scheduler.lock:\n", + " self._save_results(\n", + " results, data_dir / f\"train-{uuid4()}.jsonl\"\n", + " )\n", + " results.clear()\n", + " scheduler._push_to_hub()\n", + " scheduler.stop()\n", + " except Exception as e:\n", + " raise e\n", + " finally:\n", + " if delete_files_after:\n", + " shutil.rmtree(dataset_dir)\n", + "\n", + " return scheduler.repo_id, [f.result() for f in futures if f.done()]" ] }, { @@ -327,6 +485,29 @@ "show(translations)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate translations and push results to Hugging Face Hub as a dataset\n", + "# Be sure to have the HF_TOKEN environment variable set to your Hugging Face API token\n", + "fast_data = FastData(model=\"claude-3-haiku-20240307\")\n", + "repo_id, translations = fast_data.generate_to_hf(\n", + " prompt_template=prompt_template,\n", + " inputs=[{\"persona\": persona, \"examples\": examples} for persona in personas],\n", + " schema=Translation,\n", + " sp=sp,\n", + " repo_id=f\"personas-translation-{uuid4()}\",\n", + " save_interval=2, # It will create a local file each 2 translations \n", + ")\n", + "assert len(translations) == len(personas)\n", + "\n", + "new_dataset = load_dataset(repo_id)\n", + "assert len(new_dataset['train']) == len(personas)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -492,9 +673,7 @@ } ], "source": [ - "import uuid\n", - "str(uuid.uuid4())\n", - "print(prompt_template.format(**dict(datum=str(uuid.uuid4()))))" + "print(prompt_template.format(**dict(datum=str(uuid4()))))" ] }, { @@ -514,7 +693,7 @@ } ], "source": [ - "Datum(str(uuid.uuid4()))" + "Datum(str(uuid4()))" ] }, { @@ -538,7 +717,7 @@ } ], "source": [ - "in_vals = [{\"datum\":str(uuid.uuid4())} for _ in range(100)]\n", + "in_vals = [{\"datum\":str(uuid4())} for _ in range(100)]\n", "out_vals = fast_data.generate(\n", " prompt_template=prompt_template,\n", " inputs=in_vals,\n", diff --git a/nbs/sidebar.yml b/nbs/sidebar.yml new file mode 100644 index 0000000..45222d2 --- /dev/null +++ b/nbs/sidebar.yml @@ -0,0 +1,5 @@ +website: + sidebar: + contents: + - index.ipynb + - 00_core.ipynb diff --git a/settings.ini b/settings.ini index bdd1400..0751f10 100644 --- a/settings.ini +++ b/settings.ini @@ -26,7 +26,7 @@ keywords = nbdev jupyter notebook python language = English status = 3 user = AnswerDotAI -requirements = claudette fastcore ratelimit tqdm +requirements = claudette fastcore ratelimit tqdm huggingface_hub dev_requirements = black datasets ipykernel nbdev readme_nb = index.ipynb allowed_metadata_keys = From 2297849acffc5d2784f1211e3312621756d47229 Mon Sep 17 00:00:00 2001 From: Andrea Francis Soria Jimenez Date: Wed, 20 Nov 2024 16:38:14 -0400 Subject: [PATCH 2/7] Update fastdata/core.py Co-authored-by: Lucain --- fastdata/core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fastdata/core.py b/fastdata/core.py index f25fc0c..5dc3cfd 100644 --- a/fastdata/core.py +++ b/fastdata/core.py @@ -215,8 +215,6 @@ def generate_to_hf( results.clear() scheduler._push_to_hub() scheduler.stop() - except Exception as e: - raise e finally: if delete_files_after: shutil.rmtree(dataset_dir) From aff52f9c48800fd538d90808888f410cdece36d0 Mon Sep 17 00:00:00 2001 From: Andrea Soria Date: Wed, 20 Nov 2024 16:40:25 -0400 Subject: [PATCH 3/7] Addressing @Wauplin comments --- nbs/00_core.ipynb | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index f9be864..10c172e 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -205,15 +205,19 @@ " private=private,\n", " token=token,\n", " )\n", - " DatasetCard(\n", - " DATASET_CARD_TEMPLATE.format(\n", - " title=repo_id,\n", - " model_id=self.cli.model,\n", - " system_prompt=sp,\n", - " prompt_template=prompt_template,\n", - " sample_input=inputs[:2],\n", - " )\n", - " ).save(dataset_dir / \"README.md\")\n", + "\n", + " readme_path = dataset_dir / \"README.md\"\n", + "\n", + " if not readme_path.exists():\n", + " DatasetCard(\n", + " DATASET_CARD_TEMPLATE.format(\n", + " title=repo_id,\n", + " model_id=self.cli.model,\n", + " system_prompt=sp,\n", + " prompt_template=prompt_template,\n", + " sample_input=inputs[:2],\n", + " )\n", + " ).save(readme_path)\n", "\n", " results = []\n", " total_inputs = len(inputs)\n", @@ -251,8 +255,6 @@ " results.clear()\n", " scheduler._push_to_hub()\n", " scheduler.stop()\n", - " except Exception as e:\n", - " raise e\n", " finally:\n", " if delete_files_after:\n", " shutil.rmtree(dataset_dir)\n", From ff61f56c713fffec8b9cbf0397376bcac955f644 Mon Sep 17 00:00:00 2001 From: Andrea Soria Date: Wed, 20 Nov 2024 16:46:33 -0400 Subject: [PATCH 4/7] nbdev_prepare execution --- fastdata/core.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/fastdata/core.py b/fastdata/core.py index 5dc3cfd..f9e1dce 100644 --- a/fastdata/core.py +++ b/fastdata/core.py @@ -169,15 +169,19 @@ def generate_to_hf( private=private, token=token, ) - DatasetCard( - DATASET_CARD_TEMPLATE.format( - title=repo_id, - model_id=self.cli.model, - system_prompt=sp, - prompt_template=prompt_template, - sample_input=inputs[:2], - ) - ).save(dataset_dir / "README.md") + + readme_path = dataset_dir / "README.md" + + if not readme_path.exists(): + DatasetCard( + DATASET_CARD_TEMPLATE.format( + title=repo_id, + model_id=self.cli.model, + system_prompt=sp, + prompt_template=prompt_template, + sample_input=inputs[:2], + ) + ).save(readme_path) results = [] total_inputs = len(inputs) From 29ebb42fc1be309f86f83d8c46aad8259264354d Mon Sep 17 00:00:00 2001 From: Andrea Soria Date: Wed, 20 Nov 2024 21:17:58 -0400 Subject: [PATCH 5/7] Addressing @Wauplin comments and adding doc for unordered output --- examples/push_to_hf.py | 2 +- fastdata/core.py | 24 ++++++++++-------------- nbs/00_core.ipynb | 31 ++++++++++++------------------- 3 files changed, 23 insertions(+), 34 deletions(-) diff --git a/examples/push_to_hf.py b/examples/push_to_hf.py index 9d6b8a1..43ec3ba 100644 --- a/examples/push_to_hf.py +++ b/examples/push_to_hf.py @@ -38,7 +38,7 @@ def __repr__(self): inputs=inputs, schema=Translation, repo_id=dataset_name, - save_interval=4, + max_items_per_file=4, ) print(f"A new repository has been create on {repo_id}") print(translations) diff --git a/fastdata/core.py b/fastdata/core.py index f9e1dce..3b3306e 100644 --- a/fastdata/core.py +++ b/fastdata/core.py @@ -129,7 +129,7 @@ def generate_to_hf( temp: float = 1.0, sp: str = "You are a helpful assistant.", max_workers: int = 64, - save_interval: int = 100, + max_items_per_file: int = 100, commit_every: Union[int, float] = 5, private: bool = False, token: Optional[str] = None, @@ -137,6 +137,8 @@ def generate_to_hf( ) -> tuple[str, list[dict]]: """ Generate data based on a prompt template and schema, and save it to Hugging Face dataset repository. + This function writes the generated records to multiple files, each containing a maximum of `max_items_per_file` records. + Due to the multi-threaded execution of the function, the order of the records in the files is not guaranteed to match the order of the input data. Args: prompt_template (str): The template for generating prompts. @@ -146,7 +148,7 @@ def generate_to_hf( temp (float, optional): The temperature for generation. Defaults to 1.0. sp (str, optional): The system prompt for the assistant. Defaults to "You are a helpful assistant.". max_workers (int, optional): The maximum number of worker threads. Defaults to 64. - save_interval (int, optional): The batch size at which to save the results. Defaults to 100. + max_items_per_file (int, optional): The maximum number of items to save in each file. Defaults to 100. commit_every (Union[int, float], optional): The number of minutes between each commit. Defaults to 5. private (bool, optional): Whether the repository is private. Defaults to False. token (Optional[str], optional): The token to use to commit to the repo. Defaults to the token saved on the machine. @@ -201,24 +203,18 @@ def generate_to_hf( ) for input_data in inputs ] - completed = 0 + current_file = data_dir / f"train-{uuid4()}.jsonl" for completed_future in concurrent.futures.as_completed(futures): result = completed_future.result() if result is not None: results.append(result) - completed += 1 + with scheduler.lock: + self._save_results(results, current_file) pbar.update(1) - - if completed % save_interval == 0 or completed == total_inputs: - if results: - with scheduler.lock: - self._save_results( - results, data_dir / f"train-{uuid4()}.jsonl" - ) - results.clear() - scheduler._push_to_hub() - scheduler.stop() + if len(results) >= max_items_per_file: + current_file = data_dir / f"train-{uuid4()}.jsonl" + results.clear() finally: if delete_files_after: shutil.rmtree(dataset_dir) diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index 10c172e..960d9fe 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -165,7 +165,7 @@ " temp: float = 1.0,\n", " sp: str = \"You are a helpful assistant.\",\n", " max_workers: int = 64,\n", - " save_interval: int = 100,\n", + " max_items_per_file: int = 100,\n", " commit_every: Union[int, float] = 5,\n", " private: bool = False,\n", " token: Optional[str] = None,\n", @@ -173,6 +173,8 @@ " ) -> tuple[str, list[dict]]:\n", " \"\"\"\n", " Generate data based on a prompt template and schema, and save it to Hugging Face dataset repository.\n", + " This function writes the generated records to multiple files, each containing a maximum of `max_items_per_file` records. \n", + " Due to the multi-threaded execution of the function, the order of the records in the files is not guaranteed to match the order of the input data. \n", "\n", " Args:\n", " prompt_template (str): The template for generating prompts.\n", @@ -182,7 +184,7 @@ " temp (float, optional): The temperature for generation. Defaults to 1.0.\n", " sp (str, optional): The system prompt for the assistant. Defaults to \"You are a helpful assistant.\".\n", " max_workers (int, optional): The maximum number of worker threads. Defaults to 64.\n", - " save_interval (int, optional): The batch size at which to save the results. Defaults to 100.\n", + " max_items_per_file (int, optional): The maximum number of items to save in each file. Defaults to 100.\n", " commit_every (Union[int, float], optional): The number of minutes between each commit. Defaults to 5.\n", " private (bool, optional): Whether the repository is private. Defaults to False.\n", " token (Optional[str], optional): The token to use to commit to the repo. Defaults to the token saved on the machine.\n", @@ -237,24 +239,18 @@ " )\n", " for input_data in inputs\n", " ]\n", - " completed = 0\n", "\n", + " current_file = data_dir / f\"train-{uuid4()}.jsonl\"\n", " for completed_future in concurrent.futures.as_completed(futures):\n", " result = completed_future.result()\n", " if result is not None:\n", " results.append(result)\n", - " completed += 1\n", + " with scheduler.lock:\n", + " self._save_results(results, current_file)\n", " pbar.update(1)\n", - "\n", - " if completed % save_interval == 0 or completed == total_inputs:\n", - " if results:\n", - " with scheduler.lock:\n", - " self._save_results(\n", - " results, data_dir / f\"train-{uuid4()}.jsonl\"\n", - " )\n", - " results.clear()\n", - " scheduler._push_to_hub()\n", - " scheduler.stop()\n", + " if len(results) >= max_items_per_file:\n", + " current_file = data_dir / f\"train-{uuid4()}.jsonl\"\n", + " results.clear()\n", " finally:\n", " if delete_files_after:\n", " shutil.rmtree(dataset_dir)\n", @@ -502,12 +498,9 @@ " schema=Translation,\n", " sp=sp,\n", " repo_id=f\"personas-translation-{uuid4()}\",\n", - " save_interval=2, # It will create a local file each 2 translations \n", + " max_items_per_file=2, # It will create a local file each 2 translations \n", ")\n", - "assert len(translations) == len(personas)\n", - "\n", - "new_dataset = load_dataset(repo_id)\n", - "assert len(new_dataset['train']) == len(personas)" + "assert len(translations) == len(personas)" ] }, { From eb58f8394613c0152cc1bf87f7b4701afebdf281 Mon Sep 17 00:00:00 2001 From: Andrea Francis Soria Jimenez Date: Thu, 21 Nov 2024 07:55:41 -0400 Subject: [PATCH 6/7] Update fastdata/core.py Co-authored-by: Lucain --- fastdata/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastdata/core.py b/fastdata/core.py index 3b3306e..8eba273 100644 --- a/fastdata/core.py +++ b/fastdata/core.py @@ -216,6 +216,7 @@ def generate_to_hf( current_file = data_dir / f"train-{uuid4()}.jsonl" results.clear() finally: + scheduler.trigger().result() # force upload last result if delete_files_after: shutil.rmtree(dataset_dir) From e889f4523ff15f26bcdc4663dad868be247474bf Mon Sep 17 00:00:00 2001 From: Andrea Soria Date: Thu, 21 Nov 2024 08:05:47 -0400 Subject: [PATCH 7/7] Fix scheduler trigger --- nbs/00_core.ipynb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index 960d9fe..da4bf86 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -252,6 +252,7 @@ " current_file = data_dir / f\"train-{uuid4()}.jsonl\"\n", " results.clear()\n", " finally:\n", + " scheduler.trigger().result() # force upload last result\n", " if delete_files_after:\n", " shutil.rmtree(dataset_dir)\n", "\n", @@ -500,6 +501,9 @@ " repo_id=f\"personas-translation-{uuid4()}\",\n", " max_items_per_file=2, # It will create a local file each 2 translations \n", ")\n", + "assert len(translations) == len(personas)\n", + "\n", + "new_dataset = load_dataset(repo_id)\n", "assert len(translations) == len(personas)" ] },