Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] added agumentation yielding evenly distributed classes #768

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions textattack/augmentation/augmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
===================
"""
import random
from collections import Counter

import tqdm

Expand Down Expand Up @@ -230,6 +231,98 @@ def augment_text_with_ids(self, text_list, id_list, show_progress=True):
all_id_list.extend([_id] * (1 + len(augmented_texts)))
return all_text_list, all_id_list

def augment_text_with_ids_evenly(
self,
text_list,
id_list,
additional_examples=0,
perfectly_even=True,
show_progress=True,
):
"""Supplements a list of text with more text data so that there is approximately
the same number of sentences for each label.
Each ID from `id_list` will be represented the same number of times
as the most frequent ID plus `additional_examples`.
If `perfectly_even` is set to `True`, every ID will be occurring exactly the same number of times (recommended,
but slightly slower).

Returns the augmented text along with the corresponding IDs for
each augmented example.
"""
if len(text_list) != len(id_list):
raise ValueError("List of text must be same length as list of IDs")
if additional_examples < 0:
raise ValueError("Additional examples must be non-negative")
all_text_list = []
all_id_list = []
examples_per_id = Counter(id_list)
max_examples = max(examples_per_id.values()) + additional_examples
diff_per_example = {k: max_examples - v for k, v in examples_per_id.items()}
original_transformations_per_example = self.transformations_per_example
remainders = {}
if show_progress:
text_list = tqdm.tqdm(text_list, desc="Augmenting data...")
for text, _id in zip(text_list, id_list):
# distribute augmentation of the original documents evenly
self.transformations_per_example = (
diff_per_example[_id] // examples_per_id[_id]
)
remainders[_id] = diff_per_example[_id] % examples_per_id[_id]
all_text_list.append(text)
all_id_list.append(_id)
if self.transformations_per_example > 0:
augmented_texts = []
while len(augmented_texts) < self.transformations_per_example:
augmented_texts.extend(self.augment(text))
all_text_list.extend(augmented_texts)
all_id_list.extend([_id] * len(augmented_texts))

if perfectly_even:
self.transformations_per_example = 1
# (1) add missing examples:
if show_progress:
added = tqdm.tqdm(
desc="Adding additional examples...", total=sum(remainders.values())
)
while any(remainders.values()):
for text, _id in zip(text_list, id_list):
if remainders[_id] > 0:
# add missing elements one-by-one
remainders[_id] -= 1
if show_progress:
added.update(1)
augmented_texts = self.augment(text)
all_text_list.extend(augmented_texts)
all_id_list.append(_id)
if show_progress:
added.close()
# (2) remove excess:
excess = {k: v - max_examples for k, v in Counter(all_id_list).items()}
new_id_list = []
new_text_list = []
if show_progress:
to_be_removed = int(sum([e > 0 for e in excess.values()]))
removed = tqdm.tqdm(
desc="Removing abundant examples...", total=to_be_removed
)
# count backwards so that the newer elements (most probably being augmented) are deleted first
for i in range(len(all_id_list) - 1, -1, -1):
if excess[all_id_list[i]] <= 0:
new_id_list.append(all_id_list[i])
new_text_list.append(all_text_list[i])
else:
# skip entry for new id and text list
excess[all_id_list[i]] -= 1
if show_progress:
removed.update(1)
all_id_list = new_id_list
all_text_list = new_text_list
if show_progress:
removed.close()

self.transformations_per_example = original_transformations_per_example
return all_text_list, all_id_list

def __repr__(self):
main_str = "Augmenter" + "("
lines = []
Expand Down
Loading