Skip to content

Commit

Permalink
adjustments for running small model during generation
Browse files Browse the repository at this point in the history
  • Loading branch information
stolzenp committed Apr 6, 2024
1 parent be26ac6 commit 707d47e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 24 deletions.
25 changes: 14 additions & 11 deletions src/fabricator/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,16 @@ def _inner_generate_loop(
# require a second parameter for sample from label options and not from fewshot examples
prompt_labels = prompt_template.label_options

if fewshot_dataset:
# label-conditioned generation with label options
if fewshot_dataset is None and isinstance(prompt_labels,list) or prompt_call_idx > train_small_model_every_x_generations:
prompt_labels = choice(prompt_labels, 1)[0]

if fewshot_dataset and prompt_labels in fewshot_dataset['label']:
prompt_labels, fewshot_examples = self._sample_fewshot_examples(
prompt_template, fewshot_dataset, fewshot_sampling_strategy, fewshot_examples_per_class,
prompt_labels, fewshot_dataset, fewshot_sampling_strategy, fewshot_examples_per_class,
fewshot_sampling_column
)

# label-conditioned generation with label options
if fewshot_dataset is None and isinstance(prompt_labels, list):
prompt_labels = choice(prompt_labels,1)[0]

prompt_text = prompt_template.get_prompt_text(prompt_labels, fewshot_examples)

if unlabeled_dataset:
Expand Down Expand Up @@ -302,9 +302,11 @@ def _inner_generate_loop(

if train_small_model_every_x_generations is not None and train_small_model_every_x_generations > 0:
if prompt_call_idx % train_small_model_every_x_generations == 0:
logger.info("Commencing small model training.")
small_model = import_module(small_model_training,"src.small_model_training")
inf_subset = small_model.get_influential_subset(generated_dataset)
fewshot_dataset = inf_subset
logger.info("Continuing generation.")

generated_dataset = Dataset.from_dict(generated_dataset)

Expand Down Expand Up @@ -339,29 +341,30 @@ def _convert_prediction(self, prediction: str, target_type: type) -> Any:

@staticmethod
def _sample_fewshot_examples(
prompt_template: BasePrompt,
prompt_labels,
fewshot_dataset: Dataset,
fewshot_sampling_strategy: str,
fewshot_examples_per_class: int,
fewshot_sampling_column: str
) -> Tuple[Union[List[str], str], Dataset]:

if fewshot_sampling_strategy == "uniform":
prompt_labels = choice(prompt_template.label_options, 1)[0]
prompt_labels = choice(prompt_labels, 1)[0]
fewshot_examples = fewshot_dataset.filter(
lambda example: example[fewshot_sampling_column] == prompt_labels
).shuffle().select(range(fewshot_examples_per_class))
)
fewshot_examples = fewshot_examples.shuffle().select(
range(fewshot_examples_per_class) if fewshot_examples_per_class is not None else range(len(fewshot_examples))
)

elif fewshot_sampling_strategy == "stratified":
prompt_labels = prompt_template.label_options
fewshot_examples = single_label_stratified_sample(
fewshot_dataset,
fewshot_sampling_column,
fewshot_examples_per_class
)

else:
prompt_labels = prompt_template.label_options if prompt_template.label_options else None
if fewshot_examples_per_class:
fewshot_examples = fewshot_dataset.shuffle().select(range(fewshot_examples_per_class))
else:
Expand Down
6 changes: 6 additions & 0 deletions src/fabricator/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ def get_prompt_text(self, labels: Union[str, List[str]] = None, examples: Option
task_description = self.task_description

if examples:
# due to small_model training fewshot_examples appear after some iterations
if self.relevant_columns_for_fewshot_examples is None:
self.relevant_columns_for_fewshot_examples = examples.column_names
self.fewshot_prompt = self.inner_fewshot_example_separator.join(
[f"{var}: {{{var}}}" for var in self.relevant_columns_for_fewshot_examples]
)
examples = self.filter_examples_by_columns(examples, self.relevant_columns_for_fewshot_examples)
formatted_examples = [self.fewshot_prompt.format(**example) for example in examples]
prompt_text = self.fewshot_example_separator.join(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
max_prompt_calls=100,
num_samples_to_generate=100,
small_model_training='text_classification',
train_small_model_every_X_generations=10

train_small_model_every_X_generations=10,
)

generated_dataset.push_to_hub("your-first-generated-dataset")
27 changes: 16 additions & 11 deletions src/small_model_training/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import torch.nn
import numpy as np
from dataclasses import dataclass, field
Expand All @@ -20,8 +21,15 @@ def get_influential_subset(dataset):
# setup preprocessing
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)

def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True, padding=True)
# define labels
# TO-DO: dynamic labels
id2label = {0: "negative", 1: "positive"}
label2id = {"negative": 0, "positive": 1}

def preprocess_function(batch):
preprocessed_tokens = tokenizer(batch["text"], truncation=True, padding=True)
preprocessed_tokens["label"] = [label2id[label] for label in batch["label"]]
return preprocessed_tokens

# setup compute_metrics
accuracy = evaluate.load("accuracy")
Expand All @@ -43,11 +51,6 @@ def compute_metrics(eval_pred):

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# define labels
# TO-DO: dynamic labels
id2label = {0: "negative", 1: "positive"}
label2id = {"negative": 0, "positive": 1}

model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path, num_labels=2, id2label=id2label, label2id=label2id
)
Expand All @@ -68,16 +71,17 @@ def compute_metrics(eval_pred):

inf_subset = []

indices = calculate_candidates_knn(outputs)
indices = calculate_candidates_knn(outputs, math.ceil(len(dataset["test"])/2))

for elem in indices:
inf_subset.append(dataset["test"][elem.item()])

# TO-DO: check for pre-processing
inf_subset = Dataset.from_list(inf_subset)

return inf_subset

#TO-DO: dynamic variant for calculating inf_subset
def calculate_candidates_knn(model_outputs):
def calculate_candidates_knn(model_outputs, num_candidates):
logits = model_outputs[0]

logits = torch.from_numpy(logits)
Expand All @@ -88,6 +92,7 @@ def calculate_candidates_knn(model_outputs):
second_values = scores[:, 1]

distance = torch.abs(first_values - second_values)
_, knn_indices = distance.topk(5, largest=False)
#make dynamic
_, knn_indices = distance.topk(num_candidates, largest=False)

return knn_indices

0 comments on commit 707d47e

Please sign in to comment.