diff --git a/textaugment/word2vec.py b/textaugment/word2vec.py index aee6f69..4752df2 100644 --- a/textaugment/word2vec.py +++ b/textaugment/word2vec.py @@ -106,51 +106,36 @@ def geometric(self, data): return data[first_trial] def augment(self, data: str, top_n: int = 10): - """ - The method to replace words with similar words. - - :type data: str - :param data: Input data - :type top_n: int - :param top_n: top_n of most similar words to randomly choose from - - :rtype: str - :return: The augmented data - """ + if not isinstance(top_n, int) or not isinstance(data, str): + raise TypeError("Only integers and strings are supported") + + data_tokens = data.lower().split() + + if self.v: + for _ in range(self.runs): + for index in range(len(data_tokens)): + try: + similar_words = [syn for syn, t in self.model.wv.most_similar(data_tokens[index], topn=top_n)] + r = random.randrange(len(similar_words)) + data_tokens[index] = similar_words[r].lower() + except KeyError: + pass + else: + for _ in range(self.runs): + data_tokens_idx = [[x, y] for (x, y) in enumerate(data_tokens)] + words = self.geometric(data=data_tokens_idx).tolist() + for w in words: + try: + similar_words_and_weights = [(syn, t) for syn, t in self.model.wv.most_similar(w[1])] + similar_words = [word for word, t in similar_words_and_weights] + similar_words_weights = [t for word, t in similar_words_and_weights] + word = random.choices(similar_words, similar_words_weights, k=1) + data_tokens[int(w[0])] = word[0].lower() + except KeyError: + pass + + return " ".join(data_tokens) - # Avoid nulls and other unsupported types - if type(top_n) is not int: - raise TypeError("Only integers are supported") - if type(data) is not str: - raise TypeError("Only strings are supported") - # Lower case and split - data_tokens = data.lower().split() - - # Verbose = True then replace all the words. - if self.v: - for _ in range(self.runs): - for index in range(len(data_tokens)): # Index from 0 to length of data_tokens - try: - similar_words = [syn for syn, t in self.model.wv.most_similar(data_tokens[index], topn=top_n)] - r = random.randrange(len(similar_words)) - data_tokens[index] = similar_words[r].lower() # Replace with random synonym from 10 synonyms - except KeyError: - pass # For words not in the word2vec model - else: # Randomly replace some words - for _ in range(self.runs): - data_tokens_idx = [[x, y] for (x, y) in enumerate(data_tokens)] # Enumerate data - words = self.geometric(data=data_tokens_idx).tolist() # List of words indexed - for w in words: - try: - similar_words_and_weights = [(syn, t) for syn, t in self.model.wv.most_similar(w[1])] - similar_words = [word for word, t in similar_words_and_weights] - similar_words_weights = [t for word, t in similar_words_and_weights] - word = random.choices(similar_words, similar_words_weights, k=1) - data_tokens[int(w[0])] = word[0].lower() # Replace with random synonym from 10 synonyms - except KeyError: - pass - return " ".join(data_tokens) - return " ".join(data_tokens) class Fasttext(Word2vec):