Skip to content

Commit

Permalink
Merge pull request #9 from AndreaFrancis/main
Browse files Browse the repository at this point in the history
Add generate_to_hf Method to FastData Class
  • Loading branch information
ncoop57 authored Nov 21, 2024
2 parents e4f7a5d + e889f45 commit f6c40a7
Show file tree
Hide file tree
Showing 6 changed files with 438 additions and 51 deletions.
44 changes: 44 additions & 0 deletions examples/push_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from fastcore.utils import *
from fastdata.core import FastData


class Translation:
"Translation from an English phrase to a Spanish phrase"

def __init__(self, english: str, spanish: str):
self.english = english
self.spanish = spanish

def __repr__(self):
return f"{self.english} ➡ *{self.spanish}*"


prompt_template = """\
Generate English and Spanish translations on the following topic:
<topic>{topic}</topic>
"""

inputs = [
{"topic": "I am going to the beach this weekend"},
{"topic": "I am going to the gym after work"},
{"topic": "I am going to the park with my kids"},
{"topic": "I am going to the movies with my friends"},
{"topic": "I am going to the store to buy some groceries"},
{"topic": "I am going to the library to read some books"},
{"topic": "I am going to the zoo to see the animals"},
{"topic": "I am going to the museum to see the art"},
{"topic": "I am going to the restaurant to eat some food"},
]

fast_data = FastData(model="claude-3-haiku-20240307")
dataset_name = "my_dataset"

repo_id, translations = fast_data.generate_to_hf(
prompt_template=prompt_template,
inputs=inputs,
schema=Translation,
repo_id=dataset_name,
max_items_per_file=4,
)
print(f"A new repository has been create on {repo_id}")
print(translations)
5 changes: 4 additions & 1 deletion fastdata/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,8 @@
'lib_path': 'fastdata'},
'syms': { 'fastdata.core': { 'fastdata.core.FastData': ('core.html#fastdata', 'fastdata/core.py'),
'fastdata.core.FastData.__init__': ('core.html#fastdata.__init__', 'fastdata/core.py'),
'fastdata.core.FastData._process_input': ('core.html#fastdata._process_input', 'fastdata/core.py'),
'fastdata.core.FastData._save_results': ('core.html#fastdata._save_results', 'fastdata/core.py'),
'fastdata.core.FastData._set_rate_limit': ('core.html#fastdata._set_rate_limit', 'fastdata/core.py'),
'fastdata.core.FastData.generate': ('core.html#fastdata.generate', 'fastdata/core.py'),
'fastdata.core.FastData.set_rate_limit': ('core.html#fastdata.set_rate_limit', 'fastdata/core.py')}}}
'fastdata.core.FastData.generate_to_hf': ('core.html#fastdata.generate_to_hf', 'fastdata/core.py')}}}
201 changes: 179 additions & 22 deletions fastdata/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,68 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_core.ipynb.

# %% auto 0
__all__ = ['FastData']
__all__ = ['DATASET_CARD_TEMPLATE', 'FastData']

# %% ../nbs/00_core.ipynb 3
from claudette import *
import concurrent.futures
import json
import shutil
from pathlib import Path
from uuid import uuid4
from typing import Optional, Union

from tqdm import tqdm
from fastcore.utils import *
from ratelimit import limits, sleep_and_retry
from tqdm import tqdm

import concurrent.futures
from huggingface_hub import CommitScheduler, DatasetCard
from claudette import *

# %% ../nbs/00_core.ipynb 4
DATASET_CARD_TEMPLATE = """
---
tags:
- fastdata
- synthetic
---
# {title}
_Note: This is an AI-generated dataset, so its content may be inaccurate or false._
**Source of the data:**
The dataset was generated using [Fastdata](https://github.com/AnswerDotAI/fastdata) library and {model_id} with the following input:
## System Prompt
```
{system_prompt}
```
## Prompt Template
```
{prompt_template}
```
## Sample Input
```json
{sample_input}
```
"""


class FastData:
def __init__(self,
model: str = "claude-3-haiku-20240307",
calls: int = 100,
period: int = 60):
self.cli = Client(model)
self.set_rate_limit(calls, period)
self._set_rate_limit(calls, period)

def set_rate_limit(self, calls: int, period: int):
def _set_rate_limit(self, calls: int, period: int):
"""Set a new rate limit."""
@sleep_and_retry
@limits(calls=calls, period=period)
Expand All @@ -35,6 +77,22 @@ def rate_limited_call(prompt: str, schema, temp: float, sp: str):

self._rate_limited_call = rate_limited_call

def _process_input(self, prompt_template, schema, temp, sp, input_data):
try:
prompt = prompt_template.format(**input_data)
return self._rate_limited_call(
prompt=prompt, schema=schema, temp=temp, sp=sp
)
except Exception as e:
print(f"Error processing input {input_data}: {e}")
return None

def _save_results(self, results: list[dict], save_path: Path) -> None:
with open(save_path, "w") as f:
for res in results:
obj_dict = getattr(res, "__stored_args__", res.__dict__)
f.write(json.dumps(obj_dict) + "\n")

def generate(self,
prompt_template: str,
inputs: list[dict],
Expand All @@ -44,23 +102,122 @@ def generate(self,
max_workers: int = 64) -> list[dict]:
"For every input in INPUTS, fill PROMPT_TEMPLATE and generate a value fitting SCHEMA"

def process_input(input_data):
try:
prompt = prompt_template.format(**input_data)
return self._rate_limited_call(
prompt=prompt,
schema=schema,
temp=temp,
sp=sp
)
except Exception as e:
print(f"Error processing input: {e}")
return None

results = []
with tqdm(total=len(inputs)) as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(process_input, input_data) for input_data in inputs]
futures = [
executor.submit(
self._process_input,
prompt_template,
schema,
temp,
sp,
input_data,
)
for input_data in inputs
]

for completed_future in concurrent.futures.as_completed(futures):
pbar.update(1)
return [f.result() for f in futures]

def generate_to_hf(
self,
prompt_template: str,
inputs: list[dict],
schema,
repo_id: str,
temp: float = 1.0,
sp: str = "You are a helpful assistant.",
max_workers: int = 64,
max_items_per_file: int = 100,
commit_every: Union[int, float] = 5,
private: bool = False,
token: Optional[str] = None,
delete_files_after: bool = True,
) -> tuple[str, list[dict]]:
"""
Generate data based on a prompt template and schema, and save it to Hugging Face dataset repository.
This function writes the generated records to multiple files, each containing a maximum of `max_items_per_file` records.
Due to the multi-threaded execution of the function, the order of the records in the files is not guaranteed to match the order of the input data.
Args:
prompt_template (str): The template for generating prompts.
inputs (list[dict]): A list of input dictionaries to be processed.
schema: The schema to parse the generated data.
repo_id (str): The HuggingFace dataset name.
temp (float, optional): The temperature for generation. Defaults to 1.0.
sp (str, optional): The system prompt for the assistant. Defaults to "You are a helpful assistant.".
max_workers (int, optional): The maximum number of worker threads. Defaults to 64.
max_items_per_file (int, optional): The maximum number of items to save in each file. Defaults to 100.
commit_every (Union[int, float], optional): The number of minutes between each commit. Defaults to 5.
private (bool, optional): Whether the repository is private. Defaults to False.
token (Optional[str], optional): The token to use to commit to the repo. Defaults to the token saved on the machine.
delete_files_after (bool, optional): Whether to delete files after processing. Defaults to True.
Returns:
tuple[str, list[dict]]: A tuple with the generated repo_id and the list of generated data dictionaries.
"""
dataset_dir = Path(repo_id.replace("/", "_"))
dataset_dir.mkdir(parents=True, exist_ok=True)
data_dir = dataset_dir / "data"
data_dir.mkdir(exist_ok=True)

try:
scheduler = CommitScheduler(
repo_id=repo_id,
repo_type="dataset",
folder_path=dataset_dir,
every=commit_every,
private=private,
token=token,
)

readme_path = dataset_dir / "README.md"

if not readme_path.exists():
DatasetCard(
DATASET_CARD_TEMPLATE.format(
title=repo_id,
model_id=self.cli.model,
system_prompt=sp,
prompt_template=prompt_template,
sample_input=inputs[:2],
)
).save(readme_path)

results = []
total_inputs = len(inputs)

with tqdm(total=total_inputs) as pbar:
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
futures = [
executor.submit(
self._process_input,
prompt_template,
schema,
temp,
sp,
input_data,
)
for input_data in inputs
]

current_file = data_dir / f"train-{uuid4()}.jsonl"
for completed_future in concurrent.futures.as_completed(futures):
result = completed_future.result()
if result is not None:
results.append(result)
with scheduler.lock:
self._save_results(results, current_file)
pbar.update(1)
if len(results) >= max_items_per_file:
current_file = data_dir / f"train-{uuid4()}.jsonl"
results.clear()
finally:
scheduler.trigger().result() # force upload last result
if delete_files_after:
shutil.rmtree(dataset_dir)

return scheduler.repo_id, [f.result() for f in futures if f.done()]
Loading

0 comments on commit f6c40a7

Please sign in to comment.