Skip to content

Commit

Permalink
Use fasttext via floret
Browse files Browse the repository at this point in the history
  • Loading branch information
juhoinkinen committed Jul 5, 2024
1 parent e6a20fc commit c6d23d8
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 14 deletions.
13 changes: 3 additions & 10 deletions annif/backend/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os.path
from typing import TYPE_CHECKING, Any

import fasttext
import floret

import annif.util
from annif.exception import NotInitializedException, NotSupportedException
Expand Down Expand Up @@ -65,14 +65,7 @@ def default_params(self) -> dict[str, Any]:

@staticmethod
def _load_model(path: str) -> _FastText:
# monkey patch fasttext.FastText.eprint to avoid spurious warning
# see https://github.com/facebookresearch/fastText/issues/1067
orig_eprint = fasttext.FastText.eprint
fasttext.FastText.eprint = lambda x: None
model = fasttext.load_model(path)
# restore the original eprint
fasttext.FastText.eprint = orig_eprint
return model
return floret.load_model(path)

def initialize(self, parallel: bool = False) -> None:
if self._model is None:
Expand Down Expand Up @@ -132,7 +125,7 @@ def _create_model(self, params: dict[str, Any], jobs: int) -> None:
if jobs != 0: # jobs set by user to non-default value
params["thread"] = jobs
self.debug("Model parameters: {}".format(params))
self._model = fasttext.train_supervised(trainpath, **params)
self._model = floret.train_supervised(trainpath, **params)
self._model.save_model(modelpath)

def _train(
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ simplemma = "0.9.*"
jsonschema = "4.21.*"
huggingface-hub = "0.22.*"

fasttext-wheel = { version = "0.9.2", optional = true }
floret = { version = "~0.10.5", optional = true }
voikko = { version = "0.5.*", optional = true }
tensorflow-cpu = { version = "2.15.*", optional = true, python = "<3.12" }
lmdb = { version = "1.4.1", optional = true }
Expand All @@ -71,7 +71,7 @@ isort = "*"
schemathesis = "3.*.*"

[tool.poetry.extras]
fasttext = ["fasttext-wheel"]
fasttext = ["floret"]
voikko = ["voikko"]
nn = ["tensorflow-cpu", "lmdb"]
omikuji = ["omikuji"]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def test_fill_params_with_defaults(project):


@pytest.mark.skipif(
importlib.util.find_spec("fasttext") is not None,
reason="test requires that fastText is NOT installed",
importlib.util.find_spec("floret") is not None,
reason="test requires that floret is NOT installed",
)
def test_get_backend_fasttext_not_installed():
with pytest.raises(ValueError) as excinfo:
Expand Down

0 comments on commit c6d23d8

Please sign in to comment.