diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f80e69a..d7c4930 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -79,64 +79,6 @@ jobs: # name: codecov-umbrella verbose: true - doctests: - # This action runs doctests for coverage collection and uploads them to codecov.io. - # This requires the secret `CODECOV_TOKEN` be set as secret on GitHub, both for - # Actions and Dependabot - - name: "${{ matrix.os }} / 3.8 / doctest" - strategy: - max-parallel: 4 - fail-fast: false - matrix: - os: [ubuntu] - - runs-on: ${{ matrix.os }}-latest - continue-on-error: true # allow failure until doctests are added - env: - OS: ${{ matrix.os }}-latest - steps: - - uses: actions/checkout@v4 - with: - submodules: true - - - name: Set up the environment - uses: pdm-project/setup-pdm@v4 - id: setup-python - with: - python-version: ${{ env.MINIMUM_PYTHON_VERSION }} - - - name: Load cached venv - id: cached-venv - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml') }}-${{ hashFiles('.github/workflows/test.yml') }} - - - name: Install dependencies - if: steps.cached-venv.outputs.cache-hit != 'true' - run: make install-dev - #---------------------------------------------- - # Run tests and upload coverage - #---------------------------------------------- - - name: make doc-tests - run: make doc-tests cov_report=xml - - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v4 - with: - token: ${{ secrets.CODECOV_TOKEN }} - # directory: ./coverage - env_vars: OS,PYTHON,TESTTYPE - fail_ci_if_error: true - # files: ./coverage/coverage.xml - # flags: unittests - # name: codecov-umbrella - verbose: true - env: - PYTHON: ${{ env.MINIMUM_PYTHON_VERSION }} - TESTTYPE: doctest - minimal: # This action chooses the oldest version of the dependencies permitted by Cargo.toml to ensure # that this crate is compatible with the minimal version that this crate and its dependencies @@ -169,7 +111,6 @@ jobs: if: always() needs: - coverage - - doctests - minimal runs-on: ubuntu-latest permissions: {} diff --git a/.gitignore b/.gitignore index 04f1bf2..5729ade 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,7 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ .pdm-python + +/examples/*/.snakemake +/examples/*/export +/examples/*/sparv-workdir diff --git a/Makefile b/Makefile index f189ca3..a8feaeb 100644 --- a/Makefile +++ b/Makefile @@ -50,7 +50,7 @@ help: @echo "" @echo "publish [branch=]" @echo " pushes the given branch including tags to origin, for CI to publish based on tags. (Default: branch='main')" - @echo " Typically used after `make bumpversion`" + @echo " Typically used after 'make bumpversion'" @echo "" @echo "prepare-release" @echo " run tasks to prepare a release" @@ -89,6 +89,11 @@ install-dev: install: pdm sync --prod +lock: pdm.lock + +pdm.lock: pyproject.toml + pdm lock + .PHONY: test test: ${INVENV} pytest -vv ${tests} @@ -142,20 +147,27 @@ publish: .PHONY: prepare-release -prepare-release: tests/requirements-testing.lock +prepare-release: update-changelog tests/requirements-testing.lock # we use lock extension so that dependabot doesn't pick up changes in this file -tests/requirements-testing.lock: pyproject.toml +tests/requirements-testing.lock: pyproject.toml pdm.lock pdm export --dev --format requirements --output $@ -.PHONY: kb-bert-prepare-release -sparv-sbx-sentence-sentiment-kb-sent-prepare-release: sparv-sbx-sentence-sentiment-kb-sent/CHANGELOG.md - +.PHONY: update-changelog update-changelog: CHANGELOG.md sparv-sbx-sentence-sentiment-kb-sent/CHANGELOG.md +.PHONY: CHANGELOG.md CHANGELOG.md: git cliff --unreleased --prepend $@ +# update snapshots for `syrupy` +.PHONY: snapshot-update +snapshot-update: + ${INVENV} pytest --snapshot-update + +.PHONY: kb-bert-prepare-release +sparv-sbx-sentence-sentiment-kb-sent-prepare-release: sparv-sbx-sentence-sentiment-kb-sent/CHANGELOG.md + .PHONY: sparv-sbx-sentence-sentiment-kb-sent/CHANGELOG.md sparv-sbx-sentence-sentiment-kb-sent/CHANGELOG.md: git cliff --unreleased --include-path "sparv-sbx-sentence-sentiment-kb-sent/**/*" --include-path "examples/sparv-sbx-sentence-sentiment-kb-sent/**/*" --prepend $@ diff --git a/examples/sparv-sbx-sentence-sentiment-kb-sent/config.yaml b/examples/sparv-sbx-sentence-sentiment-kb-sent/config.yaml new file mode 100644 index 0000000..35e742b --- /dev/null +++ b/examples/sparv-sbx-sentence-sentiment-kb-sent/config.yaml @@ -0,0 +1,16 @@ +metadata: + id: example-sparv-sbx-sentence-kb-sent + language: swe + +import: + importer: text_import:parse + +export: + annotations: + - + # - + - :stanza.pos + - :sbx_sentence_sentiment_kb_sent.sbx-sentence-sentiment--kb-sent + +sparv: + compression: none diff --git a/examples/sparv-sbx-sentence-sentiment-kb-sent/source/small.txt b/examples/sparv-sbx-sentence-sentiment-kb-sent/source/small.txt new file mode 120000 index 0000000..31ab14a --- /dev/null +++ b/examples/sparv-sbx-sentence-sentiment-kb-sent/source/small.txt @@ -0,0 +1 @@ +../../texts/small.txt \ No newline at end of file diff --git a/examples/texts/small.txt b/examples/texts/small.txt new file mode 100644 index 0000000..ee5df9e --- /dev/null +++ b/examples/texts/small.txt @@ -0,0 +1,3 @@ +Stora regnmängder väntas under måndagen och SMHI har utfärdat en gul varning för skyfallsliknande regn över stora delar av landets södra halva. +Jag är förvånad, chockad och bestört över det här, säger Anders Persson till SVT Nyheter Småland. +Vi hoppas och tror att detta också snabbt ska kunna komma på plats, säger Garborg. diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..f8584b4 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +python_version = 3.8 diff --git a/pdm.lock b/pdm.lock index 8e98e0f..8f4caee 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:b0567665804410c8a28af54a39f497ef8c1c9227852b339b988139390544dac2" +content_hash = "sha256:c392a5b1b02da98f72b705681a112add2c4f8aa8a120fab08d1b3a4941da1dd7" [[package]] name = "annotated-types" @@ -1850,28 +1850,28 @@ files = [ [[package]] name = "ruff" -version = "0.3.7" +version = "0.4.5" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." groups = ["dev"] files = [ - {file = "ruff-0.3.7-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:0e8377cccb2f07abd25e84fc5b2cbe48eeb0fea9f1719cad7caedb061d70e5ce"}, - {file = "ruff-0.3.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:15a4d1cc1e64e556fa0d67bfd388fed416b7f3b26d5d1c3e7d192c897e39ba4b"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d28bdf3d7dc71dd46929fafeec98ba89b7c3550c3f0978e36389b5631b793663"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:379b67d4f49774ba679593b232dcd90d9e10f04d96e3c8ce4a28037ae473f7bb"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c060aea8ad5ef21cdfbbe05475ab5104ce7827b639a78dd55383a6e9895b7c51"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:ebf8f615dde968272d70502c083ebf963b6781aacd3079081e03b32adfe4d58a"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d48098bd8f5c38897b03604f5428901b65e3c97d40b3952e38637b5404b739a2"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:da8a4fda219bf9024692b1bc68c9cff4b80507879ada8769dc7e985755d662ea"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c44e0149f1d8b48c4d5c33d88c677a4aa22fd09b1683d6a7ff55b816b5d074f"}, - {file = "ruff-0.3.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3050ec0af72b709a62ecc2aca941b9cd479a7bf2b36cc4562f0033d688e44fa1"}, - {file = "ruff-0.3.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a29cc38e4c1ab00da18a3f6777f8b50099d73326981bb7d182e54a9a21bb4ff7"}, - {file = "ruff-0.3.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5b15cc59c19edca917f51b1956637db47e200b0fc5e6e1878233d3a938384b0b"}, - {file = "ruff-0.3.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e491045781b1e38b72c91247cf4634f040f8d0cb3e6d3d64d38dcf43616650b4"}, - {file = "ruff-0.3.7-py3-none-win32.whl", hash = "sha256:bc931de87593d64fad3a22e201e55ad76271f1d5bfc44e1a1887edd0903c7d9f"}, - {file = "ruff-0.3.7-py3-none-win_amd64.whl", hash = "sha256:5ef0e501e1e39f35e03c2acb1d1238c595b8bb36cf7a170e7c1df1b73da00e74"}, - {file = "ruff-0.3.7-py3-none-win_arm64.whl", hash = "sha256:789e144f6dc7019d1f92a812891c645274ed08af6037d11fc65fcbc183b7d59f"}, - {file = "ruff-0.3.7.tar.gz", hash = "sha256:d5c1aebee5162c2226784800ae031f660c350e7a3402c4d1f8ea4e97e232e3ba"}, + {file = "ruff-0.4.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8f58e615dec58b1a6b291769b559e12fdffb53cc4187160a2fc83250eaf54e96"}, + {file = "ruff-0.4.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:84dd157474e16e3a82745d2afa1016c17d27cb5d52b12e3d45d418bcc6d49264"}, + {file = "ruff-0.4.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25f483ad9d50b00e7fd577f6d0305aa18494c6af139bce7319c68a17180087f4"}, + {file = "ruff-0.4.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:63fde3bf6f3ad4e990357af1d30e8ba2730860a954ea9282c95fc0846f5f64af"}, + {file = "ruff-0.4.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78e3ba4620dee27f76bbcad97067766026c918ba0f2d035c2fc25cbdd04d9c97"}, + {file = "ruff-0.4.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:441dab55c568e38d02bbda68a926a3d0b54f5510095c9de7f95e47a39e0168aa"}, + {file = "ruff-0.4.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1169e47e9c4136c997f08f9857ae889d614c5035d87d38fda9b44b4338909cdf"}, + {file = "ruff-0.4.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:755ac9ac2598a941512fc36a9070a13c88d72ff874a9781493eb237ab02d75df"}, + {file = "ruff-0.4.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f4b02a65985be2b34b170025a8b92449088ce61e33e69956ce4d316c0fe7cce0"}, + {file = "ruff-0.4.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:75a426506a183d9201e7e5664de3f6b414ad3850d7625764106f7b6d0486f0a1"}, + {file = "ruff-0.4.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:6e1b139b45e2911419044237d90b60e472f57285950e1492c757dfc88259bb06"}, + {file = "ruff-0.4.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a6f29a8221d2e3d85ff0c7b4371c0e37b39c87732c969b4d90f3dad2e721c5b1"}, + {file = "ruff-0.4.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d6ef817124d72b54cc923f3444828ba24fa45c3164bc9e8f1813db2f3d3a8a11"}, + {file = "ruff-0.4.5-py3-none-win32.whl", hash = "sha256:aed8166c18b1a169a5d3ec28a49b43340949e400665555b51ee06f22813ef062"}, + {file = "ruff-0.4.5-py3-none-win_amd64.whl", hash = "sha256:b0b03c619d2b4350b4a27e34fd2ac64d0dabe1afbf43de57d0f9d8a05ecffa45"}, + {file = "ruff-0.4.5-py3-none-win_arm64.whl", hash = "sha256:9d15de3425f53161b3f5a5658d4522e4eee5ea002bf2ac7aa380743dd9ad5fba"}, + {file = "ruff-0.4.5.tar.gz", hash = "sha256:286eabd47e7d4d521d199cab84deca135557e6d1e0f0d01c29e757c3cb151b54"}, ] [[package]] @@ -2076,7 +2076,7 @@ files = [ [[package]] name = "sparv-sbx-sentence-sentiment-kb-sent" version = "0.1.0" -requires_python = ">= 3.8,<3.12" +requires_python = ">= 3.8.1,<3.12" editable = true path = "./sparv-sbx-sentence-sentiment-kb-sent" summary = "A sparv plugin for computing word neighbours using a BERT model." @@ -2128,6 +2128,20 @@ files = [ {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, ] +[[package]] +name = "syrupy" +version = "4.6.1" +requires_python = ">=3.8.1,<4" +summary = "Pytest Snapshot Test Utility" +groups = ["dev"] +dependencies = [ + "pytest<9.0.0,>=7.0.0", +] +files = [ + {file = "syrupy-4.6.1-py3-none-any.whl", hash = "sha256:203e52f9cb9fa749cf683f29bd68f02c16c3bc7e7e5fe8f2fc59bdfe488ce133"}, + {file = "syrupy-4.6.1.tar.gz", hash = "sha256:37a835c9ce7857eeef86d62145885e10b3cb9615bc6abeb4ce404b3f18e1bb36"}, +] + [[package]] name = "tabulate" version = "0.9.0" diff --git a/pyproject.toml b/pyproject.toml index f5b6f46..3c5e706 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "sparv-sbx-sentence-sentiment-workspace" dependencies = [] -requires-python = ">=3.8,<3.12" +requires-python = ">=3.8.1,<3.12" version = "0.0.0" [tool.pdm.dev-dependencies] @@ -10,6 +10,7 @@ dev = [ "pytest>=8.1.1", "pytest-cov>=4.1.0", "mypy>=1.9.0", - "ruff>=0.3.2", + "ruff>=0.4.5", "bump-my-version>=0.19.0", + "syrupy>=4.6.1", ] diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..d743b6a --- /dev/null +++ b/ruff.toml @@ -0,0 +1,61 @@ +line-length = 97 + +target-version = "py38" + +[lint] +select = [ + "A", # flake8-builtins + "ANN", # flake8-annotations + "ARG", # flake8-unused-arguments + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "COM", # flake8-commas + "D", # pydocstyle + "D400", # pydocstyle: ends-in-period + "D401", # pydocstyle: non-imperative-mood + "E", # pycodestyle: errors + "F", # Pyflakes + "FLY", # flynt + "FURB", # refurb + "G", # flake8-logging-format + "I", # isort + "ISC", # flake8-implicit-str-concat + "N", # pep8-naming + "PERF", # Perflint + "PIE", # flake8-pie + "PL", # Pylint + # "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-raise + "RUF", # Ruff-specific rules + "SIM", # flake8-simplify + "T20", # flake8-print + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle: warnings +] +ignore = [ + "ANN101", # flake8-annotations: missing-type-self (deprecated) + "ANN102", # flake8-annotations: missing-type-cls (deprecated) + "ANN401", # flake8-annotations: any-type + "B008", # flake8-bugbear: function-call-in-default-argument + "ISC001", + "COM812", # flake8-commas: missing-trailing-comma + "PLR09", # Pylint: too-many-* + "SIM105", # flake8-simplify: suppressible-exception +] +preview = true + +# Avoid trying to fix flake8-bugbear (`B`) violations. +unfixable = ["B"] + + +[lint.pydocstyle] +convention = "google" + + +# Ignore `E402` (import violations) in all `__init__.py` files, and in `path/to/file.py`. +[lint.per-file-ignores] +"*/tests/*" = ["D", "ARG002", "E501"] diff --git a/sparv-sbx-sentence-sentiment-kb-sent/README.md b/sparv-sbx-sentence-sentiment-kb-sent/README.md index ee1fb4a..c8a69cb 100644 --- a/sparv-sbx-sentence-sentiment-kb-sent/README.md +++ b/sparv-sbx-sentence-sentiment-kb-sent/README.md @@ -13,7 +13,7 @@ Plugin for applying bert masking as a [Sparv](https://github.com/spraakbanken/sp ## Install -First, install Sparv, as suggested: +First, install Sparv as suggested: ```bash pipx install sparv-pipeline diff --git a/sparv-sbx-sentence-sentiment-kb-sent/pdm.lock b/sparv-sbx-sentence-sentiment-kb-sent/pdm.lock index 00414ea..cbea02c 100644 --- a/sparv-sbx-sentence-sentiment-kb-sent/pdm.lock +++ b/sparv-sbx-sentence-sentiment-kb-sent/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:e2f1b35f371e06c6f8984cbf664e879ec980dd33ff1fe50a42ac2750118d45bf" +content_hash = "sha256:ae5c879ee52fe6a0ae43754676bbb64a7960e079d7948483ac996fef40aab011" [[package]] name = "appdirs" @@ -1761,6 +1761,20 @@ files = [ {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, ] +[[package]] +name = "syrupy" +version = "4.6.1" +requires_python = ">=3.8.1,<4" +summary = "Pytest Snapshot Test Utility" +groups = ["dev"] +dependencies = [ + "pytest<9.0.0,>=7.0.0", +] +files = [ + {file = "syrupy-4.6.1-py3-none-any.whl", hash = "sha256:203e52f9cb9fa749cf683f29bd68f02c16c3bc7e7e5fe8f2fc59bdfe488ce133"}, + {file = "syrupy-4.6.1.tar.gz", hash = "sha256:37a835c9ce7857eeef86d62145885e10b3cb9615bc6abeb4ce404b3f18e1bb36"}, +] + [[package]] name = "tabulate" version = "0.9.0" diff --git a/sparv-sbx-sentence-sentiment-kb-sent/pyproject.toml b/sparv-sbx-sentence-sentiment-kb-sent/pyproject.toml index 7f65ce1..5d239af 100644 --- a/sparv-sbx-sentence-sentiment-kb-sent/pyproject.toml +++ b/sparv-sbx-sentence-sentiment-kb-sent/pyproject.toml @@ -3,12 +3,13 @@ name = "sparv-sbx-sentence-sentiment-kb-sent" version = "0.1.0" description = "A sparv plugin for computing word neighbours using a BERT model." authors = [ + { name = "Språkbanken Text", email = "sb-info@svenska.gu.se" }, { name = "Kristoffer Andersson", email = "kristoffer.andersson@gu.se" }, ] dependencies = ["sparv-pipeline >=5.2.0", "transformers>=4.34.1"] license = "MIT" readme = "README.md" -requires-python = ">= 3.8,<3.12" +requires-python = ">= 3.8.1,<3.12" classifiers = [ "Development Status :: 3 - Alpha", # "Development Status :: 4 - Beta", @@ -54,4 +55,7 @@ allow-direct-references = true [tool.pdm.dev-dependencies] -dev = ["pytest>=8.0.0"] +dev = [ + "pytest>=8.0.0", + "syrupy>=4.6.1", +] diff --git a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/__init__.py b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/__init__.py index 3dc1f76..3567771 100644 --- a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/__init__.py +++ b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/__init__.py @@ -1 +1,8 @@ +"""Sparv plugin for annotating sentences with sentiment analysis.""" + +from sbx_sentence_sentiment_kb_sent.annotations import annotate_sentence_sentiment + +__all__ = ["annotate_sentence_sentiment"] + +__description__ = "Annotate sentence with sentiment analysis." __version__ = "0.1.0" diff --git a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/annotations.py b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/annotations.py new file mode 100644 index 0000000..1cbd28c --- /dev/null +++ b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/annotations.py @@ -0,0 +1,32 @@ +"""Sparv annotator.""" + +from sparv import api as sparv_api # type: ignore [import-untyped] + +from sbx_sentence_sentiment_kb_sent.sentiment_analyzer import SentimentAnalyzer + +logger = sparv_api.get_logger(__name__) + + +@sparv_api.annotator("Sentiment analysis of sentences", language=["swe"]) +def annotate_sentence_sentiment( + out_sentence_sentiment: sparv_api.Output = sparv_api.Output( + ":sbx_sentence_sentiment_kb_sent.sbx-sentence-sentiment--kb-sent", + # cls="sbx_sentence_sentiment_kb_sent", + description="Sentiment analysis of sentence with KBLab/robust-swedish-sentiment-multiclass", # noqa: E501 + ), + word: sparv_api.Annotation = sparv_api.Annotation(""), + sentence: sparv_api.Annotation = sparv_api.Annotation(""), +) -> None: + """Sentiment analysis of sentence with KBLab/robust-swedish-sentiment-multiclass.""" + sentences, _orphans = sentence.get_children(word) + token_word = list(word.read()) + out_sentence_sentiment_annotation = sentence.create_empty_attribute() + + analyzer = SentimentAnalyzer.default() + + logger.progress(total=len(sentences)) # type: ignore + for sent_i, sent in enumerate(sentences): + sent_to_tag = [token_word[token_index] for token_index in sent] + out_sentence_sentiment_annotation[sent_i] = analyzer.analyze_sentence(sent_to_tag) + + out_sentence_sentiment.write(out_sentence_sentiment_annotation) diff --git a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/py.typed b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/sentiment_analyzer.py b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/sentiment_analyzer.py new file mode 100644 index 0000000..dbf9633 --- /dev/null +++ b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/sentiment_analyzer.py @@ -0,0 +1,102 @@ +"""Sentiment analyzer.""" + +from typing import List, Optional + +from sparv import api as sparv_api # type: ignore [import-untyped] +from transformers import ( # type: ignore [import-untyped] + AutoModelForSequenceClassification, + AutoTokenizer, + MegatronBertForSequenceClassification, + PreTrainedTokenizerFast, + pipeline, +) + +logger = sparv_api.get_logger(__name__) + +TOKENIZER_NAME = "KBLab/megatron-bert-large-swedish-cased-165k" +TOKENIZER_REVISION = "90c57ab49e27b820bd85308a488409dfea25600d" +MODEL_NAME = "KBLab/robust-swedish-sentiment-multiclass" +MODEL_REVISION = "b0ec32dca56aa6182a6955c8f12129bbcbc7fdbd" + +TOK_SEP = " " + + +class SentimentAnalyzer: + """Sentiment analyzer.""" + + def __init__( + self, + *, + tokenizer: PreTrainedTokenizerFast, + model: MegatronBertForSequenceClassification, + num_decimals: int = 3, + ) -> None: + """Create a SentimentAnalyzer using the given tokenizer and model. + + The given number of num_decimals works both as rounding and cut-off. + + Args: + tokenizer (PreTrainedTokenizerFast): the tokenizer to use + model (MegatronBertForSequenceClassification): the model to use + num_decimals (int): number of decimals to use (defaults to 3) + """ + logger.debug("type(tokenizer)=%s", type(tokenizer)) + logger.debug("type(model)=%s", type(model)) + self.tokenizer = tokenizer + self.model = model + self.num_decimals = num_decimals + self.classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer) + + @classmethod + def default(cls) -> "SentimentAnalyzer": + """Create a SentimentAnalyzer with default tokenizer and model. + + Returns: + SentimentAnalyzer: the create SentimentAnalyzer + """ + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, revision=TOKENIZER_REVISION) + model = AutoModelForSequenceClassification.from_pretrained( + MODEL_NAME, revision=MODEL_REVISION + ) + return cls(model=model, tokenizer=tokenizer) + + def analyze_sentence(self, text: List[str]) -> Optional[str]: + """Analyze a sentence. + + Args: + text (Iterable[str]): the text to analyze + + Returns: + List[Optional[str]]: the sentence annotations. + """ + sentence = TOK_SEP.join(text) + + classifications = self.classifier(sentence) + logger.debug("classifications=%s", classifications) + collect_label_and_score = ((clss["label"], clss["score"]) for clss in classifications) + score_format, score_pred = SCORE_FORMAT_AND_PREDICATE[self.num_decimals] + + format_scores = ( + (label, score_format.format(score)) for label, score in collect_label_and_score + ) + filter_out_zero_scores = ( + (label, score) for label, score in format_scores if not score_pred(score) + ) + classification_str = "|".join( + f"{label}:{score}" for label, score in filter_out_zero_scores + ) + return f"|{classification_str}|" if classification_str else "|" + + +SCORE_FORMAT_AND_PREDICATE = { + 1: ("{:.1f}", lambda s: s.endswith(".0")), + 2: ("{:.2f}", lambda s: s.endswith(".00")), + 3: ("{:.3f}", lambda s: s.endswith(".000")), + 4: ("{:.4f}", lambda s: s.endswith(".0000")), + 5: ("{:.5f}", lambda s: s.endswith(".00000")), + 6: ("{:.6f}", lambda s: s.endswith(".000000")), + 7: ("{:.7f}", lambda s: s.endswith(".0000000")), + 8: ("{:.8f}", lambda s: s.endswith(".00000000")), + 9: ("{:.9f}", lambda s: s.endswith(".000000000")), + 10: ("{:.10f}", lambda s: s.endswith(".0000000000")), +} diff --git a/sparv-sbx-sentence-sentiment-kb-sent/tests/__init__.py b/sparv-sbx-sentence-sentiment-kb-sent/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sparv-sbx-sentence-sentiment-kb-sent/tests/__snapshots__/test_annotations.ambr b/sparv-sbx-sentence-sentiment-kb-sent/tests/__snapshots__/test_annotations.ambr new file mode 100644 index 0000000..b0e948b --- /dev/null +++ b/sparv-sbx-sentence-sentiment-kb-sent/tests/__snapshots__/test_annotations.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_annotate_sentence_sentiment + list([ + '|POSITIVE:0.866|', + '|NEUTRAL:0.963|', + ]) +# --- diff --git a/sparv-sbx-sentence-sentiment-kb-sent/tests/requirements-testing.lock b/sparv-sbx-sentence-sentiment-kb-sent/tests/requirements-testing.lock index d95e4ed..d96bce1 100644 --- a/sparv-sbx-sentence-sentiment-kb-sent/tests/requirements-testing.lock +++ b/sparv-sbx-sentence-sentiment-kb-sent/tests/requirements-testing.lock @@ -817,6 +817,9 @@ stopit==1.1.2 \ sympy==1.12 \ --hash=sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5 \ --hash=sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8 +syrupy==4.6.1 \ + --hash=sha256:203e52f9cb9fa749cf683f29bd68f02c16c3bc7e7e5fe8f2fc59bdfe488ce133 \ + --hash=sha256:37a835c9ce7857eeef86d62145885e10b3cb9615bc6abeb4ce404b3f18e1bb36 tabulate==0.9.0 \ --hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \ --hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f diff --git a/sparv-sbx-sentence-sentiment-kb-sent/tests/test_annotations.py b/sparv-sbx-sentence-sentiment-kb-sent/tests/test_annotations.py new file mode 100644 index 0000000..b39a102 --- /dev/null +++ b/sparv-sbx-sentence-sentiment-kb-sent/tests/test_annotations.py @@ -0,0 +1,17 @@ +from sbx_sentence_sentiment_kb_sent.annotations import annotate_sentence_sentiment + +from tests.testing import MemoryOutput, MockAnnotation + + +def test_annotate_sentence_sentiment(snapshot) -> None: # noqa: ANN001 + output: MemoryOutput = MemoryOutput() + word = MockAnnotation( + name="", values=["Han", "var", "glad", ".", "Rihanna", "uppges", "gravid", "."] + ) + sentence = MockAnnotation( + name="", children={"": [[0, 1, 2, 3], [4, 5, 6, 7]]} + ) + + annotate_sentence_sentiment(output, word, sentence) + + assert output.values == snapshot diff --git a/sparv-sbx-sentence-sentiment-kb-sent/tests/test_sentiment_analyzer.py b/sparv-sbx-sentence-sentiment-kb-sent/tests/test_sentiment_analyzer.py new file mode 100644 index 0000000..7ba7775 --- /dev/null +++ b/sparv-sbx-sentence-sentiment-kb-sent/tests/test_sentiment_analyzer.py @@ -0,0 +1,29 @@ +from typing import Optional + +import pytest +from sbx_sentence_sentiment_kb_sent.sentiment_analyzer import SentimentAnalyzer + + +@pytest.fixture(name="sentiment_analyzer", scope="session") +def fixture_sentiment_analyzer() -> SentimentAnalyzer: + return SentimentAnalyzer.default() + + +def test_neutral_text(sentiment_analyzer: SentimentAnalyzer) -> None: + text = "Vi hoppas och tror att detta också snabbt ska kunna komma på plats , säger Garborg .".split( + " " + ) + + actual = sentiment_analyzer.analyze_sentence(text) + + actual = remove_scores(actual) + expected = "|NEUTRAL|" + + assert actual == expected + + +def remove_scores(actual: Optional[str]) -> Optional[str]: + """Remove scores.""" + if not actual: + return actual + return "|".join(x.split(":")[0] for x in actual.split("|")) diff --git a/sparv-sbx-sentence-sentiment-kb-sent/tests/test_version.py b/sparv-sbx-sentence-sentiment-kb-sent/tests/test_version.py deleted file mode 100644 index 3e2edb0..0000000 --- a/sparv-sbx-sentence-sentiment-kb-sent/tests/test_version.py +++ /dev/null @@ -1,5 +0,0 @@ -import sbx_sentence_sentiment_kb_sent - - -def test_version() -> None: - assert sbx_sentence_sentiment_kb_sent.__version__ == "0.1.0" diff --git a/sparv-sbx-sentence-sentiment-kb-sent/tests/testing.py b/sparv-sbx-sentence-sentiment-kb-sent/tests/testing.py new file mode 100644 index 0000000..6eec252 --- /dev/null +++ b/sparv-sbx-sentence-sentiment-kb-sent/tests/testing.py @@ -0,0 +1,69 @@ +from typing import Dict, Generator, Generic, List, Optional, Tuple, TypeVar + +from sparv.api.classes import Annotation, BaseAnnotation, Output # type: ignore [import-untyped] +from sparv.core import log_handler # type: ignore [import-untyped] # noqa: F401 + + +class MockAnnotation(Annotation): + def __init__( + self, + name: str = "", + source_file: Optional[str] = None, + values: Optional[List[str]] = None, + children: Optional[Dict[str, List[List[int]]]] = None, + ) -> None: + super().__init__(name) + self._values = values or [] + self._children = children or {} + + def read(self, allow_newlines: bool = False) -> Generator[str, None, None]: + """Yield each line from the annotation.""" + if not self._values: + return + yield from self._values + + def get_children( + self, + child: BaseAnnotation, + *, + orphan_alert: bool = False, + preserve_parent_annotation_order: bool = False, + ) -> Tuple[List, List]: + """Return two lists. + + The first one is a list with n (= total number of parents) elements where every element is a list + of indices in the child annotation. + The second one is a list of orphans, i.e. containing indices in the child annotation that have no parent. + Both parents and children are sorted according to their position in the source file, unless + preserve_parent_annotation_order is set to True, in which case the parents keep the order from the parent + annotation. + """ + return self._children[child.name], [] + + def create_empty_attribute(self) -> List: + return [None] * max(len(val) for val in self._children.values()) + + +T = TypeVar("T") + + +class MemoryOutput(Output, Generic[T]): + def __init__(self) -> None: + self.values: List[T] = [] + + def write( + self, + values: List[T], + *, + append: bool = False, + allow_newlines: bool = False, + source_file: Optional[str] = None, + ) -> None: + """Write an annotation to file. Existing annotation will be overwritten. + + 'values' should be a list of values. + """ + if append: + self.values.extend(values) + else: + self.values = values