From df4bee6da899afe65aca19b32e4956db25e1ad42 Mon Sep 17 00:00:00 2001 From: Stephan Tulkens Date: Fri, 18 Oct 2024 09:26:11 +0200 Subject: [PATCH] Add delete function (#71) --- reach/reach.py | 32 ++++++++++++++++++++++++++++++-- tests/test_vectorize.py | 24 ++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/reach/reach.py b/reach/reach.py index 281e996..b10fecb 100644 --- a/reach/reach.py +++ b/reach/reach.py @@ -104,9 +104,9 @@ def indices(self) -> dict[int, str]: return self._indices @property - def sorted_items(self) -> Tokens: + def sorted_items(self) -> list[str]: """The items, sorted by index.""" - items: Tokens = [item for item, _ in sorted(self.items.items(), key=lambda x: x[1])] + items: list[str] = [item for item, _ in sorted(self.items.items(), key=lambda x: x[1])] return items @property @@ -181,6 +181,34 @@ def insert(self, tokens: Sequence[str], vectors: npt.NDArray | None = None) -> N self.indices[len(self.items) - 1] = token self.vectors = np.concatenate([self.vectors, vectors], 0) + def delete(self, tokens: Sequence[str]) -> None: + """ + Delete tokens from the vector space. + + The removal of tokens is done in place. If the tokens are not in the vector space, + a ValueError is raised. + + :param tokens: A list of tokens to remove from the vector space. + :raises ValueError: If any passed tokens are not in the vector space. + """ + try: + curr_indices = [self.items[token] for token in tokens] + except KeyError as exc: + raise ValueError(f"Token {exc} was not in the vector space.") from exc + + tokens_set = set(tokens) + vectors = np.delete(self.vectors, curr_indices, axis=0) + new_items: dict[str, int] = {} + for item in self.items: + if item in tokens_set: + tokens_set.remove(item) + continue + new_items[item] = len(new_items) + + self._items = new_items + self._indices = {idx: item for item, idx in self.items.items()} + self.vectors = vectors + @classmethod def load_word2vec_format( cls, diff --git a/tests/test_vectorize.py b/tests/test_vectorize.py index ae0cfb9..d7b011c 100644 --- a/tests/test_vectorize.py +++ b/tests/test_vectorize.py @@ -120,3 +120,27 @@ def test_mean_pool_unk(self) -> None: vec = reach.mean_pool([], safeguard=False) self.assertTrue(np.allclose(vec, np.zeros_like(vec))) + + def test_delete(self) -> None: + """Test delete method.""" + words, vectors = self.data() + reach = Reach(vectors, words) + + reach.delete(["donatello"]) + self.assertNotIn("donatello", reach.items) + + with self.assertRaises(ValueError): + reach.delete("donatello") + + def test_delete_intersect(self) -> None: + """Test delete method workaround.""" + words, vectors = self.data() + reach = Reach(vectors, words) + + tokens_to_delete = ["donatello"] + items_to_keep = set(reach.items) - set(tokens_to_delete) + new_r = reach.intersect(items_to_keep) + self.assertNotIn("donatello", new_r.items) + + with self.assertRaises(ValueError): + new_r.delete("donatello")