diff --git a/.github/ISSUE_TEMPLATE/1_bug_report.yaml b/.github/ISSUE_TEMPLATE/1_bug_report.yaml
index 96120764..c0f48ede 100644
--- a/.github/ISSUE_TEMPLATE/1_bug_report.yaml
+++ b/.github/ISSUE_TEMPLATE/1_bug_report.yaml
@@ -38,4 +38,4 @@ body:
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
- render: shell
\ No newline at end of file
+ render: shell
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f7d371e4..c4d87bbe 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,14 +1,6 @@
fail_fast: true
repos:
- - repo: local
- hooks:
- - id: ruff
- name: ruff
- files: \.py$
- entry: ruff
- language: python
- exclude: ^(tests|scripts|notebooks|quadra/utils/tests)/
- repo: local
hooks:
- id: jupyter-nb-clear-output
@@ -23,47 +15,16 @@ repos:
- id: prettier
name: (prettier) Reformat YAML files with prettier
types: [yaml]
-
- - repo: https://github.com/myint/autoflake
- rev: v1.4
- hooks:
- - id: autoflake
- name: Remove unused variables and imports
- language: python
- entry: autoflake
- types: [python]
- args:
- [
- "--in-place",
- "--remove-all-unused-imports",
- "--remove-unused-variables",
- "--expand-star-imports",
- "--ignore-init-module-imports",
- ]
- files: \.py$
- - repo: local
- hooks:
- - id: isort
- name: (isort) Sorting import statements
- args: [--settings-path=pyproject.toml]
- language: python
- types: [python]
- files: \.py$
- entry: isort
- - repo: local
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.4.1
hooks:
- - id: black
- name: (black) Format Python code
- args: [--config=pyproject.toml]
- language: python
- types: [python]
- entry: black
- - id: black-jupyter
- name: (black-jupyter) Format jupyter notebooks
- args: [--config=pyproject.toml]
- language: python
- types: [jupyter]
- entry: black
+ - id: ruff
+ types_or: [python, pyi, jupyter]
+ args: ["--config", "pyproject.toml", "--fix"]
+ exclude: ^(tests|scripts|notebooks|quadra/utils/tests)/
+ - id: ruff-format
+ types_or: [python, pyi, jupyter]
+ args: ["--config", "pyproject.toml"]
- repo: local
hooks:
- id: pylint
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8b3df508..7296d4e5 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,17 @@
# Changelog
All notable changes to this project will be documented in this file.
+### [2.1.4]
+
+#### Updated
+
+- Remove black from pre-commit hooks
+- Use ruff as the main formatter and linting tool
+- Upgrade mypy version
+- Upgrade mlflow version
+- Apply new pre-commits to all tests
+- Update most of typing to py310 style
+
### [2.1.3]
#### Updated
diff --git a/README.md b/README.md
index 5dcab234..20b5e713 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,11 @@
@@ -251,7 +247,7 @@ First clone the repository from `Github`, then we need to install the package wi
Now you can start developing and the pre-commit hooks will run automatically to prevent you from committing code that does not pass the linting and formatting checks.
-We rely on a combination of `Black`, `Pylint`, `Mypy`, `Ruff` and `Isort` to enforce code quality.
+We rely on a combination of `Pylint`, `Mypy` and `Ruff` to enforce code quality.
## Building Documentations
@@ -273,7 +269,7 @@ This project is based on many open-source libraries and frameworks, we would lik
- Documentation website is using [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/) and [MkDocs](https://www.mkdocs.org/). For code documentation we are using [Mkdocstrings](https://mkdocstrings.github.io/). For releasing software versions we combine [Bumpver](https://github.com/mbarkhau/bumpver) and [Mike](https://github.com/jimporter/mike).
- Models can be exported in different ways (`torchscript` or `torch` file). We have also added [ONNX](https://onnx.ai/) support for some models.
- Testing framework is based on [Pytest](https://docs.pytest.org/en/) and related plug-ins.
-- Code quality is ensured by [pre-commit](https://pre-commit.com/) hooks. We are using [Black](https://github.com/psf/black) for formatting, [Pylint](https://www.pylint.org/) for linting, [Mypy](https://mypy.readthedocs.io/en/stable/) for type checking, [Isort](https://pycqa.github.io/isort/) for sorting imports, and [Ruff](https://github.com/astral-sh/ruff) for checking futher code and documentation quality.
+- Code quality is ensured by [pre-commit](https://pre-commit.com/) hooks. We are using [Ruff](https://github.com/astral-sh/ruff) for linting, enforcing code quality and formatting, [Pylint](https://www.pylint.org/) for in depth linting and [Mypy](https://mypy.readthedocs.io/en/stable/) for type checking.
## FAQ
diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py
index f06ea3c9..cb1c3ff6 100644
--- a/docs/gen_ref_pages.py
+++ b/docs/gen_ref_pages.py
@@ -17,7 +17,7 @@ def init_file_imports(init_file: Path) -> bool:
Returns:
True if the file imports anything, False otherwise.
"""
- with open(init_file, "r") as fd:
+ with open(init_file) as fd:
return any(line.startswith("__all__") for line in fd.readlines())
@@ -61,7 +61,7 @@ def add_submodules_as_list(parent_folder: Path) -> str:
full_doc_path = full_doc_path.with_name("index.md")
readme_file = path.parent / "README.md"
if readme_file.exists():
- with open(readme_file, "r") as rfd:
+ with open(readme_file) as rfd:
MARKDOWN_CONTENTS += f"{rfd.read()}\n"
if not init_file_imports(path):
MARKDOWN_CONTENTS += add_submodules_as_list(path.parent)
@@ -76,14 +76,14 @@ def add_submodules_as_list(parent_folder: Path) -> str:
mkdocs_gen_files.set_edit_path(full_doc_path, path)
-with open(Path("README.md"), "r") as read_fd:
+with open(Path("README.md")) as read_fd:
readme = read_fd.read()
readme = readme.replace("# Quadra: Deep Learning Experiment Orchestration Library", "# Home")
readme = readme.replace("docs/", "")
with mkdocs_gen_files.open("getting_started.md", "w") as nav_file: # (2)
nav_file.write(readme)
-with open("CHANGELOG.md", "r") as change_fd:
+with open("CHANGELOG.md") as change_fd:
changelog = change_fd.read()
changelog = changelog.replace("All notable changes to this project will be documented in this file.", "")
with mkdocs_gen_files.open("reference/CHANGELOG.md", "w") as nav_file: # (2)
diff --git a/poetry.lock b/poetry.lock
index febd57db..b98fed53 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -187,7 +187,7 @@ test = ["flake8 (==3.7.9)", "mock (==2.0.0)", "pylint (==1.9.3)"]
[[package]]
name = "anomalib"
-version = "0.7.0+obx.1.3.1"
+version = "0.7.0+obx.1.3.2"
description = "anomalib - Anomaly Detection Library"
optional = false
python-versions = ">=3.7"
@@ -382,52 +382,6 @@ files = [
tests = ["pytest (>=3.2.1,!=3.3.0)"]
typecheck = ["mypy"]
-[[package]]
-name = "black"
-version = "22.12.0"
-description = "The uncompromising code formatter."
-optional = false
-python-versions = ">=3.7"
-files = [
- {file = "black-22.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eedd20838bd5d75b80c9f5487dbcb06836a43833a37846cf1d8c1cc01cef59d"},
- {file = "black-22.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:159a46a4947f73387b4d83e87ea006dbb2337eab6c879620a3ba52699b1f4351"},
- {file = "black-22.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d30b212bffeb1e252b31dd269dfae69dd17e06d92b87ad26e23890f3efea366f"},
- {file = "black-22.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:7412e75863aa5c5411886804678b7d083c7c28421210180d67dfd8cf1221e1f4"},
- {file = "black-22.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c116eed0efb9ff870ded8b62fe9f28dd61ef6e9ddd28d83d7d264a38417dcee2"},
- {file = "black-22.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:1f58cbe16dfe8c12b7434e50ff889fa479072096d79f0a7f25e4ab8e94cd8350"},
- {file = "black-22.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77d86c9f3db9b1bf6761244bc0b3572a546f5fe37917a044e02f3166d5aafa7d"},
- {file = "black-22.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:82d9fe8fee3401e02e79767016b4907820a7dc28d70d137eb397b92ef3cc5bfc"},
- {file = "black-22.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101c69b23df9b44247bd88e1d7e90154336ac4992502d4197bdac35dd7ee3320"},
- {file = "black-22.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:559c7a1ba9a006226f09e4916060982fd27334ae1998e7a38b3f33a37f7a2148"},
- {file = "black-22.12.0-py3-none-any.whl", hash = "sha256:436cc9167dd28040ad90d3b404aec22cedf24a6e4d7de221bec2730ec0c97bcf"},
- {file = "black-22.12.0.tar.gz", hash = "sha256:229351e5a18ca30f447bf724d007f890f97e13af070bb6ad4c0a441cd7596a2f"},
-]
-
-[package.dependencies]
-click = ">=8.0.0"
-mypy-extensions = ">=0.4.3"
-pathspec = ">=0.9.0"
-platformdirs = ">=2"
-tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""}
-typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}
-
-[package.extras]
-colorama = ["colorama (>=0.4.3)"]
-d = ["aiohttp (>=3.7.4)"]
-jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
-uvloop = ["uvloop (>=0.15.2)"]
-
-[[package]]
-name = "blinker"
-version = "1.7.0"
-description = "Fast, simple object-to-object and broadcast signaling"
-optional = false
-python-versions = ">=3.8"
-files = [
- {file = "blinker-1.7.0-py3-none-any.whl", hash = "sha256:c3f865d4d54db7abc53758a01601cf343fe55b84c1de4e3fa910e420b438d5b9"},
- {file = "blinker-1.7.0.tar.gz", hash = "sha256:e6820ff6fa4e4d1d8e2747c2283749c3f547e4fee112b98555cdcdae32996182"},
-]
-
[[package]]
name = "boto3"
version = "1.26.165"
@@ -1120,26 +1074,6 @@ files = [
docs = ["ipython", "matplotlib", "numpydoc", "sphinx"]
tests = ["pytest", "pytest-cov", "pytest-xdist"]
-[[package]]
-name = "databricks-cli"
-version = "0.18.0"
-description = "A command line interface for Databricks"
-optional = false
-python-versions = ">=3.7"
-files = [
- {file = "databricks-cli-0.18.0.tar.gz", hash = "sha256:87569709eda9af3e9db8047b691e420b5e980c62ef01675575c0d2b9b4211eb7"},
- {file = "databricks_cli-0.18.0-py2.py3-none-any.whl", hash = "sha256:1176a5f42d3e8af4abfc915446fb23abc44513e325c436725f5898cbb9e3384b"},
-]
-
-[package.dependencies]
-click = ">=7.0"
-oauthlib = ">=3.1.0"
-pyjwt = ">=1.7.0"
-requests = ">=2.17.3"
-six = ">=1.10.0"
-tabulate = ">=0.7.7"
-urllib3 = ">=1.26.7,<3"
-
[[package]]
name = "defusedxml"
version = "0.7.1"
@@ -1176,27 +1110,6 @@ files = [
{file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"},
]
-[[package]]
-name = "docker"
-version = "6.1.3"
-description = "A Python library for the Docker Engine API."
-optional = false
-python-versions = ">=3.7"
-files = [
- {file = "docker-6.1.3-py3-none-any.whl", hash = "sha256:aecd2277b8bf8e506e484f6ab7aec39abe0038e29fa4a6d3ba86c3fe01844ed9"},
- {file = "docker-6.1.3.tar.gz", hash = "sha256:aa6d17830045ba5ef0168d5eaa34d37beeb113948c413affe1d5991fc11f9a20"},
-]
-
-[package.dependencies]
-packaging = ">=14.0"
-pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""}
-requests = ">=2.26.0"
-urllib3 = ">=1.26.0"
-websocket-client = ">=0.32.0"
-
-[package.extras]
-ssh = ["paramiko (>=2.4.3)"]
-
[[package]]
name = "docker-pycreds"
version = "0.4.0"
@@ -1428,29 +1341,6 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
typing = ["typing-extensions (>=4.8)"]
-[[package]]
-name = "flask"
-version = "2.3.3"
-description = "A simple framework for building complex web applications."
-optional = false
-python-versions = ">=3.8"
-files = [
- {file = "flask-2.3.3-py3-none-any.whl", hash = "sha256:f69fcd559dc907ed196ab9df0e48471709175e696d6e698dd4dbe940f96ce66b"},
- {file = "flask-2.3.3.tar.gz", hash = "sha256:09c347a92aa7ff4a8e7f3206795f30d826654baf38b873d0744cd571ca609efc"},
-]
-
-[package.dependencies]
-blinker = ">=1.6.2"
-click = ">=8.1.3"
-importlib-metadata = {version = ">=3.6.0", markers = "python_version < \"3.10\""}
-itsdangerous = ">=2.1.2"
-Jinja2 = ">=3.1.2"
-Werkzeug = ">=2.3.7"
-
-[package.extras]
-async = ["asgiref (>=3.2)"]
-dotenv = ["python-dotenv"]
-
[[package]]
name = "flatbuffers"
version = "23.5.26"
@@ -1958,26 +1848,6 @@ files = [
[package.extras]
protobuf = ["grpcio-tools (>=1.60.0)"]
-[[package]]
-name = "gunicorn"
-version = "20.1.0"
-description = "WSGI HTTP Server for UNIX"
-optional = false
-python-versions = ">=3.5"
-files = [
- {file = "gunicorn-20.1.0-py3-none-any.whl", hash = "sha256:9dcc4547dbb1cb284accfb15ab5667a0e5d1881cc443e0677b4882a4067a807e"},
- {file = "gunicorn-20.1.0.tar.gz", hash = "sha256:e0a968b5ba15f8a328fdfd7ab1fcb5af4470c28aaf7e55df02a99bc13138e6e8"},
-]
-
-[package.dependencies]
-setuptools = ">=3.0"
-
-[package.extras]
-eventlet = ["eventlet (>=0.24.1)"]
-gevent = ["gevent (>=1.4.0)"]
-setproctitle = ["setproctitle"]
-tornado = ["tornado (>=0.2)"]
-
[[package]]
name = "h11"
version = "0.14.0"
@@ -2391,17 +2261,6 @@ pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"
plugins = ["setuptools"]
requirements-deprecated-finder = ["pip-api", "pipreqs"]
-[[package]]
-name = "itsdangerous"
-version = "2.1.2"
-description = "Safely pass data to untrusted environments and back."
-optional = false
-python-versions = ">=3.7"
-files = [
- {file = "itsdangerous-2.1.2-py3-none-any.whl", hash = "sha256:2c2349112351b88699d8d4b6b075022c0808887cb7ad10069318a8b0bc88db44"},
- {file = "itsdangerous-2.1.2.tar.gz", hash = "sha256:5dbbc68b317e5e42f327f9021763545dc3fc3bfe22e6deb96aaf1fc38874156a"},
-]
-
[[package]]
name = "jaraco-classes"
version = "3.3.0"
@@ -3399,85 +3258,38 @@ files = [
griffe = ">=0.35"
mkdocstrings = ">=0.20"
-[[package]]
-name = "mlflow"
-version = "2.3.1"
-description = "MLflow: A Platform for ML Development and Productionization"
-optional = false
-python-versions = ">=3.8"
-files = [
- {file = "mlflow-2.3.1-py3-none-any.whl", hash = "sha256:699c512d659c7463a498e087c5f74d3d139b5708cf6aaaccfa398d7b0c095204"},
- {file = "mlflow-2.3.1.tar.gz", hash = "sha256:63439397b2718ce5747288ef5475f46b3716b370a517be3e3c67b799a247a186"},
-]
-
-[package.dependencies]
-alembic = "<1.10.0 || >1.10.0,<2"
-click = ">=7.0,<9"
-cloudpickle = "<3"
-databricks-cli = ">=0.8.7,<1"
-docker = ">=4.0.0,<7"
-entrypoints = "<1"
-Flask = "<3"
-gitpython = ">=2.1.0,<4"
-gunicorn = {version = "<21", markers = "platform_system != \"Windows\""}
-importlib-metadata = ">=3.7.0,<4.7.0 || >4.7.0,<7"
-Jinja2 = [
- {version = ">=2.11,<4", markers = "platform_system != \"Windows\""},
- {version = ">=3.0,<4", markers = "platform_system == \"Windows\""},
-]
-markdown = ">=3.3,<4"
-matplotlib = "<4"
-numpy = "<2"
-packaging = "<24"
-pandas = "<3"
-protobuf = ">=3.12.0,<5"
-pyarrow = ">=4.0.0,<12"
-pytz = "<2024"
-pyyaml = ">=5.1,<7"
-querystring-parser = "<2"
-requests = ">=2.17.3,<3"
-scikit-learn = "<2"
-scipy = "<2"
-sqlalchemy = ">=1.4.0,<3"
-sqlparse = ">=0.4.0,<1"
-waitress = {version = "<3", markers = "platform_system == \"Windows\""}
-
-[package.extras]
-aliyun-oss = ["aliyunstoreplugin"]
-databricks = ["azure-storage-file-datalake (>12)", "boto3 (>1)", "google-cloud-storage (>=1.30.0)"]
-extras = ["azureml-core (>=1.2.0)", "boto3", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=1.2.0,!=1.3.1)", "mlserver-mlflow (>=1.2.0,!=1.3.1)", "prometheus-flask-exporter", "pyarrow", "pysftp", "requests-auth-aws-sigv4", "virtualenv"]
-sqlserver = ["mlflow-dbstore"]
-
[[package]]
name = "mlflow-skinny"
-version = "2.3.1"
-description = "MLflow: A Platform for ML Development and Productionization"
+version = "2.12.1"
+description = "MLflow is an open source platform for the complete machine learning lifecycle"
optional = false
python-versions = ">=3.8"
files = [
- {file = "mlflow-skinny-2.3.1.tar.gz", hash = "sha256:0ee64e7119028f5cd5b5c7204424b41c003f833b9e29ff9d21ef67d60ee655ab"},
- {file = "mlflow_skinny-2.3.1-py3-none-any.whl", hash = "sha256:33ba9668ff027af8ef865ecaf9f4984e113d46e2a9bcd93b38b326c237e5b13c"},
+ {file = "mlflow_skinny-2.12.1-py3-none-any.whl", hash = "sha256:51539de93a7f8b74b2d6b4307f204235469f6fadd19078599abd12a4bcf6b5b0"},
+ {file = "mlflow_skinny-2.12.1.tar.gz", hash = "sha256:a5c3bb2f111867db988d4cdd782b6224fca4fc3d86706e9e587c379973ce7353"},
]
[package.dependencies]
click = ">=7.0,<9"
-cloudpickle = "<3"
-databricks-cli = ">=0.8.7,<1"
+cloudpickle = "<4"
entrypoints = "<1"
-gitpython = ">=2.1.0,<4"
-importlib-metadata = ">=3.7.0,<4.7.0 || >4.7.0,<7"
-packaging = "<24"
-protobuf = ">=3.12.0,<5"
-pytz = "<2024"
+gitpython = ">=3.1.9,<4"
+importlib-metadata = ">=3.7.0,<4.7.0 || >4.7.0,<8"
+packaging = "<25"
+protobuf = ">=3.12.0,<6"
+pytz = "<2025"
pyyaml = ">=5.1,<7"
requests = ">=2.17.3,<3"
sqlparse = ">=0.4.0,<1"
[package.extras]
aliyun-oss = ["aliyunstoreplugin"]
-databricks = ["azure-storage-file-datalake (>12)", "boto3 (>1)", "google-cloud-storage (>=1.30.0)"]
-extras = ["azureml-core (>=1.2.0)", "boto3", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=1.2.0,!=1.3.1)", "mlserver-mlflow (>=1.2.0,!=1.3.1)", "prometheus-flask-exporter", "pyarrow", "pysftp", "requests-auth-aws-sigv4", "virtualenv"]
+databricks = ["azure-storage-file-datalake (>12)", "boto3 (>1)", "botocore", "google-cloud-storage (>=1.30.0)"]
+extras = ["azureml-core (>=1.2.0)", "boto3", "botocore", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=1.2.0,!=1.3.1,<1.4.0)", "mlserver-mlflow (>=1.2.0,!=1.3.1,<1.4.0)", "prometheus-flask-exporter", "pyarrow", "pysftp", "requests-auth-aws-sigv4", "virtualenv"]
+gateway = ["aiohttp (<4)", "boto3 (>=1.28.56,<2)", "fastapi (<1)", "pydantic (>=1.0,<3)", "slowapi (>=0.1.9,<1)", "tiktoken (<1)", "uvicorn[standard] (<1)", "watchfiles (<1)"]
+genai = ["aiohttp (<4)", "boto3 (>=1.28.56,<2)", "fastapi (<1)", "pydantic (>=1.0,<3)", "slowapi (>=0.1.9,<1)", "tiktoken (<1)", "uvicorn[standard] (<1)", "watchfiles (<1)"]
sqlserver = ["mlflow-dbstore"]
+xethub = ["mlflow-xethub"]
[[package]]
name = "monotonic"
@@ -3683,48 +3495,49 @@ yaml = ["PyYAML (>=5.1.0)"]
[[package]]
name = "mypy"
-version = "1.0.1"
+version = "1.10.0"
description = "Optional static typing for Python"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "mypy-1.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:71a808334d3f41ef011faa5a5cd8153606df5fc0b56de5b2e89566c8093a0c9a"},
- {file = "mypy-1.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:920169f0184215eef19294fa86ea49ffd4635dedfdea2b57e45cb4ee85d5ccaf"},
- {file = "mypy-1.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27a0f74a298769d9fdc8498fcb4f2beb86f0564bcdb1a37b58cbbe78e55cf8c0"},
- {file = "mypy-1.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:65b122a993d9c81ea0bfde7689b3365318a88bde952e4dfa1b3a8b4ac05d168b"},
- {file = "mypy-1.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:5deb252fd42a77add936b463033a59b8e48eb2eaec2976d76b6878d031933fe4"},
- {file = "mypy-1.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2013226d17f20468f34feddd6aae4635a55f79626549099354ce641bc7d40262"},
- {file = "mypy-1.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:48525aec92b47baed9b3380371ab8ab6e63a5aab317347dfe9e55e02aaad22e8"},
- {file = "mypy-1.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c96b8a0c019fe29040d520d9257d8c8f122a7343a8307bf8d6d4a43f5c5bfcc8"},
- {file = "mypy-1.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:448de661536d270ce04f2d7dddaa49b2fdba6e3bd8a83212164d4174ff43aa65"},
- {file = "mypy-1.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:d42a98e76070a365a1d1c220fcac8aa4ada12ae0db679cb4d910fabefc88b994"},
- {file = "mypy-1.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e64f48c6176e243ad015e995de05af7f22bbe370dbb5b32bd6988438ec873919"},
- {file = "mypy-1.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fdd63e4f50e3538617887e9aee91855368d9fc1dea30da743837b0df7373bc4"},
- {file = "mypy-1.0.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:dbeb24514c4acbc78d205f85dd0e800f34062efcc1f4a4857c57e4b4b8712bff"},
- {file = "mypy-1.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a2948c40a7dd46c1c33765718936669dc1f628f134013b02ff5ac6c7ef6942bf"},
- {file = "mypy-1.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5bc8d6bd3b274dd3846597855d96d38d947aedba18776aa998a8d46fabdaed76"},
- {file = "mypy-1.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:17455cda53eeee0a4adb6371a21dd3dbf465897de82843751cf822605d152c8c"},
- {file = "mypy-1.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e831662208055b006eef68392a768ff83596035ffd6d846786578ba1714ba8f6"},
- {file = "mypy-1.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e60d0b09f62ae97a94605c3f73fd952395286cf3e3b9e7b97f60b01ddfbbda88"},
- {file = "mypy-1.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:0af4f0e20706aadf4e6f8f8dc5ab739089146b83fd53cb4a7e0e850ef3de0bb6"},
- {file = "mypy-1.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:24189f23dc66f83b839bd1cce2dfc356020dfc9a8bae03978477b15be61b062e"},
- {file = "mypy-1.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93a85495fb13dc484251b4c1fd7a5ac370cd0d812bbfc3b39c1bafefe95275d5"},
- {file = "mypy-1.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f546ac34093c6ce33f6278f7c88f0f147a4849386d3bf3ae193702f4fe31407"},
- {file = "mypy-1.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c6c2ccb7af7154673c591189c3687b013122c5a891bb5651eca3db8e6c6c55bd"},
- {file = "mypy-1.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b5a824b58c7c822c51bc66308e759243c32631896743f030daf449fe3677f3"},
- {file = "mypy-1.0.1-py3-none-any.whl", hash = "sha256:eda5c8b9949ed411ff752b9a01adda31afe7eae1e53e946dbdf9db23865e66c4"},
- {file = "mypy-1.0.1.tar.gz", hash = "sha256:28cea5a6392bb43d266782983b5a4216c25544cd7d80be681a155ddcdafd152d"},
-]
-
-[package.dependencies]
-mypy-extensions = ">=0.4.3"
+ {file = "mypy-1.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:da1cbf08fb3b851ab3b9523a884c232774008267b1f83371ace57f412fe308c2"},
+ {file = "mypy-1.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:12b6bfc1b1a66095ab413160a6e520e1dc076a28f3e22f7fb25ba3b000b4ef99"},
+ {file = "mypy-1.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e36fb078cce9904c7989b9693e41cb9711e0600139ce3970c6ef814b6ebc2b2"},
+ {file = "mypy-1.10.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2b0695d605ddcd3eb2f736cd8b4e388288c21e7de85001e9f85df9187f2b50f9"},
+ {file = "mypy-1.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:cd777b780312ddb135bceb9bc8722a73ec95e042f911cc279e2ec3c667076051"},
+ {file = "mypy-1.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3be66771aa5c97602f382230165b856c231d1277c511c9a8dd058be4784472e1"},
+ {file = "mypy-1.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8b2cbaca148d0754a54d44121b5825ae71868c7592a53b7292eeb0f3fdae95ee"},
+ {file = "mypy-1.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ec404a7cbe9fc0e92cb0e67f55ce0c025014e26d33e54d9e506a0f2d07fe5de"},
+ {file = "mypy-1.10.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e22e1527dc3d4aa94311d246b59e47f6455b8729f4968765ac1eacf9a4760bc7"},
+ {file = "mypy-1.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:a87dbfa85971e8d59c9cc1fcf534efe664d8949e4c0b6b44e8ca548e746a8d53"},
+ {file = "mypy-1.10.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a781f6ad4bab20eef8b65174a57e5203f4be627b46291f4589879bf4e257b97b"},
+ {file = "mypy-1.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b808e12113505b97d9023b0b5e0c0705a90571c6feefc6f215c1df9381256e30"},
+ {file = "mypy-1.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f55583b12156c399dce2df7d16f8a5095291354f1e839c252ec6c0611e86e2e"},
+ {file = "mypy-1.10.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4cf18f9d0efa1b16478c4c129eabec36148032575391095f73cae2e722fcf9d5"},
+ {file = "mypy-1.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:bc6ac273b23c6b82da3bb25f4136c4fd42665f17f2cd850771cb600bdd2ebeda"},
+ {file = "mypy-1.10.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9fd50226364cd2737351c79807775136b0abe084433b55b2e29181a4c3c878c0"},
+ {file = "mypy-1.10.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f90cff89eea89273727d8783fef5d4a934be2fdca11b47def50cf5d311aff727"},
+ {file = "mypy-1.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fcfc70599efde5c67862a07a1aaf50e55bce629ace26bb19dc17cece5dd31ca4"},
+ {file = "mypy-1.10.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:075cbf81f3e134eadaf247de187bd604748171d6b79736fa9b6c9685b4083061"},
+ {file = "mypy-1.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:3f298531bca95ff615b6e9f2fc0333aae27fa48052903a0ac90215021cdcfa4f"},
+ {file = "mypy-1.10.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fa7ef5244615a2523b56c034becde4e9e3f9b034854c93639adb667ec9ec2976"},
+ {file = "mypy-1.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3236a4c8f535a0631f85f5fcdffba71c7feeef76a6002fcba7c1a8e57c8be1ec"},
+ {file = "mypy-1.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a2b5cdbb5dd35aa08ea9114436e0d79aceb2f38e32c21684dcf8e24e1e92821"},
+ {file = "mypy-1.10.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92f93b21c0fe73dc00abf91022234c79d793318b8a96faac147cd579c1671746"},
+ {file = "mypy-1.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:28d0e038361b45f099cc086d9dd99c15ff14d0188f44ac883010e172ce86c38a"},
+ {file = "mypy-1.10.0-py3-none-any.whl", hash = "sha256:f8c083976eb530019175aabadb60921e73b4f45736760826aa1689dda8208aee"},
+ {file = "mypy-1.10.0.tar.gz", hash = "sha256:3d087fcbec056c4ee34974da493a826ce316947485cef3901f511848e687c131"},
+]
+
+[package.dependencies]
+mypy-extensions = ">=1.0.0"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
-typing-extensions = ">=3.10"
+typing-extensions = ">=4.1.0"
[package.extras]
dmypy = ["psutil (>=4.0)"]
install-types = ["pip"]
-python2 = ["typed-ast (>=1.4.0,<2)"]
+mypyc = ["setuptools (>=50)"]
reports = ["lxml"]
[[package]]
@@ -4814,43 +4627,6 @@ files = [
{file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
]
-[[package]]
-name = "pyarrow"
-version = "11.0.0"
-description = "Python library for Apache Arrow"
-optional = false
-python-versions = ">=3.7"
-files = [
- {file = "pyarrow-11.0.0-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:40bb42afa1053c35c749befbe72f6429b7b5f45710e85059cdd534553ebcf4f2"},
- {file = "pyarrow-11.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7c28b5f248e08dea3b3e0c828b91945f431f4202f1a9fe84d1012a761324e1ba"},
- {file = "pyarrow-11.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a37bc81f6c9435da3c9c1e767324ac3064ffbe110c4e460660c43e144be4ed85"},
- {file = "pyarrow-11.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad7c53def8dbbc810282ad308cc46a523ec81e653e60a91c609c2233ae407689"},
- {file = "pyarrow-11.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:25aa11c443b934078bfd60ed63e4e2d42461682b5ac10f67275ea21e60e6042c"},
- {file = "pyarrow-11.0.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:e217d001e6389b20a6759392a5ec49d670757af80101ee6b5f2c8ff0172e02ca"},
- {file = "pyarrow-11.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ad42bb24fc44c48f74f0d8c72a9af16ba9a01a2ccda5739a517aa860fa7e3d56"},
- {file = "pyarrow-11.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d942c690ff24a08b07cb3df818f542a90e4d359381fbff71b8f2aea5bf58841"},
- {file = "pyarrow-11.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f010ce497ca1b0f17a8243df3048055c0d18dcadbcc70895d5baf8921f753de5"},
- {file = "pyarrow-11.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:2f51dc7ca940fdf17893227edb46b6784d37522ce08d21afc56466898cb213b2"},
- {file = "pyarrow-11.0.0-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:1cbcfcbb0e74b4d94f0b7dde447b835a01bc1d16510edb8bb7d6224b9bf5bafc"},
- {file = "pyarrow-11.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaee8f79d2a120bf3e032d6d64ad20b3af6f56241b0ffc38d201aebfee879d00"},
- {file = "pyarrow-11.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:410624da0708c37e6a27eba321a72f29d277091c8f8d23f72c92bada4092eb5e"},
- {file = "pyarrow-11.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2d53ba72917fdb71e3584ffc23ee4fcc487218f8ff29dd6df3a34c5c48fe8c06"},
- {file = "pyarrow-11.0.0-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:f12932e5a6feb5c58192209af1d2607d488cb1d404fbc038ac12ada60327fa34"},
- {file = "pyarrow-11.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:41a1451dd895c0b2964b83d91019e46f15b5564c7ecd5dcb812dadd3f05acc97"},
- {file = "pyarrow-11.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:becc2344be80e5dce4e1b80b7c650d2fc2061b9eb339045035a1baa34d5b8f1c"},
- {file = "pyarrow-11.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f40be0d7381112a398b93c45a7e69f60261e7b0269cc324e9f739ce272f4f70"},
- {file = "pyarrow-11.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:362a7c881b32dc6b0eccf83411a97acba2774c10edcec715ccaab5ebf3bb0835"},
- {file = "pyarrow-11.0.0-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:ccbf29a0dadfcdd97632b4f7cca20a966bb552853ba254e874c66934931b9841"},
- {file = "pyarrow-11.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3e99be85973592051e46412accea31828da324531a060bd4585046a74ba45854"},
- {file = "pyarrow-11.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69309be84dcc36422574d19c7d3a30a7ea43804f12552356d1ab2a82a713c418"},
- {file = "pyarrow-11.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da93340fbf6f4e2a62815064383605b7ffa3e9eeb320ec839995b1660d69f89b"},
- {file = "pyarrow-11.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:caad867121f182d0d3e1a0d36f197df604655d0b466f1bc9bafa903aa95083e4"},
- {file = "pyarrow-11.0.0.tar.gz", hash = "sha256:5461c57dbdb211a632a48facb9b39bbeb8a7905ec95d768078525283caef5f6d"},
-]
-
-[package.dependencies]
-numpy = ">=1.16.6"
-
[[package]]
name = "pyasn1"
version = "0.5.1"
@@ -5006,23 +4782,6 @@ files = [
plugins = ["importlib-metadata"]
windows-terminal = ["colorama (>=0.4.6)"]
-[[package]]
-name = "pyjwt"
-version = "2.8.0"
-description = "JSON Web Token implementation in Python"
-optional = false
-python-versions = ">=3.7"
-files = [
- {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"},
- {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"},
-]
-
-[package.extras]
-crypto = ["cryptography (>=3.4.0)"]
-dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
-docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
-tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
-
[[package]]
name = "pylint"
version = "2.16.4"
@@ -5315,29 +5074,6 @@ files = [
{file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"},
]
-[[package]]
-name = "pywin32"
-version = "306"
-description = "Python for Window Extensions"
-optional = false
-python-versions = "*"
-files = [
- {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"},
- {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"},
- {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"},
- {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"},
- {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"},
- {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"},
- {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"},
- {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"},
- {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"},
- {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"},
- {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"},
- {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"},
- {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"},
- {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"},
-]
-
[[package]]
name = "pywin32-ctypes"
version = "0.2.2"
@@ -5440,20 +5176,6 @@ opencv-python-headless = ">=4.0.1"
scikit-learn = ">=0.19.1"
typing-extensions = "*"
-[[package]]
-name = "querystring-parser"
-version = "1.2.4"
-description = "QueryString parser for Python/Django that correctly handles nested dictionaries"
-optional = false
-python-versions = "*"
-files = [
- {file = "querystring_parser-1.2.4-py2.py3-none-any.whl", hash = "sha256:d2fa90765eaf0de96c8b087872991a10238e89ba015ae59fedfed6bd61c242a0"},
- {file = "querystring_parser-1.2.4.tar.gz", hash = "sha256:644fce1cffe0530453b43a83a38094dbe422ccba8c9b2f2a1c00280e14ca8a62"},
-]
-
-[package.dependencies]
-six = "*"
-
[[package]]
name = "rapidfuzz"
version = "3.6.1"
@@ -5776,32 +5498,6 @@ files = [
[package.dependencies]
pyasn1 = ">=0.1.3"
-[[package]]
-name = "ruff"
-version = "0.0.257"
-description = "An extremely fast Python linter, written in Rust."
-optional = false
-python-versions = ">=3.7"
-files = [
- {file = "ruff-0.0.257-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:7280640690c1d0046b20e0eb924319a89d8e22925d7d232180ce31196e7478f8"},
- {file = "ruff-0.0.257-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:4582b73da61ab410ffda35b2987a6eacb33f18263e1c91810f0b9779ec4f41a9"},
- {file = "ruff-0.0.257-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5acae9878f1136893e266348acdb9d30dfae23c296d3012043816432a5abdd51"},
- {file = "ruff-0.0.257-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d9f0912d045eee15e8e02e335c16d7a7f9fb6821aa5eb1628eeb5bbfa3d88908"},
- {file = "ruff-0.0.257-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9542c34ee5298b31be6c6ba304f14b672dcf104846ee65adb2466d3e325870"},
- {file = "ruff-0.0.257-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:3464f1ad4cea6c4b9325da13ae306bd22bf15d226e18d19c52db191b1f4355ac"},
- {file = "ruff-0.0.257-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a54bfd559e558ee0df2a2f3756423fe6a9de7307bc290d807c3cdf351cb4c24"},
- {file = "ruff-0.0.257-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3438fd38446e1a0915316f4085405c9feca20fe00a4b614995ab7034dbfaa7ff"},
- {file = "ruff-0.0.257-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:358cc2b547bd6451dcf2427b22a9c29a2d9c34e66576c693a6381c5f2ed3011d"},
- {file = "ruff-0.0.257-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:783390f1e94a168c79d7004426dae3e4ae2999cc85f7d00fdd86c62262b71854"},
- {file = "ruff-0.0.257-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:aaa3b5b6929c63a854b6bcea7a229453b455ab26337100b2905fae4523ca5667"},
- {file = "ruff-0.0.257-py3-none-musllinux_1_2_i686.whl", hash = "sha256:4ecd7a84db4816df2dcd0f11c5365a9a2cf4fa70a19b3ac161b7b0bfa592959d"},
- {file = "ruff-0.0.257-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3db8d77d5651a2c0d307102d717627a025d4488d406f54c2764b21cfbe11d822"},
- {file = "ruff-0.0.257-py3-none-win32.whl", hash = "sha256:d2c8755fa4f6c5e5ec032ad341ca3beeecd16786e12c3f26e6b0cc40418ae998"},
- {file = "ruff-0.0.257-py3-none-win_amd64.whl", hash = "sha256:3cec07d6fecb1ebbc45ea8eeb1047b929caa2f7dfb8dd4b0e1869ff789326da5"},
- {file = "ruff-0.0.257-py3-none-win_arm64.whl", hash = "sha256:352f1bdb9b433b3b389aee512ffb0b82226ae1e25b3d92e4eaf0e7be6b1b6f6a"},
- {file = "ruff-0.0.257.tar.gz", hash = "sha256:fedfd06a37ddc17449203c3e38fc83fb68de7f20b5daa0ee4e60d3599b38bab0"},
-]
-
[[package]]
name = "s3transfer"
version = "0.6.2"
@@ -7251,21 +6947,6 @@ platformdirs = ">=3.9.1,<5"
docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"]
test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"]
-[[package]]
-name = "waitress"
-version = "2.1.2"
-description = "Waitress WSGI server"
-optional = false
-python-versions = ">=3.7.0"
-files = [
- {file = "waitress-2.1.2-py3-none-any.whl", hash = "sha256:7500c9625927c8ec60f54377d590f67b30c8e70ef4b8894214ac6e4cad233d2a"},
- {file = "waitress-2.1.2.tar.gz", hash = "sha256:780a4082c5fbc0fde6a2fcfe5e26e6efc1e8f425730863c04085769781f51eba"},
-]
-
-[package.extras]
-docs = ["Sphinx (>=1.8.1)", "docutils", "pylons-sphinx-themes (>=1.0.9)"]
-testing = ["coverage (>=5.0)", "pytest", "pytest-cover"]
-
[[package]]
name = "wandb"
version = "0.12.17"
@@ -7365,22 +7046,6 @@ files = [
{file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"},
]
-[[package]]
-name = "websocket-client"
-version = "1.7.0"
-description = "WebSocket client for Python with low level API options"
-optional = false
-python-versions = ">=3.8"
-files = [
- {file = "websocket-client-1.7.0.tar.gz", hash = "sha256:10e511ea3a8c744631d3bd77e61eb17ed09304c413ad42cf6ddfa4c7787e8fe6"},
- {file = "websocket_client-1.7.0-py3-none-any.whl", hash = "sha256:f4c3d22fec12a2461427a29957ff07d35098ee2d976d3ba244e688b8b4057588"},
-]
-
-[package.extras]
-docs = ["Sphinx (>=6.0)", "sphinx-rtd-theme (>=1.1.0)"]
-optional = ["python-socks", "wsaccel"]
-test = ["websockets"]
-
[[package]]
name = "werkzeug"
version = "3.0.1"
@@ -7829,4 +7494,4 @@ onnx = ["onnx", "onnxruntime_gpu", "onnxsim"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.11"
-content-hash = "211f46e98afb35d377534cd3f659f1d697f4acc128900726ff47a2c38bd458bc"
+content-hash = "ad1f194222d7e7133cb7b7e0bccf34311d2fd56933d197d3f073b50893a8277d"
diff --git a/pylintrc b/pylintrc
index c805a91b..5740000b 100644
--- a/pylintrc
+++ b/pylintrc
@@ -66,7 +66,6 @@ suggestion-mode=yes
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
-
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
@@ -84,38 +83,109 @@ confidence=
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=raw-checker-failed,
- bad-inline-option,
- locally-disabled,
- file-ignored,
- suppressed-message,
- useless-suppression,
- deprecated-pragma,
- use-symbolic-message-instead,
- unspecified-encoding,
- missing-function-docstring, # This should be dealt by interrogate
- unidiomatic-typecheck,
- abstract-method,
- arguments-differ,
- not-callable,
- import-error,
- too-many-locals,
- too-many-arguments,
- too-few-public-methods,
- attribute-defined-outside-init,
- too-many-instance-attributes,
- missing-module-docstring,
- too-many-statements,
- too-many-branches,
- function-redefined,
- no-member,
- unbalanced-tuple-unpacking,
- duplicate-code,
- unsubscriptable-object,
- too-many-nested-blocks,
- fixme,
- too-many-ancestors,
- broad-exception-caught,
-
+ bad-inline-option,
+ locally-disabled,
+ file-ignored,
+ suppressed-message,
+ useless-suppression,
+ deprecated-pragma,
+ use-symbolic-message-instead,
+ unspecified-encoding,
+ # Handled by interrogate
+ missing-function-docstring,
+ unidiomatic-typecheck,
+ abstract-method,
+ arguments-differ,
+ not-callable,
+ import-error,
+ too-many-locals,
+ too-many-arguments,
+ too-few-public-methods,
+ attribute-defined-outside-init,
+ too-many-instance-attributes,
+ missing-module-docstring,
+ too-many-statements,
+ too-many-branches,
+ function-redefined,
+ no-member,
+ unbalanced-tuple-unpacking,
+ duplicate-code,
+ unsubscriptable-object,
+ too-many-nested-blocks,
+ fixme,
+ too-many-ancestors,
+ broad-exception-caught,
+ # Taken from https://pypi.org/project/pylint-to-ruff/
+ # Already checked thanks to ruff
+ line-too-long,
+ trailing-whitespace,
+ useless-import-alias,
+ non-ascii-name,
+ unnecessary-dunder-call,
+ unnecessary-direct-lambda-call,
+ syntax-error,
+ return-in-init,
+ return-outside-function,
+ yield-outside-function,
+ nonlocal-and-global,
+ continue-in-finally,
+ nonlocal-without-binding,
+ duplicate-bases,
+ unexpected-special-method-signature,
+ invalid-all-object,
+ invalid-all-format,
+ potential-index-error,
+ misplaced-bare-raise,
+ dict-iter-missing-items,
+ await-outside-async,
+ logging-too-many-args,
+ logging-too-few-args,
+ bad-string-format-type,
+ bad-str-strip-call,
+ invalid-envvar-value,
+ singledispatch-method,
+ singledispatchmethod-function,
+ bidirectional-unicode,
+ invalid-character-backspace,
+ invalid-character-sub,
+ invalid-character-esc,
+ invalid-character-nul,
+ invalid-character-zero-width-space,
+ modified-iterating-set,
+ comparison-with-itself,
+ no-classmethod-decorator,
+ no-staticmethod-decorator,
+ useless-object-inheritance,
+ too-many-public-methods,
+ too-many-return-statements,
+ too-many-boolean-expressions,
+ redefined-argument-from-local,
+ useless-return,
+ unnecessary-comprehension,
+ unnecessary-dict-index-lookup,
+ unnecessary-list-index-lookup,
+ unnecessary-lambda,
+ useless-else-on-loop,
+ self-assigning-variable,
+ redeclared-assigned-name,
+ assert-on-string-literal,
+ duplicate-value,
+ nan-comparison,
+ bad-staticmethod-argument,
+ super-without-brackets,
+ import-self,
+ global-variable-not-assigned,
+ global-statement,
+ global-at-module-level,
+ unused-import,
+ unused-variable,
+ bare-except,
+ binary-op-exception,
+ bad-open-mode,
+ invalid-envvar-default,
+ subprocess-popen-preexec-fn,
+ useless-with-lock,
+ nested-min-max,
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
@@ -123,7 +193,6 @@ disable=raw-checker-failed,
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
-
[REPORTS]
# Python expression which should return a score less than or equal to 10. You
@@ -148,7 +217,6 @@ reports=no
# Activate the evaluation score.
score=yes
-
[REFACTORING]
# Maximum number of nested blocks for function / method body
@@ -160,7 +228,6 @@ max-nested-blocks=5
# printed.
never-returning-functions=sys.exit,argparse.parse_error
-
[LOGGING]
# The type of string formatting that logging methods do. `old` means using %
@@ -171,7 +238,6 @@ logging-format-style=old
# function parameter format.
logging-modules=logging
-
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
@@ -182,7 +248,6 @@ check-quote-consistency=no
# on implicit string concatenation in sequences defined over several lines.
check-str-concat-over-line-jumps=no
-
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
@@ -206,7 +271,6 @@ spelling-private-dict-file=
# --spelling-private-dict-file option) instead of raising a message.
spelling-store-unknown-words=no
-
[TYPECHECK]
# List of decorators that produce context managers, such as
@@ -265,7 +329,6 @@ mixin-class-rgx=.*[Mm]ixin
# List of decorators that change the signature of a decorated function.
signature-mutators=
-
[BASIC]
# Naming style matching correct argument names.
@@ -285,12 +348,13 @@ attr-naming-style=snake_case
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
-bad-names=foo,
- bar,
- baz,
- toto,
- tutu,
- tata
+bad-names=
+ foo,
+ bar,
+ baz,
+ toto,
+ tutu,
+ tata
# Bad variable names regexes, separated by a comma. If names match any regex,
# they will always be refused
@@ -340,12 +404,13 @@ function-naming-style=snake_case
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
-good-names=i,
- j,
- k,
- ex,
- Run,
- _
+good-names=
+ i,
+ j,
+ k,
+ ex,
+ Run,
+ _
# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
@@ -401,18 +466,16 @@ variable-naming-style=snake_case
# naming style.
#variable-rgx=
-
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
- XXX,
- TODO
+ XXX,
+ TODO
# Regular expression of note tags to take in consideration.
#notes-rgx=
-
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
@@ -442,7 +505,6 @@ single-line-class-stmt=no
# else.
single-line-if-stmt=no
-
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
@@ -457,8 +519,7 @@ allowed-redefined-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
-callbacks=cb_,
- _cb
+callbacks=cb_,_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
@@ -475,7 +536,6 @@ init-import=no
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
-
[SIMILARITIES]
# Comments are removed from the similarity computation
@@ -493,7 +553,6 @@ ignore-signatures=no
# Minimum lines number of a similarity.
min-similarity-lines=4
-
[IMPORTS]
# List of modules that can be imported at any level, not just the top level
@@ -534,7 +593,6 @@ known-third-party=enchant
# Couples of modules and preferred modules, separated by a comma.
preferred-modules=
-
[DESIGN]
# List of regular expressions of class ancestor names to ignore when counting
@@ -575,25 +633,24 @@ max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
-
[CLASSES]
# Warn about protected attribute access inside special methods
check-protected-access-in-special-methods=no
# List of method names used to declare (i.e. assign) instance attributes.
-defining-attr-methods=__init__,
- __new__,
- setUp,
- __post_init__
-
-# List of member names, which should be excluded from the protected access
-# warning.
-exclude-protected=_asdict,
- _fields,
- _replace,
- _source,
- _make
+defining-attr-methods=
+ __init__,
+ __new__,
+ setUp,
+ __post_init__
+ # List of member names, which should be excluded from the protected access
+ # warning.
+ exclude-protected=_asdict,
+ _fields,
+ _replace,
+ _source,
+ _make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
@@ -601,10 +658,8 @@ valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=cls
-
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
-overgeneral-exceptions=builtins.BaseException,
- builtins.Exception
+overgeneral-exceptions=builtins.BaseException,builtins.Exception
diff --git a/pyproject.toml b/pyproject.toml
index 7bf2d7c7..ac91d7c3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,33 +1,33 @@
[tool.poetry]
name = "quadra"
-version = "2.1.3"
+version = "2.1.4"
description = "Deep Learning experiment orchestration library"
authors = [
- "Federico Belotti ",
- "Silvia Bianchetti ",
- "Refik Can Malli ",
- "Lorenzo Mammana ",
- "Alessandro Polidori ",
+ "Federico Belotti ",
+ "Silvia Bianchetti ",
+ "Refik Can Malli ",
+ "Lorenzo Mammana ",
+ "Alessandro Polidori ",
]
license = "Apache-2.0"
readme = "README.md"
keywords = ["deep learning", "experiment", "lightning", "hydra-core"]
classifiers = [
- "Programming Language :: Python :: 3",
- "Intended Audience :: Developers",
- "Intended Audience :: Education",
- "Intended Audience :: Science/Research",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
- "Topic :: Software Development",
- "Topic :: Software Development :: Libraries",
- "Topic :: Software Development :: Libraries :: Python Modules",
- "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Education",
+ "Intended Audience :: Science/Research",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Software Development",
+ "Topic :: Software Development :: Libraries",
+ "Topic :: Software Development :: Libraries :: Python Modules",
+ "License :: OSI Approved :: Apache Software License",
]
homepage = "https://orobix.github.io/quadra"
repository = "https://github.com/orobix/quadra"
packages = [
- { include = "quadra" },
- { include = "hydra_plugins", from = "quadra_hydra_plugin" },
+ { include = "quadra" },
+ { include = "hydra_plugins", from = "quadra_hydra_plugin" },
]
[build-system]
@@ -42,10 +42,10 @@ python = ">=3.9,<3.11"
poetry = "1.7.1"
torch = [
- { url = "https://download.pytorch.org/whl/cu121/torch-2.1.2%2Bcu121-cp310-cp310-linux_x86_64.whl", markers = "sys_platform == 'linux' and python_version == '3.10'" },
- { url = "https://download.pytorch.org/whl/cu121/torch-2.1.2%2Bcu121-cp310-cp310-win_amd64.whl", markers = "sys_platform == 'win32' and python_version == '3.10'" },
- { url = "https://download.pytorch.org/whl/cu121/torch-2.1.2%2Bcu121-cp39-cp39-linux_x86_64.whl", markers = "sys_platform == 'linux' and python_version == '3.9'" },
- { url = "https://download.pytorch.org/whl/cu121/torch-2.1.2%2Bcu121-cp39-cp39-win_amd64.whl", markers = "sys_platform == 'win32' and python_version == '3.9'" },
+ { url = "https://download.pytorch.org/whl/cu121/torch-2.1.2%2Bcu121-cp310-cp310-linux_x86_64.whl", markers = "sys_platform == 'linux' and python_version == '3.10'" },
+ { url = "https://download.pytorch.org/whl/cu121/torch-2.1.2%2Bcu121-cp310-cp310-win_amd64.whl", markers = "sys_platform == 'win32' and python_version == '3.10'" },
+ { url = "https://download.pytorch.org/whl/cu121/torch-2.1.2%2Bcu121-cp39-cp39-linux_x86_64.whl", markers = "sys_platform == 'linux' and python_version == '3.9'" },
+ { url = "https://download.pytorch.org/whl/cu121/torch-2.1.2%2Bcu121-cp39-cp39-win_amd64.whl", markers = "sys_platform == 'win32' and python_version == '3.9'" },
]
torchvision = { version = "~0.16", source = "torch_cu121" }
@@ -55,8 +55,7 @@ torchmetrics = "~0.10"
hydra_core = "~1.3"
hydra_colorlog = "~1.2"
hydra_optuna_sweeper = "~1.2"
-mlflow = "2.3.1"
-mlflow_skinny = "2.3.1"
+mlflow-skinny = "^2.3.1"
boto3 = "~1.26"
minio = "~7.1"
tensorboard = "~2.11"
@@ -106,13 +105,10 @@ optional = true
hydra-plugins = { path = "quadra_hydra_plugin" }
# Dev dependencies
interrogate = "~1.5"
-black = "~22.12"
-isort = "~5.11"
pre_commit = "~3.0"
pylint = "~2.16"
types_pyyaml = "~6.0.12"
-mypy = "~1.0"
-ruff = "0.0.257"
+mypy = "^1.9.0"
pandas_stubs = "~1.5.3"
twine = "~4.0"
poetry-bumpversion = "~0.3"
@@ -152,35 +148,6 @@ onnx = ["onnx", "onnxsim", "onnxruntime_gpu"]
search = '__version__ = "{current_version}"'
replace = '__version__ = "{new_version}"'
-# Black formatting
-[tool.black]
-line-length = 120
-include = '\.pyi?$'
-exclude = '''
-/(
- \.eggs # exclude a few common directories in the
- | \.git # root of the project
- | \.hg
- | \.mypy_cache
- | \.tox
- | \.venv
- | _build
- | buck-out
- | build
- | dist
- )/
-'''
-
-# iSort
-[tool.isort]
-force_grid_wrap = 0
-use_parentheses = true
-ensure_newline_before_comments = true
-line_length = 120
-multi_line_output = 3
-include_trailing_comma = true
-skip_gitignore = true
-
# Pytest
[tool.pytest.ini_options]
testpaths = ["tests"]
@@ -210,23 +177,23 @@ quiet = false
whitelist_regex = []
color = true
ignore_regex = [
- "^get$",
- "^mock_.*",
- ".*BaseClass.*",
- ".*on_train.*",
- ".*on_validation.*",
- ".*on_test.*",
- ".*on_predict.*",
- ".*forward.*",
- ".*backward.*",
- ".*training_step.*",
- ".*validation_step.*",
- ".*test_step.*",
- ".*predict_step.*",
- ".*train_epoch.*",
- ".*validation_epoch.*",
- ".*test_epoch.*",
- ".*on_fit.*",
+ "^get$",
+ "^mock_.*",
+ ".*BaseClass.*",
+ ".*on_train.*",
+ ".*on_validation.*",
+ ".*on_test.*",
+ ".*on_predict.*",
+ ".*forward.*",
+ ".*backward.*",
+ ".*training_step.*",
+ ".*validation_step.*",
+ ".*test_step.*",
+ ".*predict_step.*",
+ ".*train_epoch.*",
+ ".*validation_epoch.*",
+ ".*test_epoch.*",
+ ".*on_fit.*",
]
generate-badge = "docs/images"
badge-format = "svg"
@@ -266,22 +233,63 @@ warn_return_any = false
exclude = ["quadra/utils/tests", "tests"]
[tool.ruff]
-select = ["D"]
+extend-include = ["*.ipynb"]
+target-version = "py39"
+# Orobix guidelines
+line-length = 120
+indent-width = 4
+
+[tool.ruff.lint]
+select = [
+ # pycodestyle
+ "E",
+ # pycodestyle
+ "W",
+ # Pyflakes
+ "F",
+ # pyupgrade
+ "UP",
+ # flake8-bugbear
+ "B",
+ # flake8-simplify
+ "SIM",
+ # isort
+ "I",
+ # flake8-comprehensions
+ "C4",
+ # docstrings
+ "D",
+ # Pylint
+ "PL",
+]
+
ignore = [
- "D100",
- # this is controlled by interrogate with exlude_regex
- # we can skip it here
- "D102",
- "D104",
- "D105",
- "D107",
- # no blank line after summary line. This might be not required.
- # usually we violate this rule
- "D205",
+ "D100", # Missing docstring in public module
+ # this is controlled by interrogate with exlude_regex
+ # we can skip it here
+ "D102",
+ "D104", # Missing docstring in public package
+ "D105", # Missing docstring for magic method (def __*__)
+ "D107", # Missing docstring in __init__
+ "D205", # no blank line after summary line. Usually we violate this rule
+ "E731", # Do not assign a lambda expression, use a def
+ "E741", # Checks for the use of the characters 'l', 'O', or 'I' as variable names.
+ "E402", # Module level import not at top of file
+ "SIM108", # https://github.com/astral-sh/ruff/issues/5528
+ "SIM117", # Single with statement instead of multiple with statements
+ # Pylint specific ignores
+ "PLR0912", # too-many-branches
+ "PLR0913", # too-many-arguments
+ "PLR0914", # too-many-locals
+ "PLR0915", # too-many-statements
+ "PLR1702", # too-many-nested-blocks
+ "PLW1514", # unspecified-encoding
+ "PLR2004", # magic-value-comparison
]
-exclude = ["Makefile", ".gitignore"]
-[tool.ruff.pydocstyle]
+exclude = ["Makefile", ".gitignore", "tests"]
+
+[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.pytest_env]
diff --git a/quadra/__init__.py b/quadra/__init__.py
index 1d9d954e..d8e4ae22 100644
--- a/quadra/__init__.py
+++ b/quadra/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "2.1.3"
+__version__ = "2.1.4"
def get_version():
diff --git a/quadra/callbacks/anomalib.py b/quadra/callbacks/anomalib.py
index 70f3abe5..80cefb8c 100644
--- a/quadra/callbacks/anomalib.py
+++ b/quadra/callbacks/anomalib.py
@@ -1,5 +1,7 @@
+from __future__ import annotations
+
from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Any
import cv2
import matplotlib
@@ -7,7 +9,12 @@
import numpy as np
import pytorch_lightning as pl
from anomalib.models.components.base import AnomalyModule
-from anomalib.post_processing import add_anomalous_label, add_normal_label, compute_mask, superimpose_anomaly_map
+from anomalib.post_processing import (
+ add_anomalous_label,
+ add_normal_label,
+ compute_mask,
+ superimpose_anomaly_map,
+)
from anomalib.pre_processing.transforms import Denormalize
from anomalib.utils.loggers import AnomalibWandbLogger
from pytorch_lightning import Callback
@@ -29,12 +36,12 @@ class Visualizer:
"""
def __init__(self) -> None:
- self.images: List[Dict] = []
+ self.images: list[dict] = []
self.figure: matplotlib.figure.Figure
self.axis: np.ndarray
- def add_image(self, image: np.ndarray, title: str, color_map: Optional[str] = None):
+ def add_image(self, image: np.ndarray, title: str, color_map: str | None = None):
"""Add image to figure.
Args:
@@ -140,7 +147,7 @@ def on_test_batch_end(
self,
trainer: pl.Trainer,
pl_module: AnomalyModule,
- outputs: Optional[STEP_OUTPUT],
+ outputs: STEP_OUTPUT | None,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
@@ -161,27 +168,29 @@ def on_test_batch_end(
assert outputs is not None and isinstance(outputs, dict)
- if any(x not in outputs.keys() for x in ["image_path", "image", "mask", "anomaly_maps", "label"]):
+ if any(x not in outputs for x in ["image_path", "image", "mask", "anomaly_maps", "label"]):
# I'm probably in the classification scenario so I can't use the visualizer
return
- if self.inputs_are_normalized:
- normalize = False # anomaly maps are already normalized
- else:
- normalize = True # raw anomaly maps. Still need to normalize
-
if self.threshold_type == "pixel":
if hasattr(pl_module.pixel_metrics.F1Score, "threshold"):
threshold = pl_module.pixel_metrics.F1Score.threshold
else:
raise AttributeError("Metric has no threshold attribute")
+ elif hasattr(pl_module.image_metrics.F1Score, "threshold"):
+ threshold = pl_module.image_metrics.F1Score.threshold
else:
- if hasattr(pl_module.image_metrics.F1Score, "threshold"):
- threshold = pl_module.image_metrics.F1Score.threshold
- else:
- raise AttributeError("Metric has no threshold attribute")
-
- for filename, image, true_mask, anomaly_map, gt_label, pred_label, anomaly_score in tqdm(
+ raise AttributeError("Metric has no threshold attribute")
+
+ for (
+ filename,
+ image,
+ true_mask,
+ anomaly_map,
+ gt_label,
+ pred_label,
+ anomaly_score,
+ ) in tqdm(
zip(
outputs["image_path"],
outputs["image"],
@@ -192,63 +201,68 @@ def on_test_batch_end(
outputs["pred_scores"],
)
):
- image = Denormalize()(image.cpu())
- true_mask = true_mask.cpu().numpy()
- anomaly_map = anomaly_map.cpu().numpy()
+ denormalized_image = Denormalize()(image.cpu())
+ current_true_mask = true_mask.cpu().numpy()
+ current_anomaly_map = anomaly_map.cpu().numpy()
output_label_folder = "ok" if pred_label == gt_label else "wrong"
if self.plot_only_wrong and output_label_folder == "ok":
continue
- heat_map = superimpose_anomaly_map(anomaly_map, image, normalize=normalize)
+ heatmap = superimpose_anomaly_map(
+ current_anomaly_map, denormalized_image, normalize=not self.inputs_are_normalized
+ )
+
if isinstance(threshold, float):
- pred_mask = compute_mask(anomaly_map, threshold)
+ pred_mask = compute_mask(current_anomaly_map, threshold)
else:
raise TypeError("Threshold should be float")
- vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick")
+ vis_img = mark_boundaries(denormalized_image, pred_mask, color=(1, 0, 0), mode="thick")
visualizer = Visualizer()
if self.task == "segmentation":
- visualizer.add_image(image=image, title="Image")
+ visualizer.add_image(image=denormalized_image, title="Image")
if "mask" in outputs:
- true_mask = true_mask * 255
- visualizer.add_image(image=true_mask, color_map="gray", title="Ground Truth")
- visualizer.add_image(image=heat_map, title="Predicted Heat Map")
+ current_true_mask = current_true_mask * 255
+ visualizer.add_image(image=current_true_mask, color_map="gray", title="Ground Truth")
+ visualizer.add_image(image=heatmap, title="Predicted Heat Map")
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
visualizer.add_image(image=vis_img, title="Segmentation Result")
elif self.task == "classification":
- gt_im = add_anomalous_label(image) if gt_label else add_normal_label(image)
+ gt_im = add_anomalous_label(denormalized_image) if gt_label else add_normal_label(denormalized_image)
visualizer.add_image(gt_im, title="Image/True label")
if anomaly_score >= threshold:
- image_classified = add_anomalous_label(heat_map, anomaly_score)
+ image_classified = add_anomalous_label(heatmap, anomaly_score)
else:
- image_classified = add_normal_label(heat_map, 1 - anomaly_score)
+ image_classified = add_normal_label(heatmap, 1 - anomaly_score)
visualizer.add_image(image=image_classified, title="Prediction")
visualizer.generate()
visualizer.figure.suptitle(
- f"F1 threshold: {threshold}, Mask_max: {anomaly_map.max():.3f}, Anomaly_score: {anomaly_score:.3f}"
+ f"F1 threshold: {threshold}, Mask_max: {current_anomaly_map.max():.3f}, "
+ f"Anomaly_score: {anomaly_score:.3f}"
)
- filename = Path(filename)
- self._add_images(visualizer, filename, output_label_folder)
+ path_filename = Path(filename)
+ self._add_images(visualizer, path_filename, output_label_folder)
visualizer.close()
if self.plot_raw_outputs:
- for raw_output, raw_name in zip([heat_map, vis_img], ["heatmap", "segmentation"]):
+ for raw_output, raw_name in zip([heatmap, vis_img], ["heatmap", "segmentation"]):
+ current_raw_output = raw_output
if raw_name == "segmentation":
- raw_output = (raw_output * 255).astype(np.uint8)
- raw_output = cv2.cvtColor(raw_output, cv2.COLOR_RGB2BGR)
+ current_raw_output = (raw_output * 255).astype(np.uint8)
+ current_raw_output = cv2.cvtColor(current_raw_output, cv2.COLOR_RGB2BGR)
raw_filename = (
Path(self.output_path)
/ "images"
/ output_label_folder
- / filename.parent.name
+ / path_filename.parent.name
/ "raw_outputs"
- / Path(filename.stem + f"_{raw_name}.png")
+ / Path(path_filename.stem + f"_{raw_name}.png")
)
raw_filename.parent.mkdir(parents=True, exist_ok=True)
- cv2.imwrite(str(raw_filename), raw_output)
+ cv2.imwrite(str(raw_filename), current_raw_output)
def on_test_end(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Sync logs.
diff --git a/quadra/callbacks/mlflow.py b/quadra/callbacks/mlflow.py
index 17b7b6f4..f54e1fbb 100644
--- a/quadra/callbacks/mlflow.py
+++ b/quadra/callbacks/mlflow.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
import glob
import os
-from typing import Any, Dict, Optional
+from typing import Any, Literal
import torch
from pytorch_lightning import Callback, LightningModule, Trainer
@@ -123,7 +125,7 @@ class LogGradients(Callback):
def __init__(
self,
norm: int = 2,
- tag: Optional[str] = None,
+ tag: str | None = None,
sep: str = "/",
round_to: int = 3,
log_all_grads: bool = False,
@@ -134,12 +136,9 @@ def __init__(
self.round_to = round_to
self.log_all_grads = log_all_grads
- def _grad_norm(self, named_params) -> Dict:
+ def _grad_norm(self, named_params) -> dict:
"""Compute the gradient norm and return it in a dictionary."""
- if self.tag is None:
- grad_tag = ""
- else:
- grad_tag = "_" + self.tag
+ grad_tag = "" if self.tag is None else "_" + self.tag
results = {}
for name, p in named_params:
if p.requires_grad and p.grad is not None:
@@ -161,7 +160,7 @@ def on_train_batch_end(
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
- unused: Optional[int] = 0,
+ unused: int | None = 0,
) -> None:
"""Method called at the end of the train batch
Args:
@@ -259,7 +258,7 @@ class LogLearningRate(LearningRateMonitor):
log_momentum: If True, log momentum as well.
"""
- def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False):
+ def __init__(self, logging_interval: Literal["step", "epoch"] | None = None, log_momentum: bool = False):
super().__init__(logging_interval=logging_interval, log_momentum=log_momentum)
def on_train_batch_start(self, trainer, *args, **kwargs):
diff --git a/quadra/datamodules/anomaly.py b/quadra/datamodules/anomaly.py
index eb2e1a8d..913c9d28 100644
--- a/quadra/datamodules/anomaly.py
+++ b/quadra/datamodules/anomaly.py
@@ -1,6 +1,7 @@
+from __future__ import annotations
+
import os
import pathlib
-from typing import Optional, Tuple, Union
import albumentations
import pandas as pd
@@ -43,22 +44,22 @@ class AnomalyDataModule(BaseDataModule):
def __init__(
self,
data_path: str,
- category: Optional[str] = None,
- image_size: Optional[Union[int, Tuple[int, int]]] = None,
+ category: str | None = None,
+ image_size: int | tuple[int, int] | None = None,
train_batch_size: int = 32,
test_batch_size: int = 32,
num_workers: int = 8,
- train_transform: Optional[albumentations.Compose] = None,
- val_transform: Optional[albumentations.Compose] = None,
- test_transform: Optional[albumentations.Compose] = None,
+ train_transform: albumentations.Compose | None = None,
+ val_transform: albumentations.Compose | None = None,
+ test_transform: albumentations.Compose | None = None,
seed: int = 0,
task: str = "segmentation",
- mask_suffix: Optional[str] = None,
+ mask_suffix: str | None = None,
create_test_set_if_empty: bool = True,
phase: str = "train",
name: str = "anomaly_datamodule",
- valid_area_mask: Optional[str] = None,
- crop_area: Optional[Tuple[int, int, int, int]] = None,
+ valid_area_mask: str | None = None,
+ crop_area: tuple[int, int, int, int] | None = None,
**kwargs,
) -> None:
super().__init__(
@@ -108,7 +109,7 @@ def _prepare_data(self) -> None:
create_test_set_if_empty=self.create_test_set_if_empty,
)
- def setup(self, stage: Optional[str] = None) -> None:
+ def setup(self, stage: str | None = None) -> None:
"""Setup data module based on stages of training."""
if stage == "fit" and self.phase == "train":
self.train_dataset = AnomalyDataset(
diff --git a/quadra/datamodules/base.py b/quadra/datamodules/base.py
index 451321a5..3371b584 100644
--- a/quadra/datamodules/base.py
+++ b/quadra/datamodules/base.py
@@ -1,10 +1,13 @@
+from __future__ import annotations
+
import multiprocessing as mp
import multiprocessing.pool as mpp
import os
import pickle as pkl
import typing
+from collections.abc import Callable, Iterable, Sequence
from functools import wraps
-from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Union, cast
+from typing import Any, Literal, Union, cast
import albumentations
import numpy as np
@@ -105,7 +108,7 @@ def istarmap(self, func: Callable, iterable: Iterable, chunksize: int = 1):
"""Starmap-version of imap."""
self._check_running()
if chunksize < 1:
- raise ValueError("Chunksize must be 1+, not {0:n}".format(chunksize))
+ raise ValueError(f"Chunksize must be 1+, not {chunksize:n}")
task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
result = mpp.IMapIterator(self)
@@ -146,13 +149,13 @@ def __init__(
batch_size: int = 32,
seed: int = 42,
load_aug_images: bool = False,
- aug_name: Optional[str] = None,
- n_aug_to_take: Optional[int] = None,
- replace_str_from: Optional[str] = None,
- replace_str_to: Optional[str] = None,
- train_transform: Optional[albumentations.Compose] = None,
- val_transform: Optional[albumentations.Compose] = None,
- test_transform: Optional[albumentations.Compose] = None,
+ aug_name: str | None = None,
+ n_aug_to_take: int | None = None,
+ replace_str_from: str | None = None,
+ replace_str_to: str | None = None,
+ train_transform: albumentations.Compose | None = None,
+ val_transform: albumentations.Compose | None = None,
+ test_transform: albumentations.Compose | None = None,
enable_hashing: bool = True,
hash_size: Literal[32, 64, 128] = 64,
hash_type: Literal["content", "size"] = "content",
@@ -178,7 +181,7 @@ def __init__(
self.n_aug_to_take = n_aug_to_take
self.replace_str_from = replace_str_from
self.replace_str_to = replace_str_to
- self.extra_args: Dict[str, Any] = {}
+ self.extra_args: dict[str, Any] = {}
self.train_dataset: TrainDataset
self.val_dataset: ValDataset
self.test_dataset: TestDataset
@@ -282,7 +285,7 @@ def prepare_data(self) -> None:
self.hash_data()
self.save_checkpoint()
- def __getstate__(self) -> Dict[str, Any]:
+ def __getstate__(self) -> dict[str, Any]:
"""This method is called when pickling the object.
It's useful to remove attributes that shouldn't be pickled.
"""
@@ -341,12 +344,12 @@ def restore_checkpoint(self) -> None:
# TODO: Check if this function can be removed
def load_augmented_samples(
self,
- samples: List[str],
- targets: List[Any],
- replace_str_from: Optional[str] = None,
- replace_str_to: Optional[str] = None,
+ samples: list[str],
+ targets: list[Any],
+ replace_str_from: str | None = None,
+ replace_str_to: str | None = None,
shuffle: bool = False,
- ) -> Tuple[List[str], List[str]]:
+ ) -> tuple[list[str], list[str]]:
"""Loads augmented samples."""
if self.n_aug_to_take is None:
raise ValueError("`n_aug_to_take` is not set. Cannot load augmented samples.")
@@ -355,9 +358,10 @@ def load_augmented_samples(
for sample, label in zip(samples, targets):
aug_samples.append(sample)
aug_labels.append(label)
+ final_sample = sample
if replace_str_from is not None and replace_str_to is not None:
- sample = sample.replace(replace_str_from, replace_str_to)
- base, ext = os.path.splitext(sample)
+ final_sample = final_sample.replace(replace_str_from, replace_str_to)
+ base, ext = os.path.splitext(final_sample)
for k in range(self.n_aug_to_take):
aug_samples.append(base + "_" + str(k + 1) + ext)
aug_labels.append(label)
diff --git a/quadra/datamodules/classification.py b/quadra/datamodules/classification.py
index 3f3b1fa2..00a47a37 100644
--- a/quadra/datamodules/classification.py
+++ b/quadra/datamodules/classification.py
@@ -1,7 +1,10 @@
# pylint: disable=unsupported-assignment-operation,unsubscriptable-object
+from __future__ import annotations
+
import os
import random
-from typing import Any, Callable, Dict, List, Optional, Tuple, Type
+from collections.abc import Callable
+from typing import Any
import albumentations
import numpy as np
@@ -53,29 +56,29 @@ class ClassificationDataModule(BaseDataModule):
def __init__(
self,
data_path: str,
- dataset: Type[ImageClassificationListDataset] = ImageClassificationListDataset,
+ dataset: type[ImageClassificationListDataset] = ImageClassificationListDataset,
name: str = "classification_datamodule",
num_workers: int = 8,
batch_size: int = 32,
seed: int = 42,
- val_size: Optional[float] = 0.2,
+ val_size: float | None = 0.2,
test_size: float = 0.2,
- num_data_class: Optional[int] = None,
- exclude_filter: Optional[List[str]] = None,
- include_filter: Optional[List[str]] = None,
- label_map: Optional[Dict[str, Any]] = None,
+ num_data_class: int | None = None,
+ exclude_filter: list[str] | None = None,
+ include_filter: list[str] | None = None,
+ label_map: dict[str, Any] | None = None,
load_aug_images: bool = False,
- aug_name: Optional[str] = None,
- n_aug_to_take: Optional[int] = 4,
- replace_str_from: Optional[str] = None,
- replace_str_to: Optional[str] = None,
- train_transform: Optional[albumentations.Compose] = None,
- val_transform: Optional[albumentations.Compose] = None,
- test_transform: Optional[albumentations.Compose] = None,
- train_split_file: Optional[str] = None,
- test_split_file: Optional[str] = None,
- val_split_file: Optional[str] = None,
- class_to_idx: Optional[Dict[str, int]] = None,
+ aug_name: str | None = None,
+ n_aug_to_take: int | None = 4,
+ replace_str_from: str | None = None,
+ replace_str_to: str | None = None,
+ train_transform: albumentations.Compose | None = None,
+ val_transform: albumentations.Compose | None = None,
+ test_transform: albumentations.Compose | None = None,
+ train_split_file: str | None = None,
+ test_split_file: str | None = None,
+ val_split_file: str | None = None,
+ class_to_idx: dict[str, int] | None = None,
**kwargs: Any,
):
super().__init__(
@@ -105,7 +108,7 @@ def __init__(
self.train_split_file = train_split_file
self.test_split_file = test_split_file
self.val_split_file = val_split_file
- self.class_to_idx: Optional[Dict[str, int]]
+ self.class_to_idx: dict[str, int] | None
if class_to_idx is not None:
self.class_to_idx = class_to_idx
@@ -118,7 +121,7 @@ def __init__(
else:
self.num_classes = len(self.class_to_idx)
- def _read_split(self, split_file: str) -> Tuple[List[str], List[str]]:
+ def _read_split(self, split_file: str) -> tuple[list[str], list[str]]:
"""Reads split file.
Args:
@@ -128,7 +131,7 @@ def _read_split(self, split_file: str) -> Tuple[List[str], List[str]]:
List of paths to images.
"""
samples, targets = [], []
- with open(split_file, "r") as f:
+ with open(split_file) as f:
split = f.readlines()
for row in split:
csv_values = row.split(",")
@@ -143,7 +146,7 @@ def _read_split(self, split_file: str) -> Tuple[List[str], List[str]]:
# log.warning(f"{sample_path} does not exist")
return samples, targets
- def _find_classes_from_data_path(self, data_path: str) -> Optional[Dict[str, int]]:
+ def _find_classes_from_data_path(self, data_path: str) -> dict[str, int] | None:
"""Given a data_path, build a random class_to_idx from the subdirectories.
Args:
@@ -161,22 +164,27 @@ def _find_classes_from_data_path(self, data_path: str) -> Optional[Dict[str, int
item_path = os.path.join(data_path, item)
# Check if it's a directory and not starting with "."
- if os.path.isdir(item_path) and not item.startswith("."):
+ if (
+ os.path.isdir(item_path)
+ and not item.startswith(".")
# Check if there's at least one image file in the subdirectory
- if any(
+ and any(
os.path.splitext(file)[1].lower().endswith(tuple(utils.IMAGE_EXTENSIONS))
for file in os.listdir(item_path)
- ):
- subdirectories.append(item)
+ )
+ ):
+ subdirectories.append(item)
+
if len(subdirectories) > 0:
return {cl: idx for idx, cl in enumerate(sorted(subdirectories))}
return None
+
return None
@staticmethod
def _find_images_and_targets(
- root_folder: str, class_to_idx: Optional[Dict[str, int]] = None
- ) -> Tuple[List[Tuple[str, int]], Dict[str, int]]:
+ root_folder: str, class_to_idx: dict[str, int] | None = None
+ ) -> tuple[list[tuple[str, int]], dict[str, int]]:
"""Collects the samples from item folders."""
images_and_targets, class_to_idx = find_images_and_targets(
folder=root_folder, types=utils.IMAGE_EXTENSIONS, class_to_idx=class_to_idx
@@ -184,14 +192,14 @@ def _find_images_and_targets(
return images_and_targets, class_to_idx
def _filter_images_and_targets(
- self, images_and_targets: List[Tuple[str, int]], class_to_idx: Dict[str, int]
- ) -> Tuple[List[str], List[str]]:
+ self, images_and_targets: list[tuple[str, int]], class_to_idx: dict[str, int]
+ ) -> tuple[list[str], list[str]]:
"""Filters the images and targets."""
- samples: List[str] = []
- targets: List[str] = []
+ samples: list[str] = []
+ targets: list[str] = []
idx_to_class = {v: k for k, v in class_to_idx.items()}
+ images_and_targets = [(str(image_path), target) for image_path, target in images_and_targets]
for image_path, target in images_and_targets:
- image_path = str(image_path)
target_class = idx_to_class[target]
if self.exclude_filter is not None and any(
exclude_filter in image_path for exclude_filter in self.exclude_filter
@@ -216,12 +224,12 @@ def _prepare_data(self) -> None:
if self.label_map is not None:
all_targets, _ = group_labels(all_targets, self.label_map)
- samples_train: List[str] = []
- targets_train: List[str] = []
- samples_test: List[str] = []
- targets_test: List[str] = []
- samples_val: List[str] = []
- targets_val: List[str] = []
+ samples_train: list[str] = []
+ targets_train: list[str] = []
+ samples_test: list[str] = []
+ targets_test: list[str] = []
+ samples_val: list[str] = []
+ targets_val: list[str] = []
if self.test_size < 1.0:
samples_train, samples_test, targets_train, targets_test = train_test_split(
@@ -295,7 +303,7 @@ def _prepare_data(self) -> None:
# )
unique_targets = [str(t) for t in np.unique(targets_train)]
if self.class_to_idx is None:
- sorted_targets = list(sorted(unique_targets, key=natural_key))
+ sorted_targets = sorted(unique_targets, key=natural_key)
class_to_idx = {c: idx for idx, c in enumerate(sorted_targets)}
self.class_to_idx = class_to_idx
log.info("Class_to_idx not provided in config, building it from targets: %s", class_to_idx)
@@ -308,13 +316,13 @@ def _prepare_data(self) -> None:
"The number of classes in the class_to_idx dictionary does not match the number of unique targets."
f" `class_to_idx`: {self.class_to_idx}, `unique_targets`: {unique_targets}"
)
- if not all(c in unique_targets for c in self.class_to_idx.keys()):
+ if not all(c in unique_targets for c in self.class_to_idx):
raise ValueError(
"The classes in the class_to_idx dictionary do not match the available unique targets in the"
" datasset. `class_to_idx`: {self.class_to_idx}, `unique_targets`: {unique_targets}"
)
- def setup(self, stage: Optional[str] = None) -> None:
+ def setup(self, stage: str | None = None) -> None:
"""Setup data module based on stages of training."""
if stage in ["train", "fit"]:
self.train_dataset = self.dataset(
@@ -447,26 +455,26 @@ class SklearnClassificationDataModule(BaseDataModule):
def __init__(
self,
data_path: str,
- exclude_filter: Optional[List[str]] = None,
- include_filter: Optional[List[str]] = None,
+ exclude_filter: list[str] | None = None,
+ include_filter: list[str] | None = None,
val_size: float = 0.2,
- class_to_idx: Optional[Dict[str, int]] = None,
- label_map: Optional[Dict[str, Any]] = None,
+ class_to_idx: dict[str, int] | None = None,
+ label_map: dict[str, Any] | None = None,
seed: int = 42,
batch_size: int = 32,
num_workers: int = 6,
- train_transform: Optional[albumentations.Compose] = None,
- val_transform: Optional[albumentations.Compose] = None,
- test_transform: Optional[albumentations.Compose] = None,
- roi: Optional[Tuple[int, int, int, int]] = None,
+ train_transform: albumentations.Compose | None = None,
+ val_transform: albumentations.Compose | None = None,
+ test_transform: albumentations.Compose | None = None,
+ roi: tuple[int, int, int, int] | None = None,
n_splits: int = 1,
phase: str = "train",
cache: bool = False,
- limit_training_data: Optional[int] = None,
- train_split_file: Optional[str] = None,
- test_split_file: Optional[str] = None,
+ limit_training_data: int | None = None,
+ train_split_file: str | None = None,
+ test_split_file: str | None = None,
name: str = "sklearn_classification_datamodule",
- dataset: Type[ImageClassificationListDataset] = ImageClassificationListDataset,
+ dataset: type[ImageClassificationListDataset] = ImageClassificationListDataset,
**kwargs: Any,
):
super().__init__(
@@ -496,8 +504,8 @@ def __init__(
self.val_size = val_size
self.label_map = label_map
self.full_dataset: ImageClassificationListDataset
- self.train_dataset: List[ImageClassificationListDataset]
- self.val_dataset: List[ImageClassificationListDataset]
+ self.train_dataset: list[ImageClassificationListDataset]
+ self.val_dataset: list[ImageClassificationListDataset]
def _prepare_data(self) -> None:
"""Prepares the data for the data module."""
@@ -598,7 +606,7 @@ def predict_dataloader(self) -> DataLoader:
"""Returns a dataloader used for predictions."""
return self.test_dataloader()
- def train_dataloader(self) -> List[DataLoader]:
+ def train_dataloader(self) -> list[DataLoader]:
"""Returns a list of train dataloader.
Raises:
@@ -624,7 +632,7 @@ def train_dataloader(self) -> List[DataLoader]:
)
return loader
- def val_dataloader(self) -> List[DataLoader]:
+ def val_dataloader(self) -> list[DataLoader]:
"""Returns a list of validation dataloader.
Raises:
@@ -736,23 +744,23 @@ class MultilabelClassificationDataModule(BaseDataModule):
def __init__(
self,
data_path: str,
- images_and_labels_file: Optional[str] = None,
- train_split_file: Optional[str] = None,
- test_split_file: Optional[str] = None,
- val_split_file: Optional[str] = None,
+ images_and_labels_file: str | None = None,
+ train_split_file: str | None = None,
+ test_split_file: str | None = None,
+ val_split_file: str | None = None,
name: str = "multilabel_datamodule",
dataset: Callable = MultilabelClassificationDataset,
- num_classes: Optional[int] = None,
+ num_classes: int | None = None,
num_workers: int = 16,
batch_size: int = 64,
test_batch_size: int = 64,
seed: int = 42,
- val_size: Optional[float] = 0.2,
- test_size: Optional[float] = 0.2,
- train_transform: Optional[albumentations.Compose] = None,
- val_transform: Optional[albumentations.Compose] = None,
- test_transform: Optional[albumentations.Compose] = None,
- class_to_idx: Optional[Dict[str, int]] = None,
+ val_size: float | None = 0.2,
+ test_size: float | None = 0.2,
+ train_transform: albumentations.Compose | None = None,
+ val_transform: albumentations.Compose | None = None,
+ test_transform: albumentations.Compose | None = None,
+ class_to_idx: dict[str, int] | None = None,
**kwargs,
):
super().__init__(
@@ -785,7 +793,7 @@ def __init__(
self.val_dataset: MultilabelClassificationDataset
self.test_dataset: MultilabelClassificationDataset
- def _read_split(self, split_file: str) -> Tuple[List[str], List[List[str]]]:
+ def _read_split(self, split_file: str) -> tuple[list[str], list[list[str]]]:
"""Reads split file.
Args:
@@ -895,7 +903,7 @@ def _prepare_data(self) -> None:
test_df["split"] = "test"
self.data = pd.concat([train_df, val_df, test_df], axis=0)
- def setup(self, stage: Optional[str] = None) -> None:
+ def setup(self, stage: str | None = None) -> None:
"""Setup data module based on stages of training."""
if stage in ["train", "fit"]:
train_samples = self.data[self.data["split"] == "train"]["samples"].tolist()
diff --git a/quadra/datamodules/generic/imagenette.py b/quadra/datamodules/generic/imagenette.py
index 88554cf2..32b902cd 100644
--- a/quadra/datamodules/generic/imagenette.py
+++ b/quadra/datamodules/generic/imagenette.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
import os
import shutil
-from typing import Any, Dict, Optional
+from typing import Any
import pandas as pd
from sklearn.model_selection import train_test_split
@@ -46,7 +48,7 @@ def __init__(
name: str = "imagenette_classification_datamodule",
imagenette_version: str = "320",
force_download: bool = False,
- class_to_idx: Optional[Dict[str, int]] = None,
+ class_to_idx: dict[str, int] | None = None,
**kwargs: Any,
):
if imagenette_version not in ["320", "160", "full"]:
@@ -135,12 +137,8 @@ class ImagenetteSSLDataModule(ImagenetteClassificationDataModule, SSLDataModule)
def __init__(
self,
- *args,
+ *args: Any,
name="imagenette_ssl",
- **kwargs,
+ **kwargs: Any,
):
- super().__init__(
- name=name,
- *args,
- **kwargs,
- )
+ super().__init__(*args, name=name, **kwargs) # type: ignore[misc]
diff --git a/quadra/datamodules/generic/mnist.py b/quadra/datamodules/generic/mnist.py
index 89477d2c..b00be49c 100644
--- a/quadra/datamodules/generic/mnist.py
+++ b/quadra/datamodules/generic/mnist.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
import os
import shutil
-from typing import Any, Optional
+from typing import Any
import cv2
from torchvision.datasets.mnist import MNIST
@@ -15,7 +17,7 @@ class MNISTAnomalyDataModule(AnomalyDataModule):
"""Standard anomaly datamodule with automatic download of the MNIST dataset."""
def __init__(
- self, data_path: str, good_number: int, limit_data: int = 100, category: Optional[str] = None, **kwargs: Any
+ self, data_path: str, good_number: int, limit_data: int = 100, category: str | None = None, **kwargs: Any
):
"""Initialize the MNIST anomaly datamodule.
diff --git a/quadra/datamodules/generic/oxford_pet.py b/quadra/datamodules/generic/oxford_pet.py
index 066b82bc..ca8f5406 100644
--- a/quadra/datamodules/generic/oxford_pet.py
+++ b/quadra/datamodules/generic/oxford_pet.py
@@ -1,5 +1,7 @@
+from __future__ import annotations
+
import os
-from typing import Any, Dict, Optional, Type
+from typing import Any
import albumentations
import cv2
@@ -36,17 +38,17 @@ class OxfordPetSegmentationDataModule(SegmentationMulticlassDataModule):
def __init__(
self,
data_path: str,
- idx_to_class: Dict,
+ idx_to_class: dict,
name: str = "oxford_pet_segmentation_datamodule",
- dataset: Type[SegmentationDatasetMulticlass] = SegmentationDatasetMulticlass,
+ dataset: type[SegmentationDatasetMulticlass] = SegmentationDatasetMulticlass,
batch_size: int = 32,
test_size: float = 0.3,
val_size: float = 0.3,
seed: int = 42,
num_workers: int = 6,
- train_transform: Optional[albumentations.Compose] = None,
- test_transform: Optional[albumentations.Compose] = None,
- val_transform: Optional[albumentations.Compose] = None,
+ train_transform: albumentations.Compose | None = None,
+ test_transform: albumentations.Compose | None = None,
+ val_transform: albumentations.Compose | None = None,
**kwargs: Any,
):
super().__init__(
@@ -88,11 +90,7 @@ def _preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
def _check_exists(self, image_folder: str, annotation_folder: str) -> bool:
"""Check if the dataset is already downloaded."""
- for folder in (image_folder, annotation_folder):
- if not (os.path.exists(folder) and os.path.isdir(folder)):
- return False
-
- return True
+ return all(os.path.exists(folder) and os.path.isdir(folder) for folder in (image_folder, annotation_folder))
def download_data(self):
"""Download the dataset if it is not already downloaded."""
@@ -102,7 +100,7 @@ def download_data(self):
for url, md5 in self._RESOURCES:
download_and_extract_archive(url, download_root=self.data_path, md5=md5, remove_finished=True)
log.info("Fixing corrupted files...")
- images_filenames = list(sorted(os.listdir(image_folder)))
+ images_filenames = sorted(os.listdir(image_folder))
for filename in images_filenames:
file_wo_ext = os.path.splitext(os.path.basename(filename))[0]
try:
diff --git a/quadra/datamodules/patch.py b/quadra/datamodules/patch.py
index 3311718f..4396fb70 100644
--- a/quadra/datamodules/patch.py
+++ b/quadra/datamodules/patch.py
@@ -1,6 +1,7 @@
+from __future__ import annotations
+
import json
import os
-from typing import Dict, List, Optional
import albumentations
import pandas as pd
@@ -35,19 +36,19 @@ class PatchSklearnClassificationDataModule(BaseDataModule):
def __init__(
self,
data_path: str,
- class_to_idx: Dict,
+ class_to_idx: dict,
name: str = "patch_classification_datamodule",
train_filename: str = "dataset.txt",
- exclude_filter: Optional[List[str]] = None,
- include_filter: Optional[List[str]] = None,
+ exclude_filter: list[str] | None = None,
+ include_filter: list[str] | None = None,
seed: int = 42,
batch_size: int = 32,
num_workers: int = 6,
- train_transform: Optional[albumentations.Compose] = None,
- val_transform: Optional[albumentations.Compose] = None,
- test_transform: Optional[albumentations.Compose] = None,
+ train_transform: albumentations.Compose | None = None,
+ val_transform: albumentations.Compose | None = None,
+ test_transform: albumentations.Compose | None = None,
balance_classes: bool = False,
- class_to_skip_training: Optional[list] = None,
+ class_to_skip_training: list | None = None,
**kwargs,
):
super().__init__(
@@ -79,12 +80,12 @@ def __init__(
def _prepare_data(self):
"""Prepare data function."""
if os.path.isfile(os.path.join(self.data_path, "info.json")):
- with open(os.path.join(self.data_path, "info.json"), "r") as f:
+ with open(os.path.join(self.data_path, "info.json")) as f:
self.info = PatchDatasetInfo(**json.load(f))
else:
raise FileNotFoundError("No `info.json` file found in the dataset folder")
- split_df_list: List[pd.DataFrame] = []
+ split_df_list: list[pd.DataFrame] = []
if os.path.isfile(os.path.join(self.train_folder, self.train_filename)):
train_samples, train_labels = load_train_file(
train_file_path=os.path.join(self.train_folder, self.train_filename),
@@ -119,7 +120,7 @@ def _prepare_data(self):
raise ValueError("No data found in all split folders")
self.data = pd.concat(split_df_list, axis=0)
- def setup(self, stage: Optional[str] = None) -> None:
+ def setup(self, stage: str | None = None) -> None:
"""Setup function."""
if stage == "fit":
self.train_dataset = PatchSklearnClassificationTrainDataset(
diff --git a/quadra/datamodules/segmentation.py b/quadra/datamodules/segmentation.py
index a9bedc1c..efa5cdc8 100644
--- a/quadra/datamodules/segmentation.py
+++ b/quadra/datamodules/segmentation.py
@@ -1,8 +1,10 @@
# pylint: disable=unsubscriptable-object,unsupported-assignment-operation,unsupported-membership-test
+from __future__ import annotations
+
import glob
import os
import random
-from typing import Any, Dict, List, Optional, Tuple, Type
+from typing import Any
import albumentations
import cv2
@@ -48,16 +50,16 @@ def __init__(
test_size: float = 0.3,
val_size: float = 0.3,
seed: int = 42,
- dataset: Type[SegmentationDataset] = SegmentationDataset,
+ dataset: type[SegmentationDataset] = SegmentationDataset,
batch_size: int = 32,
num_workers: int = 6,
- train_transform: Optional[albumentations.Compose] = None,
- test_transform: Optional[albumentations.Compose] = None,
- val_transform: Optional[albumentations.Compose] = None,
- train_split_file: Optional[str] = None,
- test_split_file: Optional[str] = None,
- val_split_file: Optional[str] = None,
- num_data_class: Optional[int] = None,
+ train_transform: albumentations.Compose | None = None,
+ test_transform: albumentations.Compose | None = None,
+ val_transform: albumentations.Compose | None = None,
+ train_split_file: str | None = None,
+ test_split_file: str | None = None,
+ val_split_file: str | None = None,
+ num_data_class: int | None = None,
exclude_good: bool = False,
**kwargs: Any,
):
@@ -104,7 +106,7 @@ def _resolve_label(path: str) -> int:
return 1
- def _read_folder(self, data_path: str) -> Tuple[List[str], List[int], List[str]]:
+ def _read_folder(self, data_path: str) -> tuple[list[str], list[int], list[str]]:
"""Read a folder containing images and masks subfolders.
Args:
@@ -137,7 +139,7 @@ def _read_folder(self, data_path: str) -> Tuple[List[str], List[int], List[str]]
return samples, targets, masks
- def _read_split(self, split_file: str) -> Tuple[List[str], List[int], List[str]]:
+ def _read_split(self, split_file: str) -> tuple[list[str], list[int], list[str]]:
"""Reads split file.
Args:
@@ -147,7 +149,7 @@ def _read_split(self, split_file: str) -> Tuple[List[str], List[int], List[str]]
List of paths to images, List of labels.
"""
samples, targets, masks = [], [], []
- with open(split_file, "r") as f:
+ with open(split_file) as f:
split = f.read().splitlines()
for sample in split:
sample_path = os.path.join(self.data_path, sample)
@@ -398,22 +400,22 @@ class SegmentationMulticlassDataModule(BaseDataModule):
def __init__(
self,
data_path: str,
- idx_to_class: Dict,
+ idx_to_class: dict,
name: str = "multiclass_segmentation_datamodule",
- dataset: Type[SegmentationDatasetMulticlass] = SegmentationDatasetMulticlass,
+ dataset: type[SegmentationDatasetMulticlass] = SegmentationDatasetMulticlass,
batch_size: int = 32,
test_size: float = 0.3,
val_size: float = 0.3,
seed: int = 42,
num_workers: int = 6,
- train_transform: Optional[albumentations.Compose] = None,
- test_transform: Optional[albumentations.Compose] = None,
- val_transform: Optional[albumentations.Compose] = None,
- train_split_file: Optional[str] = None,
- test_split_file: Optional[str] = None,
- val_split_file: Optional[str] = None,
+ train_transform: albumentations.Compose | None = None,
+ test_transform: albumentations.Compose | None = None,
+ val_transform: albumentations.Compose | None = None,
+ train_split_file: str | None = None,
+ test_split_file: str | None = None,
+ val_split_file: str | None = None,
exclude_good: bool = False,
- num_data_train: Optional[int] = None,
+ num_data_train: int | None = None,
one_hot_encoding: bool = False,
**kwargs: Any,
):
@@ -471,7 +473,7 @@ def _resolve_label(self, path: str) -> np.ndarray:
return one_hot
- def _read_folder(self, data_path: str) -> Tuple[List[str], List[np.ndarray], List[str]]:
+ def _read_folder(self, data_path: str) -> tuple[list[str], list[np.ndarray], list[str]]:
"""Read a folder containing images and masks subfolders.
Args:
@@ -504,7 +506,7 @@ def _read_folder(self, data_path: str) -> Tuple[List[str], List[np.ndarray], Lis
return samples, targets, masks
- def _read_split(self, split_file: str) -> Tuple[List[str], List[np.ndarray], List[str]]:
+ def _read_split(self, split_file: str) -> tuple[list[str], list[np.ndarray], list[str]]:
"""Reads split file.
Args:
@@ -514,7 +516,7 @@ def _read_split(self, split_file: str) -> Tuple[List[str], List[np.ndarray], Lis
List of paths to images, labels and mask paths.
"""
samples, targets, masks = [], [], []
- with open(split_file, "r") as f:
+ with open(split_file) as f:
split = f.read().splitlines()
for sample in split:
sample_path = os.path.join(self.data_path, sample)
diff --git a/quadra/datamodules/ssl.py b/quadra/datamodules/ssl.py
index 6d7e727a..14762b7a 100644
--- a/quadra/datamodules/ssl.py
+++ b/quadra/datamodules/ssl.py
@@ -1,5 +1,7 @@
# pylint: disable=unsubscriptable-object
-from typing import Any, Optional, Union
+from __future__ import annotations
+
+from typing import Any
import numpy as np
import torch
@@ -28,7 +30,7 @@ class SSLDataModule(ClassificationDataModule):
def __init__(
self,
data_path: str,
- augmentation_dataset: Union[TwoAugmentationDataset, TwoSetAugmentationDataset],
+ augmentation_dataset: TwoAugmentationDataset | TwoSetAugmentationDataset,
name: str = "ssl_datamodule",
split_validation: bool = True,
**kwargs: Any,
@@ -39,10 +41,10 @@ def __init__(
**kwargs,
)
self.augmentation_dataset = augmentation_dataset
- self.classifier_train_dataset: Optional[torch.utils.data.Dataset] = None
+ self.classifier_train_dataset: torch.utils.data.Dataset | None = None
self.split_validation = split_validation
- def setup(self, stage: Optional[str] = None) -> None:
+ def setup(self, stage: str | None = None) -> None:
"""Setup data module based on stages of training."""
if stage == "fit":
self.train_dataset = self.dataset(
diff --git a/quadra/datasets/anomaly.py b/quadra/datasets/anomaly.py
index 81ec294d..5ba04ac2 100644
--- a/quadra/datasets/anomaly.py
+++ b/quadra/datasets/anomaly.py
@@ -1,7 +1,8 @@
+from __future__ import annotations
+
import os
import random
from pathlib import Path
-from typing import Dict, Optional, Tuple, Union
import albumentations as alb
import cv2
@@ -77,10 +78,10 @@ def split_normal_images_in_train_set(samples: DataFrame, split_ratio: float = 0.
def make_anomaly_dataset(
path: Path,
- split: Optional[str] = None,
+ split: str | None = None,
split_ratio: float = 0.1,
seed: int = 0,
- mask_suffix: Optional[str] = None,
+ mask_suffix: str | None = None,
create_test_set_if_empty: bool = True,
) -> DataFrame:
"""Create dataframe by parsing a folder following the MVTec data file structure.
@@ -202,8 +203,8 @@ def __init__(
transform: alb.Compose,
samples: DataFrame,
task: str = "segmentation",
- valid_area_mask: Optional[str] = None,
- crop_area: Optional[Tuple[int, int, int, int]] = None,
+ valid_area_mask: str | None = None,
+ crop_area: tuple[int, int, int, int] | None = None,
) -> None:
self.task = task
self.transform = transform
@@ -213,19 +214,19 @@ def __init__(
self.split = self.samples.split.unique()[0]
self.crop_area = crop_area
- self.valid_area_mask: Optional[np.ndarray] = None
+ self.valid_area_mask: np.ndarray | None = None
if valid_area_mask is not None:
if not os.path.exists(valid_area_mask):
raise RuntimeError(f"Valid area mask {valid_area_mask} does not exist.")
- self.valid_area_mask = cv2.imread(valid_area_mask, 0) > 0
+ self.valid_area_mask = cv2.imread(valid_area_mask, 0) > 0 # type: ignore[operator]
def __len__(self) -> int:
"""Get length of the dataset."""
return len(self.samples)
- def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
+ def __getitem__(self, index: int) -> dict[str, str | Tensor]:
"""Get dataset item for the index ``index``.
Args:
@@ -235,7 +236,7 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
Dict of image tensor during training.
Otherwise, Dict containing image path, target path, image tensor, label and transformed bounding box.
"""
- item: Dict[str, Union[str, Tensor]] = {}
+ item: dict[str, str | Tensor] = {}
image_path = self.samples.samples.iloc[index]
image = cv2.imread(image_path)
@@ -263,12 +264,11 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
# If good images have no associated mask create an empty one
if label_index == 0:
mask = np.zeros(shape=original_image_shape[:2])
+ elif os.path.isfile(mask_path):
+ mask = cv2.imread(mask_path, flags=0) / 255.0 # type: ignore[operator]
else:
- if os.path.isfile(mask_path):
- mask = cv2.imread(mask_path, flags=0) / 255.0
- else:
- # We need ones in the mask to compute correctly at least image level f1 score
- mask = np.ones(shape=original_image_shape[:2])
+ # We need ones in the mask to compute correctly at least image level f1 score
+ mask = np.ones(shape=original_image_shape[:2])
if self.valid_area_mask is not None:
mask = mask * self.valid_area_mask
diff --git a/quadra/datasets/classification.py b/quadra/datasets/classification.py
index fdc6692e..11c8b490 100644
--- a/quadra/datasets/classification.py
+++ b/quadra/datasets/classification.py
@@ -1,5 +1,7 @@
+from __future__ import annotations
+
import warnings
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from collections.abc import Callable
import cv2
import numpy as np
@@ -37,15 +39,15 @@ class ImageClassificationListDataset(Dataset):
def __init__(
self,
- samples: List[str],
- targets: List[Union[str, int]],
- class_to_idx: Optional[Dict] = None,
- resize: Optional[int] = None,
- roi: Optional[Tuple[int, int, int, int]] = None,
- transform: Optional[Callable] = None,
+ samples: list[str],
+ targets: list[str | int],
+ class_to_idx: dict | None = None,
+ resize: int | None = None,
+ roi: tuple[int, int, int, int] | None = None,
+ transform: Callable | None = None,
rgb: bool = True,
channel: int = 3,
- allow_missing_label: Optional[bool] = False,
+ allow_missing_label: bool | None = False,
):
super().__init__()
assert len(samples) == len(
@@ -64,6 +66,7 @@ def __init__(
"be careful because None labels will not work inside Dataloaders"
),
UserWarning,
+ stacklevel=2,
)
targets = [-1 if target is None else target for target in targets]
@@ -87,7 +90,7 @@ def __init__(
self.transform = transform
- def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
+ def __getitem__(self, idx) -> tuple[np.ndarray, np.ndarray]:
path, y = self.samples[idx]
# Load image
@@ -137,12 +140,12 @@ class ClassificationDataset(ImageClassificationListDataset):
def __init__(
self,
- samples: List[str],
- targets: List[Union[str, int]],
- class_to_idx: Optional[Dict] = None,
- resize: Optional[int] = None,
- roi: Optional[Tuple[int, int, int, int]] = None,
- transform: Optional[Callable] = None,
+ samples: list[str],
+ targets: list[str | int],
+ class_to_idx: dict | None = None,
+ resize: int | None = None,
+ roi: tuple[int, int, int, int] | None = None,
+ transform: Callable | None = None,
rgb: bool = True,
channel: int = 3,
random_padding: bool = False,
@@ -191,10 +194,10 @@ class MultilabelClassificationDataset(torch.utils.data.Dataset):
def __init__(
self,
- samples: List[str],
+ samples: list[str],
targets: np.ndarray,
- class_to_idx: Optional[Dict] = None,
- transform: Optional[Callable] = None,
+ class_to_idx: dict | None = None,
+ transform: Callable | None = None,
rgb: bool = True,
):
super().__init__()
diff --git a/quadra/datasets/patch.py b/quadra/datasets/patch.py
index f4d34c9e..1a4917a1 100644
--- a/quadra/datasets/patch.py
+++ b/quadra/datasets/patch.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
import os
import random
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from collections.abc import Callable
import cv2
import h5py
@@ -33,11 +35,11 @@ class PatchSklearnClassificationTrainDataset(Dataset):
def __init__(
self,
data_path: str,
- samples: List[str],
- targets: List[Union[str, int]],
- class_to_idx: Optional[Dict] = None,
- resize: Optional[int] = None,
- transform: Optional[Callable] = None,
+ samples: list[str],
+ targets: list[str | int],
+ class_to_idx: dict | None = None,
+ resize: int | None = None,
+ transform: Callable | None = None,
rgb: bool = True,
channel: int = 3,
balance_classes: bool = False,
@@ -51,8 +53,8 @@ def __init__(
if balance_classes:
samples_array = np.array(samples)
targets_array = np.array(targets)
- samples_to_use: List[str] = []
- targets_to_use: List[Union[str, int]] = []
+ samples_to_use: list[str] = []
+ targets_to_use: list[str | int] = []
cls, counts = np.unique(targets_array, return_counts=True)
max_count = np.max(counts)
@@ -88,7 +90,7 @@ def __init__(
self.transform = transform
- def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
+ def __getitem__(self, idx) -> tuple[np.ndarray, np.ndarray]:
path, y = self.samples[idx]
h5_file = h5py.File(path)
diff --git a/quadra/datasets/segmentation.py b/quadra/datasets/segmentation.py
index 267b72b3..529a0024 100644
--- a/quadra/datasets/segmentation.py
+++ b/quadra/datasets/segmentation.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
import os
-from typing import Any, Callable, Dict, List, Optional, Union
+from collections.abc import Callable
+from typing import Any
import albumentations
import cv2
@@ -31,16 +34,16 @@ class SegmentationDataset(torch.utils.data.Dataset):
def __init__(
self,
- image_paths: List[str],
- mask_paths: List[str],
- batch_size: Optional[int] = None,
- object_masks: Optional[List[Union[np.ndarray, Any]]] = None,
+ image_paths: list[str],
+ mask_paths: list[str],
+ batch_size: int | None = None,
+ object_masks: list[np.ndarray | Any] | None = None,
resize: int = 224,
- mask_preprocess: Optional[Callable] = None,
- labels: Optional[List[str]] = None,
- transform: Optional[albumentations.Compose] = None,
+ mask_preprocess: Callable | None = None,
+ labels: list[str] | None = None,
+ transform: albumentations.Compose | None = None,
mask_smoothing: bool = False,
- defect_transform: Optional[albumentations.Compose] = None,
+ defect_transform: albumentations.Compose | None = None,
):
self.transform = transform
self.defect_transform = defect_transform
@@ -82,14 +85,13 @@ def __getitem__(self, index):
else:
mask_path = self.mask_paths[index]
mask = cv2.imread(str(mask_path), 0)
- if self.defect_transform is not None:
- if label == 1 and np.sum(mask) == 0:
- if object_mask is not None:
- object_mask *= 255
- aug = self.defect_transform(image=image, mask=mask, object_mask=object_mask, label=label)
- image = aug["image"]
- mask = aug["mask"]
- label = aug["label"]
+ if self.defect_transform is not None and label == 1 and np.sum(mask) == 0:
+ if object_mask is not None:
+ object_mask *= 255
+ aug = self.defect_transform(image=image, mask=mask, object_mask=object_mask, label=label)
+ image = aug["image"]
+ mask = aug["mask"]
+ label = aug["label"]
if self.mask_preprocess:
mask = self.mask_preprocess(mask)
if object_mask is not None:
@@ -151,11 +153,11 @@ class SegmentationDatasetMulticlass(torch.utils.data.Dataset):
def __init__(
self,
- image_paths: List[str],
- mask_paths: List[str],
- idx_to_class: Dict,
- batch_size: Optional[int] = None,
- transform: Optional[albumentations.Compose] = None,
+ image_paths: list[str],
+ mask_paths: list[str],
+ idx_to_class: dict,
+ batch_size: int | None = None,
+ transform: albumentations.Compose | None = None,
one_hot: bool = False,
):
self.transform = transform
diff --git a/quadra/datasets/ssl.py b/quadra/datasets/ssl.py
index 14e74b5d..d5b1b2ec 100644
--- a/quadra/datasets/ssl.py
+++ b/quadra/datasets/ssl.py
@@ -1,7 +1,8 @@
+from __future__ import annotations
+
import random
from collections.abc import Iterable
from enum import Enum
-from typing import Tuple, Union
import albumentations as A
import numpy as np
@@ -29,7 +30,7 @@ class TwoAugmentationDataset(Dataset):
def __init__(
self,
dataset: Dataset,
- transform: Union[A.Compose, Tuple[A.Compose, A.Compose]],
+ transform: A.Compose | tuple[A.Compose, A.Compose],
strategy: AugmentationStrategy = AugmentationStrategy.SAME_IMAGE,
):
self.dataset = dataset
@@ -82,7 +83,7 @@ class TwoSetAugmentationDataset(Dataset):
def __init__(
self,
dataset: Dataset,
- global_transforms: Tuple[A.Compose, A.Compose],
+ global_transforms: tuple[A.Compose, A.Compose],
local_transform: A.Compose,
num_local_transforms: int,
):
diff --git a/quadra/losses/classification/focal.py b/quadra/losses/classification/focal.py
index 6080bdbf..1c9da49b 100644
--- a/quadra/losses/classification/focal.py
+++ b/quadra/losses/classification/focal.py
@@ -1,5 +1,6 @@
+from __future__ import annotations
+
import warnings
-from typing import Optional
import torch
import torch.nn.functional as F
@@ -9,8 +10,8 @@
def one_hot(
labels: torch.Tensor,
num_classes: int,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
+ device: torch.device | None = None,
+ dtype: torch.dtype | None = None,
eps: float = 1e-6,
) -> torch.Tensor:
r"""Convert an integer label x-D tensor to a one-hot (x+1)-D tensor.
@@ -61,7 +62,7 @@ def focal_loss(
alpha: float,
gamma: float = 2.0,
reduction: str = "none",
- eps: Optional[float] = None,
+ eps: float | None = None,
) -> torch.Tensor:
r"""Criterion that computes Focal loss.
@@ -187,12 +188,12 @@ class FocalLoss(nn.Module):
>>> output.backward()
"""
- def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = "none", eps: Optional[float] = None) -> None:
+ def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = "none", eps: float | None = None) -> None:
super().__init__()
self.alpha: float = alpha
self.gamma: float = gamma
self.reduction: str = reduction
- self.eps: Optional[float] = eps
+ self.eps: float | None = eps
def forward(self, input_tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward call computation."""
@@ -205,7 +206,7 @@ def binary_focal_loss_with_logits(
alpha: float = 0.25,
gamma: float = 2.0,
reduction: str = "none",
- eps: Optional[float] = None,
+ eps: float | None = None,
) -> torch.Tensor:
r"""Function that computes Binary Focal loss.
diff --git a/quadra/losses/classification/prototypical.py b/quadra/losses/classification/prototypical.py
index 7ad846c8..bb7b966a 100644
--- a/quadra/losses/classification/prototypical.py
+++ b/quadra/losses/classification/prototypical.py
@@ -1,5 +1,4 @@
-# coding=utf-8
-from typing import Optional
+from __future__ import annotations
import torch
from torch.nn import functional as F
@@ -80,7 +79,7 @@ def prototypical_loss(
coords: torch.Tensor,
target: torch.Tensor,
n_support: int,
- prototypes: Optional[torch.Tensor] = None,
+ prototypes: torch.Tensor | None = None,
sen: bool = True,
eps_pos: float = 1.0,
eps_neg: float = -1e-7,
diff --git a/quadra/metrics/segmentation.py b/quadra/metrics/segmentation.py
index 5ef351ac..ea20c803 100644
--- a/quadra/metrics/segmentation.py
+++ b/quadra/metrics/segmentation.py
@@ -1,14 +1,16 @@
-from typing import List, Tuple, cast
+from __future__ import annotations
+
+from typing import cast
import numpy as np
import torch
from scipy.optimize import linear_sum_assignment
-from skimage.measure import label, regionprops
+from skimage.measure import label, regionprops # pylint: disable=no-name-in-module
from quadra.utils.evaluation import dice
-def _pad_to_shape(a: np.ndarray, shape: Tuple, constant_values: int = 0) -> np.ndarray:
+def _pad_to_shape(a: np.ndarray, shape: tuple, constant_values: int = 0) -> np.ndarray:
"""Pad lower - right with 0s
Args:
a: numpy array to pad
@@ -98,7 +100,7 @@ def _get_dice_matrix(
def segmentation_props(
pred: np.ndarray, mask: np.ndarray
-) -> Tuple[float, float, float, float, List[float], float, int, int, int, int]:
+) -> tuple[float, float, float, float, list[float], float, int, int, int, int]:
"""Return some information regarding a segmentation task.
Args:
@@ -158,7 +160,7 @@ def segmentation_props(
fn_num = 0
fp_area = 0.0
fn_area = 0.0
- fp_hist: List[float] = []
+ fp_hist: list[float] = []
if n_labels_pred > 0 and n_labels_mask > 0:
dice_mat = _get_dice_matrix(labels_pred, n_labels_pred, labels_mask, n_labels_mask)
# Thresholding over Dice scores
@@ -169,8 +171,7 @@ def segmentation_props(
# Add dummy Dices so LSA is unique and i can compute FP and FN
dice_mat = _pad_to_shape(dice_mat, (max_dim, max_dim), 1)
lsa = linear_sum_assignment(dice_mat, maximize=False)
- for (row, col) in zip(lsa[0], lsa[1]):
-
+ for row, col in zip(lsa[0], lsa[1]):
# More preds than GTs --> False Positive
if row < n_labels_pred and col >= n_labels_mask:
min_row = pred_bbox[row][0]
diff --git a/quadra/models/base.py b/quadra/models/base.py
index c0bab73a..2fbd19fc 100644
--- a/quadra/models/base.py
+++ b/quadra/models/base.py
@@ -1,7 +1,8 @@
from __future__ import annotations
import inspect
-from typing import Any, Sequence
+from collections.abc import Sequence
+from typing import Any
import torch
from torch import nn
diff --git a/quadra/models/classification/backbones.py b/quadra/models/classification/backbones.py
index 2c5aca78..f3651292 100644
--- a/quadra/models/classification/backbones.py
+++ b/quadra/models/classification/backbones.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from typing import Any
import timm
@@ -15,8 +17,8 @@ class TorchHubNetworkBuilder(BaseNetworkBuilder):
repo_or_dir: The name of the repository or the path to the directory containing the model.
model_name: The name of the model within the repository.
pretrained: Whether to load the pretrained weights for the model.
- pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity().
- classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity().
+ pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
+ classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
freeze: Whether to freeze the feature extractor. Defaults to True.
hyperspherical: Whether to map features to an hypersphere. Defaults to False.
flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True.
@@ -28,8 +30,8 @@ def __init__(
repo_or_dir: str,
model_name: str,
pretrained: bool = True,
- pre_classifier: nn.Module = nn.Identity(),
- classifier: nn.Module = nn.Identity(),
+ pre_classifier: nn.Module | None = None,
+ classifier: nn.Module | None = None,
freeze: bool = True,
hyperspherical: bool = False,
flatten_features: bool = True,
@@ -55,8 +57,8 @@ class TorchVisionNetworkBuilder(BaseNetworkBuilder):
Args:
model_name: Torchvision model function that will be evaluated, for example: torchvision.models.resnet18.
pretrained: Whether to load the pretrained weights for the model.
- pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity().
- classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity().
+ pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
+ classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
freeze: Whether to freeze the feature extractor. Defaults to True.
hyperspherical: Whether to map features to an hypersphere. Defaults to False.
flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True.
@@ -67,8 +69,8 @@ def __init__(
self,
model_name: str,
pretrained: bool = True,
- pre_classifier: nn.Module = nn.Identity(),
- classifier: nn.Module = nn.Identity(),
+ pre_classifier: nn.Module | None = None,
+ classifier: nn.Module | None = None,
freeze: bool = True,
hyperspherical: bool = False,
flatten_features: bool = True,
@@ -95,8 +97,8 @@ class TimmNetworkBuilder(BaseNetworkBuilder):
Args:
model_name: Timm model name
pretrained: Whether to load the pretrained weights for the model.
- pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity().
- classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity().
+ pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
+ classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
freeze: Whether to freeze the feature extractor. Defaults to True.
hyperspherical: Whether to map features to an hypersphere. Defaults to False.
flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True.
@@ -107,8 +109,8 @@ def __init__(
self,
model_name: str,
pretrained: bool = True,
- pre_classifier: nn.Module = nn.Identity(),
- classifier: nn.Module = nn.Identity(),
+ pre_classifier: nn.Module | None = None,
+ classifier: nn.Module | None = None,
freeze: bool = True,
hyperspherical: bool = False,
flatten_features: bool = True,
diff --git a/quadra/models/classification/base.py b/quadra/models/classification/base.py
index 24837eaf..68beff5f 100644
--- a/quadra/models/classification/base.py
+++ b/quadra/models/classification/base.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from __future__ import annotations
from torch import nn
@@ -11,8 +11,8 @@ class BaseNetworkBuilder(nn.Module):
Args:
features_extractor: Feature extractor as a toch.nn.Module.
- pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity().
- classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity().
+ pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
+ classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
freeze: Whether to freeze the feature extractor. Defaults to True.
hyperspherical: Whether to map features to an hypersphere. Defaults to False.
flatten_features: Whether to flatten the features before the pre_classifier. May be required if your model
@@ -22,13 +22,19 @@ class BaseNetworkBuilder(nn.Module):
def __init__(
self,
features_extractor: nn.Module,
- pre_classifier: nn.Module = nn.Identity(),
- classifier: nn.Module = nn.Identity(),
+ pre_classifier: nn.Module | None = None,
+ classifier: nn.Module | None = None,
freeze: bool = True,
hyperspherical: bool = False,
flatten_features: bool = True,
):
super().__init__()
+ if pre_classifier is None:
+ pre_classifier = nn.Identity()
+
+ if classifier is None:
+ classifier = nn.Identity()
+
self.features_extractor = features_extractor
self.freeze = freeze
self.hyperspherical = hyperspherical
@@ -36,7 +42,7 @@ def __init__(
self.classifier = classifier
self.flatten: bool = False
self._hyperspherical: bool = False
- self.l2: Optional[L2Norm] = None
+ self.l2: L2Norm | None = None
self.flatten_features = flatten_features
self.freeze = freeze
diff --git a/quadra/models/evaluation.py b/quadra/models/evaluation.py
index df60c2ec..47b2b6db 100644
--- a/quadra/models/evaluation.py
+++ b/quadra/models/evaluation.py
@@ -71,7 +71,7 @@ def device(self) -> str:
@device.setter
def device(self, device: str):
"""Set the device of the model."""
- if device == "cuda" and not ":" in device:
+ if device == "cuda" and ":" not in device:
device = f"{device}:0"
self._device = device
@@ -194,10 +194,11 @@ def generate_session_options(self) -> ort.SessionOptions:
dict[str, Any], OmegaConf.to_container(self.config.session_options, resolve=True)
)
for key, value in session_options_dict.items():
+ final_value = value
if isinstance(value, dict) and "_target_" in value:
- value = instantiate(value)
+ final_value = instantiate(final_value)
- setattr(session_options, key, value)
+ setattr(session_options, key, final_value)
return session_options
@@ -240,7 +241,7 @@ def _forward_from_pytorch(self, input_dict: dict[str, torch.Tensor]):
for k, v in input_dict.items():
if not v.is_contiguous():
# If not contiguous onnx give wrong results
- v = v.contiguous()
+ v = v.contiguous() # noqa: PLW2901
io_binding.bind_input(
name=k,
diff --git a/quadra/modules/base.py b/quadra/modules/base.py
index 36cb28e4..ac11bfd7 100644
--- a/quadra/modules/base.py
+++ b/quadra/modules/base.py
@@ -1,4 +1,7 @@
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from __future__ import annotations
+
+from collections.abc import Callable
+from typing import Any
import pytorch_lightning as pl
import sklearn
@@ -25,9 +28,9 @@ class BaseLightningModule(pl.LightningModule):
def __init__(
self,
model: nn.Module,
- optimizer: Optional[Optimizer] = None,
- lr_scheduler: Optional[object] = None,
- lr_scheduler_interval: Optional[str] = "epoch",
+ optimizer: Optimizer | None = None,
+ lr_scheduler: object | None = None,
+ lr_scheduler_interval: str | None = "epoch",
):
super().__init__()
self.model = ModelSignatureWrapper(model)
@@ -45,7 +48,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
return self.model(x)
- def configure_optimizers(self) -> Tuple[List[Any], List[Dict[str, Any]]]:
+ def configure_optimizers(self) -> tuple[list[Any], list[dict[str, Any]]]:
"""Get default optimizer if not passed a value.
Returns:
@@ -88,14 +91,14 @@ def __init__(
self,
model: nn.Module,
criterion: nn.Module,
- classifier: Optional[sklearn.base.ClassifierMixin] = None,
- optimizer: Optional[Optimizer] = None,
- lr_scheduler: Optional[object] = None,
- lr_scheduler_interval: Optional[str] = "epoch",
+ classifier: sklearn.base.ClassifierMixin | None = None,
+ optimizer: Optimizer | None = None,
+ lr_scheduler: object | None = None,
+ lr_scheduler_interval: str | None = "epoch",
):
super().__init__(model, optimizer, lr_scheduler, lr_scheduler_interval)
self.criterion = criterion
- self.classifier_train_loader: Optional[torch.utils.data.DataLoader]
+ self.classifier_train_loader: torch.utils.data.DataLoader | None
if classifier is None:
self.classifier = LogisticRegression(max_iter=10000, n_jobs=8, random_state=42)
else:
@@ -137,7 +140,7 @@ def on_validation_start(self) -> None:
if self.classifier_train_loader is not None:
self.fit_estimator()
- def validation_step(self, batch: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int) -> None:
+ def validation_step(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int) -> None:
# pylint: disable=unused-argument
if self.classifier_train_loader is None:
# Compute loss
@@ -175,8 +178,8 @@ def __init__(
self,
model: torch.nn.Module,
loss_fun: Callable,
- optimizer: Optional[Optimizer] = None,
- lr_scheduler: Optional[object] = None,
+ optimizer: Optimizer | None = None,
+ lr_scheduler: object | None = None,
):
super().__init__(model, optimizer, lr_scheduler)
self.loss_fun = loss_fun
@@ -192,7 +195,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.model(x)
return x
- def step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
+ def step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute loss
Args:
batch: batch.
@@ -223,7 +226,7 @@ def compute_loss(self, pred_masks: torch.Tensor, target_masks: torch.Tensor) ->
loss = self.loss_fun(pred_masks, target_masks)
return loss
- def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
+ def training_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
"""Training step."""
# pylint: disable=unused-argument
pred_masks, target_masks = self.step(batch)
@@ -236,7 +239,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
)
return loss
- def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx):
+ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx):
"""Validation step."""
# pylint: disable=unused-argument
pred_masks, target_masks = self.step(batch)
@@ -249,7 +252,7 @@ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
)
return loss
- def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
+ def test_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
"""Test step."""
# pylint: disable=unused-argument
pred_masks, target_masks = self.step(batch)
@@ -264,9 +267,9 @@ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batc
def predict_step(
self,
- batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+ batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
batch_idx: int,
- dataloader_idx: Optional[int] = None,
+ dataloader_idx: int | None = None,
) -> Any:
"""Predict step."""
# pylint: disable=unused-argument
@@ -289,12 +292,12 @@ def __init__(
self,
model: torch.nn.Module,
loss_fun: Callable,
- optimizer: Optional[Optimizer] = None,
- lr_scheduler: Optional[object] = None,
+ optimizer: Optimizer | None = None,
+ lr_scheduler: object | None = None,
):
super().__init__(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_fun=loss_fun)
- def step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
+ def step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute step
Args:
batch: batch.
diff --git a/quadra/modules/classification/base.py b/quadra/modules/classification/base.py
index d1678155..4bb6bdfc 100644
--- a/quadra/modules/classification/base.py
+++ b/quadra/modules/classification/base.py
@@ -1,4 +1,6 @@
-from typing import Any, List, Optional, Tuple, Union, cast
+from __future__ import annotations
+
+from typing import Any, cast
import numpy as np
import timm
@@ -37,9 +39,9 @@ def __init__(
self,
model: nn.Module,
criterion: nn.Module,
- optimizer: Union[None, optim.Optimizer] = None,
- lr_scheduler: Union[None, object] = None,
- lr_scheduler_interval: Optional[str] = "epoch",
+ optimizer: None | optim.Optimizer = None,
+ lr_scheduler: None | object = None,
+ lr_scheduler_interval: str | None = "epoch",
gradcam: bool = False,
):
super().__init__(model, optimizer, lr_scheduler, lr_scheduler_interval)
@@ -49,8 +51,8 @@ def __init__(
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
self.test_acc = torchmetrics.Accuracy()
- self.cam: Optional[GradCAM] = None
- self.grad_rollout: Optional[VitAttentionGradRollout] = None
+ self.cam: GradCAM | None = None
+ self.grad_rollout: VitAttentionGradRollout | None = None
if not isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and not is_vision_transformer(
cast(BaseNetworkBuilder, self.model).features_extractor
@@ -60,12 +62,12 @@ def __init__(
)
self.gradcam = False
- self.original_requires_grads: List[bool] = []
+ self.original_requires_grads: list[bool] = []
def forward(self, x: torch.Tensor):
return self.model(x)
- def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
+ def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int):
# pylint: disable=unused-argument
im, target = batch
outputs = self(im)
@@ -87,7 +89,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
)
return loss
- def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
+ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int):
# pylint: disable=unused-argument
im, target = batch
outputs = self(im)
@@ -109,7 +111,7 @@ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: i
)
return loss
- def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
+ def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int):
# pylint: disable=unused-argument
im, target = batch
outputs = self(im)
@@ -134,7 +136,7 @@ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
def prepare_gradcam(self) -> None:
"""Instantiate gradcam handlers."""
if isinstance(self.model.features_extractor, timm.models.resnet.ResNet):
- target_layers = [cast(BaseNetworkBuilder, self.model).features_extractor.layer4[-1]] # type: ignore[index]
+ target_layers = [cast(BaseNetworkBuilder, self.model).features_extractor.layer4[-1]]
# Get model current device
device = next(self.model.parameters()).device
@@ -186,6 +188,7 @@ def on_predict_end(self) -> None:
return super().on_predict_end()
+ # pylint: disable=unused-argument
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
"""Prediction step.
@@ -241,9 +244,9 @@ def __init__(
self,
model: nn.Sequential,
criterion: nn.Module,
- optimizer: Union[None, optim.Optimizer] = None,
- lr_scheduler: Union[None, object] = None,
- lr_scheduler_interval: Optional[str] = "epoch",
+ optimizer: None | optim.Optimizer = None,
+ lr_scheduler: None | object = None,
+ lr_scheduler_interval: str | None = "epoch",
gradcam: bool = False,
):
super().__init__(model, optimizer, lr_scheduler, lr_scheduler_interval)
diff --git a/quadra/modules/ssl/barlowtwins.py b/quadra/modules/ssl/barlowtwins.py
index c0bab973..2cd1944c 100644
--- a/quadra/modules/ssl/barlowtwins.py
+++ b/quadra/modules/ssl/barlowtwins.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple
+from __future__ import annotations
import sklearn
import torch
@@ -25,12 +25,11 @@ def __init__(
model: nn.Module,
projection_mlp: nn.Module,
criterion: nn.Module,
- classifier: Optional[sklearn.base.ClassifierMixin] = None,
- optimizer: Optional[optim.Optimizer] = None,
- lr_scheduler: Optional[object] = None,
- lr_scheduler_interval: Optional[str] = "epoch",
+ classifier: sklearn.base.ClassifierMixin | None = None,
+ optimizer: optim.Optimizer | None = None,
+ lr_scheduler: object | None = None,
+ lr_scheduler_interval: str | None = "epoch",
):
-
super().__init__(model, criterion, classifier, optimizer, lr_scheduler, lr_scheduler_interval)
# self.save_hyperparameters()
self.projection_mlp = projection_mlp
@@ -41,7 +40,7 @@ def forward(self, x):
z = self.projection_mlp(x)
return z
- def training_step(self, batch: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int):
+ def training_step(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int):
# pylint: disable=unused-argument
# Compute loss
(im_x, im_y), _ = batch
diff --git a/quadra/modules/ssl/byol.py b/quadra/modules/ssl/byol.py
index 64fceb4a..c7b97b13 100644
--- a/quadra/modules/ssl/byol.py
+++ b/quadra/modules/ssl/byol.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
import math
-from typing import Any, Callable, List, Optional, Sized, Tuple, Union
+from collections.abc import Callable, Sized
+from typing import Any
import sklearn
import torch
@@ -36,12 +39,12 @@ def __init__(
student_prediction_mlp: nn.Module,
teacher_projection_mlp: nn.Module,
criterion: nn.Module,
- classifier: Optional[sklearn.base.ClassifierMixin] = None,
- optimizer: Optional[Optimizer] = None,
- lr_scheduler: Optional[object] = None,
- lr_scheduler_interval: Optional[str] = "epoch",
+ classifier: sklearn.base.ClassifierMixin | None = None,
+ optimizer: Optimizer | None = None,
+ lr_scheduler: object | None = None,
+ lr_scheduler_interval: str | None = "epoch",
teacher_momentum: float = 0.9995,
- teacher_momentum_cosine_decay: Optional[bool] = True,
+ teacher_momentum_cosine_decay: bool | None = True,
):
super().__init__(
model=student,
@@ -116,7 +119,7 @@ def on_train_start(self) -> None:
else:
raise ValueError("BYOL requires `max_epochs` to be set and `train_dataloader` to be initialized.")
- def training_step(self, batch: Tuple[List[torch.Tensor], torch.Tensor], *args: Any) -> torch.Tensor:
+ def training_step(self, batch: tuple[list[torch.Tensor], torch.Tensor], *args: Any) -> torch.Tensor:
[image1, image2], _ = batch
online_pred_one = self.student_prediction_mlp(self.student_projection_mlp(self.model(image1)))
@@ -137,8 +140,8 @@ def optimizer_step(
self,
epoch: int,
batch_idx: int,
- optimizer: Union[Optimizer, LightningOptimizer],
- optimizer_closure: Optional[Callable[[], Any]] = None,
+ optimizer: Optimizer | LightningOptimizer,
+ optimizer_closure: Callable[[], Any] | None = None,
) -> None:
"""Override optimizer step to update the teacher parameters."""
super().optimizer_step(
@@ -162,7 +165,7 @@ def calculate_accuracy(self, batch):
def on_test_epoch_start(self) -> None:
self.fit_estimator()
- def test_step(self, batch, *args: List[Any]) -> None:
+ def test_step(self, batch, *args: list[Any]) -> None:
"""Calculate accuracy on the test set for the given batch."""
acc = self.calculate_accuracy(batch)
self.log(name="test_acc", value=acc, on_step=False, on_epoch=True, prog_bar=True)
diff --git a/quadra/modules/ssl/common.py b/quadra/modules/ssl/common.py
index 312e9d4c..b3e6eb98 100644
--- a/quadra/modules/ssl/common.py
+++ b/quadra/modules/ssl/common.py
@@ -1,4 +1,4 @@
-from typing import List, Optional, Tuple, Union
+from __future__ import annotations
import torch
from torch import nn
@@ -17,10 +17,10 @@ class ProjectionHead(torch.nn.Module):
`non_linearity_layer`.
"""
- def __init__(self, blocks: List[Tuple[Optional[torch.nn.Module], ...]]):
+ def __init__(self, blocks: list[tuple[torch.nn.Module | None, ...]]):
super().__init__()
- layers: List[nn.Module] = []
+ layers: list[nn.Module] = []
for linear, batch_norm, non_linearity in blocks:
if linear:
layers.append(linear)
@@ -223,11 +223,11 @@ def __init__(
):
super().__init__()
num_layers = max(num_layers, 1)
- self.mlp: Union[nn.Linear, nn.Sequential]
+ self.mlp: nn.Linear | nn.Sequential
if num_layers == 1:
self.mlp = nn.Linear(input_dim, bottleneck_dim)
else:
- layers: List[nn.Module] = [nn.Linear(input_dim, hidden_dim)]
+ layers: list[nn.Module] = [nn.Linear(input_dim, hidden_dim)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
@@ -240,7 +240,7 @@ def __init__(
self.mlp = nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, output_dim, bias=False))
- self.last_layer.weight_g.data.fill_(1) # type: ignore[operator]
+ self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
diff --git a/quadra/modules/ssl/dino.py b/quadra/modules/ssl/dino.py
index 59555e8c..81720193 100644
--- a/quadra/modules/ssl/dino.py
+++ b/quadra/modules/ssl/dino.py
@@ -1,4 +1,7 @@
-from typing import Any, Callable, List, Optional, Tuple, Union
+from __future__ import annotations
+
+from collections.abc import Callable
+from typing import Any
import sklearn
import torch
@@ -39,12 +42,12 @@ def __init__(
teacher_projection_mlp: nn.Module,
criterion: nn.Module,
freeze_last_layer: int = 1,
- classifier: Optional[sklearn.base.ClassifierMixin] = None,
- optimizer: Optional[Optimizer] = None,
- lr_scheduler: Optional[object] = None,
- lr_scheduler_interval: Optional[str] = "epoch",
+ classifier: sklearn.base.ClassifierMixin | None = None,
+ optimizer: Optimizer | None = None,
+ lr_scheduler: object | None = None,
+ lr_scheduler_interval: str | None = "epoch",
teacher_momentum: float = 0.9995,
- teacher_momentum_cosine_decay: Optional[bool] = True,
+ teacher_momentum_cosine_decay: bool | None = True,
):
super().__init__(
student=student,
@@ -93,7 +96,7 @@ def initialize_teacher(self):
self.teacher_initialized = True
- def student_multicrop_forward(self, x: List[torch.Tensor]) -> torch.Tensor:
+ def student_multicrop_forward(self, x: list[torch.Tensor]) -> torch.Tensor:
"""Student forward on the multicrop imges.
Args:
@@ -111,7 +114,7 @@ def student_multicrop_forward(self, x: List[torch.Tensor]) -> torch.Tensor:
chunks = logits.chunk(n_crops) # n_crops * (n_samples, out_dim)
return chunks
- def teacher_multicrop_forward(self, x: List[torch.Tensor]) -> torch.Tensor:
+ def teacher_multicrop_forward(self, x: list[torch.Tensor]) -> torch.Tensor:
"""Teacher forward on the multicrop imges.
Args:
@@ -143,7 +146,7 @@ def cancel_gradients_last_layer(self, epoch: int, freeze_last_layer: int):
if "last_layer" in n:
p.grad = None
- def training_step(self, batch: Tuple[List[torch.Tensor], torch.Tensor], *args: Any) -> torch.Tensor:
+ def training_step(self, batch: tuple[list[torch.Tensor], torch.Tensor], *args: Any) -> torch.Tensor:
images, _ = batch
with torch.no_grad():
teacher_output = self.teacher_multicrop_forward(images[:2])
@@ -157,8 +160,8 @@ def training_step(self, batch: Tuple[List[torch.Tensor], torch.Tensor], *args: A
def configure_gradient_clipping(
self,
optimizer: Optimizer,
- gradient_clip_val: Optional[Union[int, float]] = None,
- gradient_clip_algorithm: Optional[str] = None,
+ gradient_clip_val: int | float | None = None,
+ gradient_clip_algorithm: str | None = None,
):
"""Configure gradient clipping for the optimizer."""
if gradient_clip_algorithm is not None and gradient_clip_val is not None:
@@ -170,8 +173,8 @@ def optimizer_step(
self,
epoch: int,
batch_idx: int,
- optimizer: Union[Optimizer, LightningOptimizer],
- optimizer_closure: Optional[Callable[[], Any]] = None,
+ optimizer: Optimizer | LightningOptimizer,
+ optimizer_closure: Callable[[], Any] | None = None,
) -> None:
"""Override optimizer step to update the teacher parameters."""
super().optimizer_step(
diff --git a/quadra/modules/ssl/hyperspherical.py b/quadra/modules/ssl/hyperspherical.py
index b7242cac..86a21a8d 100644
--- a/quadra/modules/ssl/hyperspherical.py
+++ b/quadra/modules/ssl/hyperspherical.py
@@ -1,5 +1,7 @@
+from __future__ import annotations
+
+from collections.abc import Callable
from enum import Enum
-from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
@@ -48,20 +50,20 @@ class TLHyperspherical(BaseLightningModule):
def __init__(
self,
model: nn.Module,
- optimizer: Optional[optim.Optimizer] = None,
- lr_scheduler: Optional[object] = None,
+ optimizer: optim.Optimizer | None = None,
+ lr_scheduler: object | None = None,
align_weight: float = 1,
unifo_weight: float = 1,
classifier_weight: float = 1,
align_loss_type: AlignLoss = AlignLoss.L2,
classifier_loss: bool = False,
- num_classes: Optional[int] = None,
+ num_classes: int | None = None,
):
super().__init__(model, optimizer, lr_scheduler)
- self.align_loss_fun: Union[
- Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor],
- Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
- ]
+ self.align_loss_fun: (
+ Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
+ | Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
+ )
self.align_weight = align_weight
self.unifo_weight = unifo_weight
self.classifier_weight = classifier_weight
@@ -73,9 +75,8 @@ def __init__(
else:
raise ValueError("The align loss must be one of 'AlignLoss.L2' (L2 distance) or AlignLoss.COSINE")
- if classifier_loss:
- if model.classifier is None:
- raise AssertionError("Classifier is not defined")
+ if classifier_loss and model.classifier is None:
+ raise AssertionError("Classifier is not defined")
self.classifier_loss = classifier_loss
self.num_classes = num_classes
diff --git a/quadra/modules/ssl/idmm.py b/quadra/modules/ssl/idmm.py
index 01adbbad..cbc3e1cc 100644
--- a/quadra/modules/ssl/idmm.py
+++ b/quadra/modules/ssl/idmm.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from __future__ import annotations
import sklearn
import timm
@@ -31,13 +31,12 @@ def __init__(
prediction_mlp: torch.nn.Module,
criterion: torch.nn.Module,
multiview_loss: bool = True,
- mixup_fn: Optional[timm.data.Mixup] = None,
- classifier: Optional[sklearn.base.ClassifierMixin] = None,
- optimizer: Optional[torch.optim.Optimizer] = None,
- lr_scheduler: Optional[object] = None,
- lr_scheduler_interval: Optional[str] = "epoch",
+ mixup_fn: timm.data.Mixup | None = None,
+ classifier: sklearn.base.ClassifierMixin | None = None,
+ optimizer: torch.optim.Optimizer | None = None,
+ lr_scheduler: object | None = None,
+ lr_scheduler_interval: str | None = "epoch",
):
-
super().__init__(
model,
criterion,
diff --git a/quadra/modules/ssl/simclr.py b/quadra/modules/ssl/simclr.py
index 4e1b01f6..99837b35 100644
--- a/quadra/modules/ssl/simclr.py
+++ b/quadra/modules/ssl/simclr.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple
+from __future__ import annotations
import sklearn
import torch
@@ -26,10 +26,10 @@ def __init__(
model: nn.Module,
projection_mlp: nn.Module,
criterion: torch.nn.Module,
- classifier: Optional[sklearn.base.ClassifierMixin] = None,
- optimizer: Optional[torch.optim.Optimizer] = None,
- lr_scheduler: Optional[object] = None,
- lr_scheduler_interval: Optional[str] = "epoch",
+ classifier: sklearn.base.ClassifierMixin | None = None,
+ optimizer: torch.optim.Optimizer | None = None,
+ lr_scheduler: object | None = None,
+ lr_scheduler_interval: str | None = "epoch",
):
super().__init__(
model,
@@ -47,7 +47,7 @@ def forward(self, x):
return x
def training_step(
- self, batch: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int
+ self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int
) -> torch.Tensor:
"""Args:
batch: The batch of data
diff --git a/quadra/modules/ssl/simsiam.py b/quadra/modules/ssl/simsiam.py
index 2a7692b1..b2c2d14e 100644
--- a/quadra/modules/ssl/simsiam.py
+++ b/quadra/modules/ssl/simsiam.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple
+from __future__ import annotations
import sklearn
import torch
@@ -26,12 +26,11 @@ def __init__(
projection_mlp: torch.nn.Module,
prediction_mlp: torch.nn.Module,
criterion: torch.nn.Module,
- classifier: Optional[sklearn.base.ClassifierMixin] = None,
- optimizer: Optional[torch.optim.Optimizer] = None,
- lr_scheduler: Optional[object] = None,
- lr_scheduler_interval: Optional[str] = "epoch",
+ classifier: sklearn.base.ClassifierMixin | None = None,
+ optimizer: torch.optim.Optimizer | None = None,
+ lr_scheduler: object | None = None,
+ lr_scheduler_interval: str | None = "epoch",
):
-
super().__init__(
model,
criterion,
@@ -50,7 +49,7 @@ def forward(self, x):
p = self.prediction_mlp(z)
return p, z.detach()
- def training_step(self, batch: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int):
+ def training_step(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int):
# pylint: disable=unused-argument
# Compute loss
(im_x, im_y), _ = batch
diff --git a/quadra/modules/ssl/vicreg.py b/quadra/modules/ssl/vicreg.py
index b36b124a..60bb4826 100644
--- a/quadra/modules/ssl/vicreg.py
+++ b/quadra/modules/ssl/vicreg.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple
+from __future__ import annotations
import sklearn
import torch
@@ -26,12 +26,11 @@ def __init__(
model: nn.Module,
projection_mlp: nn.Module,
criterion: nn.Module,
- classifier: Optional[sklearn.base.ClassifierMixin] = None,
- optimizer: Optional[optim.Optimizer] = None,
- lr_scheduler: Optional[object] = None,
- lr_scheduler_interval: Optional[str] = "epoch",
+ classifier: sklearn.base.ClassifierMixin | None = None,
+ optimizer: optim.Optimizer | None = None,
+ lr_scheduler: object | None = None,
+ lr_scheduler_interval: str | None = "epoch",
):
-
super().__init__(
model,
criterion,
@@ -49,7 +48,7 @@ def forward(self, x):
z = self.projection_mlp(x)
return z
- def training_step(self, batch: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int):
+ def training_step(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int):
# pylint: disable=unused-argument
# Compute loss
(im_x, im_y), _ = batch
diff --git a/quadra/optimizers/lars.py b/quadra/optimizers/lars.py
index 661a4d77..81a67846 100644
--- a/quadra/optimizers/lars.py
+++ b/quadra/optimizers/lars.py
@@ -2,11 +2,14 @@
- https://arxiv.org/pdf/1708.03888.pdf
- https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py.
"""
-from typing import Callable, List, Optional
+
+from __future__ import annotations
+
+from collections.abc import Callable
import torch
from torch.nn import Parameter
-from torch.optim.optimizer import Optimizer, _RequiredParameter, required # type: ignore[attr-defined]
+from torch.optim.optimizer import Optimizer, _RequiredParameter, required
class LARS(Optimizer):
@@ -60,7 +63,7 @@ class LARS(Optimizer):
def __init__(
self,
- params: List[Parameter],
+ params: list[Parameter],
lr: _RequiredParameter = required,
momentum: float = 0,
dampening: float = 0,
@@ -69,7 +72,7 @@ def __init__(
trust_coefficient: float = 0.001,
eps: float = 1e-8,
):
- if lr is not required and lr < 0.0:
+ if lr is not required and lr < 0.0: # type: ignore[operator]
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
@@ -98,7 +101,7 @@ def __setstate__(self, state):
group.setdefault("nesterov", False)
@torch.no_grad()
- def step(self, closure: Optional[Callable] = None):
+ def step(self, closure: Callable | None = None):
"""Performs a single optimization step.
Args:
@@ -125,13 +128,12 @@ def step(self, closure: Optional[Callable] = None):
g_norm = torch.norm(p.grad.data)
# lars scaling + weight decay part
- if weight_decay != 0:
- if p_norm != 0 and g_norm != 0:
- lars_lr = p_norm / (g_norm + p_norm * weight_decay + self.eps)
- lars_lr *= self.trust_coefficient
+ if weight_decay != 0 and p_norm != 0 and g_norm != 0:
+ lars_lr = p_norm / (g_norm + p_norm * weight_decay + self.eps)
+ lars_lr *= self.trust_coefficient
- d_p = d_p.add(p, alpha=weight_decay)
- d_p *= lars_lr
+ d_p = d_p.add(p, alpha=weight_decay)
+ d_p *= lars_lr
# sgd part
if momentum != 0:
diff --git a/quadra/optimizers/sam.py b/quadra/optimizers/sam.py
index f1350401..b73b8086 100644
--- a/quadra/optimizers/sam.py
+++ b/quadra/optimizers/sam.py
@@ -1,4 +1,7 @@
-from typing import Any, Callable, List, Optional
+from __future__ import annotations
+
+from collections.abc import Callable
+from typing import Any
import torch
from torch.nn import Parameter
@@ -19,7 +22,7 @@ class SAM(torch.optim.Optimizer):
def __init__(
self,
- params: List[Parameter],
+ params: list[Parameter],
base_optimizer: torch.optim.Optimizer,
rho: float = 0.05,
adaptive: bool = True,
@@ -85,7 +88,7 @@ def second_step(self, zero_grad: bool = False) -> None:
self.zero_grad()
@torch.no_grad()
- def step(self, closure: Optional[Callable] = None) -> None:
+ def step(self, closure: Callable | None = None) -> None: # type: ignore[override]
"""Step for SAM optimizer.
Args:
diff --git a/quadra/schedulers/base.py b/quadra/schedulers/base.py
index f406d2d4..6a4f55e1 100644
--- a/quadra/schedulers/base.py
+++ b/quadra/schedulers/base.py
@@ -1,4 +1,4 @@
-from typing import Tuple
+from __future__ import annotations
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
@@ -11,7 +11,7 @@ class LearningRateScheduler(_LRScheduler):
Do not use this class directly, use one of the sub classes.
"""
- def __init__(self, optimizer: Optimizer, init_lr: Tuple[float, ...]):
+ def __init__(self, optimizer: Optimizer, init_lr: tuple[float, ...]):
# pylint: disable=super-init-not-called
self.optimizer = optimizer
self.init_lr = init_lr
@@ -20,7 +20,7 @@ def step(self, *args, **kwargs):
"""Base method, must be implemented by the sub classes."""
raise NotImplementedError
- def set_lr(self, lr: Tuple[float, ...]):
+ def set_lr(self, lr: tuple[float, ...]):
"""Set the learning rate for the optimizer."""
if self.optimizer is not None:
for i, g in enumerate(self.optimizer.param_groups):
@@ -29,11 +29,10 @@ def set_lr(self, lr: Tuple[float, ...]):
lr_to_set = self.init_lr[0]
else:
lr_to_set = self.init_lr[i]
+ elif len(lr) == 1:
+ lr_to_set = lr[0]
else:
- if len(lr) == 1:
- lr_to_set = lr[0]
- else:
- lr_to_set = lr[i]
+ lr_to_set = lr[i]
g["lr"] = lr_to_set
def get_lr(self):
diff --git a/quadra/schedulers/warmup.py b/quadra/schedulers/warmup.py
index 15ae0207..53426569 100644
--- a/quadra/schedulers/warmup.py
+++ b/quadra/schedulers/warmup.py
@@ -1,5 +1,6 @@
+from __future__ import annotations
+
import math
-from typing import List, Optional, Tuple
import torch
@@ -10,12 +11,12 @@
def cosine_annealing_with_warmup(
- init_lrs: List[float],
+ init_lrs: list[float],
step: int,
total_steps: int,
warmup_steps: int,
lr_reduce_factor: float = 0.001,
-) -> List[float]:
+) -> list[float]:
"""Cosine learning rate scheduler with linear warmup helper function.
Args:
@@ -83,14 +84,13 @@ def __init__(
optimizer: torch.optim.Optimizer,
batch_size: int,
total_epochs: int,
- init_lr: Tuple[float, ...] = (0.01,),
+ init_lr: tuple[float, ...] = (0.01,),
lr_scale: float = 256.0,
linear_warmup_epochs: int = 10,
lr_reduce_factor: float = 0.001,
- len_loader: Optional[int] = None,
+ len_loader: int | None = None,
scheduler_interval: str = "epoch",
) -> None:
-
super().__init__(optimizer, init_lr)
assert batch_size > 0
assert total_epochs > 0
diff --git a/quadra/tasks/anomaly.py b/quadra/tasks/anomaly.py
index 712a03e0..ec617cd9 100644
--- a/quadra/tasks/anomaly.py
+++ b/quadra/tasks/anomaly.py
@@ -1,9 +1,11 @@
+from __future__ import annotations
+
import csv
import glob
import json
import os
from collections import Counter
-from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union, cast
+from typing import Any, Generic, Literal, TypeVar, cast
import cv2
import hydra
@@ -50,7 +52,7 @@ def __init__(
self,
config: DictConfig,
module_function: DictConfig,
- checkpoint_path: Optional[str] = None,
+ checkpoint_path: str | None = None,
run_test: bool = True,
report: bool = True,
):
@@ -64,7 +66,7 @@ def __init__(
self.module_function = module_function
self.export_folder = "deployment_model"
self.report_path = ""
- self.test_results: Optional[List[Dict]] = None
+ self.test_results: list[dict] | None = None
@property
def module(self) -> AnomalyModule:
@@ -154,11 +156,11 @@ def _generate_report(self) -> None:
json.dump(self.test_results[0], f)
all_output = cast(
- List[Dict], self.trainer.predict(model=self.module, dataloaders=self.datamodule.test_dataloader())
+ list[dict], self.trainer.predict(model=self.module, dataloaders=self.datamodule.test_dataloader())
)
- all_output_flatten: Dict[str, Union[torch.Tensor, List]] = {}
+ all_output_flatten: dict[str, torch.Tensor | list] = {}
- for key in all_output[0].keys():
+ for key in all_output[0]:
if type(all_output[0][key]) == torch.Tensor:
tensor_gatherer = torch.cat([x[key] for x in all_output])
all_output_flatten[key] = tensor_gatherer
@@ -185,7 +187,7 @@ def _generate_report(self) -> None:
gt_labels = [class_to_idx[x] for x in named_labels]
pred_labels = []
- for i, x in enumerate(named_labels):
+ for i, _ in enumerate(named_labels):
pred_label = all_output_flatten["pred_labels"][i].item()
if pred_label == 0:
@@ -230,11 +232,13 @@ def _generate_report(self) -> None:
# Lightning has a callback attribute but is not inside the __init__ so mypy complains
if any(
- isinstance(x, MinMaxNormalizationCallback) for x in self.trainer.callbacks # type: ignore[attr-defined]
+ isinstance(x, MinMaxNormalizationCallback)
+ for x in self.trainer.callbacks # type: ignore[attr-defined]
):
threshold = torch.tensor(0.5)
elif any(
- isinstance(x, ThresholdNormalizationCallback) for x in self.trainer.callbacks # type: ignore[attr-defined]
+ isinstance(x, ThresholdNormalizationCallback)
+ for x in self.trainer.callbacks # type: ignore[attr-defined]
):
threshold = torch.tensor(100.0)
else:
@@ -287,7 +291,7 @@ def _upload_artifacts(self):
mflow_logger.experiment.log_artifact(run_id=mflow_logger.run_id, local_path="test_confusion_matrix.png")
mflow_logger.experiment.log_artifact(run_id=mflow_logger.run_id, local_path="avg_score_by_label.csv")
- if "visualizer" in self.config.callbacks.keys():
+ if "visualizer" in self.config.callbacks:
artifacts = glob.glob(os.path.join(self.config.callbacks.visualizer.output_path, "**", "*"))
for a in artifacts:
mflow_logger.experiment.log_artifact(
@@ -299,7 +303,7 @@ def _upload_artifacts(self):
artifacts.append("test_confusion_matrix.png")
artifacts.append("avg_score_by_label.csv")
- if "visualizer" in self.config.callbacks.keys():
+ if "visualizer" in self.config.callbacks:
artifacts.extend(
glob.glob(os.path.join(self.config.callbacks.visualizer.output_path, "**/*"), recursive=True)
)
@@ -339,8 +343,8 @@ def __init__(
config: DictConfig,
model_path: str,
use_training_threshold: bool = False,
- device: Optional[str] = None,
- training_threshold_type: Optional[Literal["image", "pixel"]] = None,
+ device: str | None = None,
+ training_threshold_type: Literal["image", "pixel"] | None = None,
):
super().__init__(config=config, model_path=model_path, device=device)
diff --git a/quadra/tasks/base.py b/quadra/tasks/base.py
index 72a1edcc..60a35b80 100644
--- a/quadra/tasks/base.py
+++ b/quadra/tasks/base.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
import json
import os
from pathlib import Path
-from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
+from typing import Any, Generic, TypeVar
import hydra
import torch
@@ -34,7 +36,7 @@ def __init__(self, config: DictConfig):
self.config = config
self.export_folder: str = "deployment_model"
self._datamodule: DataModuleT
- self.metadata: Dict[str, Any]
+ self.metadata: dict[str, Any]
self.save_config()
def save_config(self) -> None:
@@ -103,7 +105,7 @@ class LightningTask(Generic[DataModuleT], Task[DataModuleT]):
def __init__(
self,
config: DictConfig,
- checkpoint_path: Optional[str] = None,
+ checkpoint_path: str | None = None,
run_test: bool = False,
report: bool = False,
):
@@ -112,9 +114,9 @@ def __init__(
self.run_test = run_test
self.report = report
self._module: LightningModule
- self._devices: Union[int, List[int]]
- self._callbacks: List[Callback]
- self._logger: List[Logger]
+ self._devices: int | list[int]
+ self._callbacks: list[Callback]
+ self._logger: list[Logger]
self._trainer: Trainer
def prepare(self) -> None:
@@ -160,7 +162,7 @@ def trainer(self, trainer_config: DictConfig) -> None:
self._trainer = trainer
@property
- def callbacks(self) -> List[Callback]:
+ def callbacks(self) -> list[Callback]:
"""List[Callback]: The callbacks."""
return self._callbacks
@@ -182,10 +184,9 @@ def callbacks(self, callbacks_config) -> None:
with open_dict(cb_conf):
del cb_conf.disable
- if not torch.cuda.is_available():
- # Skip the gpu stats logger callback if no gpu is available to avoid errors
- if cb_conf["_target_"] == "nvitop.callbacks.lightning.GpuStatsLogger":
- continue
+ # Skip the gpu stats logger callback if no gpu is available to avoid errors
+ if not torch.cuda.is_available() and cb_conf["_target_"] == "nvitop.callbacks.lightning.GpuStatsLogger":
+ continue
log.info("Instantiating callback <%s>", cb_conf["_target_"])
instatiated_callbacks.append(hydra.utils.instantiate(cb_conf))
@@ -194,7 +195,7 @@ def callbacks(self, callbacks_config) -> None:
log.warning("No callback found in configuration.")
@property
- def logger(self) -> List[Logger]:
+ def logger(self) -> list[Logger]:
"""List[Logger]: The loggers."""
return self._logger
@@ -219,7 +220,7 @@ def logger(self, logger_config) -> None:
log.warning("No logger found in configuration.")
@property
- def devices(self) -> Union[int, List[int]]:
+ def devices(self) -> int | list[int]:
"""List[int]: The devices ids."""
return self._devices
@@ -282,11 +283,12 @@ def finalize(self) -> None:
export_folder=self.export_folder,
)
- if not self.config.trainer.get("fast_dev_run"):
- if self.trainer.checkpoint_callback is not None and hasattr(
- self.trainer.checkpoint_callback, "best_model_path"
- ):
- log.info("Best model ckpt: %s", self.trainer.checkpoint_callback.best_model_path)
+ if (
+ not self.config.trainer.get("fast_dev_run")
+ and self.trainer.checkpoint_callback is not None
+ and hasattr(self.trainer.checkpoint_callback, "best_model_path")
+ ):
+ log.info("Best model ckpt: %s", self.trainer.checkpoint_callback.best_model_path)
def add_callback(self, callback: Callback):
"""Add a callback to the trainer.
@@ -334,7 +336,7 @@ def __init__(
self,
config: DictConfig,
model_path: str,
- device: Optional[str] = None,
+ device: str | None = None,
):
super().__init__(config=config)
@@ -344,7 +346,7 @@ def __init__(
self.device = device
self.config = config
- self.model_data: Dict[str, Any]
+ self.model_data: dict[str, Any]
self.model_path = model_path
self._deployment_model: BaseEvaluationModel
self.deployment_model_type: str
diff --git a/quadra/tasks/classification.py b/quadra/tasks/classification.py
index 2f0adec3..eb05fb52 100644
--- a/quadra/tasks/classification.py
+++ b/quadra/tasks/classification.py
@@ -6,7 +6,7 @@
import typing
from copy import deepcopy
from pathlib import Path
-from typing import Any, Dict, Generic, List, Optional, cast
+from typing import Any, Generic, cast
import cv2
import hydra
@@ -82,8 +82,8 @@ def __init__(
self,
config: DictConfig,
output: DictConfig,
- checkpoint_path: Optional[str] = None,
- lr_multiplier: Optional[float] = None,
+ checkpoint_path: str | None = None,
+ lr_multiplier: float | None = None,
gradcam: bool = False,
report: bool = False,
run_test: bool = False,
@@ -102,11 +102,11 @@ def __init__(
self._model: nn.Module
self._optimizer: torch.optim.Optimizer
self._scheduler: torch.optim.lr_scheduler._LRScheduler
- self.model_json: Optional[Dict[str, Any]] = None
+ self.model_json: dict[str, Any] | None = None
self.export_folder: str = "deployment_model"
self.deploy_info_file: str = "model.json"
self.report_confmat: pd.DataFrame
- self.best_model_path: Optional[str] = None
+ self.best_model_path: str | None = None
@property
def optimizer(self) -> torch.optim.Optimizer:
@@ -172,7 +172,7 @@ def module(self) -> ClassificationModule:
return self._module
@LightningTask.module.setter
- def module(self, module_config):
+ def module(self, module_config): # noqa: F811
"""Set the module of the model."""
module = hydra.utils.instantiate(
module_config,
@@ -435,7 +435,7 @@ def generate_report(self) -> None:
else:
utils.upload_file_tensorboard(a, tensorboard_logger)
- def freeze_layers_by_name(self, freeze_parameters_name: List[str]):
+ def freeze_layers_by_name(self, freeze_parameters_name: list[str]):
"""Freeze layers specified in freeze_parameters_name.
Args:
@@ -453,7 +453,7 @@ def freeze_layers_by_name(self, freeze_parameters_name: List[str]):
log.info("Frozen %d parameters", count_frozen)
- def freeze_parameters_by_index(self, freeze_parameters_index: List[int]):
+ def freeze_parameters_by_index(self, freeze_parameters_index: list[int]):
"""Freeze parameters specified in freeze_parameters_name.
Args:
@@ -507,7 +507,7 @@ def __init__(
self._backbone: ModelSignatureWrapper
self._trainer: SklearnClassificationTrainer
self._model: ClassifierMixin
- self.metadata: Dict[str, Any] = {
+ self.metadata: dict[str, Any] = {
"test_confusion_matrix": [],
"test_accuracy": [],
"test_results": [],
@@ -516,8 +516,8 @@ def __init__(
}
self.export_folder = "deployment_model"
self.deploy_info_file = "model.json"
- self.train_dataloader_list: List[torch.utils.data.DataLoader] = []
- self.test_dataloader_list: List[torch.utils.data.DataLoader] = []
+ self.train_dataloader_list: list[torch.utils.data.DataLoader] = []
+ self.test_dataloader_list: list[torch.utils.data.DataLoader] = []
self.automatic_batch_size = automatic_batch_size
self.save_model_summary = save_model_summary
self.half_precision = half_precision
@@ -866,7 +866,7 @@ class SklearnTestClassification(Evaluation[SklearnClassificationDataModuleT]):
"""
def __init__(
- self, # pylint: disable=W0613
+ self,
config: DictConfig,
output: DictConfig,
model_path: str,
@@ -879,10 +879,10 @@ def __init__(
self.output = output
self._backbone: BaseEvaluationModel
self._classifier: ClassifierMixin
- self.class_to_idx: Dict[str, int]
- self.idx_to_class: Dict[int, str]
+ self.class_to_idx: dict[str, int]
+ self.idx_to_class: dict[int, str]
self.test_dataloader: torch.utils.data.DataLoader
- self.metadata: Dict[str, Any] = {
+ self.metadata: dict[str, Any] = {
"test_confusion_matrix": None,
"test_accuracy": None,
"test_results": None,
@@ -1050,7 +1050,7 @@ def __init__(
model_path: str,
report: bool = True,
gradcam: bool = False,
- device: Optional[str] = None,
+ device: str | None = None,
):
super().__init__(config=config, model_path=model_path, device=device)
self.report_path = "test_output"
@@ -1199,7 +1199,7 @@ def test(self) -> None:
grayscale_cam = ndimage.zoom(grayscale_cam_low_res, zoom_factors, order=1)
grayscale_cams_list.append(torch.from_numpy(grayscale_cam))
- grayscale_cams: Optional[torch.Tensor] = None
+ grayscale_cams: torch.Tensor | None = None
if self.gradcam:
grayscale_cams = torch.cat(grayscale_cams_list, dim=0)
diff --git a/quadra/tasks/patch.py b/quadra/tasks/patch.py
index cc117a27..d48e17fc 100644
--- a/quadra/tasks/patch.py
+++ b/quadra/tasks/patch.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
import json
import os
from pathlib import Path
-from typing import Any, Dict, List, cast
+from typing import Any, cast
import hydra
import torch
@@ -48,11 +50,11 @@ def __init__(
self.device: str = device
self.output: DictConfig = output
self.return_polygon: bool = True
- self.reconstruction_results: Dict[str, Any]
+ self.reconstruction_results: dict[str, Any]
self._backbone: ModelSignatureWrapper
self._trainer: SklearnClassificationTrainer
self._model: ClassifierMixin
- self.metadata: Dict[str, Any] = {
+ self.metadata: dict[str, Any] = {
"test_confusion_matrix": [],
"test_accuracy": [],
"test_results": [],
@@ -131,9 +133,7 @@ def train(self) -> None:
self.datamodule.setup(stage="fit")
class_to_keep = None
if hasattr(self.datamodule, "class_to_skip_training") and self.datamodule.class_to_skip_training is not None:
- class_to_keep = [
- x for x in self.datamodule.class_to_idx.keys() if x not in self.datamodule.class_to_skip_training
- ]
+ class_to_keep = [x for x in self.datamodule.class_to_idx if x not in self.datamodule.class_to_skip_training]
self.model = self.config.model
self.trainer.change_classifier(self.model)
@@ -165,7 +165,7 @@ def generate_report(self) -> None:
idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
datamodule: PatchSklearnClassificationDataModule = self.datamodule
- val_img_info: List[PatchDatasetFileFormat] = datamodule.info.val_files
+ val_img_info: list[PatchDatasetFileFormat] = datamodule.info.val_files
for img_info in val_img_info:
if not os.path.isabs(img_info.image_path):
img_info.image_path = os.path.join(datamodule.data_path, img_info.image_path)
@@ -293,16 +293,16 @@ def __init__(
self.output = output
self._backbone: BaseEvaluationModel
self._classifier: ClassifierMixin
- self.class_to_idx: Dict[str, int]
- self.idx_to_class: Dict[int, str]
- self.metadata: Dict[str, Any] = {
+ self.class_to_idx: dict[str, int]
+ self.idx_to_class: dict[int, str]
+ self.metadata: dict[str, Any] = {
"test_confusion_matrix": None,
"test_accuracy": None,
"test_results": None,
"test_labels": None,
}
- self.class_to_skip: List[str] = []
- self.reconstruction_results: Dict[str, Any]
+ self.class_to_skip: list[str] = []
+ self.reconstruction_results: dict[str, Any]
self.return_polygon: bool = True
def prepare(self) -> None:
@@ -336,7 +336,7 @@ def test(self) -> None:
class_to_keep = None
if self.class_to_skip is not None:
- class_to_keep = [x for x in self.datamodule.class_to_idx.keys() if x not in self.class_to_skip]
+ class_to_keep = [x for x in self.datamodule.class_to_idx if x not in self.class_to_skip]
_, pd_cm, accuracy, res, _ = self.trainer.test(
test_dataloader=test_dataloader,
idx_to_class=self.idx_to_class,
diff --git a/quadra/tasks/segmentation.py b/quadra/tasks/segmentation.py
index 18974475..327ba620 100644
--- a/quadra/tasks/segmentation.py
+++ b/quadra/tasks/segmentation.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
import json
import os
import typing
-from typing import Any, Dict, Generic, List, Optional
+from typing import Any, Generic
import cv2
import hydra
@@ -42,9 +44,9 @@ def __init__(
self,
config: DictConfig,
num_viz_samples: int = 5,
- checkpoint_path: Optional[str] = None,
+ checkpoint_path: str | None = None,
run_test: bool = False,
- evaluate: Optional[DictConfig] = None,
+ evaluate: DictConfig | None = None,
report: bool = False,
):
super().__init__(
@@ -56,7 +58,7 @@ def __init__(
self.evaluate = evaluate
self.num_viz_samples = num_viz_samples
self.export_folder: str = "deployment_model"
- self.exported_model_path: Optional[str] = None
+ self.exported_model_path: str | None = None
if self.evaluate and any(self.evaluate.values()):
if (
self.config.export is None
@@ -86,13 +88,14 @@ def module(self, module_config) -> None:
"""Set the module."""
log.info("Instantiating model <%s>", module_config.model["_target_"])
- if isinstance(self.datamodule, SegmentationMulticlassDataModule):
- if module_config.model.num_classes != (len(self.datamodule.idx_to_class) + 1):
- log.warning(
- f"Number of classes in the model ({module_config.model.num_classes}) does not match the number of "
- + f"classes in the datamodule ({len(self.datamodule.idx_to_class)}). Updating the model..."
- )
- module_config.model.num_classes = len(self.datamodule.idx_to_class) + 1
+ if isinstance(self.datamodule, SegmentationMulticlassDataModule) and module_config.model.num_classes != (
+ len(self.datamodule.idx_to_class) + 1
+ ):
+ log.warning(
+ f"Number of classes in the model ({module_config.model.num_classes}) does not match the number of "
+ + f"classes in the datamodule ({len(self.datamodule.idx_to_class)}). Updating the model..."
+ )
+ module_config.model.num_classes = len(self.datamodule.idx_to_class) + 1
model = hydra.utils.instantiate(module_config.model)
model = ModelSignatureWrapper(model)
@@ -181,7 +184,7 @@ def generate_report(self) -> None:
"""Generate a report for the task."""
if self.evaluate is not None:
log.info("Generating evaluation report!")
- eval_tasks: List[SegmentationEvaluation] = []
+ eval_tasks: list[SegmentationEvaluation] = []
if self.evaluate.analysis:
if self.exported_model_path is None:
raise ValueError(
@@ -242,7 +245,7 @@ def __init__(
self,
config: DictConfig,
model_path: str,
- device: Optional[str] = "cpu",
+ device: str | None = "cpu",
):
super().__init__(config=config, model_path=model_path, device=device)
self.config = config
@@ -267,7 +270,7 @@ def prepare(self) -> None:
@torch.no_grad()
def inference(
self, dataloader: DataLoader, deployment_model: BaseEvaluationModel, device: torch.device
- ) -> Dict[str, torch.Tensor]:
+ ) -> dict[str, torch.Tensor]:
"""Run inference on the dataloader and return the output.
Args:
@@ -306,10 +309,10 @@ def __init__(
self,
config: DictConfig,
model_path: str,
- device: Optional[str] = None,
+ device: str | None = None,
):
super().__init__(config=config, model_path=model_path, device=device)
- self.test_output: Dict[str, Any] = {}
+ self.test_output: dict[str, Any] = {}
def train(self) -> None:
"""Skip training."""
@@ -325,8 +328,8 @@ def test(self) -> None:
"""Run testing."""
log.info("Starting inference for analysis.")
- stages: List[str] = []
- dataloaders: List[torch.utils.data.DataLoader] = []
+ stages: list[str] = []
+ dataloaders: list[torch.utils.data.DataLoader] = []
# if self.datamodule.train_dataset_available:
# stages.append("train")
diff --git a/quadra/tasks/ssl.py b/quadra/tasks/ssl.py
index 4638a2c6..63817a15 100644
--- a/quadra/tasks/ssl.py
+++ b/quadra/tasks/ssl.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
import json
import os
-from typing import Any, List, Optional, Tuple, cast
+from typing import Any, cast
import hydra
import torch
@@ -36,7 +38,7 @@ def __init__(
config: DictConfig,
run_test: bool = False,
report: bool = False,
- checkpoint_path: Optional[str] = None,
+ checkpoint_path: str | None = None,
):
super().__init__(
config=config,
@@ -49,7 +51,7 @@ def __init__(
self._lr_scheduler: torch.optim.lr_scheduler._LRScheduler
self.export_folder = "deployment_model"
- def learnable_parameters(self) -> List[nn.Parameter]:
+ def learnable_parameters(self) -> list[nn.Parameter]:
"""Get the learnable parameters."""
raise NotImplementedError("This method must be implemented by the subclass")
@@ -127,7 +129,7 @@ class Simsiam(SSL):
def __init__(
self,
config: DictConfig,
- checkpoint_path: Optional[str] = None,
+ checkpoint_path: str | None = None,
run_test: bool = False,
):
super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
@@ -135,7 +137,7 @@ def __init__(
self.projection_mlp: nn.Module
self.prediction_mlp: nn.Module
- def learnable_parameters(self) -> List[nn.Parameter]:
+ def learnable_parameters(self) -> list[nn.Parameter]:
"""Get the learnable parameters."""
return list(
list(self.backbone.parameters())
@@ -196,14 +198,14 @@ class SimCLR(SSL):
def __init__(
self,
config: DictConfig,
- checkpoint_path: Optional[str] = None,
+ checkpoint_path: str | None = None,
run_test: bool = False,
):
super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
self.backbone: nn.Module
self.projection_mlp: nn.Module
- def learnable_parameters(self) -> List[nn.Parameter]:
+ def learnable_parameters(self) -> list[nn.Parameter]:
"""Get the learnable parameters."""
return list(self.backbone.parameters()) + list(self.projection_mlp.parameters())
@@ -257,7 +259,7 @@ class Barlow(SimCLR):
def __init__(
self,
config: DictConfig,
- checkpoint_path: Optional[str] = None,
+ checkpoint_path: str | None = None,
run_test: bool = False,
):
super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
@@ -295,7 +297,7 @@ class BYOL(SSL):
def __init__(
self,
config: DictConfig,
- checkpoint_path: Optional[str] = None,
+ checkpoint_path: str | None = None,
run_test: bool = False,
**kwargs: Any,
):
@@ -311,7 +313,7 @@ def __init__(
self.student_prediction_mlp: nn.Module
self.teacher_projection_mlp: nn.Module
- def learnable_parameters(self) -> List[nn.Parameter]:
+ def learnable_parameters(self) -> list[nn.Parameter]:
"""Get the learnable parameters."""
return list(
list(self.student_model.parameters())
@@ -377,7 +379,7 @@ class DINO(SSL):
def __init__(
self,
config: DictConfig,
- checkpoint_path: Optional[str] = None,
+ checkpoint_path: str | None = None,
run_test: bool = False,
):
super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
@@ -386,7 +388,7 @@ def __init__(
self.student_projection_mlp: nn.Module
self.teacher_projection_mlp: nn.Module
- def learnable_parameters(self) -> List[nn.Parameter]:
+ def learnable_parameters(self) -> list[nn.Parameter]:
"""Get the learnable parameters."""
return list(
list(self.student_model.parameters()) + list(self.student_projection_mlp.parameters()),
@@ -451,7 +453,7 @@ def __init__(
config: DictConfig,
model_path: str,
report_folder: str = "embeddings",
- embedding_image_size: Optional[int] = None,
+ embedding_image_size: int | None = None,
):
super().__init__(config=config)
@@ -515,7 +517,7 @@ def test(self) -> None:
self.datamodule.setup("test")
dataloader = self.datamodule.test_dataloader()
images = []
- metadata: List[Tuple[int, str, str]] = []
+ metadata: list[tuple[int, str, str]] = []
embeddings = []
std = torch.tensor(self.config.transforms.std).view(1, -1, 1, 1)
mean = torch.tensor(self.config.transforms.mean).view(1, -1, 1, 1)
diff --git a/quadra/trainers/classification.py b/quadra/trainers/classification.py
index cca99515..c3bb3c3a 100644
--- a/quadra/trainers/classification.py
+++ b/quadra/trainers/classification.py
@@ -1,4 +1,6 @@
-from typing import Dict, List, Optional, Tuple, Union, cast
+from __future__ import annotations
+
+from typing import cast
import numpy as np
import pandas as pd
@@ -29,7 +31,7 @@ class SklearnClassificationTrainer:
def __init__(
self,
- input_shape: List,
+ input_shape: list,
backbone: torch.nn.Module,
random_state: int = 42,
classifier: ClassifierMixin = LogisticRegression,
@@ -58,9 +60,9 @@ def change_classifier(self, classifier: ClassifierMixin):
def fit(
self,
- train_dataloader: Optional[DataLoader] = None,
- train_features: Optional[ndarray] = None,
- train_labels: Optional[ndarray] = None,
+ train_dataloader: DataLoader | None = None,
+ train_features: ndarray | None = None,
+ train_labels: ndarray | None = None,
):
"""Fit classifier on training set."""
# Extract feature
@@ -91,16 +93,16 @@ def fit(
def test(
self,
test_dataloader: DataLoader,
- test_labels: Optional[ndarray] = None,
- test_features: Optional[ndarray] = None,
- class_to_keep: Optional[List[int]] = None,
- idx_to_class: Optional[Dict[int, str]] = None,
+ test_labels: ndarray | None = None,
+ test_features: ndarray | None = None,
+ class_to_keep: list[int] | None = None,
+ idx_to_class: dict[int, str] | None = None,
predict_proba: bool = True,
gradcam: bool = False,
- ) -> Union[
- Tuple[Union[str, Dict], DataFrame, float, DataFrame, Optional[np.ndarray]],
- Tuple[None, None, None, DataFrame, Optional[np.ndarray]],
- ]:
+ ) -> (
+ tuple[str | dict, DataFrame, float, DataFrame, np.ndarray | None]
+ | tuple[None, None, None, DataFrame, np.ndarray | None]
+ ):
"""Test classifier on test set.
Args:
@@ -148,7 +150,7 @@ def test(
raise ValueError("You must provide `idx_to_class` and `test_labels` when using `class_to_keep`")
filtered_test_labels = [int(x) if idx_to_class[x] in class_to_keep else -1 for x in final_test_labels]
else:
- filtered_test_labels = cast(List[int], final_test_labels.tolist())
+ filtered_test_labels = cast(list[int], final_test_labels.tolist())
if not hasattr(test_dataloader.dataset, "x"):
raise ValueError("Current dataset doesn't provide an `x` attribute")
diff --git a/quadra/utils/anomaly.py b/quadra/utils/anomaly.py
index 779b3924..b8d46a6a 100644
--- a/quadra/utils/anomaly.py
+++ b/quadra/utils/anomaly.py
@@ -10,7 +10,7 @@
except ImportError:
from typing import Any
- from typing_extensions import TypeAlias
+ from typing_extensions import TypeAlias # noqa
# MyPy wants TypeAlias, but pylint has problems dealing with it
@@ -100,10 +100,12 @@ def _normalize_batch(self, outputs, pl_module):
"""Normalize a batch of predictions."""
image_threshold = pl_module.image_threshold.value.cpu()
pixel_threshold = pl_module.pixel_threshold.value.cpu()
- outputs["pred_scores"] = normalize_anomaly_score(outputs["pred_scores"], image_threshold)
+ outputs["pred_scores"] = normalize_anomaly_score(outputs["pred_scores"], image_threshold.item())
threshold = pixel_threshold if self.threshold_type == "pixel" else image_threshold
- if "anomaly_maps" in outputs.keys():
+ threshold = threshold.item()
+
+ if "anomaly_maps" in outputs:
outputs["anomaly_maps"] = normalize_anomaly_score(outputs["anomaly_maps"], threshold)
if "box_scores" in outputs:
diff --git a/quadra/utils/classification.py b/quadra/utils/classification.py
index 58682962..7fb17815 100644
--- a/quadra/utils/classification.py
+++ b/quadra/utils/classification.py
@@ -4,7 +4,8 @@
import os
import random
import re
-from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple, Union
+from collections.abc import Generator, Sequence
+from typing import TYPE_CHECKING, Any
import matplotlib.pyplot as plt
import numpy as np
@@ -28,7 +29,7 @@
def get_file_condition(
- file_name: str, root: str, exclude_filter: Optional[List[str]] = None, include_filter: Optional[List[str]] = None
+ file_name: str, root: str, exclude_filter: list[str] | None = None, include_filter: list[str] | None = None
):
"""Check if a file should be included or excluded based on the filters provided.
@@ -45,9 +46,10 @@ def get_file_condition(
if any(fil in root for fil in exclude_filter):
return False
- if include_filter is not None:
- if not any(fil in file_name for fil in include_filter) and not any(fil in root for fil in include_filter):
- return False
+ if include_filter is not None and (
+ not any(fil in file_name for fil in include_filter) and not any(fil in root for fil in include_filter)
+ ):
+ return False
return True
@@ -59,14 +61,14 @@ def natural_key(string_):
def find_images_and_targets(
folder: str,
- types: Optional[list] = None,
- class_to_idx: Optional[Dict[str, int]] = None,
+ types: list | None = None,
+ class_to_idx: dict[str, int] | None = None,
leaf_name_only: bool = True,
sort: bool = True,
- exclude_filter: Optional[list] = None,
- include_filter: Optional[list] = None,
- label_map: Optional[Dict[str, Any]] = None,
-) -> Tuple[np.ndarray, np.ndarray, dict]:
+ exclude_filter: list | None = None,
+ include_filter: list | None = None,
+ label_map: dict[str, Any] | None = None,
+) -> tuple[np.ndarray, np.ndarray, dict]:
"""Given a folder, extract the absolute path of all the files with a valid extension.
Then assign a label based on subfolder name.
@@ -125,7 +127,7 @@ def find_images_and_targets(
if class_to_idx is None:
# building class index
unique_labels = set(labels)
- sorted_labels = list(sorted(unique_labels, key=natural_key))
+ sorted_labels = sorted(unique_labels, key=natural_key)
class_to_idx = {str(c): idx for idx, c in enumerate(sorted_labels)}
images_and_targets = [(f, l) for f, l in zip(filenames, labels) if l in class_to_idx]
@@ -138,13 +140,13 @@ def find_images_and_targets(
def find_test_image(
folder: str,
- types: Optional[List[str]] = None,
- exclude_filter: Optional[List[str]] = None,
- include_filter: Optional[List[str]] = None,
+ types: list[str] | None = None,
+ exclude_filter: list[str] | None = None,
+ include_filter: list[str] | None = None,
include_none_class: bool = True,
- test_split_file: Optional[str] = None,
+ test_split_file: str | None = None,
label_map=None,
-) -> Tuple[List[str], List[Optional[str]]]:
+) -> tuple[list[str], list[str | None]]:
"""Given a path extract images and labels with filters, labels are based on the parent folder name of the images
Args:
folder: root directory containing the images
@@ -163,11 +165,8 @@ def find_test_image(
filenames = []
for root, _, files in os.walk(folder, topdown=False, followlinks=True):
- if root != folder:
- rel_path = os.path.relpath(root, folder)
- else:
- rel_path = ""
- label: Optional[str] = os.path.basename(rel_path)
+ rel_path = os.path.relpath(root, folder) if root != folder else ""
+ label: str | None = os.path.basename(rel_path)
for f in files:
if not get_file_condition(
file_name=f, root=root, exclude_filter=exclude_filter, include_filter=include_filter
@@ -195,7 +194,7 @@ def find_test_image(
if not os.path.exists(test_split_file):
raise FileNotFoundError(f"test_split_file {test_split_file} does not exist")
- with open(test_split_file, "r") as test_file:
+ with open(test_split_file) as test_file:
test_split = test_file.read().splitlines()
file_samples = []
@@ -230,9 +229,7 @@ def find_test_image(
return filenames, labels
-def group_labels(
- labels: Sequence[Optional[str]], class_mapping: Dict[str, Union[Optional[str], List[str]]]
-) -> Tuple[List, Dict]:
+def group_labels(labels: Sequence[str | None], class_mapping: dict[str, str | None | list[str]]) -> tuple[list, dict]:
"""Group labels based on class_mapping.
Raises:
@@ -258,8 +255,8 @@ def group_labels(
```
"""
grouped_labels = []
- specified_targets = [k for k in class_mapping.keys() if class_mapping[k] is not None]
- non_specified_targets = [k for k in class_mapping.keys() if class_mapping[k] is None]
+ specified_targets = [k for k in class_mapping if class_mapping[k] is not None]
+ non_specified_targets = [k for k in class_mapping if class_mapping[k] is None]
if len(non_specified_targets) > 1:
raise ValueError(f"More than one non specified target: {non_specified_targets}")
for label in labels:
@@ -282,7 +279,7 @@ def group_labels(
return grouped_labels, class_to_idx
-def filter_with_file(list_of_full_paths: List[str], file_path: str, root_path: str) -> Tuple[List[str], List[bool]]:
+def filter_with_file(list_of_full_paths: list[str], file_path: str, root_path: str) -> tuple[list[str], list[bool]]:
"""Filter a list of items using a file containing the items to keep. Paths inside file
should be relative to root_path not absolute to avoid user related issues.
@@ -298,7 +295,7 @@ def filter_with_file(list_of_full_paths: List[str], file_path: str, root_path: s
filtered_full_paths = []
filter_mask = []
- with open(file_path, "r") as f:
+ with open(file_path) as f:
for relative_path in f.read().splitlines():
full_path = os.path.join(root_path, relative_path)
if full_path in list_of_full_paths:
@@ -312,17 +309,17 @@ def filter_with_file(list_of_full_paths: List[str], file_path: str, root_path: s
def get_split(
image_dir: str,
- exclude_filter: Optional[List[str]] = None,
- include_filter: Optional[List[str]] = None,
+ exclude_filter: list[str] | None = None,
+ include_filter: list[str] | None = None,
test_size: float = 0.3,
random_state: int = 42,
- class_to_idx: Optional[Dict[str, int]] = None,
- label_map: Optional[Dict] = None,
+ class_to_idx: dict[str, int] | None = None,
+ label_map: dict | None = None,
n_splits: int = 1,
include_none_class: bool = False,
- limit_training_data: Optional[int] = None,
- train_split_file: Optional[str] = None,
-) -> Tuple[np.ndarray, np.ndarray, Generator[List, None, None], Dict]:
+ limit_training_data: int | None = None,
+ train_split_file: str | None = None,
+) -> tuple[np.ndarray, np.ndarray, Generator[list, None, None], dict]:
"""Given a folder, extract the absolute path of all the files with a valid extension and name
and split them into train/test.
@@ -364,7 +361,7 @@ def get_split(
class_to_idx.pop(_cl)
if train_split_file is not None:
- with open(train_split_file, "r") as f:
+ with open(train_split_file) as f:
train_split = f.read().splitlines()
file_samples = []
@@ -413,9 +410,9 @@ def save_classification_result(
test_dataloader: DataLoader,
config: DictConfig,
output: DictConfig,
- accuracy: Optional[float] = None,
- confmat: Optional[pd.DataFrame] = None,
- grayscale_cams: Optional[np.ndarray] = None,
+ accuracy: float | None = None,
+ confmat: pd.DataFrame | None = None,
+ grayscale_cams: np.ndarray | None = None,
):
"""Save csv results, confusion matrix and example images.
@@ -468,9 +465,8 @@ def save_classification_result(
os.makedirs(original_images_folder)
gradcam_folder = os.path.join(images_folder, "gradcam")
- if save_gradcams:
- if not os.path.isdir(gradcam_folder):
- os.makedirs(gradcam_folder)
+ if save_gradcams and not os.path.isdir(gradcam_folder):
+ os.makedirs(gradcam_folder)
for v in np.unique([results["real_label"], results["pred_label"]]):
if np.isnan(v) or v == -1:
@@ -518,11 +514,11 @@ def save_classification_result(
def get_results(
- test_labels: Union[np.ndarray, List[int]],
- pred_labels: Union[np.ndarray, List[int]],
- idx_to_labels: Optional[Dict] = None,
+ test_labels: np.ndarray | list[int],
+ pred_labels: np.ndarray | list[int],
+ idx_to_labels: dict | None = None,
cl_rep_digits: int = 3,
-) -> Tuple[Union[str, Dict], pd.DataFrame, float]:
+) -> tuple[str | dict, pd.DataFrame, float]:
"""Get prediction results from predicted and test labels.
Args:
diff --git a/quadra/utils/deprecation.py b/quadra/utils/deprecation.py
index dfdd4fe5..9c6ebd1a 100644
--- a/quadra/utils/deprecation.py
+++ b/quadra/utils/deprecation.py
@@ -1,5 +1,5 @@
import functools
-from typing import Callable
+from collections.abc import Callable
from quadra.utils.utils import get_logger
diff --git a/quadra/utils/evaluation.py b/quadra/utils/evaluation.py
index cf04ca97..35a3e04a 100644
--- a/quadra/utils/evaluation.py
+++ b/quadra/utils/evaluation.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
import os
from ast import literal_eval
+from collections.abc import Callable
from functools import wraps
-from typing import Callable, Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
@@ -33,7 +35,7 @@ def dice(
target: torch.Tensor,
smooth: float = 1.0,
eps: float = 1e-8,
- reduction: Optional[str] = "mean",
+ reduction: str | None = "mean",
) -> torch.Tensor:
"""Dice loss computation function.
@@ -94,12 +96,12 @@ def calculate_mask_based_metrics(
show_orj_predictions: bool = False,
metric: Callable = score_dice,
multilabel: bool = False,
- n_classes: Optional[int] = None,
-) -> Tuple[
- Dict[str, float],
- Dict[str, List[np.ndarray]],
- Dict[str, List[np.ndarray]],
- Dict[str, List[Union[str, float]]],
+ n_classes: int | None = None,
+) -> tuple[
+ dict[str, float],
+ dict[str, list[np.ndarray]],
+ dict[str, list[np.ndarray]],
+ dict[str, list[str | float]],
]:
"""Calculate metrics based on masks and predictions.
@@ -154,13 +156,13 @@ def calculate_mask_based_metrics(
result["num_good_image"] = 0
result["num_bad_image"] = 0
bad_dice, good_dice = [], []
- fg: Dict[str, List[np.ndarray]] = {"image": [], "mask": [], "thresh_pred": []}
- fb: Dict[str, List[np.ndarray]] = {"image": [], "mask": [], "thresh_pred": []}
+ fg: dict[str, list[np.ndarray]] = {"image": [], "mask": [], "thresh_pred": []}
+ fb: dict[str, list[np.ndarray]] = {"image": [], "mask": [], "thresh_pred": []}
if show_orj_predictions:
fg["pred"] = []
fb["pred"] = []
- area_graph: Dict[str, List[Union[str, float]]] = {
+ area_graph: dict[str, list[str | float]] = {
"Defect Area Percentage": [],
"Accuracy": [],
}
@@ -220,7 +222,7 @@ def calculate_mask_based_metrics(
def create_mask_report(
stage: str,
- output: Dict[str, torch.Tensor],
+ output: dict[str, torch.Tensor],
mean: npt.ArrayLike,
std: npt.ArrayLike,
report_path: str,
@@ -231,7 +233,7 @@ def create_mask_report(
threshold: float = 0.5,
metric: Callable = score_dice,
show_orj_predictions: bool = False,
-) -> List[str]:
+) -> list[str]:
"""Create report for segmentation experiment
Args:
stage: stage name. Train, validation or test
diff --git a/quadra/utils/export.py b/quadra/utils/export.py
index 90d43784..34807e55 100644
--- a/quadra/utils/export.py
+++ b/quadra/utils/export.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
import os
-from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, TypeVar, Union, cast
+from collections.abc import Sequence
+from typing import Any, Literal, TypeVar, cast
import torch
from anomalib.models.cflow import CflowLightning
@@ -29,12 +32,12 @@
def generate_torch_inputs(
- input_shapes: List[Any],
- device: Union[str, torch.device],
+ input_shapes: list[Any],
+ device: str | torch.device,
half_precision: bool = False,
dtype: torch.dtype = torch.float32,
batch_size: int = 1,
-) -> Union[List[Any], Tuple[Any, ...], torch.Tensor]:
+) -> list[Any] | tuple[Any, ...] | torch.Tensor:
"""Given a list of input shapes that can contain either lists, tuples or dicts, with tuples being the input shapes
of the model, generate a list of torch tensors with the given device and dtype.
"""
@@ -71,11 +74,11 @@ def generate_torch_inputs(
def extract_torch_model_inputs(
- model: Union[nn.Module, ModelSignatureWrapper],
- input_shapes: Optional[List[Any]] = None,
+ model: nn.Module | ModelSignatureWrapper,
+ input_shapes: list[Any] | None = None,
half_precision: bool = False,
batch_size: int = 1,
-) -> Optional[Tuple[Union[List[Any], Tuple[Any, ...], torch.Tensor], List[Any]]]:
+) -> tuple[list[Any] | tuple[Any, ...] | torch.Tensor, list[Any]] | None:
"""Extract the input shapes for the given model and generate a list of torch tensors with the
given device and dtype.
@@ -85,9 +88,8 @@ def extract_torch_model_inputs(
half_precision: If True, the model will be exported with half precision
batch_size: Batch size for the input shapes
"""
- if isinstance(model, ModelSignatureWrapper):
- if input_shapes is None:
- input_shapes = model.input_shapes
+ if isinstance(model, ModelSignatureWrapper) and input_shapes is None:
+ input_shapes = model.input_shapes
if input_shapes is None:
log.warning(
@@ -112,10 +114,10 @@ def extract_torch_model_inputs(
def export_torchscript_model(
model: nn.Module,
output_path: str,
- input_shapes: Optional[List[Any]] = None,
+ input_shapes: list[Any] | None = None,
half_precision: bool = False,
model_name: str = "model.pt",
-) -> Optional[Tuple[str, Any]]:
+) -> tuple[str, Any] | None:
"""Export a PyTorch model with TorchScript.
Args:
@@ -175,10 +177,10 @@ def export_onnx_model(
model: nn.Module,
output_path: str,
onnx_config: DictConfig,
- input_shapes: Optional[List[Any]] = None,
+ input_shapes: list[Any] | None = None,
half_precision: bool = False,
model_name: str = "model.onnx",
-) -> Optional[Tuple[str, Any]]:
+) -> tuple[str, Any] | None:
"""Export a PyTorch model with ONNX.
Args:
@@ -240,16 +242,15 @@ def export_onnx_model(
if hasattr(onnx_config, "fixed_batch_size") and onnx_config.fixed_batch_size is not None:
dynamic_axes = None
- else:
- if dynamic_axes is None:
- dynamic_axes = {}
- for i, _ in enumerate(input_names):
- dynamic_axes[input_names[i]] = {0: "batch_size"}
+ elif dynamic_axes is None:
+ dynamic_axes = {}
+ for i, _ in enumerate(input_names):
+ dynamic_axes[input_names[i]] = {0: "batch_size"}
- for i, _ in enumerate(output_names):
- dynamic_axes[output_names[i]] = {0: "batch_size"}
+ for i, _ in enumerate(output_names):
+ dynamic_axes[output_names[i]] = {0: "batch_size"}
- onnx_config = cast(Dict[str, Any], OmegaConf.to_container(onnx_config, resolve=True))
+ onnx_config = cast(dict[str, Any], OmegaConf.to_container(onnx_config, resolve=True))
onnx_config["input_names"] = input_names
onnx_config["output_names"] = output_names
@@ -331,10 +332,10 @@ def export_model(
model: Any,
export_folder: str,
half_precision: bool,
- input_shapes: Optional[List[Any]] = None,
- idx_to_class: Optional[Dict[int, str]] = None,
+ input_shapes: list[Any] | None = None,
+ idx_to_class: dict[int, str] | None = None,
pytorch_model_type: Literal["backbone", "model"] = "model",
-) -> Tuple[Dict[str, Any], Dict[str, str]]:
+) -> tuple[dict[str, Any], dict[str, str]]:
"""Generate deployment models for the task.
Args:
@@ -427,7 +428,7 @@ def import_deployment_model(
model_path: str,
inference_config: DictConfig,
device: str,
- model_architecture: Optional[nn.Module] = None,
+ model_architecture: nn.Module | None = None,
) -> BaseEvaluationModel:
"""Try to import a model for deployment, currently only supports torchscript .pt files and
state dictionaries .pth files.
@@ -444,7 +445,7 @@ def import_deployment_model(
log.info("Importing trained model")
file_extension = os.path.splitext(os.path.basename(model_path))[1]
- deployment_model: Optional[BaseEvaluationModel] = None
+ deployment_model: BaseEvaluationModel | None = None
if file_extension == ".pt":
deployment_model = TorchscriptEvaluationModel(config=inference_config.torchscript)
diff --git a/quadra/utils/imaging.py b/quadra/utils/imaging.py
index affffd99..6d83c48a 100644
--- a/quadra/utils/imaging.py
+++ b/quadra/utils/imaging.py
@@ -1,10 +1,10 @@
-from typing import Tuple
+from __future__ import annotations
import cv2
import numpy as np
-def crop_image(image: np.ndarray, roi: Tuple[int, int, int, int]) -> np.ndarray:
+def crop_image(image: np.ndarray, roi: tuple[int, int, int, int]) -> np.ndarray:
"""Crop an image given a roi in proper format.
Args:
diff --git a/quadra/utils/mlflow.py b/quadra/utils/mlflow.py
index 2ea1b28f..33549576 100644
--- a/quadra/utils/mlflow.py
+++ b/quadra/utils/mlflow.py
@@ -8,7 +8,8 @@
except ImportError:
MLFLOW_AVAILABLE = False
-from typing import Any, Sequence
+from collections.abc import Sequence
+from typing import Any
import torch
from pytorch_lightning import Trainer
diff --git a/quadra/utils/model_manager.py b/quadra/utils/model_manager.py
index b2b1696a..baee6d65 100644
--- a/quadra/utils/model_manager.py
+++ b/quadra/utils/model_manager.py
@@ -256,9 +256,8 @@ def register_best_model(
if mode == "max":
if run.data.metrics[metric] > best_run.data.metrics[metric]:
best_run = run
- else:
- if run.data.metrics[metric] < best_run.data.metrics[metric]:
- best_run = run
+ elif run.data.metrics[metric] < best_run.data.metrics[metric]:
+ best_run = run
if best_run is None:
log.error("No runs found for experiment %s with the given metric", experiment_name)
diff --git a/quadra/utils/models.py b/quadra/utils/models.py
index 9de61d17..e78cb10b 100644
--- a/quadra/utils/models.py
+++ b/quadra/utils/models.py
@@ -1,6 +1,9 @@
+from __future__ import annotations
+
import math
import warnings
-from typing import Callable, List, Optional, Tuple, Type, Union, cast
+from collections.abc import Callable
+from typing import Union, cast
import numpy as np
import timm
@@ -40,7 +43,7 @@ def net_hat(input_size: int, output_size: int) -> torch.nn.Sequential:
return torch.nn.Sequential(torch.nn.Linear(input_size, output_size))
-def create_net_hat(dims: List[int], act_fun: Callable = torch.nn.ReLU, dropout_p: float = 0) -> torch.nn.Sequential:
+def create_net_hat(dims: list[int], act_fun: Callable = torch.nn.ReLU, dropout_p: float = 0) -> torch.nn.Sequential:
"""Create a sequence of linear layers with activation functions and dropout.
Args:
@@ -52,7 +55,7 @@ def create_net_hat(dims: List[int], act_fun: Callable = torch.nn.ReLU, dropout_p
Sequence of linear layers of dimension specified by the input, each linear layer is followed
by an activation function and optionally a dropout layer with the input probability
"""
- components: List[nn.Module] = []
+ components: list[nn.Module] = []
for i, _ in enumerate(dims[:-2]):
if dropout_p > 0:
components.append(torch.nn.Dropout(dropout_p))
@@ -85,14 +88,14 @@ def init_weights(m):
def get_feature(
- feature_extractor: Union[torch.nn.Module, BaseEvaluationModel],
+ feature_extractor: torch.nn.Module | BaseEvaluationModel,
dl: torch.utils.data.DataLoader,
iteration_over_training: int = 1,
gradcam: bool = False,
- classifier: Optional[ClassifierMixin] = None,
- input_shape: Optional[Tuple[int, int, int]] = None,
- limit_batches: Optional[int] = None,
-) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
+ classifier: ClassifierMixin | None = None,
+ input_shape: tuple[int, int, int] | None = None,
+ limit_batches: int | None = None,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
"""Given a dataloader and a PyTorch model, extract features with the model and return features and labels.
Args:
@@ -132,9 +135,9 @@ def get_feature(
)
for p in feature_extractor.features_extractor.layer4[-1].parameters():
p.requires_grad = True
- elif is_vision_transformer(feature_extractor.features_extractor): # type: ignore[arg-type]
+ elif is_vision_transformer(feature_extractor.features_extractor):
grad_rollout = VitAttentionGradRollout(
- feature_extractor.features_extractor, # type: ignore[arg-type]
+ feature_extractor.features_extractor,
classifier=classifier,
example_input=None if input_shape is None else torch.randn(1, *input_shape),
)
@@ -159,10 +162,10 @@ def get_feature(
if gradcam:
y_hat = cast(
- Union[List[torch.Tensor], Tuple[torch.Tensor], torch.Tensor], feature_extractor(x1).detach()
+ Union[list[torch.Tensor], tuple[torch.Tensor], torch.Tensor], feature_extractor(x1).detach()
)
# mypy can't detect that gradcam is true only if we have a features_extractor
- if is_vision_transformer(feature_extractor.features_extractor): # type: ignore[union-attr, arg-type]
+ if is_vision_transformer(feature_extractor.features_extractor): # type: ignore[union-attr]
grayscale_cam_low_res = grad_rollout(
input_tensor=x1, targets_list=y1
) # TODO: We are using labels (y1) but it would be better to use preds
@@ -175,7 +178,7 @@ def get_feature(
feature_extractor.zero_grad(set_to_none=True) # type: ignore[union-attr]
else:
with torch.no_grad():
- y_hat = cast(Union[List[torch.Tensor], Tuple[torch.Tensor], torch.Tensor], feature_extractor(x1))
+ y_hat = cast(Union[list[torch.Tensor], tuple[torch.Tensor], torch.Tensor], feature_extractor(x1))
grayscale_cams = None
if isinstance(y_hat, (list, tuple)):
@@ -275,7 +278,7 @@ def trunc_normal_(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a:
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
-def clip_gradients(model: nn.Module, clip: float) -> List[float]:
+def clip_gradients(model: nn.Module, clip: float) -> list[float]:
"""Args:
model: The model
clip: The clip value.
@@ -319,9 +322,7 @@ def clear(self):
"""Clear the grabbed attentions."""
self.attentions = torch.zeros((1, 0))
- def get_attention(
- self, module: nn.Module, input_tensor: torch.Tensor, output: torch.Tensor
- ): # pylint: disable=unused-argument
+ def get_attention(self, module: nn.Module, input_tensor: torch.Tensor, output: torch.Tensor): # pylint: disable=unused-argument
"""Method to be registered to grab attentions."""
self.attentions = output.detach().clone().cpu()
@@ -357,7 +358,7 @@ def process_attention_maps(attentions: torch.Tensor, img_width: int, img_height:
attentions = F.interpolate(attentions, scale_factor=patch_size, mode="nearest")
return attentions
- def forward(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ def forward(self, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self.clear()
out = self.model(t)
return (out, self.attentions) # torch.jit.trace does not complain
@@ -378,7 +379,7 @@ class PositionalEncoding1D(torch.nn.Module):
def __init__(self, d_model: int, temperature: float = 10000.0, dropout: float = 0.0, max_len: int = 5000):
super().__init__()
- self.dropout: Union[torch.nn.Dropout, torch.nn.Identity]
+ self.dropout: torch.nn.Dropout | torch.nn.Identity
if dropout > 0:
self.dropout = torch.nn.Dropout(p=dropout)
else:
@@ -431,8 +432,8 @@ def __init__(
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
- act_layer: Type[nn.Module] = torch.nn.GELU,
- norm_layer: Type[torch.nn.LayerNorm] = torch.nn.LayerNorm,
+ act_layer: type[nn.Module] = torch.nn.GELU,
+ norm_layer: type[torch.nn.LayerNorm] = torch.nn.LayerNorm,
mask_diagonal: bool = True,
learnable_temperature: bool = True,
):
diff --git a/quadra/utils/patch/dataset.py b/quadra/utils/patch/dataset.py
index d607f37f..265beff8 100644
--- a/quadra/utils/patch/dataset.py
+++ b/quadra/utils/patch/dataset.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import glob
import itertools
import json
@@ -6,17 +8,18 @@
import random
import shutil
import warnings
+from collections.abc import Callable
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from multiprocessing import Pool
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any
import cv2
import h5py
import numpy as np
from scipy import ndimage
-from skimage.measure import label, regionprops
+from skimage.measure import label, regionprops # pylint: disable=no-name-in-module
from skimage.util import view_as_windows
from skmultilearn.model_selection import iterative_train_test_split
from tqdm import tqdm
@@ -32,29 +35,31 @@ class PatchDatasetFileFormat:
"""Model representing the content of the patch dataset split_files field in the info.json file."""
image_path: str
- mask_path: Optional[str] = None
+ mask_path: str | None = None
@dataclass
class PatchDatasetInfo:
"""Model representing the content of the patch dataset info.json file."""
- patch_size: Optional[Tuple[int, int]]
- patch_number: Optional[Tuple[int, int]]
- annotated_good: Optional[List[int]]
+ patch_size: tuple[int, int] | None
+ patch_number: tuple[int, int] | None
+ annotated_good: list[int] | None
overlap: float
- train_files: List[PatchDatasetFileFormat]
- val_files: List[PatchDatasetFileFormat]
- test_files: List[PatchDatasetFileFormat]
+ train_files: list[PatchDatasetFileFormat]
+ val_files: list[PatchDatasetFileFormat]
+ test_files: list[PatchDatasetFileFormat]
@staticmethod
- def _map_files(files: List[Any]):
+ def _map_files(files: list[Any]):
"""Convert a list of dict to a list of PatchDatasetFileFormat."""
mapped_files = []
for file in files:
+ current_file = file
if isinstance(file, dict):
- file = PatchDatasetFileFormat(**file)
- mapped_files.append(file)
+ current_file = PatchDatasetFileFormat(**current_file)
+ mapped_files.append(current_file)
+
return mapped_files
def __post_init__(self):
@@ -65,10 +70,10 @@ def __post_init__(self):
def get_image_mask_association(
data_folder: str,
- mask_folder: Optional[str] = None,
+ mask_folder: str | None = None,
mask_extension: str = "",
warning_on_missing_mask: bool = True,
-) -> List[Dict]:
+) -> list[dict]:
"""Function used to match images and mask from a folder or sub-folders.
Args:
@@ -140,7 +145,7 @@ def compute_patch_info(
patch_num_h: int,
patch_num_w: int,
overlap: float = 0.0,
-) -> Tuple[Tuple[int, int], Tuple[int, int]]:
+) -> tuple[tuple[int, int], tuple[int, int]]:
"""Compute the patch size and step size given the number of patches and the overlap.
Args:
@@ -207,7 +212,7 @@ def compute_patch_info_from_patch_dim(
patch_height: int,
patch_width: int,
overlap: float = 0.0,
-) -> Tuple[Tuple[int, int], Tuple[int, int]]:
+) -> tuple[tuple[int, int], tuple[int, int]]:
"""Compute patch info given the patch dimension
Args:
img_h: height of the image
@@ -237,7 +242,7 @@ def compute_patch_info_from_patch_dim(
return (patch_num_h, patch_num_w), (step_h, step_w)
-def from_rgb_to_idx(img: np.ndarray, class_to_color: Dict, class_to_idx: Dict) -> np.ndarray:
+def from_rgb_to_idx(img: np.ndarray, class_to_color: dict, class_to_idx: dict) -> np.ndarray:
"""Args:
img: Rgb mask in which each different color is associated with a class
class_to_color: Dict "key": [R, G, B]
@@ -259,17 +264,17 @@ def from_rgb_to_idx(img: np.ndarray, class_to_color: Dict, class_to_idx: Dict) -
def __save_patch_dataset(
image_patches: np.ndarray,
- labelled_patches: Optional[np.ndarray] = None,
- mask_patches: Optional[np.ndarray] = None,
- labelled_mask: Optional[np.ndarray] = None,
+ labelled_patches: np.ndarray | None = None,
+ mask_patches: np.ndarray | None = None,
+ labelled_mask: np.ndarray | None = None,
output_folder: str = "extraction_data",
image_name: str = "example",
area_threshold: float = 0.45,
area_defect_threshold: float = 0.2,
mask_extension: str = "_mask",
save_mask: bool = False,
- mask_output_folder: Optional[str] = None,
- class_to_idx: Optional[Dict] = None,
+ mask_output_folder: str | None = None,
+ class_to_idx: dict | None = None,
) -> None:
"""Given a view_as_window computed patches, masks and labelled mask, save all the images in subdirectory
divided by name and position in the grid, ambiguous patches i.e. the one that contains defects but with not enough
@@ -305,11 +310,10 @@ def __save_patch_dataset(
missing_classes = set(classes_in_mask).difference(class_to_idx.values())
assert len(missing_classes) == 0, f"Found index in mask that has no corresponding class {missing_classes}"
+ elif mask_patches is not None:
+ reference_classes = {k: str(v) for k, v in enumerate(list(np.unique(mask_patches)))}
else:
- if mask_patches is not None:
- reference_classes = {k: str(v) for k, v in enumerate(list(np.unique(mask_patches)))}
- else:
- raise ValueError("If no `class_to_idx` is provided, `mask_patches` must be provided")
+ raise ValueError("If no `class_to_idx` is provided, `mask_patches` must be provided")
log.debug("Classes from mask: %s", reference_classes)
class_to_idx = {v: k for k, v in reference_classes.items()}
@@ -412,29 +416,29 @@ def __save_patch_dataset(
def generate_patch_dataset(
- data_dictionary: List[Dict],
- class_to_idx: Dict,
+ data_dictionary: list[dict],
+ class_to_idx: dict,
val_size: float = 0.3,
test_size: float = 0.0,
seed: int = 42,
- patch_number: Optional[Tuple[int, int]] = None,
- patch_size: Optional[Tuple[int, int]] = None,
+ patch_number: tuple[int, int] | None = None,
+ patch_size: tuple[int, int] | None = None,
overlap: float = 0.0,
output_folder: str = "extraction_data",
save_original_images_and_masks: bool = True,
area_threshold: float = 0.45,
area_defect_threshold: float = 0.2,
mask_extension: str = "_mask",
- mask_output_folder: Optional[str] = None,
+ mask_output_folder: str | None = None,
save_mask: bool = False,
clear_output_folder: bool = False,
- mask_preprocessing: Optional[Callable] = None,
+ mask_preprocessing: Callable | None = None,
train_filename: str = "dataset.txt",
repeat_good_images: int = 1,
balance_defects: bool = True,
- annotated_good: Optional[List[str]] = None,
+ annotated_good: list[str] | None = None,
num_workers: int = 1,
-) -> Optional[Dict]:
+) -> dict | None:
"""Giving a data_dictionary as:
>>> {
>>> 'base_name': '163931_1_5.jpg',
@@ -486,14 +490,13 @@ def generate_patch_dataset(
"""
if len(data_dictionary) == 0:
- warnings.warn("Input data dictionary is empty!", UserWarning)
+ warnings.warn("Input data dictionary is empty!", UserWarning, stacklevel=2)
return None
if val_size < 0 or test_size < 0 or (val_size + test_size) > 1:
raise ValueError("Validation and Test size must be greater or equal than zero and sum up to maximum 1")
- if clear_output_folder:
- if os.path.exists(output_folder):
- shutil.rmtree(output_folder)
+ if clear_output_folder and os.path.exists(output_folder):
+ shutil.rmtree(output_folder)
os.makedirs(output_folder, exist_ok=True)
os.makedirs(os.path.join(output_folder, "original"), exist_ok=True)
if save_original_images_and_masks:
@@ -588,11 +591,11 @@ def generate_patch_dataset(
def multilabel_stratification(
output_folder: str,
- data_dictionary: List[Dict],
+ data_dictionary: list[dict],
num_classes: int,
val_size: float,
test_size: float,
-) -> Tuple[List[Dict], List[Dict], List[Dict]]:
+) -> tuple[list[dict], list[dict], list[dict]]:
"""Split data dictionary using multilabel based stratification, place every sample with None
mask inside the test set,for all the others read the labels contained in the masks
to create one-hot encoded labels.
@@ -622,7 +625,9 @@ def multilabel_stratification(
if len(data_dictionary) == 0:
# All the item in the data dictionary have None mask, put everything in test
warnings.warn(
- "All the images have None mask and the test size is not equal to 1! Put everything in test", UserWarning
+ "All the images have None mask and the test size is not equal to 1! Put everything in test",
+ UserWarning,
+ stacklevel=2,
)
return [], [], test_data_dictionary
@@ -644,7 +649,7 @@ def multilabel_stratification(
else:
y = np.concatenate([y, one_hot])
- x_test: Union[List[Any], np.ndarray]
+ x_test: list[Any] | np.ndarray
if empty_test_size > test_size:
warnings.warn(
@@ -653,6 +658,7 @@ def multilabel_stratification(
f" {empty_test_size}!"
),
UserWarning,
+ stacklevel=2,
)
x_train, _, x_val, _ = iterative_train_test_split(np.expand_dims(np.array(x), 1), y, val_size)
x_test = [q["base_name"] for q in test_data_dictionary]
@@ -689,18 +695,18 @@ def multilabel_stratification(
def generate_patch_sliding_window_dataset(
- data_dictionary: List[Dict],
+ data_dictionary: list[dict],
subfolder_name: str,
- patch_number: Optional[Tuple[int, int]] = None,
- patch_size: Optional[Tuple[int, int]] = None,
+ patch_number: tuple[int, int] | None = None,
+ patch_size: tuple[int, int] | None = None,
overlap: float = 0.0,
output_folder: str = "extraction_data",
area_threshold: float = 0.45,
area_defect_threshold: float = 0.2,
mask_extension: str = "_mask",
- mask_output_folder: Optional[str] = None,
+ mask_output_folder: str | None = None,
save_mask: bool = False,
- class_to_idx: Optional[Dict] = None,
+ class_to_idx: dict | None = None,
) -> None:
"""Giving a data_dictionary as:
>>> {
@@ -800,9 +806,9 @@ def generate_patch_sliding_window_dataset(
def extract_patches(
image: np.ndarray,
- patch_number: Tuple[int, ...],
- patch_size: Tuple[int, ...],
- step: Tuple[int, ...],
+ patch_number: tuple[int, ...],
+ patch_size: tuple[int, ...],
+ step: tuple[int, ...],
overlap: float,
) -> np.ndarray:
"""From an image extract N x M Patch[h, w] if the image is not perfectly divided by the number of patches of given
@@ -910,17 +916,17 @@ def extract_patches(
def generate_patch_sampling_dataset(
- data_dictionary: List[Dict[Any, Any]],
+ data_dictionary: list[dict[Any, Any]],
output_folder: str,
- idx_to_class: Dict,
+ idx_to_class: dict,
overlap: float,
repeat_good_images: int = 1,
balance_defects: bool = True,
- patch_number: Optional[Tuple[int, int]] = None,
- patch_size: Optional[Tuple[int, int]] = None,
+ patch_number: tuple[int, int] | None = None,
+ patch_size: tuple[int, int] | None = None,
subfolder_name: str = "train",
train_filename: str = "dataset.txt",
- annotated_good: Optional[List[int]] = None,
+ annotated_good: list[int] | None = None,
num_workers: int = 1,
) -> None:
"""Generate a dataset of patches.
@@ -1009,18 +1015,18 @@ def generate_patch_sampling_dataset(
def create_h5(
- data_dictionary: List[Dict[Any, Any]],
- idx_to_class: Dict,
+ data_dictionary: list[dict[Any, Any]],
+ idx_to_class: dict,
overlap: float,
repeat_good_images: int,
balance_defects: bool,
output_folder: str,
labelled_masks_path: str,
sampling_dataset_folder: str,
- annotated_good: Optional[List[int]] = None,
- patch_size: Optional[Tuple[int, int]] = None,
- patch_number: Optional[Tuple[int, int]] = None,
-) -> List[str]:
+ annotated_good: list[int] | None = None,
+ patch_size: tuple[int, int] | None = None,
+ patch_number: tuple[int, int] | None = None,
+) -> list[str]:
"""Create h5 files for each image in the dataset.
Args:
@@ -1057,7 +1063,7 @@ def create_h5(
mask = np.zeros([h, w])
else:
# this works even if item["mask"] is already an absolute path
- mask = cv2.imread(os.path.join(output_folder, item["mask"]), 0)
+ mask = cv2.imread(os.path.join(output_folder, item["mask"]), 0) # type: ignore[assignment]
if patch_size is not None:
patch_height = patch_size[1]
@@ -1065,7 +1071,11 @@ def create_h5(
else:
# Mypy complains because patch_number is Optional, but we already checked that it is not None.
[patch_height, patch_width], _ = compute_patch_info(
- h, w, patch_number[0], patch_number[1], overlap # type: ignore[index]
+ h,
+ w,
+ patch_number[0], # type: ignore[index]
+ patch_number[1], # type: ignore[index]
+ overlap,
)
h5_file_name_good = os.path.join(sampling_dataset_folder, f"{os.path.splitext(item['base_name'])[0]}_good.h5")
@@ -1082,7 +1092,7 @@ def create_h5(
f.create_dataset("triangles", data=np.array([], dtype=np.uint8), dtype=np.uint8)
f.create_dataset("triangles_weights", data=np.array([], dtype=np.uint8), dtype=np.uint8)
- for i in range(repeat_good_images):
+ for _ in range(repeat_good_images):
output_list.append(f"{os.path.basename(h5_file_name_good)},{target}\n")
continue
@@ -1254,7 +1264,7 @@ def triangle_area(triangle: np.ndarray) -> float:
return abs(0.5 * (((x2 - x1) * (y3 - y1)) - ((x3 - x1) * (y2 - y1))))
-def triangulate_region(mask: ndimage) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
+def triangulate_region(mask: ndimage) -> tuple[np.ndarray | None, np.ndarray | None]:
"""Extract from a binary image containing a single roi (with or without holes) a list of triangles
(and their normalized area) that completely subdivide an approximated polygon defined around mask contours,
the output can be used to easily sample uniformly points that are almost guarantee to lie inside the roi.
@@ -1302,10 +1312,7 @@ def triangulate_region(mask: ndimage) -> Tuple[Optional[np.ndarray], Optional[np
current_triangles = np.array([list(x) for x in current_triangles])
- if triangles is None:
- triangles = current_triangles
- else:
- triangles = np.concatenate([triangles, current_triangles])
+ triangles = current_triangles if triangles is None else np.concatenate([triangles, current_triangles])
if triangles is None:
return None, None
@@ -1329,10 +1336,10 @@ class InvalidNumWorkersNumberException(Exception):
def load_train_file(
train_file_path: str,
- include_filter: Optional[List[str]] = None,
- exclude_filter: Optional[List[str]] = None,
- class_to_skip: Optional[list] = None,
-) -> Tuple[List[str], List[str]]:
+ include_filter: list[str] | None = None,
+ exclude_filter: list[str] | None = None,
+ class_to_skip: list | None = None,
+) -> tuple[list[str], list[str]]:
"""Load a train file and return a list of samples and a list of targets. It is expected that train files will be in
the same location as the train_file_path.
@@ -1377,7 +1384,7 @@ def load_train_file(
return samples, targets
-def compute_safe_patch_range(sampled_point: int, patch_size: int, image_size: int) -> Tuple[int, int]:
+def compute_safe_patch_range(sampled_point: int, patch_size: int, image_size: int) -> tuple[int, int]:
"""Computes the safe patch size for the given image size.
Args:
@@ -1403,7 +1410,7 @@ def compute_safe_patch_range(sampled_point: int, patch_size: int, image_size: in
return left, right
-def trisample(triangle: np.ndarray) -> Tuple[int, int]:
+def trisample(triangle: np.ndarray) -> tuple[int, int]:
"""Sample a point uniformly in a triangle.
Args:
diff --git a/quadra/utils/patch/metrics.py b/quadra/utils/patch/metrics.py
index 2ff7a544..3c71889f 100644
--- a/quadra/utils/patch/metrics.py
+++ b/quadra/utils/patch/metrics.py
@@ -1,12 +1,13 @@
+from __future__ import annotations
+
import os
import warnings
-from typing import Dict, List, Optional, Tuple
import cv2
import numpy as np
import pandas as pd
from scipy import ndimage
-from skimage.measure import label, regionprops
+from skimage.measure import label, regionprops # pylint: disable=no-name-in-module
from tqdm import tqdm
from quadra.utils import utils
@@ -36,18 +37,18 @@ def get_sorted_patches_by_image(test_results: pd.DataFrame, img_name: str) -> pd
def compute_patch_metrics(
- test_img_info: List[PatchDatasetFileFormat],
+ test_img_info: list[PatchDatasetFileFormat],
test_results: pd.DataFrame,
overlap: float,
- idx_to_class: Dict,
- patch_num_h: Optional[int] = None,
- patch_num_w: Optional[int] = None,
- patch_w: Optional[int] = None,
- patch_h: Optional[int] = None,
+ idx_to_class: dict,
+ patch_num_h: int | None = None,
+ patch_num_w: int | None = None,
+ patch_w: int | None = None,
+ patch_h: int | None = None,
return_polygon: bool = False,
patch_reconstruction_method: str = "priority",
- annotated_good: Optional[List[int]] = None,
-) -> Tuple[int, int, int, List[Dict]]:
+ annotated_good: list[int] | None = None,
+) -> tuple[int, int, int, list[dict]]:
"""Compute the metrics of a patch dataset.
Args:
@@ -103,7 +104,9 @@ def compute_patch_metrics(
if patch_h is not None and patch_w is not None and patch_num_h is not None and patch_num_w is not None:
warnings.warn(
- "Both number of patches and patch dimension are specified, using number of patches by default", UserWarning
+ "Both number of patches and patch dimension are specified, using number of patches by default",
+ UserWarning,
+ stacklevel=2,
)
log.info("Computing patch metrics!")
@@ -188,7 +191,7 @@ def compute_patch_metrics(
if annotated_good is not None:
gt_img[np.isin(gt_img, annotated_good)] = 0
- gt_img_binary = (gt_img > 0).astype(bool)
+ gt_img_binary = (gt_img > 0).astype(bool) # type: ignore[operator]
regions_pred = label(output_mask).astype(np.uint8)
for k in range(1, regions_pred.max() + 1):
@@ -200,8 +203,8 @@ def compute_patch_metrics(
output_mask = (output_mask > 0).astype(np.uint8)
gt_img = label(gt_img)
- for i in range(1, gt_img.max() + 1):
- region = (gt_img == i).astype(bool)
+ for i in range(1, gt_img.max() + 1): # type: ignore[union-attr]
+ region = (gt_img == i).astype(bool) # type: ignore[union-attr]
if np.sum(np.bitwise_and(region, output_mask)) == 0:
false_region_good += 1
else:
@@ -211,16 +214,16 @@ def compute_patch_metrics(
def reconstruct_patch(
- input_img_shape: Tuple[int, ...],
- patch_size: Tuple[int, int],
+ input_img_shape: tuple[int, ...],
+ patch_size: tuple[int, int],
pred: np.ndarray,
patch_num_h: int,
patch_num_w: int,
- idx_to_class: Dict,
- step: Tuple[int, int],
+ idx_to_class: dict,
+ step: tuple[int, int],
return_polygon: bool = True,
method: str = "priority",
-) -> Tuple[np.ndarray, List[Dict]]:
+) -> tuple[np.ndarray, list[dict]]:
"""Reconstructs the prediction image from the patches.
Args:
@@ -269,15 +272,15 @@ def reconstruct_patch(
def _reconstruct_patch_priority(
- input_img_shape: Tuple[int, ...],
- patch_size: Tuple[int, int],
+ input_img_shape: tuple[int, ...],
+ patch_size: tuple[int, int],
pred: np.ndarray,
patch_num_h: int,
patch_num_w: int,
- idx_to_class: Dict,
- step: Tuple[int, int],
+ idx_to_class: dict,
+ step: tuple[int, int],
return_polygon: bool = True,
-) -> Tuple[np.ndarray, List[Dict]]:
+) -> tuple[np.ndarray, list[dict]]:
"""Reconstruct patch polygons using the priority method."""
final_mask = np.zeros([input_img_shape[0], input_img_shape[1]], dtype=np.uint8)
predicted_defect = []
@@ -333,13 +336,13 @@ def _reconstruct_patch_priority(
def _reconstruct_patch_major_voting(
- input_img_shape: Tuple[int, ...],
- patch_size: Tuple[int, int],
+ input_img_shape: tuple[int, ...],
+ patch_size: tuple[int, int],
pred: np.ndarray,
patch_num_h: int,
patch_num_w: int,
- idx_to_class: Dict,
- step: Tuple[int, int],
+ idx_to_class: dict,
+ step: tuple[int, int],
return_polygon: bool = True,
):
"""Reconstruct patch polygons using the major voting method."""
diff --git a/quadra/utils/patch/model.py b/quadra/utils/patch/model.py
index 0420fa14..62d304d7 100644
--- a/quadra/utils/patch/model.py
+++ b/quadra/utils/patch/model.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
import json
import os
-from typing import Any, Dict, List, Optional
+from typing import Any
import matplotlib.pyplot as plt
import numpy as np
@@ -20,13 +22,13 @@
def save_classification_result(
results: pd.DataFrame,
output_folder: str,
- confusion_matrix: Optional[pd.DataFrame],
+ confusion_matrix: pd.DataFrame | None,
accuracy: float,
test_dataloader: DataLoader,
- reconstructions: List[Dict],
+ reconstructions: list[dict],
config: DictConfig,
output: DictConfig,
- ignore_classes: Optional[List[int]] = None,
+ ignore_classes: list[int] | None = None,
):
"""Save classification results.
@@ -120,9 +122,8 @@ def save_classification_result(
if is_polygon:
if len(reconstruction["prediction"]) == 0:
continue
- else:
- if reconstruction["prediction"].sum() == 0:
- continue
+ elif reconstruction["prediction"].sum() == 0:
+ continue
if counter > 5:
break
diff --git a/quadra/utils/patch/visualization.py b/quadra/utils/patch/visualization.py
index fc9f7e2f..df21f78a 100644
--- a/quadra/utils/patch/visualization.py
+++ b/quadra/utils/patch/visualization.py
@@ -1,5 +1,6 @@
+from __future__ import annotations
+
import os
-from typing import Dict, List, Optional
import cv2
import matplotlib.pyplot as plt
@@ -15,10 +16,10 @@
def plot_patch_reconstruction(
- reconstruction: Dict,
- idx_to_class: Dict[int, str],
- class_to_idx: Dict[str, int],
- ignore_classes: Optional[List[int]] = None,
+ reconstruction: dict,
+ idx_to_class: dict[int, str],
+ class_to_idx: dict[str, int],
+ ignore_classes: list[int] | None = None,
is_polygon: bool = True,
) -> Figure:
"""Helper function for plotting the patch reconstruction.
@@ -74,7 +75,7 @@ def plot_patch_reconstruction(
-1,
class_to_idx[c_label],
thickness=cv2.FILLED,
- )
+ ) # type: ignore[call-overload]
else:
out = reconstruction["prediction"]
@@ -103,9 +104,9 @@ def show_mask_on_image(image: np.ndarray, mask: np.ndarray):
def create_rgb_mask(
mask: np.ndarray,
- color_map: Dict,
- ignore_classes: Optional[List[int]] = None,
- ground_truth_mask: Optional[np.ndarray] = None,
+ color_map: dict,
+ ignore_classes: list[int] | None = None,
+ ground_truth_mask: np.ndarray | None = None,
):
"""Convert index mask to RGB mask."""
output_mask = np.zeros([mask.shape[0], mask.shape[1], 3])
@@ -123,15 +124,16 @@ def create_rgb_mask(
def plot_patch_results(
image: np.ndarray,
prediction_image: np.ndarray,
- ground_truth_image: Optional[np.ndarray],
- class_to_idx: Dict[str, int],
+ ground_truth_image: np.ndarray | None,
+ class_to_idx: dict[str, int],
plot_original: bool = True,
- ignore_classes: Optional[List[int]] = None,
+ ignore_classes: list[int] | None = None,
image_height: int = 10,
- save_path: Optional[str] = None,
- cmap: Colormap = get_cmap("tab20"),
+ save_path: str | None = None,
+ cmap: Colormap | None = None,
) -> Figure:
- """Function used to plot the image predicted
+ """Function used to plot the image predicted.
+
Args:
prediction_image: The prediction image
image: The original image to plot
@@ -141,7 +143,7 @@ def plot_patch_results(
ignore_classes: The classes to ignore, default is 0
image_height: The height of the output figure
save_path: The path to save the figure
- cmap: The colormap to use.
+ cmap: The colormap to use. If None, tab20 is used
Returns:
The matplotlib figure
@@ -149,6 +151,9 @@ def plot_patch_results(
if ignore_classes is None:
ignore_classes = [0]
+ if cmap is None:
+ cmap = get_cmap("tab20")
+
image = image[0 : prediction_image.shape[0], 0 : prediction_image.shape[1], :]
idx_to_class = {v: k for k, v in class_to_idx.items()}
diff --git a/quadra/utils/resolver.py b/quadra/utils/resolver.py
index dde43bb1..a4e56a50 100644
--- a/quadra/utils/resolver.py
+++ b/quadra/utils/resolver.py
@@ -1,4 +1,6 @@
-from typing import Any, Tuple
+from __future__ import annotations
+
+from typing import Any
from hydra.core.hydra_config import HydraConfig
from omegaconf import OmegaConf
@@ -29,7 +31,7 @@ def multirun_subdir_beautify(subdir: str) -> str:
return subdir
-def as_tuple(*args: Any) -> Tuple[Any, ...]:
+def as_tuple(*args: Any) -> tuple[Any, ...]:
"""Resolves a list of arguments to a tuple."""
return tuple(args)
diff --git a/quadra/utils/tests/fixtures/__init__.py b/quadra/utils/tests/fixtures/__init__.py
index 7834127f..49805902 100644
--- a/quadra/utils/tests/fixtures/__init__.py
+++ b/quadra/utils/tests/fixtures/__init__.py
@@ -1 +1 @@
-from .dataset import *
+from .dataset import * # noqa: F403
diff --git a/quadra/utils/tests/fixtures/dataset/anomaly.py b/quadra/utils/tests/fixtures/dataset/anomaly.py
index 3a3e9546..41fde91b 100644
--- a/quadra/utils/tests/fixtures/dataset/anomaly.py
+++ b/quadra/utils/tests/fixtures/dataset/anomaly.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
import shutil
from dataclasses import dataclass
from pathlib import Path
-from typing import Any, Tuple
+from typing import Any
import cv2
import pytest
@@ -20,13 +22,13 @@ class AnomalyDatasetArguments:
"""
train_samples: int
- val_samples: Tuple[int, int]
- test_samples: Tuple[int, int]
+ val_samples: tuple[int, int]
+ test_samples: tuple[int, int]
def _build_anomaly_dataset(
tmp_path: Path, dataset_arguments: AnomalyDatasetArguments
-) -> Tuple[str, AnomalyDatasetArguments]:
+) -> tuple[str, AnomalyDatasetArguments]:
"""Generate anomaly dataset in the standard mvtec format.
Args:
@@ -86,7 +88,7 @@ def _build_anomaly_dataset(
@pytest.fixture
-def anomaly_dataset(tmp_path: Path, dataset_arguments: AnomalyDatasetArguments) -> Tuple[str, AnomalyDatasetArguments]:
+def anomaly_dataset(tmp_path: Path, dataset_arguments: AnomalyDatasetArguments) -> tuple[str, AnomalyDatasetArguments]:
"""Fixture used to dinamically generate anomaly dataset. By default images are random grayscales with size 10x10.
Args:
@@ -104,7 +106,7 @@ def anomaly_dataset(tmp_path: Path, dataset_arguments: AnomalyDatasetArguments)
@pytest.fixture(
params=[AnomalyDatasetArguments(**{"train_samples": 10, "val_samples": (1, 1), "test_samples": (1, 1)})]
)
-def base_anomaly_dataset(tmp_path: Path, request: Any) -> Tuple[str, AnomalyDatasetArguments]:
+def base_anomaly_dataset(tmp_path: Path, request: Any) -> tuple[str, AnomalyDatasetArguments]:
"""Generate base anomaly dataset with the following parameters:
- train_samples: 10
- val_samples: (10, 10)
diff --git a/quadra/utils/tests/fixtures/dataset/classification.py b/quadra/utils/tests/fixtures/dataset/classification.py
index f643bd62..059c07f1 100644
--- a/quadra/utils/tests/fixtures/dataset/classification.py
+++ b/quadra/utils/tests/fixtures/dataset/classification.py
@@ -1,10 +1,12 @@
+from __future__ import annotations
+
import glob
import os
import random
import shutil
from dataclasses import dataclass
from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any
import cv2
import numpy as np
@@ -25,10 +27,10 @@ class ClassificationDatasetArguments:
test_size: test set size
"""
- samples: List[int]
- classes: Optional[List[str]] = None
- val_size: Optional[float] = None
- test_size: Optional[float] = None
+ samples: list[int]
+ classes: list[str] | None = None
+ val_size: float | None = None
+ test_size: float | None = None
@dataclass
@@ -43,11 +45,11 @@ class ClassificationMultilabelDatasetArguments:
percentage_other_classes: probability of adding other classes to the labels of each sample
"""
- samples: List[int]
- classes: Optional[List[str]] = None
- val_size: Optional[float] = None
- test_size: Optional[float] = None
- percentage_other_classes: Optional[float] = 0.0
+ samples: list[int]
+ classes: list[str] | None = None
+ val_size: float | None = None
+ test_size: float | None = None
+ percentage_other_classes: float | None = 0.0
@dataclass
@@ -65,19 +67,19 @@ class ClassificationPatchDatasetArguments:
annotated_good: list of class names that are considered as good annotations (E.g. ["good"])
"""
- samples: List[int]
+ samples: list[int]
overlap: float
- patch_size: Optional[Tuple[int, int]] = None
- patch_number: Optional[Tuple[int, int]] = None
- classes: Optional[List[str]] = None
- val_size: Optional[float] = 0.0
- test_size: Optional[float] = 0.0
- annotated_good: Optional[List[str]] = None
+ patch_size: tuple[int, int] | None = None
+ patch_number: tuple[int, int] | None = None
+ classes: list[str] | None = None
+ val_size: float | None = 0.0
+ test_size: float | None = 0.0
+ annotated_good: list[str] | None = None
def _build_classification_dataset(
tmp_path: Path, dataset_arguments: ClassificationDatasetArguments
-) -> Tuple[str, ClassificationDatasetArguments]:
+) -> tuple[str, ClassificationDatasetArguments]:
"""Generate classification dataset. If val_size or test_size are set, it will generate a train.txt, val.txt and
test.txt file in the dataset directory. By default generated images are 10x10 pixels.
@@ -129,7 +131,7 @@ def _build_classification_dataset(
@pytest.fixture
def classification_dataset(
tmp_path: Path, dataset_arguments: ClassificationDatasetArguments
-) -> Tuple[str, ClassificationDatasetArguments]:
+) -> tuple[str, ClassificationDatasetArguments]:
"""Generate classification dataset. If val_size or test_size are set, it will generate a train.txt, val.txt and
test.txt file in the dataset directory. By default generated images are 10x10 pixels.
@@ -152,7 +154,7 @@ def classification_dataset(
)
]
)
-def base_classification_dataset(tmp_path: Path, request: Any) -> Tuple[str, ClassificationDatasetArguments]:
+def base_classification_dataset(tmp_path: Path, request: Any) -> tuple[str, ClassificationDatasetArguments]:
"""Generate base classification dataset with the following parameters:
- 10 samples per class
- 2 classes (class_1 and class_2)
@@ -172,7 +174,7 @@ def base_classification_dataset(tmp_path: Path, request: Any) -> Tuple[str, Clas
def _build_multilabel_classification_dataset(
tmp_path: Path, dataset_arguments: ClassificationMultilabelDatasetArguments
-) -> Tuple[str, ClassificationMultilabelDatasetArguments]:
+) -> tuple[str, ClassificationMultilabelDatasetArguments]:
"""Generate a multilabel classification dataset.
Generates a samples.txt file in the dataset directory containing the path to the image and the corresponding
classes. If val_size or test_size are set, it will generate a train.txt, val.txt and test.txt file in the
@@ -236,7 +238,7 @@ def _build_multilabel_classification_dataset(
@pytest.fixture
def multilabel_classification_dataset(
tmp_path: Path, dataset_arguments: ClassificationMultilabelDatasetArguments
-) -> Tuple[str, ClassificationMultilabelDatasetArguments]:
+) -> tuple[str, ClassificationMultilabelDatasetArguments]:
"""Fixture to dinamically generate a multilabel classification dataset.
Generates a samples.txt file in the dataset directory containing the path to the image and the corresponding
classes. If val_size or test_size are set, it will generate a train.txt, val.txt and test.txt file in the
@@ -269,7 +271,7 @@ def multilabel_classification_dataset(
)
def base_multilabel_classification_dataset(
tmp_path: Path, request: Any
-) -> Tuple[str, ClassificationMultilabelDatasetArguments]:
+) -> tuple[str, ClassificationMultilabelDatasetArguments]:
"""Fixture to generate base multilabel classification dataset with the following parameters:
- 10 samples per class
- 3 classes (class_1, class_2 and class_3)
@@ -292,7 +294,7 @@ def base_multilabel_classification_dataset(
def _build_classification_patch_dataset(
tmp_path: Path, dataset_arguments: ClassificationDatasetArguments
-) -> Tuple[str, ClassificationDatasetArguments, Dict[str, int]]:
+) -> tuple[str, ClassificationDatasetArguments, dict[str, int]]:
"""Generate a classification patch dataset. By default generated images are 224x224 pixels
and associated masks contains a 50x50 pixels square with the corresponding image class, so at the current stage
is not possible to have images with multiple annotations. The patch dataset will be generated using the standard
@@ -350,7 +352,7 @@ def _build_classification_patch_dataset(
@pytest.fixture
def classification_patch_dataset(
tmp_path: Path, dataset_arguments: ClassificationDatasetArguments
-) -> Tuple[str, ClassificationDatasetArguments, Dict[str, int]]:
+) -> tuple[str, ClassificationDatasetArguments, dict[str, int]]:
"""Fixture to dinamically generate a classification patch dataset.
By default generated images are 224x224 pixels
@@ -386,7 +388,7 @@ def classification_patch_dataset(
)
def base_patch_classification_dataset(
tmp_path: Path, request: Any
-) -> Tuple[str, ClassificationDatasetArguments, Dict[str, int]]:
+) -> tuple[str, ClassificationDatasetArguments, dict[str, int]]:
"""Generate a classification patch dataset with the following parameters:
- 3 classes named bg, a and b
- 5, 5 and 5 samples for each class
diff --git a/quadra/utils/tests/fixtures/dataset/imagenette.py b/quadra/utils/tests/fixtures/dataset/imagenette.py
index 86683aed..3072c88c 100644
--- a/quadra/utils/tests/fixtures/dataset/imagenette.py
+++ b/quadra/utils/tests/fixtures/dataset/imagenette.py
@@ -39,7 +39,7 @@ def _build_imagenette_dataset(tmp_path: Path, classes: int, class_samples: int)
@pytest.fixture
def imagenette_dataset(tmp_path: Path) -> str:
- """Generate a mock imagenette dataset to test efficient_ad model
+ """Generate a mock imagenette dataset to test efficient_ad model.
Args:
tmp_path: Path to temporary directory
diff --git a/quadra/utils/tests/fixtures/dataset/segmentation.py b/quadra/utils/tests/fixtures/dataset/segmentation.py
index d4731121..01fb5b00 100644
--- a/quadra/utils/tests/fixtures/dataset/segmentation.py
+++ b/quadra/utils/tests/fixtures/dataset/segmentation.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
import shutil
from dataclasses import dataclass
from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any
import cv2
import numpy as np
@@ -21,15 +23,15 @@ class SegmentationDatasetArguments:
classes: Optional list of class names, must be equal to len(train_samples) - 1
"""
- train_samples: List[int]
- val_samples: Optional[List[int]] = None
- test_samples: Optional[List[int]] = None
- classes: Optional[List[str]] = None
+ train_samples: list[int]
+ val_samples: list[int] | None = None
+ test_samples: list[int] | None = None
+ classes: list[str] | None = None
def _build_segmentation_dataset(
tmp_path: Path, dataset_arguments: SegmentationDatasetArguments
-) -> Tuple[str, SegmentationDatasetArguments, Dict[str, int]]:
+) -> tuple[str, SegmentationDatasetArguments, dict[str, int]]:
"""Generate segmentation dataset.
Args:
@@ -80,7 +82,7 @@ def _build_segmentation_dataset(
@pytest.fixture
def segmentation_dataset(
tmp_path: Path, dataset_arguments: SegmentationDatasetArguments
-) -> Tuple[str, SegmentationDatasetArguments, Dict[str, int]]:
+) -> tuple[str, SegmentationDatasetArguments, dict[str, int]]:
"""Fixture to dinamically generate a segmentation dataset. By default generated images are 224x224 pixels
and associated masks contains a 50x50 pixels square with the corresponding image class, so at the current stage
is not possible to have images with multiple annotations. Split files are saved as train.txt,
@@ -107,7 +109,7 @@ def segmentation_dataset(
)
def base_binary_segmentation_dataset(
tmp_path: Path, request: Any
-) -> Tuple[str, SegmentationDatasetArguments, Dict[str, int]]:
+) -> tuple[str, SegmentationDatasetArguments, dict[str, int]]:
"""Generate a base binary segmentation dataset with the following structure:
- 3 good and 2 bad samples in train set
- 2 good and 2 bad samples in validation set
@@ -140,7 +142,7 @@ def base_binary_segmentation_dataset(
)
def base_multiclass_segmentation_dataset(
tmp_path: Path, request: Any
-) -> Tuple[str, SegmentationDatasetArguments, Dict[str, int]]:
+) -> tuple[str, SegmentationDatasetArguments, dict[str, int]]:
"""Generate a base binary segmentation dataset with the following structure:
- 2 good, 2 defect_1 and 2 defect_2 samples in train set
- 2 good, 2 defect_1 and 2 defect_2 samples in validation set
diff --git a/quadra/utils/tests/fixtures/models/__init__.py b/quadra/utils/tests/fixtures/models/__init__.py
index 38502b2e..866eec92 100644
--- a/quadra/utils/tests/fixtures/models/__init__.py
+++ b/quadra/utils/tests/fixtures/models/__init__.py
@@ -1,3 +1,3 @@
-from .anomaly import *
-from .classification import *
-from .segmentation import *
+from .anomaly import * # noqa: F403
+from .classification import * # noqa: F403
+from .segmentation import * # noqa: F403
diff --git a/quadra/utils/tests/helpers.py b/quadra/utils/tests/helpers.py
index 40bb2cc8..284d93c5 100644
--- a/quadra/utils/tests/helpers.py
+++ b/quadra/utils/tests/helpers.py
@@ -1,6 +1,7 @@
+from __future__ import annotations
+
import os
from pathlib import Path
-from typing import List, Tuple
import numpy as np
import torch
@@ -12,12 +13,12 @@
# taken from hydra unit tests
-def _random_image(size: Tuple[int, int] = (10, 10)) -> np.ndarray:
+def _random_image(size: tuple[int, int] = (10, 10)) -> np.ndarray:
"""Generate random image."""
return np.random.randint(0, 255, size=size, dtype=np.uint8)
-def execute_quadra_experiment(overrides: List[str], experiment_path: Path) -> None:
+def execute_quadra_experiment(overrides: list[str], experiment_path: Path) -> None:
"""Execute quadra experiment."""
with initialize_config_module(config_module="quadra.configs", version_base="1.3.0"):
if not experiment_path.exists():
@@ -49,7 +50,7 @@ def get_quadra_test_device():
return os.environ.get("QUADRA_TEST_DEVICE", "cpu")
-def setup_trainer_for_lightning() -> List[str]:
+def setup_trainer_for_lightning() -> list[str]:
"""Setup trainer for lightning depending on the device. If cuda is used, the device index is also set.
If cpu is used, the trainer is set to lightning_cpu.
diff --git a/quadra/utils/utils.py b/quadra/utils/utils.py
index e33315d7..ebee7812 100644
--- a/quadra/utils/utils.py
+++ b/quadra/utils/utils.py
@@ -2,6 +2,8 @@
Some of them are mostly based on https://github.com/ashleve/lightning-hydra-template.
"""
+from __future__ import annotations
+
import glob
import json
import logging
@@ -9,7 +11,8 @@
import subprocess
import sys
import warnings
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, cast
+from collections.abc import Iterable, Iterator, Sequence
+from typing import Any, cast
import cv2
import dotenv
@@ -39,7 +42,7 @@
ONNX_AVAILABLE = False
-IMAGE_EXTENSIONS: List[str] = [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".pbm", ".pgm", ".ppm", ".pxm", ".pnm"]
+IMAGE_EXTENSIONS: list[str] = [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".pbm", ".pgm", ".ppm", ".pxm", ".pnm"]
def get_logger(name=__name__) -> logging.Logger:
@@ -183,28 +186,23 @@ def log_hyperparameters(
hparams["command"] = config.core.command
hparams["library/version"] = str(quadra.__version__)
- # pylint: disable=consider-using-with
- if (
- subprocess.call(
- ["git", "-C", get_original_cwd(), "status"], stderr=subprocess.STDOUT, stdout=open(os.devnull, "w")
- )
- == 0
- ):
- try:
- hparams["git/commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
- hparams["git/branch"] = (
- subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("ascii").strip()
- )
- hparams["git/remote"] = (
- subprocess.check_output(["git", "remote", "get-url", "origin"]).decode("ascii").strip()
- )
- except subprocess.CalledProcessError:
- log.warning(
- "Could not get git commit, branch or remote information, the repository might not have any commits yet "
- "or it might be initialized wrongly."
- )
- else:
- log.warning("Could not find git repository, skipping git commit and branch info")
+ with open(os.devnull, "w") as fnull:
+ if subprocess.call(["git", "-C", get_original_cwd(), "status"], stderr=subprocess.STDOUT, stdout=fnull) == 0:
+ try:
+ hparams["git/commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
+ hparams["git/branch"] = (
+ subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("ascii").strip()
+ )
+ hparams["git/remote"] = (
+ subprocess.check_output(["git", "remote", "get-url", "origin"]).decode("ascii").strip()
+ )
+ except subprocess.CalledProcessError:
+ log.warning(
+ "Could not get git commit, branch or remote information, the repository might not have any commits "
+ " yet or it might have been initialized wrongly."
+ )
+ else:
+ log.warning("Could not find git repository, skipping git commit and branch info")
# send hparams to all loggers
trainer.logger.log_hyperparams(hparams)
@@ -221,18 +219,18 @@ def upload_file_tensorboard(file_path: str, tensorboard_logger: TensorBoardLogge
ext = os.path.splitext(file_path)[1].lower()
if ext == ".json":
- with open(file_path, "r") as f:
+ with open(file_path) as f:
json_content = json.load(f)
json_content = f"```json\n{json.dumps(json_content, indent=4)}\n```"
tensorboard_logger.experiment.add_text(tag=tag, text_string=json_content, global_step=0)
elif ext in [".yaml", ".yml"]:
- with open(file_path, "r") as f:
+ with open(file_path) as f:
yaml_content = f.read()
yaml_content = f"```yaml\n{yaml_content}\n```"
tensorboard_logger.experiment.add_text(tag=tag, text_string=yaml_content, global_step=0)
else:
- with open(file_path, "r", encoding="utf-8") as f:
+ with open(file_path, encoding="utf-8") as f:
tensorboard_logger.experiment.add_text(tag=tag, text_string=f.read().replace("\n", " \n"), global_step=0)
tensorboard_logger.experiment.flush()
@@ -243,8 +241,8 @@ def finish(
module: pl.LightningModule,
datamodule: pl.LightningDataModule,
trainer: pl.Trainer,
- callbacks: List[pl.Callback],
- logger: List[pl.loggers.Logger],
+ callbacks: list[pl.Callback],
+ logger: list[pl.loggers.Logger],
export_folder: str,
) -> None:
"""Upload config files to MLFlow server.
@@ -284,10 +282,10 @@ def finish(
)
deployed_models = glob.glob(os.path.join(export_folder, "*"))
- model_json: Optional[Dict[str, Any]] = None
+ model_json: dict[str, Any] | None = None
if os.path.exists(os.path.join(export_folder, "model.json")):
- with open(os.path.join(export_folder, "model.json"), "r") as json_file:
+ with open(os.path.join(export_folder, "model.json")) as json_file:
model_json = json.load(json_file)
if model_json is not None:
@@ -297,7 +295,7 @@ def finish(
# Input size is not a list of lists
input_size = [input_size]
inputs = cast(
- List[Any],
+ list[Any],
quadra_export.generate_torch_inputs(input_size, device=device, half_precision=half_precision),
)
types_to_upload = config.core.get("upload_models")
@@ -353,7 +351,7 @@ def finish(
tensorboard_logger.experiment.flush()
-def load_envs(env_file: Optional[str] = None) -> None:
+def load_envs(env_file: str | None = None) -> None:
"""Load all the environment variables defined in the `env_file`.
This is equivalent to `. env_file` in bash.
@@ -366,7 +364,7 @@ def load_envs(env_file: Optional[str] = None) -> None:
dotenv.load_dotenv(dotenv_path=env_file, override=True)
-def model_type_from_path(model_path: str) -> Optional[str]:
+def model_type_from_path(model_path: str) -> str | None:
"""Determine the type of the machine learning model based on its file extension.
Parameters:
@@ -422,7 +420,7 @@ def get_device(cuda: bool = True) -> str:
return "cpu"
-def nested_set(dic: Dict, keys: List[str], value: str) -> None:
+def nested_set(dic: dict, keys: list[str], value: str) -> None:
"""Assign the value of a dictionary using nested keys."""
for key in keys[:-1]:
dic = dic.setdefault(key, {})
@@ -430,16 +428,16 @@ def nested_set(dic: Dict, keys: List[str], value: str) -> None:
dic[keys[-1]] = value
-def flatten_list(l: Iterable[Any]) -> Iterator[Any]:
+def flatten_list(input_list: Iterable[Any]) -> Iterator[Any]:
"""Return an iterator over the flattened list.
Args:
- l: the list to be flattened
+ input_list: the list to be flattened
Yields:
- Iterator[Any]: the iterator over the flattend list
+ The iterator over the flattend list
"""
- for v in l:
+ for v in input_list:
if isinstance(v, Iterable) and not isinstance(v, (str, bytes)):
yield from flatten_list(v)
else:
@@ -451,9 +449,8 @@ class HydraEncoder(json.JSONEncoder):
def default(self, o):
"""Convert OmegaConf objects to base python objects."""
- if o is not None:
- if OmegaConf.is_config(o):
- return OmegaConf.to_container(o)
+ if o is not None and OmegaConf.is_config(o):
+ return OmegaConf.to_container(o)
return json.JSONEncoder.default(self, o)
@@ -508,7 +505,7 @@ def concat_all_gather(tensor):
return output
-def get_tensorboard_logger(trainer: pl.Trainer) -> Optional[TensorBoardLogger]:
+def get_tensorboard_logger(trainer: pl.Trainer) -> TensorBoardLogger | None:
"""Safely get tensorboard logger from Lightning Trainer loggers.
Args:
diff --git a/quadra/utils/validator.py b/quadra/utils/validator.py
index d8287ec3..dc1d96d6 100644
--- a/quadra/utils/validator.py
+++ b/quadra/utils/validator.py
@@ -1,19 +1,22 @@
+from __future__ import annotations
+
import difflib
import importlib
import inspect
-from typing import Any, Iterable, List, Tuple, Union
+from collections.abc import Iterable
+from typing import Any
from omegaconf import DictConfig, ListConfig, OmegaConf
from quadra.utils.utils import get_logger
-OMEGACONF_FIELDS: Tuple[str, ...] = ("_target_", "_convert_", "_recursive_", "_args_")
-EXCLUDE_KEYS: Tuple[str, ...] = ("hydra",)
+OMEGACONF_FIELDS: tuple[str, ...] = ("_target_", "_convert_", "_recursive_", "_args_")
+EXCLUDE_KEYS: tuple[str, ...] = ("hydra",)
logger = get_logger(__name__)
-def get_callable_arguments(full_module_path: str) -> Tuple[List[str], bool]:
+def get_callable_arguments(full_module_path: str) -> tuple[list[str], bool]:
"""Gets all arguments from module path.
Args:
@@ -56,7 +59,7 @@ def get_callable_arguments(full_module_path: str) -> Tuple[List[str], bool]:
return arg_names, accepts_kwargs
-def check_all_arguments(callable_variable: str, configuration_arguments: List[str], argument_names: List[str]) -> None:
+def check_all_arguments(callable_variable: str, configuration_arguments: list[str], argument_names: list[str]) -> None:
"""Checks if all arguments passed from configuration are valid for the target class or function.
Args:
@@ -78,7 +81,7 @@ def check_all_arguments(callable_variable: str, configuration_arguments: List[st
raise ValueError(error_string)
-def validate_config(_cfg: Union[DictConfig, ListConfig], package_name: str = "quadra") -> None:
+def validate_config(_cfg: DictConfig | ListConfig, package_name: str = "quadra") -> None:
"""Recursively traverse OmegaConf object and check if arguments are valid for the target class or function.
If not, raise a ValueError with a suggestion for the closest match of the argument name.
@@ -104,7 +107,7 @@ def validate_config(_cfg: Union[DictConfig, ListConfig], package_name: str = "qu
if key == "_target_":
callable_variable = str(_cfg[key])
if callable_variable.startswith(package_name):
- configuration_arguments = [str(x) for x in _cfg.keys() if x not in OMEGACONF_FIELDS]
+ configuration_arguments = [str(x) for x in _cfg if x not in OMEGACONF_FIELDS]
argument_names, accepts_kwargs = get_callable_arguments(callable_variable)
if not accepts_kwargs:
check_all_arguments(callable_variable, configuration_arguments, argument_names)
diff --git a/quadra/utils/visualization.py b/quadra/utils/visualization.py
index 1aef5dc3..bdbfb77a 100644
--- a/quadra/utils/visualization.py
+++ b/quadra/utils/visualization.py
@@ -1,7 +1,10 @@
+from __future__ import annotations
+
import copy
import os
import random
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
+from collections.abc import Callable, Iterable
+from typing import Any
import albumentations
import matplotlib.pyplot as plt
@@ -54,33 +57,32 @@ def create_grid_figure(
nrows: int,
ncols: int,
file_path: str,
- bounds: List[Tuple[float, float]],
- row_names: Optional[Iterable[str]] = None,
- fig_size: Tuple[int, int] = (12, 8),
+ bounds: list[tuple[float, float]],
+ row_names: Iterable[str] | None = None,
+ fig_size: tuple[int, int] = (12, 8),
):
"""Create a grid figure with images.
Args:
- images (Iterable[np.ndarray]): List of images to plot.
- nrows (int): Number of rows in the grid.
- ncols (int): Number of columns in the grid.
- file_path (str): Path to save the figure.
- row_names (Optional[Iterable[str]], optional): Row names. Defaults to None.
- fig_size (Tuple[int, int], optional): Figure size. Defaults to (12, 8).
- bounds (Optional[List[Tuple[float, float]]], optional): Bounds for the images. Defaults to None.
+ images: List of images to plot.
+ nrows: Number of rows in the grid.
+ ncols: Number of columns in the grid.
+ file_path: Path to save the figure.
+ row_names: Row names. Defaults to None.
+ fig_size: Figure size. Defaults to (12, 8).
+ bounds: Bounds for the images. Defaults to None.
"""
default_plt_backend = plt.get_backend()
plt.switch_backend("Agg")
_, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=fig_size, squeeze=False)
for i, row in enumerate(images):
for j, image in enumerate(row):
- if len(image.shape) == 3 and image.shape[0] == 1:
- image = image[0]
- ax[i][j].imshow(image, vmin=bounds[i][0], vmax=bounds[i][1])
+ image_to_plot = image[0] if len(image.shape) == 3 and image.shape[0] == 1 else image
+ ax[i][j].imshow(image_to_plot, vmin=bounds[i][0], vmax=bounds[i][1])
ax[i][j].get_xaxis().set_ticks([])
ax[i][j].get_yaxis().set_ticks([])
if row_names is not None:
- for ax, name in zip(ax[:, 0], row_names):
+ for ax, name in zip(ax[:, 0], row_names): # noqa: B020
ax.set_ylabel(name, rotation=90)
plt.tight_layout()
@@ -138,10 +140,10 @@ def show_mask_on_image(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
def reconstruct_multiclass_mask(
mask: np.ndarray,
- image_shape: Tuple[int, ...],
+ image_shape: tuple[int, ...],
color_map: ListedColormap,
- ignore_class: Optional[int] = None,
- ground_truth_mask: Optional[np.ndarray] = None,
+ ignore_class: int | None = None,
+ ground_truth_mask: np.ndarray | None = None,
) -> np.ndarray:
"""Reconstruct a multiclass mask from a single channel mask.
@@ -172,11 +174,11 @@ def plot_multiclass_prediction(
image: np.ndarray,
prediction_image: np.ndarray,
ground_truth_image: np.ndarray,
- class_to_idx: Dict[str, int],
+ class_to_idx: dict[str, int],
plot_original: bool = True,
- ignore_class: Optional[int] = 0,
+ ignore_class: int | None = 0,
image_height: int = 10,
- save_path: Optional[str] = None,
+ save_path: str | None = None,
color_map: str = "tab20",
) -> None:
"""Function used to plot the image predicted.
@@ -247,16 +249,16 @@ def plot_classification_results(
test_labels: np.ndarray,
class_name: str,
original_folder: str,
- gradcam_folder: Optional[str] = None,
- grayscale_cams: Optional[np.ndarray] = None,
- unorm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
- idx_to_class: Optional[Dict] = None,
- what: Optional[str] = None,
- real_class_to_plot: Optional[int] = None,
- pred_class_to_plot: Optional[int] = None,
- rows: Optional[int] = 1,
+ gradcam_folder: str | None = None,
+ grayscale_cams: np.ndarray | None = None,
+ unorm: Callable[[torch.Tensor], torch.Tensor] | None = None,
+ idx_to_class: dict | None = None,
+ what: str | None = None,
+ real_class_to_plot: int | None = None,
+ pred_class_to_plot: int | None = None,
+ rows: int | None = 1,
cols: int = 4,
- figsize: Tuple[int, int] = (20, 20),
+ figsize: tuple[int, int] = (20, 20),
gradcam: bool = False,
) -> None:
"""Plot and save images extracted from classification. If gradcam is True, same images
diff --git a/quadra/utils/vit_explainability.py b/quadra/utils/vit_explainability.py
index bfaeb0b7..5170a079 100644
--- a/quadra/utils/vit_explainability.py
+++ b/quadra/utils/vit_explainability.py
@@ -2,9 +2,9 @@
# Title: Explainability for Vision Transformers
# Source: https://github.com/jacobgil/vit-explain (MIT license)
# Description: This is a heavily modified version of the original jacobgil code (the underlying math is still the same).
+from __future__ import annotations
import math
-from typing import List, Optional
import numpy as np
import torch
@@ -12,7 +12,7 @@
def rollout(
- attentions: List[torch.Tensor], discard_ratio: float = 0.9, head_fusion: str = "mean", aspect_ratio: float = 1.0
+ attentions: list[torch.Tensor], discard_ratio: float = 0.9, head_fusion: str = "mean", aspect_ratio: float = 1.0
) -> np.ndarray:
"""Apply rollout on Attention matrices.
@@ -41,8 +41,8 @@ def rollout(
flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
_, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
flat.scatter_(-1, indices, 0)
- I = torch.eye(attention_heads_fused.size(-1))
- a = (attention_heads_fused + 1.0 * I) / 2
+ identity_matrix = torch.eye(attention_heads_fused.size(-1))
+ a = (attention_heads_fused + 1.0 * identity_matrix) / 2
a = a / a.sum(dim=-1).unsqueeze(1)
result = torch.matmul(a, result)
# Look at the total attention between the class token and the image patches
@@ -75,7 +75,7 @@ class VitAttentionRollout:
def __init__(
self,
model: torch.nn.Module,
- attention_layer_names: Optional[List[str]] = None,
+ attention_layer_names: list[str] | None = None,
head_fusion: str = "mean",
discard_ratio: float = 0.9,
):
@@ -89,15 +89,19 @@ def __init__(
self.model = model
self.head_fusion = head_fusion
self.discard_ratio = discard_ratio
- self.f_hook_handles: List[torch.utils.hooks.RemovableHandle] = []
+ self.f_hook_handles: list[torch.utils.hooks.RemovableHandle] = []
for name, module in self.model.named_modules():
for layer_name in attention_layer_names:
if layer_name in name:
self.f_hook_handles.append(module.register_forward_hook(self.get_attention))
- self.attentions: List[torch.Tensor] = []
+ self.attentions: list[torch.Tensor] = []
+ # pylint: disable=unused-argument
def get_attention(
- self, module: torch.nn.Module, inpt: torch.Tensor, out: torch.Tensor # pylint: disable=W0613
+ self,
+ module: torch.nn.Module,
+ inpt: torch.Tensor,
+ out: torch.Tensor,
) -> None:
"""Hook to return attention.
@@ -134,7 +138,7 @@ def __call__(self, input_tensor: torch.Tensor) -> np.ndarray:
def grad_rollout(
- attentions: List[torch.Tensor], gradients: List[torch.Tensor], discard_ratio: float = 0.9, aspect_ratio: float = 1.0
+ attentions: list[torch.Tensor], gradients: list[torch.Tensor], discard_ratio: float = 0.9, aspect_ratio: float = 1.0
) -> np.ndarray:
"""Apply gradient rollout on Attention matrices.
@@ -193,10 +197,10 @@ class VitAttentionGradRollout:
def __init__( # pylint: disable=W0102
self,
model: torch.nn.Module,
- attention_layer_names: Optional[List[str]] = None,
+ attention_layer_names: list[str] | None = None,
discard_ratio: float = 0.9,
- classifier: Optional[LinearClassifierMixin] = None,
- example_input: Optional[torch.Tensor] = None,
+ classifier: LinearClassifierMixin | None = None,
+ example_input: torch.Tensor | None = None,
):
if attention_layer_names is None:
attention_layer_names = [
@@ -221,15 +225,15 @@ def __init__( # pylint: disable=W0102
self.model = model # type: ignore[assignment]
self.discard_ratio = discard_ratio
- self.f_hook_handles: List[torch.utils.hooks.RemovableHandle] = []
- self.b_hook_handles: List[torch.utils.hooks.RemovableHandle] = []
+ self.f_hook_handles: list[torch.utils.hooks.RemovableHandle] = []
+ self.b_hook_handles: list[torch.utils.hooks.RemovableHandle] = []
for name, module in self.model.named_modules():
for layer_name in attention_layer_names:
if layer_name in name:
self.f_hook_handles.append(module.register_forward_hook(self.get_attention))
self.b_hook_handles.append(module.register_backward_hook(self.get_attention_gradient))
- self.attentions: List[torch.Tensor] = []
- self.attention_gradients: List[torch.Tensor] = []
+ self.attentions: list[torch.Tensor] = []
+ self.attention_gradients: list[torch.Tensor] = []
# Activate gradients
blocks_list = [x.split("blocks")[1].split(".attn")[0] for x in attention_layer_names]
for name, module in model.named_modules():
@@ -237,8 +241,12 @@ def __init__( # pylint: disable=W0102
if "blocks" in name and any(x in name for x in blocks_list):
p.requires_grad = True
+ # pylint: disable=unused-argument
def get_attention(
- self, module: torch.nn.Module, inpt: torch.Tensor, out: torch.Tensor # pylint: disable=W0613
+ self,
+ module: torch.nn.Module,
+ inpt: torch.Tensor,
+ out: torch.Tensor,
) -> None:
"""Hook to return attention.
@@ -249,8 +257,12 @@ def get_attention(
"""
self.attentions.append(out.detach().clone().cpu())
+ # pylint: disable=unused-argument
def get_attention_gradient(
- self, module: torch.nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor # pylint: disable=W0613
+ self,
+ module: torch.nn.Module,
+ grad_input: torch.Tensor,
+ grad_output: torch.Tensor,
) -> None:
"""Hook to return attention.
@@ -261,7 +273,7 @@ def get_attention_gradient(
"""
self.attention_gradients.append(grad_input[0].detach().clone().cpu())
- def __call__(self, input_tensor: torch.Tensor, targets_list: List[int]) -> np.ndarray:
+ def __call__(self, input_tensor: torch.Tensor, targets_list: list[int]) -> np.ndarray:
"""Called when the class instance is used as a function.
Args:
diff --git a/tests/configurations/test_experiments.py b/tests/configurations/test_experiments.py
index b59b475a..61f74f80 100644
--- a/tests/configurations/test_experiments.py
+++ b/tests/configurations/test_experiments.py
@@ -1,6 +1,7 @@
+from __future__ import annotations
+
import glob
from pathlib import Path
-from typing import List
import pytest
from hydra import compose, initialize_config_module
@@ -10,10 +11,10 @@
from quadra.utils.validator import validate_config
-def get_experiment_configs(experiment_folder: str) -> List[str]:
+def get_experiment_configs(experiment_folder: str) -> list[str]:
path = Path(__file__).parent.parent.parent / Path(f"quadra/configs/experiment/{experiment_folder}/**/*.yaml")
experiment_paths = glob.glob(str(path), recursive=True)
- experiments: List[str] = []
+ experiments: list[str] = []
for path in experiment_paths:
experiment_tag = path.split("experiment/")[-1]
experiments.append(experiment_tag.split(".yaml")[0])
diff --git a/tests/conftest.py b/tests/conftest.py
index 44b0b87e..29835bff 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,6 +1,6 @@
import os
+from collections.abc import Callable, Generator
from pathlib import Path
-from typing import Callable, Generator
import pytest
import torch
diff --git a/tests/datamodules/test_classification.py b/tests/datamodules/test_classification.py
index 00cb6292..8ca9daa8 100644
--- a/tests/datamodules/test_classification.py
+++ b/tests/datamodules/test_classification.py
@@ -73,7 +73,7 @@ def test_classification_datamodule_with_splits_phase_test(classification_dataset
datamodule.prepare_data()
datamodule.setup("test")
- with open(os.path.join(data_path, "test.txt"), "r") as f:
+ with open(os.path.join(data_path, "test.txt")) as f:
test_samples_txt = f.read().splitlines()
test_split = datamodule.data["split"] == "test"
@@ -135,31 +135,29 @@ def test_classification_patch_datamodule(classification_patch_dataset: classific
datamodule.setup("fit")
datamodule.setup("test")
- with open(os.path.join(data_path, "info.json"), "r") as f:
+ with open(os.path.join(data_path, "info.json")) as f:
info = PatchDatasetInfo(**json.load(f))
# train samples are named like imagename_class.h5
train_samples_df = datamodule.train_data["samples"].tolist()
- datamodule_train_samples = set(
- [os.path.splitext(os.path.basename("_".join(s.split("_")[0:-1])))[0] for s in train_samples_df]
- )
+ datamodule_train_samples = {
+ os.path.splitext(os.path.basename("_".join(s.split("_")[0:-1])))[0] for s in train_samples_df
+ }
# val samples are named like imagename_patchnumber.xyz
val_samples_df = datamodule.val_data["samples"].tolist()
- datamodule_val_samples = set(
- [os.path.splitext("_".join(os.path.basename(s).split("_")[0:-1]))[0] for s in val_samples_df]
- )
+ datamodule_val_samples = {
+ os.path.splitext("_".join(os.path.basename(s).split("_")[0:-1]))[0] for s in val_samples_df
+ }
# test samples are named like imagename_patchnumber.xyz and may contain #DISCARD# in the name
test_samples_df = datamodule.test_data["samples"].tolist()
- datamodule_test_samples = set(
- [
- os.path.splitext("_".join(os.path.basename(s).replace("#DISCARD#", "").split("_")[0:-1]))[0]
- for s in test_samples_df
- ]
- )
+ datamodule_test_samples = {
+ os.path.splitext("_".join(os.path.basename(s).replace("#DISCARD#", "").split("_")[0:-1]))[0]
+ for s in test_samples_df
+ }
- train_filenames = set([os.path.splitext(os.path.basename(s.image_path))[0] for s in info.train_files])
- val_filenames = set([os.path.splitext(os.path.basename(s.image_path))[0] for s in info.val_files])
- test_filenames = set([os.path.splitext(os.path.basename(s.image_path))[0] for s in info.test_files])
+ train_filenames = {os.path.splitext(os.path.basename(s.image_path))[0] for s in info.train_files}
+ val_filenames = {os.path.splitext(os.path.basename(s.image_path))[0] for s in info.val_files}
+ test_filenames = {os.path.splitext(os.path.basename(s.image_path))[0] for s in info.test_files}
assert datamodule_train_samples == train_filenames
assert datamodule_val_samples == val_filenames
@@ -229,17 +227,17 @@ def test_multilabel_classification_datamodule(multilabel_classification_dataset:
datamodule.setup("fit")
datamodule.setup("test")
- with open(os.path.join(data_path, "train.txt"), "r") as f:
+ with open(os.path.join(data_path, "train.txt")) as f:
train_samples = f.read().splitlines()
train_labels = [s.split(",")[1:] for s in train_samples]
train_labels = np.array([datamodule.class_to_idx[x] for l in train_labels for x in l])
- with open(os.path.join(data_path, "val.txt"), "r") as f:
+ with open(os.path.join(data_path, "val.txt")) as f:
val_samples = f.read().splitlines()
val_labels = [s.split(",")[1:] for s in val_samples]
val_labels = np.array([datamodule.class_to_idx[x] for l in val_labels for x in l])
- with open(os.path.join(data_path, "test.txt"), "r") as f:
+ with open(os.path.join(data_path, "test.txt")) as f:
test_samples = f.read().splitlines()
test_labels = [s.split(",")[1:] for s in test_samples]
test_labels = np.array([datamodule.class_to_idx[x] for l in test_labels for x in l])
diff --git a/tests/datasets/test_classification.py b/tests/datasets/test_classification.py
index 0519cf34..3d5eac13 100644
--- a/tests/datasets/test_classification.py
+++ b/tests/datasets/test_classification.py
@@ -50,11 +50,11 @@ def test_multilabel_classification_dataset(
data_path, _ = base_multilabel_classification_dataset
samples = glob.glob(os.path.join(data_path, "images", "*"))
- with open(os.path.join(data_path, "samples.txt"), "r") as f:
+ with open(os.path.join(data_path, "samples.txt")) as f:
samples_and_targets = [line.strip().split(",") for line in f.readlines()]
samples_mapping = {os.path.basename(st[0]): st[1:] for st in samples_and_targets}
- targets = set([item for sublist in list(samples_mapping.values()) for item in sublist])
+ targets = {item for sublist in list(samples_mapping.values()) for item in sublist}
class_to_idx = {c: i for i, c in enumerate(targets)}
one_hot_encoding = np.zeros((len(samples), len(targets)))
@@ -74,7 +74,7 @@ def test_multilabel_classification_dataset(
assert len(item) == 2
assert isinstance(item[0], np.ndarray)
assert isinstance(item[1], torch.Tensor)
- reverted_classes = set([dataset.idx_to_class[c.item()] for c in torch.where(item[1] == 1)[0]])
+ reverted_classes = {dataset.idx_to_class[c.item()] for c in torch.where(item[1] == 1)[0]}
assert reverted_classes == set(samples_mapping[os.path.basename(dataset.x[i])])
diff --git a/tests/datasets/test_segmentation.py b/tests/datasets/test_segmentation.py
index 0d2baeeb..f49b5a22 100644
--- a/tests/datasets/test_segmentation.py
+++ b/tests/datasets/test_segmentation.py
@@ -1,7 +1,8 @@
# pylint: disable=redefined-outer-name
+from __future__ import annotations
+
import glob
import os
-from typing import Optional
import albumentations as alb
import numpy as np
@@ -21,7 +22,7 @@
@pytest.mark.parametrize("batch_size", [None, 32, 256])
def test_binary_segmentation_dataset(
base_binary_segmentation_dataset: base_binary_segmentation_dataset,
- batch_size: Optional[int],
+ batch_size: int | None,
use_albumentations: bool,
):
data_path, arguments, _ = base_binary_segmentation_dataset
@@ -49,7 +50,7 @@ def test_binary_segmentation_dataset(
count_good = 0
count_bad = 0
- for (image, mask, target) in dataset:
+ for image, mask, target in dataset:
if use_albumentations:
assert isinstance(image, torch.Tensor)
assert isinstance(mask, torch.Tensor)
@@ -75,7 +76,7 @@ def test_binary_segmentation_dataset(
assert count_bad == (arguments.train_samples[1] + arguments.val_samples[1] + arguments.test_samples[1])
dataloader = DataLoader(dataset, batch_size=1)
- for (image, mask, target) in dataloader:
+ for image, mask, target in dataloader:
assert isinstance(image, torch.Tensor)
assert isinstance(mask, torch.Tensor)
if use_albumentations:
@@ -90,7 +91,7 @@ def test_binary_segmentation_dataset(
@pytest.mark.parametrize("one_hot", [True, False])
def test_multiclass_segmentation_dataset(
base_multiclass_segmentation_dataset: base_multiclass_segmentation_dataset,
- batch_size: Optional[int],
+ batch_size: int | None,
use_albumentations: bool,
one_hot: bool,
):
@@ -119,7 +120,7 @@ def test_multiclass_segmentation_dataset(
one_hot=one_hot,
)
- for (image, mask, _) in dataset:
+ for image, mask, _ in dataset:
if use_albumentations:
assert isinstance(image, torch.Tensor)
assert isinstance(mask, np.ndarray)
@@ -137,7 +138,7 @@ def test_multiclass_segmentation_dataset(
assert len(mask.shape) == 2
dataloader = DataLoader(dataset, batch_size=1)
- for (image, mask, _) in dataloader:
+ for image, mask, _ in dataloader:
assert isinstance(image, torch.Tensor)
assert isinstance(mask, torch.Tensor)
if use_albumentations:
diff --git a/tests/models/test_export.py b/tests/models/test_export.py
index 12bd9c92..7a19e395 100644
--- a/tests/models/test_export.py
+++ b/tests/models/test_export.py
@@ -1,7 +1,8 @@
from __future__ import annotations
+from collections.abc import Sequence
from pathlib import Path
-from typing import Any, Sequence
+from typing import Any
import pytest
import torch
@@ -87,7 +88,7 @@ def check_export_model_outputs(
)
models = []
- for export_type, model_path in exported_models.items():
+ for _, model_path in exported_models.items():
model = import_deployment_model(model_path=model_path, inference_config=inference_config, device=device)
models.append(model)
diff --git a/tests/tasks/test_anomaly.py b/tests/tasks/test_anomaly.py
index 7ed0e86f..cc2c11db 100644
--- a/tests/tasks/test_anomaly.py
+++ b/tests/tasks/test_anomaly.py
@@ -1,8 +1,10 @@
# pylint: disable=redefined-outer-name
+from __future__ import annotations
+
import os
import shutil
+from collections.abc import Callable, Generator
from pathlib import Path
-from typing import Callable, Generator, List
import pytest
from pytest_mock import MockerFixture
@@ -68,7 +70,7 @@ def _run_inference_experiment(data_path: str, train_path: str, test_path: str, e
execute_quadra_experiment(overrides=test_overrides, experiment_path=test_path)
-def run_inference_experiments(data_path: str, train_path: str, test_path: str, export_types: List[str]):
+def run_inference_experiments(data_path: str, train_path: str, test_path: str, export_types: list[str]):
"""Run inference experiments for the given export types."""
for export_type in export_types:
cwd = os.getcwd()
@@ -179,7 +181,7 @@ def test_efficientad(
f"model.model.imagenette_dir= {imagenette_path}",
f"model.dataset.task={task}",
f"export.types=[{','.join(BASE_EXPORT_TYPES)}]",
- f"export.input_shapes=[[3,256,256],[3,256,256]]",
+ "export.input_shapes=[[3,256,256],[3,256,256]]",
]
trainer_overrides = setup_trainer_for_lightning()
overrides += BASE_EXPERIMENT_OVERRIDES
diff --git a/tests/tasks/test_classification.py b/tests/tasks/test_classification.py
index 0bd402a1..df0f8093 100644
--- a/tests/tasks/test_classification.py
+++ b/tests/tasks/test_classification.py
@@ -1,8 +1,9 @@
# pylint: disable=redefined-outer-name
+from __future__ import annotations
+
import os
import shutil
from pathlib import Path
-from typing import List
import pytest
@@ -37,7 +38,7 @@
def _run_inference_experiment(
- test_overrides: List[str], data_path: str, train_path: str, test_path: str, export_type: str
+ test_overrides: list[str], data_path: str, train_path: str, test_path: str, export_type: str
):
"""Run an inference experiment for the given export type."""
extension = get_export_extension(export_type)
@@ -49,7 +50,7 @@ def _run_inference_experiment(
def run_inference_experiments(
- test_overrides: List[str], data_path: str, train_path: str, test_path: str, export_types: List[str]
+ test_overrides: list[str], data_path: str, train_path: str, test_path: str, export_types: list[str]
):
"""Run inference experiments for the given export types."""
for export_type in export_types:
@@ -183,7 +184,7 @@ def test_classification(
f"task.gradcam={gradcam}",
"trainer.max_epochs=1",
"task.report=True",
- f"task.run_test=true",
+ "task.run_test=true",
f"export.types=[{','.join(BASE_EXPORT_TYPES)}]",
]
trainer_overrides = setup_trainer_for_lightning()
diff --git a/tests/tasks/test_segmentation.py b/tests/tasks/test_segmentation.py
index d0f7cb18..82c98606 100644
--- a/tests/tasks/test_segmentation.py
+++ b/tests/tasks/test_segmentation.py
@@ -1,8 +1,9 @@
# pylint: disable=redefined-outer-name
+from __future__ import annotations
+
import os
import shutil
from pathlib import Path
-from typing import List
import pytest
@@ -34,7 +35,7 @@
def _run_inference_experiment(
- test_overrides: List[str], data_path: str, train_path: str, test_path: str, export_type: str
+ test_overrides: list[str], data_path: str, train_path: str, test_path: str, export_type: str
):
"""Run an inference experiment for the given export type."""
extension = get_export_extension(export_type)
@@ -46,7 +47,7 @@ def _run_inference_experiment(
def run_inference_experiments(
- test_overrides: List[str], data_path: str, train_path: str, test_path: str, export_types: List[str]
+ test_overrides: list[str], data_path: str, train_path: str, test_path: str, export_types: list[str]
):
"""Run inference experiments for the given export types."""
for export_type in export_types:
diff --git a/tests/utilities/test_mlflow_export.py b/tests/utilities/test_mlflow_export.py
index 1b0e62b3..a38545a0 100644
--- a/tests/utilities/test_mlflow_export.py
+++ b/tests/utilities/test_mlflow_export.py
@@ -9,7 +9,7 @@
except ImportError:
pytest.skip("Mlflow is not installed", allow_module_level=True)
-from typing import Sequence
+from collections.abc import Sequence
import torch
from mlflow.models import infer_signature