-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from AndreaFrancis/main
Add generate_to_hf Method to FastData Class
- Loading branch information
Showing
6 changed files
with
438 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from fastcore.utils import * | ||
from fastdata.core import FastData | ||
|
||
|
||
class Translation: | ||
"Translation from an English phrase to a Spanish phrase" | ||
|
||
def __init__(self, english: str, spanish: str): | ||
self.english = english | ||
self.spanish = spanish | ||
|
||
def __repr__(self): | ||
return f"{self.english} ➡ *{self.spanish}*" | ||
|
||
|
||
prompt_template = """\ | ||
Generate English and Spanish translations on the following topic: | ||
<topic>{topic}</topic> | ||
""" | ||
|
||
inputs = [ | ||
{"topic": "I am going to the beach this weekend"}, | ||
{"topic": "I am going to the gym after work"}, | ||
{"topic": "I am going to the park with my kids"}, | ||
{"topic": "I am going to the movies with my friends"}, | ||
{"topic": "I am going to the store to buy some groceries"}, | ||
{"topic": "I am going to the library to read some books"}, | ||
{"topic": "I am going to the zoo to see the animals"}, | ||
{"topic": "I am going to the museum to see the art"}, | ||
{"topic": "I am going to the restaurant to eat some food"}, | ||
] | ||
|
||
fast_data = FastData(model="claude-3-haiku-20240307") | ||
dataset_name = "my_dataset" | ||
|
||
repo_id, translations = fast_data.generate_to_hf( | ||
prompt_template=prompt_template, | ||
inputs=inputs, | ||
schema=Translation, | ||
repo_id=dataset_name, | ||
max_items_per_file=4, | ||
) | ||
print(f"A new repository has been create on {repo_id}") | ||
print(translations) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.