diff --git a/pyproject.toml b/pyproject.toml index d229537..9cc9d30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,19 @@ +[project] +name = "reach" +readme = "README.md" + [build-system] requires = ["setuptools >= 40.6.0", "wheel"] build-backend = "setuptools.build_meta" [tool.ruff] +exclude = [".venv/"] target-version = "py311" [[tool.mypy.overrides]] module = [ "tqdm.*", "ahocorasick.*", + "setuptools.*", ] ignore_missing_imports = true diff --git a/reach/autoreach.py b/reach/autoreach.py index b77ce1f..a7c24da 100644 --- a/reach/autoreach.py +++ b/reach/autoreach.py @@ -61,10 +61,9 @@ def __init__( items: list[str], lowercase: str | bool = "auto", name: str = "", - unk_index: int | None = None, ) -> None: """Initialize a Reach instance with an array and list of strings.""" - super().__init__(vectors, items, name, unk_index) + super().__init__(vectors, items, name) self.automaton = Automaton() if not all(isinstance(item, str) for item in self.items): raise ValueError("All your items should be strings.") diff --git a/reach/reach.py b/reach/reach.py index 5286a5c..ffcf118 100644 --- a/reach/reach.py +++ b/reach/reach.py @@ -41,15 +41,9 @@ class Reach: name : string, optional, default '' A string giving the name of the current reach. Only useful if you have multiple spaces and want to keep track of them. - unk_index : int or None, optional, default None - The index of the UNK item. If this is None, any attempts at vectorizing - OOV items will throw an error. Attributes ---------- - unk_index : int - The integer index of your unknown glyph. This glyph will be inserted - into your BoW space whenever an unknown item is encountered. name : string The name of the Reach instance. @@ -60,7 +54,6 @@ def __init__( vectors: Matrix, items: list[str], name: str = "", - unk_index: int | None = None, ) -> None: """Initialize a Reach instance with an array and list of items.""" if len(items) != len(vectors): @@ -79,8 +72,27 @@ def __init__( self._items: dict[str, int] = {w: idx for idx, w in enumerate(items)} self._indices: dict[int, str] = {idx: w for w, idx in self.items.items()} self.vectors = np.asarray(vectors) - self.unk_index = unk_index self.name = name + self._unk_token: str | None = None + self._unk_index: int | None = None + + @property + def unk_token(self) -> str | None: + """The unknown token.""" + return self._unk_token + + @unk_token.setter + def unk_token(self, token: str | None) -> None: + if token is None: + if self.unk_token is not None: + logger.info(f"Setting unk token from {self.unk_token} to None.") + self._unk_token = None + self._unk_index = None + else: + if token not in self.items: + self.insert([token]) + self._unk_token = token + self._unk_index = self.items[token] def __len__(self) -> int: """The number of the items in the vector space.""" @@ -154,6 +166,36 @@ def _normalize_or_copy(vectors: npt.NDArray) -> npt.NDArray: return vectors return Reach.normalize(vectors, norms) + def insert(self, tokens: list[str], vectors: npt.NDArray | None = None) -> None: + """ + Insert new items into the vector space. + + Parameters + ---------- + tokens : list + A list of items to insert into the vector space. + vectors : numpy array, optional, default None + The vectors to insert into the vector space. If this is None, + the vectors will be set to zero. + + """ + if vectors is None: + vectors = np.zeros((len(tokens), self.size), dtype=self.vectors.dtype) + else: + vectors = np.asarray(vectors, dtype=self.vectors.dtype) + + if len(tokens) != len(vectors): + raise ValueError( + f"Your tokens and vectors are not the same length: {len(tokens)} != {len(vectors)}" + ) + + for token in tokens: + if token in self.items: + raise ValueError(f"Token {token} is already in the vector space.") + self.items[token] = len(self.items) + self.indices[len(self.items) - 1] = token + self.vectors = np.concatenate([self.vectors, vectors], 0) + @classmethod def load( cls, @@ -161,7 +203,7 @@ def load( wordlist: tuple[str, ...] | None = None, num_to_load: int | None = None, truncate_embeddings: int | None = None, - unk_word: str | None = None, + unk_token: str | None = None, sep: str = " ", recover_from_errors: bool = False, desired_dtype: Dtype = "float32", @@ -190,8 +232,8 @@ def load( truncate_embeddings : int, optional, default None If this value is not None, the vectors in the vector space will be truncated to the number of dimensions indicated by this value. - unk_word : object - The object to treat as UNK in your vector space. If this is not + unk_token : str + The string to treat as UNK in your vector space. If this is not in your items dictionary after loading, we add it with a zero vector. recover_from_errors : bool @@ -232,25 +274,21 @@ def load( if came_from_path: file_handle.close() - if unk_word is not None: - if unk_word not in items: - unk_vec = np.zeros((1, vectors.shape[1]), dtype=desired_dtype) - vectors = np.concatenate([unk_vec, vectors], 0) - items = [unk_word] + items - unk_index = 0 - else: - unk_index = items.index(unk_word) - else: - unk_index = None - # NOTE: we use type: ignore because we pass a list of strings, which is hashable - return cls( + instance = cls( vectors, - items, # type: ignore + items, name=name, - unk_index=unk_index, ) + if unk_token is not None: + if unk_token not in items: + logger.info(f"Adding unk token {unk_token} to the vocabulary.") + instance.insert([unk_token]) + instance.unk_token = unk_token + + return instance + @staticmethod def _load( file_handle: TextIOWrapper, @@ -491,7 +529,7 @@ def bow(self, tokens: Tokens, remove_oov: bool = False) -> list[int]: except KeyError as exc: if remove_oov: continue - if self.unk_index is None: + if self._unk_index is None: raise ValueError( "You supplied OOV items but didn't " "provide the index of the replacement " @@ -499,7 +537,7 @@ def bow(self, tokens: Tokens, remove_oov: bool = False) -> list[int]: "or set unk_index to the index of the " "item which replaces any OOV items." ) from exc - out.append(self.unk_index) + out.append(self._unk_index) return out @@ -865,13 +903,14 @@ def intersect(self, itemlist: Tokens) -> Reach: itemlist = list(set(self.items) & set(itemlist)) # Get indices of intersection. indices = sorted([self.items[item] for item in itemlist]) - # Set unk_index to None if it is None or if it is not in indices - unk_index = self.unk_index if self.unk_index in indices else None # Index vectors vectors = self.vectors[indices] # Index words itemlist = [self.indices[index] for index in indices] - return Reach(vectors, itemlist, unk_index=unk_index, name=self.name) + instance = Reach(vectors, itemlist, name=self.name) + instance.unk_token = self.unk_token + + return instance def union(self, other: Reach, check: bool = True) -> Reach: """ @@ -946,7 +985,7 @@ def save_fast_format(self, filename: str) -> None: """ items, _ = zip(*sorted(self.items.items(), key=lambda x: x[1])) - items_dict = {"items": items, "unk_index": self.unk_index, "name": self.name} + items_dict = {"items": items, "unk_token": self.unk_token, "name": self.name} with open(f"{filename}_items.json", "w") as file_handle: json.dump(items_dict, file_handle) @@ -975,12 +1014,15 @@ def load_fast_format( """ with open(f"{filename}_items.json") as file_handle: items = json.load(file_handle) - words, unk_index, name = items["items"], items["unk_index"], items["name"] + words, unk_token, name = items["items"], items["unk_token"], items["name"] with open(f"{filename}_vectors.npy", "rb") as file_handle: vectors: npt.NDArray = np.load(file_handle) vectors = vectors.astype(desired_dtype) - return cls(vectors, words, unk_index=unk_index, name=name) + instance = cls(vectors, words, name=name) + instance.unk_token = unk_token + + return instance def normalize(vectors: npt.NDArray, norms: npt.NDArray | None = None) -> npt.NDArray: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_auto.py b/tests/test_auto.py index 8e79e12..8c9baf6 100644 --- a/tests/test_auto.py +++ b/tests/test_auto.py @@ -1,5 +1,4 @@ import unittest -from typing import Hashable, List, Tuple import numpy as np @@ -7,8 +6,8 @@ class TestAuto(unittest.TestCase): - def data(self) -> Tuple[List[Hashable], np.ndarray]: - words: List[Hashable] = [ + def data(self) -> tuple[list[str], np.ndarray]: + words: list[str] = [ "donatello", "leonardo", "raphael", diff --git a/tests/test_init.py b/tests/test_init.py index d3101a0..9ec0446 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,5 +1,4 @@ import unittest -from typing import Hashable, List, Tuple import numpy as np @@ -7,8 +6,8 @@ class TestInit(unittest.TestCase): - def data(self) -> Tuple[List[Hashable], np.ndarray]: - words: List[Hashable] = [ + def data(self) -> tuple[list[str], np.ndarray]: + words: list[str] = [ "donatello", "leonardo", "raphael", @@ -48,8 +47,7 @@ def test_init(self) -> None: instance = Reach(vectors, words, name="sensei") self.assertEqual(instance.name, "sensei") - instance = Reach(vectors, words, unk_index=1) - self.assertEqual(instance.unk_index, 1) + instance = Reach(vectors, words) self.assertEqual(list(instance.sorted_items), words) with self.assertRaises(AttributeError): diff --git a/tests/test_io.py b/tests/test_io.py index 510fe39..f94e837 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -61,20 +61,20 @@ def test_unk(self) -> None: lines = self.lines() tempfile.write(lines) tempfile.seek(0) - instance = Reach.load(tempfile.name, unk_word=None) - self.assertEqual(instance.unk_index, None) + instance = Reach.load(tempfile.name, unk_token=None) + self.assertEqual(instance._unk_index, None) desired_dtype = "float32" instance = Reach.load( - tempfile.name, unk_word="[UNK]", desired_dtype=desired_dtype + tempfile.name, unk_token="[UNK]", desired_dtype=desired_dtype ) - self.assertEqual(instance.unk_index, 0) - self.assertEqual(instance.items["[UNK]"], instance.unk_index) + self.assertEqual(instance._unk_index, 6) + self.assertEqual(instance.items["[UNK]"], instance._unk_index) self.assertEqual(instance.vectors.dtype, desired_dtype) - instance = Reach.load(tempfile.name, unk_word="splinter") - self.assertEqual(instance.unk_index, 2) - self.assertEqual(instance.items["splinter"], instance.unk_index) + instance = Reach.load(tempfile.name, unk_token="splinter") + self.assertEqual(instance._unk_index, 2) + self.assertEqual(instance.items["splinter"], instance._unk_index) def test_limit(self) -> None: with NamedTemporaryFile(mode="w+") as tempfile: @@ -240,7 +240,7 @@ def test_save_load_fast_format(self) -> None: self.assertEqual(instance.size, instance_2.size) self.assertEqual(len(instance), len(instance_2)) self.assertTrue(np.allclose(instance.vectors, instance_2.vectors)) - self.assertEqual(instance.unk_index, instance_2.unk_index) + self.assertEqual(instance._unk_index, instance_2._unk_index) self.assertEqual(instance.name, instance_2.name) def test_save_load(self) -> None: @@ -257,5 +257,5 @@ def test_save_load(self) -> None: self.assertEqual(instance.size, instance_2.size) self.assertEqual(len(instance), len(instance_2)) self.assertTrue(np.allclose(instance.vectors, instance_2.vectors)) - self.assertEqual(instance.unk_index, instance_2.unk_index) + self.assertEqual(instance._unk_index, instance_2._unk_index) self.assertEqual(instance.name, instance_2.name) diff --git a/tests/test_similarity.py b/tests/test_similarity.py index cb6381c..a302245 100644 --- a/tests/test_similarity.py +++ b/tests/test_similarity.py @@ -1,9 +1,10 @@ import logging import unittest from itertools import combinations -from typing import Hashable, List, Tuple +from typing import cast import numpy as np +import numpy.typing as npt from reach import Reach, normalize @@ -11,8 +12,8 @@ class TestSimilarity(unittest.TestCase): - def data(self) -> Tuple[List[Hashable], np.ndarray]: - words: List[Hashable] = [ + def data(self) -> tuple[list[str], np.ndarray]: + words: list[str] = [ "donatello", "leonardo", "raphael", @@ -47,7 +48,7 @@ def test_normalize_vector(self) -> None: def test_normalize_norm(self) -> None: x = np.arange(10) result = Reach.normalize(x) - result_2 = Reach.normalize(x, np.linalg.norm(x)) + result_2 = Reach.normalize(x, cast(npt.NDArray, np.linalg.norm(x))) self.assertTrue(np.allclose(result, result_2)) @@ -91,7 +92,7 @@ def test_ranking(self) -> None: argsorted_matrix = np.flip(np.argsort(sim_matrix, axis=1), axis=1)[:, 1:] for idx, w in enumerate(instance.items): - similar_words: List[Hashable] = [ + similar_words: list[str] = [ x[0] for x in instance.most_similar([w], num=10)[0] ] indices = [instance.items[word] for word in similar_words] @@ -142,7 +143,7 @@ def test_threshold(self) -> None: threshold = 0.0 for index, w in enumerate(instance.items): - above_threshold_1: List[Hashable] = [ + above_threshold_1: list[str] = [ x[0] for x in instance.threshold([w], threshold=threshold)[0] ] indices_1 = [instance.items[word] for word in above_threshold_1] @@ -155,7 +156,7 @@ def test_threshold(self) -> None: threshold = 0.9 for w in instance.items: - above_threshold_2: List[Hashable] = [ + above_threshold_2: list[str] = [ x[0] for x in instance.threshold([w], threshold=threshold)[0] ] indices_2 = [instance.items[word] for word in above_threshold_2] diff --git a/tests/test_vectorize.py b/tests/test_vectorize.py index 3f17ed5..57c2f03 100644 --- a/tests/test_vectorize.py +++ b/tests/test_vectorize.py @@ -1,6 +1,5 @@ import logging import unittest -from typing import Hashable, List, Tuple import numpy as np @@ -10,8 +9,8 @@ class TestVectorize(unittest.TestCase): - def data(self) -> Tuple[List[Hashable], np.ndarray]: - words: List[Hashable] = [ + def data(self) -> tuple[list[str], np.ndarray]: + words: list[str] = [ "donatello", "leonardo", "raphael", @@ -45,9 +44,13 @@ def test_vectorize_unk(self) -> None: words, vectors = self.data() words.append("") vectors = np.concatenate([vectors, np.zeros((1, vectors.shape[1]))]) - reach = Reach(vectors, words, unk_index=len(words) - 1) + reach = Reach(vectors, words) + reach.unk_token = "" + + self.assertIsNotNone(reach._unk_index) - self.assertEqual(reach.indices[reach.unk_index], "") # type: ignore + assert reach._unk_index is not None + self.assertEqual(reach.indices[reach._unk_index], "") self.assertTrue(np.allclose(vectors[-1], np.zeros(reach.size))) def test_bow_no_unk(self) -> None: @@ -70,10 +73,11 @@ def test_bow_unk(self) -> None: words, vectors = self.data() words.append("") vectors = np.concatenate([vectors, np.zeros((1, vectors.shape[1]))]) - reach = Reach(vectors, words, unk_index=len(words) - 1) + reach = Reach(vectors, words) + reach.unk_token = "" bow = reach.bow(["donatello", "leonardo", "rgieurghegh"]) - self.assertEqual(bow, [0, 1, reach.unk_index]) + self.assertEqual(bow, [0, 1, reach._unk_index]) bow = reach.bow(["donatello", "leonardo", "rgieurghegh"], remove_oov=True) self.assertEqual(bow, [0, 1]) @@ -130,7 +134,8 @@ def test_mean_pool_unk(self) -> None: words, vectors = self.data() words.append("") vectors = np.concatenate([vectors, np.zeros((1, vectors.shape[1]))]) - reach = Reach(vectors, words, unk_index=len(words) - 1) + reach = Reach(vectors, words) + reach.unk_token = "" vec = reach.mean_pool(["donatello", "dog"]) self.assertTrue(np.allclose(vec, reach["donatello"] / 2))