diff --git a/textattack/augmentation/augmenter.py b/textattack/augmentation/augmenter.py index 1b8200e2..0fa1a60b 100644 --- a/textattack/augmentation/augmenter.py +++ b/textattack/augmentation/augmenter.py @@ -3,6 +3,7 @@ =================== """ import random +from collections import Counter import tqdm @@ -230,6 +231,98 @@ def augment_text_with_ids(self, text_list, id_list, show_progress=True): all_id_list.extend([_id] * (1 + len(augmented_texts))) return all_text_list, all_id_list + def augment_text_with_ids_evenly( + self, + text_list, + id_list, + additional_examples=0, + perfectly_even=True, + show_progress=True, + ): + """Supplements a list of text with more text data so that there is approximately + the same number of sentences for each label. + Each ID from `id_list` will be represented the same number of times + as the most frequent ID plus `additional_examples`. + If `perfectly_even` is set to `True`, every ID will be occurring exactly the same number of times (recommended, + but slightly slower). + + Returns the augmented text along with the corresponding IDs for + each augmented example. + """ + if len(text_list) != len(id_list): + raise ValueError("List of text must be same length as list of IDs") + if additional_examples < 0: + raise ValueError("Additional examples must be non-negative") + all_text_list = [] + all_id_list = [] + examples_per_id = Counter(id_list) + max_examples = max(examples_per_id.values()) + additional_examples + diff_per_example = {k: max_examples - v for k, v in examples_per_id.items()} + original_transformations_per_example = self.transformations_per_example + remainders = {} + if show_progress: + text_list = tqdm.tqdm(text_list, desc="Augmenting data...") + for text, _id in zip(text_list, id_list): + # distribute augmentation of the original documents evenly + self.transformations_per_example = ( + diff_per_example[_id] // examples_per_id[_id] + ) + remainders[_id] = diff_per_example[_id] % examples_per_id[_id] + all_text_list.append(text) + all_id_list.append(_id) + if self.transformations_per_example > 0: + augmented_texts = [] + while len(augmented_texts) < self.transformations_per_example: + augmented_texts.extend(self.augment(text)) + all_text_list.extend(augmented_texts) + all_id_list.extend([_id] * len(augmented_texts)) + + if perfectly_even: + self.transformations_per_example = 1 + # (1) add missing examples: + if show_progress: + added = tqdm.tqdm( + desc="Adding additional examples...", total=sum(remainders.values()) + ) + while any(remainders.values()): + for text, _id in zip(text_list, id_list): + if remainders[_id] > 0: + # add missing elements one-by-one + remainders[_id] -= 1 + if show_progress: + added.update(1) + augmented_texts = self.augment(text) + all_text_list.extend(augmented_texts) + all_id_list.append(_id) + if show_progress: + added.close() + # (2) remove excess: + excess = {k: v - max_examples for k, v in Counter(all_id_list).items()} + new_id_list = [] + new_text_list = [] + if show_progress: + to_be_removed = int(sum([e > 0 for e in excess.values()])) + removed = tqdm.tqdm( + desc="Removing abundant examples...", total=to_be_removed + ) + # count backwards so that the newer elements (most probably being augmented) are deleted first + for i in range(len(all_id_list) - 1, -1, -1): + if excess[all_id_list[i]] <= 0: + new_id_list.append(all_id_list[i]) + new_text_list.append(all_text_list[i]) + else: + # skip entry for new id and text list + excess[all_id_list[i]] -= 1 + if show_progress: + removed.update(1) + all_id_list = new_id_list + all_text_list = new_text_list + if show_progress: + removed.close() + + self.transformations_per_example = original_transformations_per_example + return all_text_list, all_id_list + def __repr__(self): main_str = "Augmenter" + "(" lines = []