Skip to content

Commit

Permalink
added parameters to adjust max size of inf_subset and to avoid duplic…
Browse files Browse the repository at this point in the history
…ates in it
  • Loading branch information
stolzenp committed Apr 18, 2024
1 parent 37c57d4 commit cdf6c7c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/small_model_training/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@
"evaluation_strategy":"epoch",
"save_strategy":"epoch",
"load_best_model_at_end":true,
"push_to_hub":false
"push_to_hub":false,
"avoid_duplicate_fewshots": true,
"max_inf_subset_size": 5
}
27 changes: 24 additions & 3 deletions src/small_model_training/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ class ModelArguments:
metadata={"help": "Path to pretrained model or model from huggingface.co/models"}
)

@dataclass
class FewshotArguments:
avoid_duplicate_fewshots: bool = field(
metadata={"help": "Decides whether to avoid duplicate fewshot examples or not"}
)
max_inf_subset_size: int = field(
metadata={"help": "Limit for the size of the influential subset"}
)

def get_influential_subset(dataset):
id2label = {}
label2id = {}
Expand All @@ -28,8 +37,8 @@ def get_influential_subset(dataset):
counter += 1

# get parameters from config
parser = HfArgumentParser((ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_json_file('config.json')
parser = HfArgumentParser((ModelArguments, TrainingArguments, FewshotArguments))
model_args, training_args, fewshot_args = parser.parse_json_file('config.json')

# setup preprocessing
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
Expand Down Expand Up @@ -78,13 +87,25 @@ def compute_metrics(eval_pred):

inf_subset = []

indices = calculate_candidates_entropy_knn(outputs, math.ceil(len(dataset["test"])/2))
indices = calculate_candidates_entropy_knn(outputs, min(math.ceil(len(dataset["test"])/2),fewshot_args.max_inf_subset_size))

for elem in indices:
inf_subset.append(dataset["test"][elem.item()])

inf_subset = Dataset.from_list(inf_subset)

if fewshot_args.avoid_duplicate_fewshots:
unique_samples = set() # Set to store unique samples
deduplicated_inf_subset = [] # List to store deduplicated samples

# Iterate through the dataset and filter out duplicates
for example in inf_subset:
text = example['text']
if text not in unique_samples:
unique_samples.add(text)
deduplicated_inf_subset.append(example)
inf_subset = Dataset.from_list(deduplicated_inf_subset)

return inf_subset

#TO-DO: dynamic variant for calculating inf_subset
Expand Down

0 comments on commit cdf6c7c

Please sign in to comment.