diff --git a/src/fabricator/dataset_generator.py b/src/fabricator/dataset_generator.py index 0a46420..d9f1b6a 100644 --- a/src/fabricator/dataset_generator.py +++ b/src/fabricator/dataset_generator.py @@ -1,6 +1,7 @@ import json import time +from importlib import import_module from pathlib import Path from collections import defaultdict from typing import Any, Callable, Dict, Optional, Union, Tuple, List @@ -16,7 +17,6 @@ from .samplers import single_label_stratified_sample from .utils import log_dir, create_timestamp_path - class DatasetGenerator: """The DatasetGenerator class is the main class of the fabricator package. It generates datasets based on a prompt template. The main function is generate().""" @@ -59,6 +59,8 @@ def generate( return_unlabeled_dataset: bool = False, max_prompt_calls: int = 10, num_samples_to_generate: int = 10, + small_model_training: Optional[str] = None, + train_small_model_every_X_generations: Optional[int] = None, timeout_per_prompt: Optional[int] = None, log_every_n_api_calls: int = 25, dummy_response: Optional[Union[str, Callable]] = None @@ -82,6 +84,9 @@ def generate( return_unlabeled_dataset (bool, optional): Whether to return the original dataset. Defaults to False. max_prompt_calls (int, optional): Maximum number of prompt calls. Defaults to 10. num_samples_to_generate (int, optional): Number of samples to generate. Defaults to 10. + small_model_training (str, optional): Task to perform small model training on. Defaults to None. + train_small_model_every_X_generations (int, optional): Number of generations between small model + training iterations. Defaults to None. timeout_per_prompt (Optional[int], optional): Timeout per prompt call. Defaults to None. log_every_n_api_calls (int, optional): Log every n api calls. Defaults to 25. dummy_response (Optional[Union[str, Callable]], optional): Dummy response for dry runs. Defaults to None. @@ -99,6 +104,9 @@ def generate( if fewshot_dataset and not fewshot_sampling_column: fewshot_sampling_column = prompt_template.generate_data_for_column[0] + assert small_model_training in [None, "text-classification"], \ + "Task for small model training must be available in 'src/small_model_training' e.g. 'text-classification'" + generated_dataset, original_dataset = self._inner_generate_loop( prompt_template, fewshot_dataset, @@ -109,6 +117,8 @@ def generate( return_unlabeled_dataset, max_prompt_calls, num_samples_to_generate, + small_model_training, + train_small_model_every_X_generations, timeout_per_prompt, log_every_n_api_calls, dummy_response @@ -170,6 +180,8 @@ def _inner_generate_loop( return_unlabeled_dataset: bool, max_prompt_calls: int, num_samples_to_generate: int, + small_model_training: Optional[str], + train_small_model_every_x_generations: Optional[int], timeout_per_prompt: Optional[int], log_every_n_api_calls: int = 25, dummy_response: Optional[Union[str, Callable]] = None @@ -284,6 +296,12 @@ def _inner_generate_loop( if timeout_per_prompt is not None: time.sleep(timeout_per_prompt) + if train_small_model_every_x_generations > 0: + if prompt_call_idx % train_small_model_every_x_generations == 0: + small_model = import_module(small_model_training,"src.small_model_training") + inf_subset = small_model.get_influential_subset(generated_dataset) + fewshot_dataset = inf_subset + generated_dataset = Dataset.from_dict(generated_dataset) if return_unlabeled_dataset: