Skip to content

Commit

Permalink
make path to small model training modules and config.json dynamic
Browse files Browse the repository at this point in the history
  • Loading branch information
stolzenp committed Apr 19, 2024
1 parent ffa807d commit d793040
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/fabricator/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _inner_generate_loop(
if train_small_model_every_x_generations is not None and train_small_model_every_x_generations > 0:
if prompt_call_idx % train_small_model_every_x_generations == 0:
logger.info("Commencing small model training.")
small_model = import_module(small_model_training,"src.small_model_training")
small_model = import_module("src.small_model_training." + small_model_training, __package__)
inf_subset = small_model.get_influential_subset(generated_dataset)
fewshot_dataset = inf_subset
logger.info("Continuing generation.")
Expand Down
24 changes: 19 additions & 5 deletions src/small_model_training/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import torch.nn
import numpy as np
import os
from dataclasses import dataclass, field
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
Expand Down Expand Up @@ -31,14 +32,27 @@ def get_influential_subset(dataset):

# trace labels
for entry in dataset:
if entry['label'] not in label2id:
label2id[entry['label']] = counter
id2label[counter] = entry['label']
if entry["label"] not in label2id:
label2id[entry["label"]] = counter
id2label[counter] = entry["label"]
counter += 1

#get path to config
def get_project_root() -> str:
current_dir = os.path.abspath(os.path.dirname(__file__))
while not os.path.isfile(os.path.join(current_dir, "README.md")):
current_dir = os.path.dirname(current_dir)
return current_dir

def get_path_to_json_file(relative_path_to_json_file) -> str:
project_root = get_project_root()
return os.path.join(project_root, relative_path_to_json_file)

path_to_json_file = get_path_to_json_file("src/small_model_training/config.json")

# get parameters from config
parser = HfArgumentParser((ModelArguments, TrainingArguments, FewshotArguments))
model_args, training_args, fewshot_args = parser.parse_json_file('config.json')
model_args, training_args, fewshot_args = parser.parse_json_file(path_to_json_file)

# setup preprocessing
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
Expand Down Expand Up @@ -100,7 +114,7 @@ def compute_metrics(eval_pred):

# Iterate through the dataset and filter out duplicates
for example in inf_subset:
text = example['text']
text = example["text"]
if text not in unique_samples:
unique_samples.add(text)
deduplicated_inf_subset.append(example)
Expand Down

0 comments on commit d793040

Please sign in to comment.