Skip to content

Commit

Permalink
modify dataset_generator.py to support small_model_training
Browse files Browse the repository at this point in the history
  • Loading branch information
stolzenp committed Feb 1, 2024
1 parent 0805cd8 commit e6d20d5
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion src/fabricator/dataset_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import time

from importlib import import_module
from pathlib import Path
from collections import defaultdict
from typing import Any, Callable, Dict, Optional, Union, Tuple, List
Expand All @@ -16,7 +17,6 @@
from .samplers import single_label_stratified_sample
from .utils import log_dir, create_timestamp_path


class DatasetGenerator:
"""The DatasetGenerator class is the main class of the fabricator package.
It generates datasets based on a prompt template. The main function is generate()."""
Expand Down Expand Up @@ -59,6 +59,8 @@ def generate(
return_unlabeled_dataset: bool = False,
max_prompt_calls: int = 10,
num_samples_to_generate: int = 10,
small_model_training: Optional[str] = None,
train_small_model_every_X_generations: Optional[int] = None,
timeout_per_prompt: Optional[int] = None,
log_every_n_api_calls: int = 25,
dummy_response: Optional[Union[str, Callable]] = None
Expand All @@ -82,6 +84,9 @@ def generate(
return_unlabeled_dataset (bool, optional): Whether to return the original dataset. Defaults to False.
max_prompt_calls (int, optional): Maximum number of prompt calls. Defaults to 10.
num_samples_to_generate (int, optional): Number of samples to generate. Defaults to 10.
small_model_training (str, optional): Task to perform small model training on. Defaults to None.
train_small_model_every_X_generations (int, optional): Number of generations between small model
training iterations. Defaults to None.
timeout_per_prompt (Optional[int], optional): Timeout per prompt call. Defaults to None.
log_every_n_api_calls (int, optional): Log every n api calls. Defaults to 25.
dummy_response (Optional[Union[str, Callable]], optional): Dummy response for dry runs. Defaults to None.
Expand All @@ -99,6 +104,9 @@ def generate(
if fewshot_dataset and not fewshot_sampling_column:
fewshot_sampling_column = prompt_template.generate_data_for_column[0]

assert small_model_training in [None, "text-classification"], \
"Task for small model training must be available in 'src/small_model_training' e.g. 'text-classification'"

generated_dataset, original_dataset = self._inner_generate_loop(
prompt_template,
fewshot_dataset,
Expand All @@ -109,6 +117,8 @@ def generate(
return_unlabeled_dataset,
max_prompt_calls,
num_samples_to_generate,
small_model_training,
train_small_model_every_X_generations,
timeout_per_prompt,
log_every_n_api_calls,
dummy_response
Expand Down Expand Up @@ -170,6 +180,8 @@ def _inner_generate_loop(
return_unlabeled_dataset: bool,
max_prompt_calls: int,
num_samples_to_generate: int,
small_model_training: Optional[str],
train_small_model_every_x_generations: Optional[int],
timeout_per_prompt: Optional[int],
log_every_n_api_calls: int = 25,
dummy_response: Optional[Union[str, Callable]] = None
Expand Down Expand Up @@ -284,6 +296,12 @@ def _inner_generate_loop(
if timeout_per_prompt is not None:
time.sleep(timeout_per_prompt)

if train_small_model_every_x_generations > 0:
if prompt_call_idx % train_small_model_every_x_generations == 0:
small_model = import_module(small_model_training,"src.small_model_training")
inf_subset = small_model.get_influential_subset(generated_dataset)
fewshot_dataset = inf_subset

generated_dataset = Dataset.from_dict(generated_dataset)

if return_unlabeled_dataset:
Expand Down

0 comments on commit e6d20d5

Please sign in to comment.