Skip to content

Commit

Permalink
Add delete function (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul authored Oct 18, 2024
1 parent 0be3e52 commit df4bee6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
32 changes: 30 additions & 2 deletions reach/reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def indices(self) -> dict[int, str]:
return self._indices

@property
def sorted_items(self) -> Tokens:
def sorted_items(self) -> list[str]:
"""The items, sorted by index."""
items: Tokens = [item for item, _ in sorted(self.items.items(), key=lambda x: x[1])]
items: list[str] = [item for item, _ in sorted(self.items.items(), key=lambda x: x[1])]
return items

@property
Expand Down Expand Up @@ -181,6 +181,34 @@ def insert(self, tokens: Sequence[str], vectors: npt.NDArray | None = None) -> N
self.indices[len(self.items) - 1] = token
self.vectors = np.concatenate([self.vectors, vectors], 0)

def delete(self, tokens: Sequence[str]) -> None:
"""
Delete tokens from the vector space.
The removal of tokens is done in place. If the tokens are not in the vector space,
a ValueError is raised.
:param tokens: A list of tokens to remove from the vector space.
:raises ValueError: If any passed tokens are not in the vector space.
"""
try:
curr_indices = [self.items[token] for token in tokens]
except KeyError as exc:
raise ValueError(f"Token {exc} was not in the vector space.") from exc

tokens_set = set(tokens)
vectors = np.delete(self.vectors, curr_indices, axis=0)
new_items: dict[str, int] = {}
for item in self.items:
if item in tokens_set:
tokens_set.remove(item)
continue
new_items[item] = len(new_items)

self._items = new_items
self._indices = {idx: item for item, idx in self.items.items()}
self.vectors = vectors

@classmethod
def load_word2vec_format(
cls,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,27 @@ def test_mean_pool_unk(self) -> None:

vec = reach.mean_pool([], safeguard=False)
self.assertTrue(np.allclose(vec, np.zeros_like(vec)))

def test_delete(self) -> None:
"""Test delete method."""
words, vectors = self.data()
reach = Reach(vectors, words)

reach.delete(["donatello"])
self.assertNotIn("donatello", reach.items)

with self.assertRaises(ValueError):
reach.delete("donatello")

def test_delete_intersect(self) -> None:
"""Test delete method workaround."""
words, vectors = self.data()
reach = Reach(vectors, words)

tokens_to_delete = ["donatello"]
items_to_keep = set(reach.items) - set(tokens_to_delete)
new_r = reach.intersect(items_to_keep)
self.assertNotIn("donatello", new_r.items)

with self.assertRaises(ValueError):
new_r.delete("donatello")

0 comments on commit df4bee6

Please sign in to comment.