Skip to content

Commit

Permalink
Merge pull request #54 from stephantul/set-unk-token
Browse files Browse the repository at this point in the history
feat: add unk token
  • Loading branch information
stephantul authored Jul 14, 2024
2 parents e8bbde3 + f7cc6c4 commit 994f66a
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 68 deletions.
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
[project]
name = "reach"
readme = "README.md"

[build-system]
requires = ["setuptools >= 40.6.0", "wheel"]
build-backend = "setuptools.build_meta"

[tool.ruff]
exclude = [".venv/"]
target-version = "py311"

[[tool.mypy.overrides]]
module = [
"tqdm.*",
"ahocorasick.*",
"setuptools.*",
]
ignore_missing_imports = true
3 changes: 1 addition & 2 deletions reach/autoreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ def __init__(
items: list[str],
lowercase: str | bool = "auto",
name: str = "",
unk_index: int | None = None,
) -> None:
"""Initialize a Reach instance with an array and list of strings."""
super().__init__(vectors, items, name, unk_index)
super().__init__(vectors, items, name)
self.automaton = Automaton()
if not all(isinstance(item, str) for item in self.items):
raise ValueError("All your items should be strings.")
Expand Down
108 changes: 75 additions & 33 deletions reach/reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,9 @@ class Reach:
name : string, optional, default ''
A string giving the name of the current reach. Only useful if you
have multiple spaces and want to keep track of them.
unk_index : int or None, optional, default None
The index of the UNK item. If this is None, any attempts at vectorizing
OOV items will throw an error.
Attributes
----------
unk_index : int
The integer index of your unknown glyph. This glyph will be inserted
into your BoW space whenever an unknown item is encountered.
name : string
The name of the Reach instance.
Expand All @@ -60,7 +54,6 @@ def __init__(
vectors: Matrix,
items: list[str],
name: str = "",
unk_index: int | None = None,
) -> None:
"""Initialize a Reach instance with an array and list of items."""
if len(items) != len(vectors):
Expand All @@ -79,8 +72,27 @@ def __init__(
self._items: dict[str, int] = {w: idx for idx, w in enumerate(items)}
self._indices: dict[int, str] = {idx: w for w, idx in self.items.items()}
self.vectors = np.asarray(vectors)
self.unk_index = unk_index
self.name = name
self._unk_token: str | None = None
self._unk_index: int | None = None

@property
def unk_token(self) -> str | None:
"""The unknown token."""
return self._unk_token

@unk_token.setter
def unk_token(self, token: str | None) -> None:
if token is None:
if self.unk_token is not None:
logger.info(f"Setting unk token from {self.unk_token} to None.")
self._unk_token = None
self._unk_index = None
else:
if token not in self.items:
self.insert([token])
self._unk_token = token
self._unk_index = self.items[token]

def __len__(self) -> int:
"""The number of the items in the vector space."""
Expand Down Expand Up @@ -154,14 +166,44 @@ def _normalize_or_copy(vectors: npt.NDArray) -> npt.NDArray:
return vectors
return Reach.normalize(vectors, norms)

def insert(self, tokens: list[str], vectors: npt.NDArray | None = None) -> None:
"""
Insert new items into the vector space.
Parameters
----------
tokens : list
A list of items to insert into the vector space.
vectors : numpy array, optional, default None
The vectors to insert into the vector space. If this is None,
the vectors will be set to zero.
"""
if vectors is None:
vectors = np.zeros((len(tokens), self.size), dtype=self.vectors.dtype)
else:
vectors = np.asarray(vectors, dtype=self.vectors.dtype)

if len(tokens) != len(vectors):
raise ValueError(
f"Your tokens and vectors are not the same length: {len(tokens)} != {len(vectors)}"
)

for token in tokens:
if token in self.items:
raise ValueError(f"Token {token} is already in the vector space.")
self.items[token] = len(self.items)
self.indices[len(self.items) - 1] = token
self.vectors = np.concatenate([self.vectors, vectors], 0)

@classmethod
def load(
cls,
vector_file: File | str,
wordlist: tuple[str, ...] | None = None,
num_to_load: int | None = None,
truncate_embeddings: int | None = None,
unk_word: str | None = None,
unk_token: str | None = None,
sep: str = " ",
recover_from_errors: bool = False,
desired_dtype: Dtype = "float32",
Expand Down Expand Up @@ -190,8 +232,8 @@ def load(
truncate_embeddings : int, optional, default None
If this value is not None, the vectors in the vector space will
be truncated to the number of dimensions indicated by this value.
unk_word : object
The object to treat as UNK in your vector space. If this is not
unk_token : str
The string to treat as UNK in your vector space. If this is not
in your items dictionary after loading, we add it with a zero
vector.
recover_from_errors : bool
Expand Down Expand Up @@ -232,25 +274,21 @@ def load(
if came_from_path:
file_handle.close()

if unk_word is not None:
if unk_word not in items:
unk_vec = np.zeros((1, vectors.shape[1]), dtype=desired_dtype)
vectors = np.concatenate([unk_vec, vectors], 0)
items = [unk_word] + items
unk_index = 0
else:
unk_index = items.index(unk_word)
else:
unk_index = None

# NOTE: we use type: ignore because we pass a list of strings, which is hashable
return cls(
instance = cls(
vectors,
items, # type: ignore
items,
name=name,
unk_index=unk_index,
)

if unk_token is not None:
if unk_token not in items:
logger.info(f"Adding unk token {unk_token} to the vocabulary.")
instance.insert([unk_token])
instance.unk_token = unk_token

return instance

@staticmethod
def _load(
file_handle: TextIOWrapper,
Expand Down Expand Up @@ -491,15 +529,15 @@ def bow(self, tokens: Tokens, remove_oov: bool = False) -> list[int]:
except KeyError as exc:
if remove_oov:
continue
if self.unk_index is None:
if self._unk_index is None:
raise ValueError(
"You supplied OOV items but didn't "
"provide the index of the replacement "
"glyph. Either set remove_oov to True, "
"or set unk_index to the index of the "
"item which replaces any OOV items."
) from exc
out.append(self.unk_index)
out.append(self._unk_index)

return out

Expand Down Expand Up @@ -865,13 +903,14 @@ def intersect(self, itemlist: Tokens) -> Reach:
itemlist = list(set(self.items) & set(itemlist))
# Get indices of intersection.
indices = sorted([self.items[item] for item in itemlist])
# Set unk_index to None if it is None or if it is not in indices
unk_index = self.unk_index if self.unk_index in indices else None
# Index vectors
vectors = self.vectors[indices]
# Index words
itemlist = [self.indices[index] for index in indices]
return Reach(vectors, itemlist, unk_index=unk_index, name=self.name)
instance = Reach(vectors, itemlist, name=self.name)
instance.unk_token = self.unk_token

return instance

def union(self, other: Reach, check: bool = True) -> Reach:
"""
Expand Down Expand Up @@ -946,7 +985,7 @@ def save_fast_format(self, filename: str) -> None:
"""
items, _ = zip(*sorted(self.items.items(), key=lambda x: x[1]))
items_dict = {"items": items, "unk_index": self.unk_index, "name": self.name}
items_dict = {"items": items, "unk_token": self.unk_token, "name": self.name}

with open(f"{filename}_items.json", "w") as file_handle:
json.dump(items_dict, file_handle)
Expand Down Expand Up @@ -975,12 +1014,15 @@ def load_fast_format(
"""
with open(f"{filename}_items.json") as file_handle:
items = json.load(file_handle)
words, unk_index, name = items["items"], items["unk_index"], items["name"]
words, unk_token, name = items["items"], items["unk_token"], items["name"]

with open(f"{filename}_vectors.npy", "rb") as file_handle:
vectors: npt.NDArray = np.load(file_handle)
vectors = vectors.astype(desired_dtype)
return cls(vectors, words, unk_index=unk_index, name=name)
instance = cls(vectors, words, name=name)
instance.unk_token = unk_token

return instance


def normalize(vectors: npt.NDArray, norms: npt.NDArray | None = None) -> npt.NDArray:
Expand Down
Empty file added tests/__init__.py
Empty file.
5 changes: 2 additions & 3 deletions tests/test_auto.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import unittest
from typing import Hashable, List, Tuple

import numpy as np

from reach import AutoReach, Reach


class TestAuto(unittest.TestCase):
def data(self) -> Tuple[List[Hashable], np.ndarray]:
words: List[Hashable] = [
def data(self) -> tuple[list[str], np.ndarray]:
words: list[str] = [
"donatello",
"leonardo",
"raphael",
Expand Down
8 changes: 3 additions & 5 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import unittest
from typing import Hashable, List, Tuple

import numpy as np

from reach import Reach


class TestInit(unittest.TestCase):
def data(self) -> Tuple[List[Hashable], np.ndarray]:
words: List[Hashable] = [
def data(self) -> tuple[list[str], np.ndarray]:
words: list[str] = [
"donatello",
"leonardo",
"raphael",
Expand Down Expand Up @@ -48,8 +47,7 @@ def test_init(self) -> None:
instance = Reach(vectors, words, name="sensei")
self.assertEqual(instance.name, "sensei")

instance = Reach(vectors, words, unk_index=1)
self.assertEqual(instance.unk_index, 1)
instance = Reach(vectors, words)
self.assertEqual(list(instance.sorted_items), words)

with self.assertRaises(AttributeError):
Expand Down
20 changes: 10 additions & 10 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,20 @@ def test_unk(self) -> None:
lines = self.lines()
tempfile.write(lines)
tempfile.seek(0)
instance = Reach.load(tempfile.name, unk_word=None)
self.assertEqual(instance.unk_index, None)
instance = Reach.load(tempfile.name, unk_token=None)
self.assertEqual(instance._unk_index, None)

desired_dtype = "float32"
instance = Reach.load(
tempfile.name, unk_word="[UNK]", desired_dtype=desired_dtype
tempfile.name, unk_token="[UNK]", desired_dtype=desired_dtype
)
self.assertEqual(instance.unk_index, 0)
self.assertEqual(instance.items["[UNK]"], instance.unk_index)
self.assertEqual(instance._unk_index, 6)
self.assertEqual(instance.items["[UNK]"], instance._unk_index)
self.assertEqual(instance.vectors.dtype, desired_dtype)

instance = Reach.load(tempfile.name, unk_word="splinter")
self.assertEqual(instance.unk_index, 2)
self.assertEqual(instance.items["splinter"], instance.unk_index)
instance = Reach.load(tempfile.name, unk_token="splinter")
self.assertEqual(instance._unk_index, 2)
self.assertEqual(instance.items["splinter"], instance._unk_index)

def test_limit(self) -> None:
with NamedTemporaryFile(mode="w+") as tempfile:
Expand Down Expand Up @@ -240,7 +240,7 @@ def test_save_load_fast_format(self) -> None:
self.assertEqual(instance.size, instance_2.size)
self.assertEqual(len(instance), len(instance_2))
self.assertTrue(np.allclose(instance.vectors, instance_2.vectors))
self.assertEqual(instance.unk_index, instance_2.unk_index)
self.assertEqual(instance._unk_index, instance_2._unk_index)
self.assertEqual(instance.name, instance_2.name)

def test_save_load(self) -> None:
Expand All @@ -257,5 +257,5 @@ def test_save_load(self) -> None:
self.assertEqual(instance.size, instance_2.size)
self.assertEqual(len(instance), len(instance_2))
self.assertTrue(np.allclose(instance.vectors, instance_2.vectors))
self.assertEqual(instance.unk_index, instance_2.unk_index)
self.assertEqual(instance._unk_index, instance_2._unk_index)
self.assertEqual(instance.name, instance_2.name)
15 changes: 8 additions & 7 deletions tests/test_similarity.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import logging
import unittest
from itertools import combinations
from typing import Hashable, List, Tuple
from typing import cast

import numpy as np
import numpy.typing as npt

from reach import Reach, normalize

logger = logging.getLogger(__name__)


class TestSimilarity(unittest.TestCase):
def data(self) -> Tuple[List[Hashable], np.ndarray]:
words: List[Hashable] = [
def data(self) -> tuple[list[str], np.ndarray]:
words: list[str] = [
"donatello",
"leonardo",
"raphael",
Expand Down Expand Up @@ -47,7 +48,7 @@ def test_normalize_vector(self) -> None:
def test_normalize_norm(self) -> None:
x = np.arange(10)
result = Reach.normalize(x)
result_2 = Reach.normalize(x, np.linalg.norm(x))
result_2 = Reach.normalize(x, cast(npt.NDArray, np.linalg.norm(x)))

self.assertTrue(np.allclose(result, result_2))

Expand Down Expand Up @@ -91,7 +92,7 @@ def test_ranking(self) -> None:
argsorted_matrix = np.flip(np.argsort(sim_matrix, axis=1), axis=1)[:, 1:]

for idx, w in enumerate(instance.items):
similar_words: List[Hashable] = [
similar_words: list[str] = [
x[0] for x in instance.most_similar([w], num=10)[0]
]
indices = [instance.items[word] for word in similar_words]
Expand Down Expand Up @@ -142,7 +143,7 @@ def test_threshold(self) -> None:

threshold = 0.0
for index, w in enumerate(instance.items):
above_threshold_1: List[Hashable] = [
above_threshold_1: list[str] = [
x[0] for x in instance.threshold([w], threshold=threshold)[0]
]
indices_1 = [instance.items[word] for word in above_threshold_1]
Expand All @@ -155,7 +156,7 @@ def test_threshold(self) -> None:

threshold = 0.9
for w in instance.items:
above_threshold_2: List[Hashable] = [
above_threshold_2: list[str] = [
x[0] for x in instance.threshold([w], threshold=threshold)[0]
]
indices_2 = [instance.items[word] for word in above_threshold_2]
Expand Down
Loading

0 comments on commit 994f66a

Please sign in to comment.