From a72d0e2b606bf77a8d8a9513d3361cd10e512ecb Mon Sep 17 00:00:00 2001 From: stephantul Date: Sun, 14 Jul 2024 21:31:36 +0200 Subject: [PATCH] Add legacy loading --- reach/legacy/__init__.py | 0 reach/legacy/load.py | 20 ++++++++++++++++++ reach/reach.py | 41 +++++++++++++++++++++++++----------- tests/test_io.py | 45 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 12 deletions(-) create mode 100644 reach/legacy/__init__.py create mode 100644 reach/legacy/load.py diff --git a/reach/legacy/__init__.py b/reach/legacy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/reach/legacy/load.py b/reach/legacy/load.py new file mode 100644 index 0000000..b151291 --- /dev/null +++ b/reach/legacy/load.py @@ -0,0 +1,20 @@ +import json +from pathlib import Path + +import numpy as np +import numpy.typing as npt + + +def load_old_fast_format_data( + path: Path, +) -> tuple[npt.NDArray, list[str], str | None, str]: + """Load data from fast format.""" + with open(f"{path}_items.json") as file_handle: + items = json.load(file_handle) + tokens, unk_index, name = items["items"], items["unk_index"], items["name"] + + with open(f"{path}_vectors.npy", "rb") as file_handle: + vectors = np.load(file_handle) + + unk_token = tokens[unk_index] if unk_index is not None else None + return vectors, tokens, unk_token, name diff --git a/reach/reach.py b/reach/reach.py index 5f89ebf..982c2c9 100644 --- a/reach/reach.py +++ b/reach/reach.py @@ -12,6 +12,9 @@ from numpy import typing as npt from tqdm import tqdm +from reach.legacy.load import load_old_fast_format_data + + Dtype: TypeAlias = str | np.dtype File = Path | TextIOWrapper PathLike = str | Path @@ -1041,20 +1044,34 @@ def load_fast_format( """ filename_path = Path(filename) - with open(filename) as file_handle: - data: dict[str, Any] = json.load(file_handle) - items: list[str] = data["items"] - metadata: dict[str, Any] = data["metadata"] - unk_token = metadata.pop("unk_token") - name = metadata.pop("name") - numpy_path = filename_path.parent / Path(data["vectors_path"]) - - if not numpy_path.exists(): - raise ValueError(f"Could not find the vectors file at {numpy_path}") + try: + with open(filename) as file_handle: + data: dict[str, Any] = json.load(file_handle) + items: list[str] = data["items"] + + metadata: dict[str, Any] = data["metadata"] + unk_token = metadata.pop("unk_token") + name = metadata.pop("name") + numpy_path = filename_path.parent / Path(data["vectors_path"]) + + if not numpy_path.exists(): + raise ValueError(f"Could not find the vectors file at {numpy_path}") + + with open(numpy_path, "rb") as file_handle: + vectors: npt.NDArray = np.load(file_handle) + except FileNotFoundError as exc: + logger.warning("Attempting to load from old format.") + try: + vectors, items, unk_token, name = load_old_fast_format_data( + filename_path + ) + metadata = {} + except FileNotFoundError: + logger.warning("Loading from old format failed") + # NOTE: reraise old exception. + raise exc - with open(numpy_path, "rb") as file_handle: - vectors: npt.NDArray = np.load(file_handle) vectors = vectors.astype(desired_dtype) instance = cls(vectors, items, name=name, metadata=metadata) instance.unk_token = unk_token diff --git a/tests/test_io.py b/tests/test_io.py index 9a00e67..d4c04a1 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,3 +1,4 @@ +import json import unittest from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory @@ -249,6 +250,50 @@ def test_save_load_fast_format(self) -> None: self.assertEqual(instance._unk_index, instance_2._unk_index) self.assertEqual(instance.name, instance_2.name) + def test_save_load_fast_format_old(self) -> None: + with TemporaryDirectory() as temp_folder: + lines = self.lines() + + temp_folder_path = Path(temp_folder) + + temp_file_name = temp_folder_path / "test.vec" + with open(temp_file_name, "w") as tempfile: + tempfile.write(lines) + tempfile.seek(0) + + instance = Reach.load(temp_file_name) + fast_format_file = temp_folder_path / "temp" + + items_dict = { + "items": instance.sorted_items, + "unk_index": instance._unk_index, + "name": instance.name, + } + + json.dump(items_dict, open(f"{fast_format_file}_items.json", "w")) + np.save(f"{fast_format_file}_vectors.npy", instance.vectors) + + instance_2 = Reach.load_fast_format(fast_format_file) + + self.assertEqual(instance.size, instance_2.size) + self.assertEqual(len(instance), len(instance_2)) + self.assertEqual(instance.items, instance_2.items) + self.assertTrue(np.allclose(instance.vectors, instance_2.vectors)) + self.assertEqual(instance._unk_index, instance_2._unk_index) + self.assertEqual(instance.name, instance_2.name) + + fast_format_file_2 = temp_folder_path / "temp.reach" + + instance.save_fast_format(fast_format_file_2) + instance_3 = Reach.load_fast_format(fast_format_file_2) + + self.assertEqual(instance.size, instance_3.size) + self.assertEqual(len(instance), len(instance_3)) + self.assertEqual(instance.items, instance_3.items) + self.assertTrue(np.allclose(instance.vectors, instance_3.vectors)) + self.assertEqual(instance._unk_index, instance_3._unk_index) + self.assertEqual(instance.name, instance_3.name) + def test_save_load(self) -> None: with NamedTemporaryFile("w+") as tempfile: lines = self.lines()