From 07e38b5e31a9b09a8151e4073eb66d33a671648f Mon Sep 17 00:00:00 2001 From: R Schwanhold Date: Fri, 8 Jan 2021 14:19:46 +0100 Subject: [PATCH] Use mypy for type checking (#249) * add typechecking with mypy * typecheck packages without __init__ * reformat code * add mypy to lock file * resolve circular dependencies and fix bugs * refactor * add typecheck to CI * change initial value to None * test if CI passes without typecheck * add type check to CI * git status to CI for more output * ignore mypy cache * reformat code * add assertion for non optional parameter * remove redundant imports --- .github/workflows/main.yml | 4 + .gitignore | 5 +- poetry.lock | 55 ++++- pyproject.toml | 1 + setup.cfg | 5 + tests/test_downsampling.py | 10 +- typecheck.sh | 3 + wkcuber/__main__.py | 7 +- wkcuber/api/Dataset.py | 206 +++++++++++------- wkcuber/api/Layer.py | 87 +++++--- wkcuber/api/MagDataset.py | 79 +++++-- wkcuber/api/Properties/DatasetProperties.py | 75 +++++-- wkcuber/api/Properties/LayerProperties.py | 94 +++++--- .../api/Properties/ResolutionProperties.py | 10 +- wkcuber/api/TiffData/TiffMag.py | 160 ++++++++------ wkcuber/api/View.py | 124 ++++++++--- wkcuber/api/bounding_box.py | 20 +- wkcuber/check_equality.py | 25 ++- wkcuber/compress.py | 42 ++-- wkcuber/convert_knossos.py | 23 +- wkcuber/convert_nifti.py | 106 +++++---- wkcuber/cubing.py | 43 +++- wkcuber/downsampling.py | 130 ++++++----- wkcuber/export_wkw_as_tiff.py | 46 ++-- wkcuber/image_readers.py | 50 +++-- wkcuber/knossos.py | 42 ++-- wkcuber/mag.py | 40 ++-- wkcuber/metadata.py | 84 ++++--- wkcuber/recubing.py | 26 ++- wkcuber/tile_cubing.py | 43 ++-- wkcuber/utils.py | 106 +++++---- 31 files changed, 1140 insertions(+), 611 deletions(-) create mode 100644 setup.cfg create mode 100755 typecheck.sh diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1c64ad6e6..be577837e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -34,6 +34,10 @@ jobs: - name: Lint code run: poetry run pylint -j4 wkcuber + + - name: Check typing + run: | + ./typecheck.sh - name: Python tests run: poetry run pytest tests diff --git a/.gitignore b/.gitignore index 5a01e604b..1832fde44 100644 --- a/.gitignore +++ b/.gitignore @@ -99,4 +99,7 @@ testdata/WT1_wkw testdata/tiff_mag_2_reference # VSCode -.vscode/ \ No newline at end of file +.vscode/ + +# MyPy +.mypy_cache/ \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index c886122b0..447ee8013 100644 --- a/poetry.lock +++ b/poetry.lock @@ -265,11 +265,27 @@ category = "dev" optional = false python-versions = ">=3.5" +[[package]] +name = "mypy" +version = "0.770" +description = "Optional static typing for Python" +category = "main" +optional = false +python-versions = ">=3.5" + +[package.dependencies] +mypy-extensions = ">=0.4.3,<0.5.0" +typed-ast = ">=1.4.0,<1.5.0" +typing-extensions = ">=3.7.4" + +[package.extras] +dmypy = ["psutil (>=4.0)"] + [[package]] name = "mypy-extensions" version = "0.4.3" description = "Experimental type system extensions for programs checked with the mypy typechecker." -category = "dev" +category = "main" optional = false python-versions = "*" @@ -495,7 +511,7 @@ urllib3 = ">=1.21.1,<1.25.0 || >1.25.0,<1.25.1 || >1.25.1,<1.26" [package.extras] security = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)"] -socks = ["PySocks (>=1.5.6,!=1.5.7)", "win-inet-pton"] +socks = ["PySocks (>=1.5.6,<1.5.7 || >1.5.7)", "win-inet-pton"] [[package]] name = "scikit-image" @@ -514,7 +530,7 @@ PyWavelets = ">=0.4.0" scipy = ">=0.19.0" [package.extras] -docs = ["sphinx (>=1.3,!=1.7.8)", "numpydoc (>=0.9)", "sphinx-gallery", "sphinx-copybutton", "pytest-runner", "scikit-learn", "matplotlib (>=3.0.1)", "dask[array] (>=0.15.0)", "cloudpickle (>=0.2.1)"] +docs = ["sphinx (>=1.3,<1.7.8 || >1.7.8)", "numpydoc (>=0.9)", "sphinx-gallery", "sphinx-copybutton", "pytest-runner", "scikit-learn", "matplotlib (>=3.0.1)", "dask[array] (>=0.15.0)", "cloudpickle (>=0.2.1)"] optional = ["simpleitk", "astropy (>=1.2.0)", "tifffile", "qtpy", "pyamg", "dask[array] (>=0.15.0)", "cloudpickle (>=0.2.1)"] test = ["pytest (!=3.7.3)", "pytest-cov", "pytest-localserver", "flake8", "codecov"] @@ -557,7 +573,7 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" name = "typed-ast" version = "1.4.0" description = "a fork of Python 2 and 3 ast modules with type comment support" -category = "dev" +category = "main" optional = false python-versions = "*" @@ -565,7 +581,7 @@ python-versions = "*" name = "typing-extensions" version = "3.7.4.3" description = "Backported and Experimental Type Hints for Python 3.5+" -category = "dev" +category = "main" optional = false python-versions = "*" @@ -580,7 +596,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, <4" [package.extras] brotli = ["brotlipy (>=0.6.0)"] secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] -socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +socks = ["PySocks (>=1.5.6,<1.5.7 || >1.5.7,<2.0)"] [[package]] name = "wkw" @@ -620,7 +636,7 @@ testing = ["pathlib2", "contextlib2", "unittest2"] [metadata] lock-version = "1.1" python-versions = "^3.6" -content-hash = "85bcb808e67144f1a8b9984980cfe6a009e40bc2b8d743f55288ba638e347c00" +content-hash = "3fbd9addb64d94cd70c25606d41719d140bb6a87cc4590accef1c292e5adfd56" [metadata.files] appdirs = [ @@ -817,6 +833,22 @@ more-itertools = [ {file = "more-itertools-8.0.2.tar.gz", hash = "sha256:b84b238cce0d9adad5ed87e745778d20a3f8487d0f0cb8b8a586816c7496458d"}, {file = "more_itertools-8.0.2-py3-none-any.whl", hash = "sha256:c833ef592a0324bcc6a60e48440da07645063c453880c9477ceb22490aec1564"}, ] +mypy = [ + {file = "mypy-0.770-cp35-cp35m-macosx_10_6_x86_64.whl", hash = "sha256:a34b577cdf6313bf24755f7a0e3f3c326d5c1f4fe7422d1d06498eb25ad0c600"}, + {file = "mypy-0.770-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:86c857510a9b7c3104cf4cde1568f4921762c8f9842e987bc03ed4f160925754"}, + {file = "mypy-0.770-cp35-cp35m-win_amd64.whl", hash = "sha256:a8ffcd53cb5dfc131850851cc09f1c44689c2812d0beb954d8138d4f5fc17f65"}, + {file = "mypy-0.770-cp36-cp36m-macosx_10_6_x86_64.whl", hash = "sha256:7687f6455ec3ed7649d1ae574136835a4272b65b3ddcf01ab8704ac65616c5ce"}, + {file = "mypy-0.770-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:3beff56b453b6ef94ecb2996bea101a08f1f8a9771d3cbf4988a61e4d9973761"}, + {file = "mypy-0.770-cp36-cp36m-win_amd64.whl", hash = "sha256:15b948e1302682e3682f11f50208b726a246ab4e6c1b39f9264a8796bb416aa2"}, + {file = "mypy-0.770-cp37-cp37m-macosx_10_6_x86_64.whl", hash = "sha256:b90928f2d9eb2f33162405f32dde9f6dcead63a0971ca8a1b50eb4ca3e35ceb8"}, + {file = "mypy-0.770-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:c56ffe22faa2e51054c5f7a3bc70a370939c2ed4de308c690e7949230c995913"}, + {file = "mypy-0.770-cp37-cp37m-win_amd64.whl", hash = "sha256:8dfb69fbf9f3aeed18afffb15e319ca7f8da9642336348ddd6cab2713ddcf8f9"}, + {file = "mypy-0.770-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:219a3116ecd015f8dca7b5d2c366c973509dfb9a8fc97ef044a36e3da66144a1"}, + {file = "mypy-0.770-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7ec45a70d40ede1ec7ad7f95b3c94c9cf4c186a32f6bacb1795b60abd2f9ef27"}, + {file = "mypy-0.770-cp38-cp38-win_amd64.whl", hash = "sha256:f91c7ae919bbc3f96cd5e5b2e786b2b108343d1d7972ea130f7de27fdd547cf3"}, + {file = "mypy-0.770-py3-none-any.whl", hash = "sha256:3b1fc683fb204c6b4403a1ef23f0b1fac8e4477091585e0c8c54cbdf7d7bb164"}, + {file = "mypy-0.770.tar.gz", hash = "sha256:8a627507ef9b307b46a1fea9513d5c98680ba09591253082b4c48697ba05a4ae"}, +] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, @@ -946,18 +978,27 @@ pywavelets = [ {file = "PyWavelets-1.1.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:7947e51ca05489b85928af52a34fe67022ab5b81d4ae32a4109a99e883a0635e"}, {file = "PyWavelets-1.1.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:9e2528823ccf5a0a1d23262dfefe5034dce89cd84e4e124dc553dfcdf63ebb92"}, {file = "PyWavelets-1.1.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:80b924edbc012ded8aa8b91cb2fd6207fb1a9a3a377beb4049b8a07445cec6f0"}, + {file = "PyWavelets-1.1.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:c2a799e79cee81a862216c47e5623c97b95f1abee8dd1f9eed736df23fb653fb"}, {file = "PyWavelets-1.1.1-cp36-cp36m-win32.whl", hash = "sha256:d510aef84d9852653d079c84f2f81a82d5d09815e625f35c95714e7364570ad4"}, {file = "PyWavelets-1.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:889d4c5c5205a9c90118c1980df526857929841df33e4cd1ff1eff77c6817a65"}, {file = "PyWavelets-1.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:68b5c33741d26c827074b3d8f0251de1c3019bb9567b8d303eb093c822ce28f1"}, {file = "PyWavelets-1.1.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:18a51b3f9416a2ae6e9a35c4af32cf520dd7895f2b69714f4aa2f4342fca47f9"}, {file = "PyWavelets-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:cfe79844526dd92e3ecc9490b5031fca5f8ab607e1e858feba232b1b788ff0ea"}, + {file = "PyWavelets-1.1.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:2f7429eeb5bf9c7068002d0d7f094ed654c77a70ce5e6198737fd68ab85f8311"}, {file = "PyWavelets-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:720dbcdd3d91c6dfead79c80bf8b00a1d8aa4e5d551dc528c6d5151e4efc3403"}, {file = "PyWavelets-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:bc5e87b72371da87c9bebc68e54882aada9c3114e640de180f62d5da95749cd3"}, {file = "PyWavelets-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:98b2669c5af842a70cfab33a7043fcb5e7535a690a00cd251b44c9be0be418e5"}, {file = "PyWavelets-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e02a0558e0c2ac8b8bbe6a6ac18c136767ec56b96a321e0dfde2173adfa5a504"}, {file = "PyWavelets-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:6162dc0ae04669ea04b4b51420777b9ea2d30b0a9d02901b2a3b4d61d159c2e9"}, + {file = "PyWavelets-1.1.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:39c74740718e420d38c78ca4498568fa57976d78d5096277358e0fa9629a7aea"}, {file = "PyWavelets-1.1.1-cp38-cp38-win32.whl", hash = "sha256:79f5b54f9dc353e5ee47f0c3f02bebd2c899d49780633aa771fed43fa20b3149"}, {file = "PyWavelets-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:935ff247b8b78bdf77647fee962b1cc208c51a7b229db30b9ba5f6da3e675178"}, + {file = "PyWavelets-1.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6ebfefebb5c6494a3af41ad8c60248a95da267a24b79ed143723d4502b1fe4d7"}, + {file = "PyWavelets-1.1.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:6bc78fb9c42a716309b4ace56f51965d8b5662c3ba19d4591749f31773db1125"}, + {file = "PyWavelets-1.1.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:411e17ca6ed8cf5e18a7ca5ee06a91c25800cc6c58c77986202abf98d749273a"}, + {file = "PyWavelets-1.1.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:83c5e3eb78ce111c2f0b45f46106cc697c3cb6c4e5f51308e1f81b512c70c8fb"}, + {file = "PyWavelets-1.1.1-cp39-cp39-win32.whl", hash = "sha256:2b634a54241c190ee989a4af87669d377b37c91bcc9cf0efe33c10ff847f7841"}, + {file = "PyWavelets-1.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:732bab78435c48be5d6bc75486ef629d7c8f112e07b313bf1f1a2220ab437277"}, {file = "PyWavelets-1.1.1.tar.gz", hash = "sha256:1a64b40f6acb4ffbaccce0545d7fc641744f95351f62e4c6aaa40549326008c9"}, ] regex = [ diff --git a/pyproject.toml b/pyproject.toml index 44032a82d..d8561fbf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ natsort = "^6.2.0" psutil = "^5.6.7" nibabel = "^2.5.1" scikit-image = "^0.16.2" +mypy = "^0.770" [tool.poetry.dev-dependencies] pylint = "^2.6.0" diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..31ac23023 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +[mypy] +ignore_missing_imports = True + +[mypy-wkcuber.vendor.*] +ignore_errors = True \ No newline at end of file diff --git a/tests/test_downsampling.py b/tests/test_downsampling.py index abc6443ae..eaed27e5a 100644 --- a/tests/test_downsampling.py +++ b/tests/test_downsampling.py @@ -1,4 +1,6 @@ import logging +from typing import Tuple + import numpy as np from wkcuber.downsampling import ( InterpolationModes, @@ -20,7 +22,9 @@ target_info = WkwDatasetInfo("testoutput/WT1_wkw", "color", 2, wkw.Header(np.uint8)) -def read_wkw(wkw_info, offset, size): +def read_wkw( + wkw_info: WkwDatasetInfo, offset: Tuple[int, int, int], size: Tuple[int, int, int] +): with open_wkw(wkw_info) as wkw_dataset: return wkw_dataset.read(offset, size) @@ -120,7 +124,7 @@ def downsample_test_helper(use_compress): assert np.all( target_buffer - == downsample_cube(source_buffer, (2, 2, 2), InterpolationModes.MAX) + == downsample_cube(source_buffer, [2, 2, 2], InterpolationModes.MAX) ) @@ -180,7 +184,7 @@ def test_downsample_multi_channel(): for channel_index in range(num_channels): channels.append( downsample_cube( - source_data[channel_index], (2, 2, 2), InterpolationModes.MAX + source_data[channel_index], [2, 2, 2], InterpolationModes.MAX ) ) joined_buffer = np.stack(channels) diff --git a/typecheck.sh b/typecheck.sh new file mode 100755 index 000000000..896d4059f --- /dev/null +++ b/typecheck.sh @@ -0,0 +1,3 @@ +#!/bin/bash +set -eEuo pipefail +python -m mypy wkcuber --disallow-untyped-defs --show-error-codes --strict-equality --namespace-packages diff --git a/wkcuber/__main__.py b/wkcuber/__main__.py index 0a3ce40a5..2f4bbb3f9 100644 --- a/wkcuber/__main__.py +++ b/wkcuber/__main__.py @@ -4,9 +4,10 @@ from .metadata import write_webknossos_metadata, refresh_metadata from .utils import add_isotropic_flag, setup_logging, add_scale_flag from .mag import Mag +from argparse import Namespace, ArgumentParser -def create_parser(): +def create_parser() -> ArgumentParser: parser = create_cubing_parser() parser.add_argument( @@ -32,7 +33,7 @@ def create_parser(): return parser -def main(args): +def main(args: Namespace) -> None: setup_logging(args) bounding_box = cubing( @@ -82,5 +83,5 @@ def main(args): if __name__ == "__main__": - parsed_args = create_parser().parse_args() + parsed_args: Namespace = create_parser().parse_args() main(parsed_args) diff --git a/wkcuber/api/Dataset.py b/wkcuber/api/Dataset.py index 26d4c57d4..aaded747f 100644 --- a/wkcuber/api/Dataset.py +++ b/wkcuber/api/Dataset.py @@ -4,10 +4,13 @@ from os import makedirs, path from os.path import join, normpath, basename from pathlib import Path +from typing import Type, Tuple, Union, Dict, Any, Optional, cast + import numpy as np import os import re +from wkcuber.mag import Mag from wkcuber.utils import logger from wkcuber.api.Properties.DatasetProperties import ( @@ -29,9 +32,11 @@ def is_int(s: str) -> bool: return False -def convert_dtypes(dtype, num_channels, dtype_per_layer_to_dtype_per_channel): - if dtype is None: - return None +def convert_dtypes( + dtype: Union[str, np.dtype], + num_channels: int, + dtype_per_layer_to_dtype_per_channel: bool, +) -> str: op = operator.truediv if dtype_per_layer_to_dtype_per_channel else operator.mul # split the dtype into the actual type and the number of bits @@ -45,7 +50,9 @@ def convert_dtypes(dtype, num_channels, dtype_per_layer_to_dtype_per_channel): return "".join(converted_dtype_parts) -def dtype_per_layer_to_dtype_per_channel(dtype_per_layer, num_channels): +def dtype_per_layer_to_dtype_per_channel( + dtype_per_layer: Union[str, np.dtype], num_channels: int +) -> np.dtype: try: return np.dtype( convert_dtypes( @@ -59,7 +66,9 @@ def dtype_per_layer_to_dtype_per_channel(dtype_per_layer, num_channels): ) -def dtype_per_channel_to_dtype_per_layer(dtype_per_channel, num_channels): +def dtype_per_channel_to_dtype_per_layer( + dtype_per_channel: Union[str, np.dtype], num_channels: int +) -> str: return convert_dtypes( np.dtype(dtype_per_channel), num_channels, @@ -69,11 +78,11 @@ def dtype_per_channel_to_dtype_per_layer(dtype_per_channel, num_channels): class AbstractDataset(ABC): @abstractmethod - def __init__(self, dataset_path): - properties = self._get_properties_type()._from_json( + def __init__(self, dataset_path: Union[str, Path]) -> None: + properties: Properties = self._get_properties_type()._from_json( join(dataset_path, Properties.FILE_NAME) ) - self.layers = {} + self.layers: Dict[str, Layer] = {} self.path = Path(properties.path).parent self.properties = properties self.data_format = "abstract" @@ -91,7 +100,7 @@ def __init__(self, dataset_path): self.layers[layer_name].setup_mag(resolution.mag.to_layer_name()) @classmethod - def create_with_properties(cls, properties): + def create_with_properties(cls, properties: Properties) -> "AbstractDataset": dataset_path = path.dirname(properties.path) if os.path.exists(dataset_path): @@ -114,13 +123,10 @@ def create_with_properties(cls, properties): # initialize object return cls(dataset_path) - def downsample(self, layer, target_mag_shape, source_mag): - raise NotImplemented() - def get_properties(self) -> Properties: return self.properties - def get_layer(self, layer_name) -> Layer: + def get_layer(self, layer_name: str) -> Layer: if layer_name not in self.layers.keys(): raise IndexError( "The layer {} is not a layer of this dataset".format(layer_name) @@ -129,13 +135,13 @@ def get_layer(self, layer_name) -> Layer: def add_layer( self, - layer_name, - category, - dtype_per_layer=None, - dtype_per_channel=None, - num_channels=None, - **kwargs, - ): + layer_name: str, + category: str, + dtype_per_layer: Union[str, np.dtype] = None, + dtype_per_channel: Union[str, np.dtype] = None, + num_channels: int = None, + **kwargs: Any, + ) -> Layer: if "dtype" in kwargs: raise ValueError( f"Called Dataset.add_layer with 'dtype'={kwargs['dtype']}. This parameter is deprecated. Use 'dtype_per_layer' or 'dtype_per_channel' instead." @@ -192,34 +198,18 @@ def add_layer( def get_or_add_layer( self, - layer_name, - category, - dtype_per_layer=None, - dtype_per_channel=None, - num_channels=None, - **kwargs, + layer_name: str, + category: str, + dtype_per_layer: Union[str, np.dtype] = None, + dtype_per_channel: Union[str, np.dtype] = None, + num_channels: int = None, + **kwargs: Any, ) -> Layer: if "dtype" in kwargs: raise ValueError( f"Called Dataset.get_or_add_layer with 'dtype'={kwargs['dtype']}. This parameter is deprecated. Use 'dtype_per_layer' or 'dtype_per_channel' instead." ) if layer_name in self.layers.keys(): - assert self.properties.data_layers[layer_name].category == category, ( - f"Cannot get_or_add_layer: The layer '{layer_name}' already exists, but the categories do not match. " - + f"The category of the existing layer is '{self.properties.data_layers[layer_name].category}' " - + f"and the passed parameter is '{category}'." - ) - dtype_per_channel = dtype_per_layer_to_dtype_per_channel( - dtype_per_layer, num_channels - ) - assert ( - dtype_per_layer is None - or self.layers[layer_name].dtype_per_channel == dtype_per_channel - ), ( - f"Cannot get_or_add_layer: The layer '{layer_name}' already exists, but the dtypes do not match. " - + f"The dtype_per_channel of the existing layer is '{self.layers[layer_name].dtype_per_channel}' " - + f"and the passed parameter would result in a dtype_per_channel of '{dtype_per_channel}'." - ) assert ( num_channels is None or self.layers[layer_name].num_channels == num_channels @@ -228,6 +218,28 @@ def get_or_add_layer( + f"The number of channels of the existing layer are '{self.layers[layer_name].num_channels}' " + f"and the passed parameter is '{num_channels}'." ) + assert self.properties.data_layers[layer_name].category == category, ( + f"Cannot get_or_add_layer: The layer '{layer_name}' already exists, but the categories do not match. " + + f"The category of the existing layer is '{self.properties.data_layers[layer_name].category}' " + + f"and the passed parameter is '{category}'." + ) + + if dtype_per_channel is not None or dtype_per_layer is not None: + dtype_per_channel = ( + dtype_per_channel + or dtype_per_layer_to_dtype_per_channel( + dtype_per_layer, + num_channels or self.layers[layer_name].num_channels, + ) + ) + assert ( + dtype_per_channel is None + or self.layers[layer_name].dtype_per_channel == dtype_per_channel + ), ( + f"Cannot get_or_add_layer: The layer '{layer_name}' already exists, but the dtypes do not match. " + + f"The dtype_per_channel of the existing layer is '{self.layers[layer_name].dtype_per_channel}' " + + f"and the passed parameter would result in a dtype_per_channel of '{dtype_per_channel}'." + ) return self.layers[layer_name] else: return self.add_layer( @@ -239,7 +251,7 @@ def get_or_add_layer( **kwargs, ) - def delete_layer(self, layer_name): + def delete_layer(self, layer_name: str) -> None: if layer_name not in self.layers.keys(): raise IndexError( f"Removing layer {layer_name} failed. There is no layer with this name" @@ -249,7 +261,7 @@ def delete_layer(self, layer_name): # delete files on disk rmtree(join(self.path, layer_name)) - def add_symlink_layer(self, foreign_layer_path) -> Layer: + def add_symlink_layer(self, foreign_layer_path: Union[str, Path]) -> Layer: foreign_layer_path = os.path.abspath(foreign_layer_path) layer_name = os.path.basename(os.path.normpath(foreign_layer_path)) if layer_name in self.layers.keys(): @@ -277,33 +289,46 @@ def add_symlink_layer(self, foreign_layer_path) -> Layer: self.layers[layer_name].setup_mag(resolution.mag.to_layer_name()) return self.layers[layer_name] - def get_view(self, layer_name, mag, size, offset=None, is_bounded=True) -> View: + def get_view( + self, + layer_name: str, + mag: Union[str, Mag], + size: Tuple[int, int, int], + offset: Tuple[int, int, int] = None, + is_bounded: bool = True, + ) -> View: layer = self.get_layer(layer_name) mag_ds = layer.get_mag(mag) return mag_ds.get_view(size=size, offset=offset, is_bounded=is_bounded) - def _create_layer(self, layer_name, dtype_per_channel, num_channels) -> Layer: + def _create_layer( + self, layer_name: str, dtype_per_channel: np.dtype, num_channels: int + ) -> Layer: raise NotImplementedError @abstractmethod - def _get_properties_type(self): + def _get_properties_type(self) -> Type[Properties]: pass @abstractmethod - def _get_type(self): + def _get_type(self) -> Type["AbstractDataset"]: pass class WKDataset(AbstractDataset): @classmethod - def create(cls, dataset_path, scale): + def create( + cls, dataset_path: Union[str, Path], scale: Tuple[float, float, float] + ) -> "WKDataset": name = basename(normpath(dataset_path)) properties = WKProperties(join(dataset_path, Properties.FILE_NAME), name, scale) - return WKDataset.create_with_properties(properties) + return cast(WKDataset, WKDataset.create_with_properties(properties)) @classmethod - def get_or_create(cls, dataset_path, scale): + def get_or_create( + cls, dataset_path: Union[str, Path], scale: Tuple[float, float, float] + ) -> "WKDataset": if os.path.exists( join(dataset_path, Properties.FILE_NAME) ): # use the properties file to check if the Dataset exists @@ -315,27 +340,36 @@ def get_or_create(cls, dataset_path, scale): else: return cls.create(dataset_path, scale) - def __init__(self, dataset_path): + def __init__(self, dataset_path: Union[str, Path]) -> None: super().__init__(dataset_path) self.data_format = "wkw" assert isinstance(self.properties, WKProperties) - def to_tiff_dataset(self, new_dataset_path): + def to_tiff_dataset(self, new_dataset_path: Union[str, Path]) -> "TiffDataset": raise NotImplementedError # TODO; implement - def _create_layer(self, layer_name, dtype_per_channel, num_channels) -> Layer: + def _create_layer( + self, layer_name: str, dtype_per_channel: np.dtype, num_channels: int + ) -> Layer: return WKLayer(layer_name, self, dtype_per_channel, num_channels) - def _get_properties_type(self): + def _get_properties_type(self) -> Type[WKProperties]: return WKProperties - def _get_type(self): + def _get_type(self) -> Type["WKDataset"]: return WKDataset class TiffDataset(AbstractDataset): + properties: TiffProperties + @classmethod - def create(cls, dataset_path, scale, pattern="{zzzzz}.tif"): + def create( + cls, + dataset_path: Union[str, Path], + scale: Tuple[float, float, float], + pattern: str = "{zzzzz}.tif", + ) -> "TiffDataset": validate_pattern(pattern) name = basename(normpath(dataset_path)) properties = TiffProperties( @@ -345,10 +379,15 @@ def create(cls, dataset_path, scale, pattern="{zzzzz}.tif"): pattern=pattern, tile_size=None, ) - return TiffDataset.create_with_properties(properties) + return cast(TiffDataset, TiffDataset.create_with_properties(properties)) @classmethod - def get_or_create(cls, dataset_path, scale, pattern=None): + def get_or_create( + cls, + dataset_path: Union[str, Path], + scale: Tuple[float, float, float], + pattern: str = None, + ) -> "TiffDataset": if os.path.exists( join(dataset_path, Properties.FILE_NAME) ): # use the properties file to check if the Dataset exists @@ -367,29 +406,37 @@ def get_or_create(cls, dataset_path, scale, pattern=None): else: return cls.create(dataset_path, scale, pattern) - def __init__(self, dataset_path): + def __init__(self, dataset_path: Union[str, Path]) -> None: super().__init__(dataset_path) self.data_format = "tiff" assert isinstance(self.properties, TiffProperties) - def to_wk_dataset(self, new_dataset_path): + def to_wk_dataset(self, new_dataset_path: Union[str, Path]) -> WKDataset: raise NotImplementedError # TODO; implement - def _create_layer(self, layer_name, dtype_per_channel, num_channels) -> Layer: + def _create_layer( + self, layer_name: str, dtype_per_channel: np.dtype, num_channels: int + ) -> Layer: return TiffLayer(layer_name, self, dtype_per_channel, num_channels) - def _get_properties_type(self): + def _get_properties_type(self) -> Type[TiffProperties]: return TiffProperties - def _get_type(self): + def _get_type(self) -> Type["TiffDataset"]: return TiffDataset class TiledTiffDataset(AbstractDataset): + properties: TiffProperties + @classmethod def create( - cls, dataset_path, scale, tile_size, pattern="{xxxxx}/{yyyyy}/{zzzzz}.tif" - ): + cls, + dataset_path: Union[str, Path], + scale: Tuple[float, float, float], + tile_size: Tuple[int, int], + pattern: str = "{xxxxx}/{yyyyy}/{zzzzz}.tif", + ) -> "TiledTiffDataset": validate_pattern(pattern) name = basename(normpath(dataset_path)) properties = TiffProperties( @@ -399,10 +446,18 @@ def create( pattern=pattern, tile_size=tile_size, ) - return TiledTiffDataset.create_with_properties(properties) + return cast( + TiledTiffDataset, TiledTiffDataset.create_with_properties(properties) + ) @classmethod - def get_or_create(cls, dataset_path, scale, tile_size, pattern=None): + def get_or_create( + cls, + dataset_path: Union[str, Path], + scale: Tuple[float, float, float], + tile_size: Tuple[int, int], + pattern: str = None, + ) -> "TiledTiffDataset": if os.path.exists( join(dataset_path, Properties.FILE_NAME) ): # use the properties file to check if the Dataset exists @@ -410,6 +465,7 @@ def get_or_create(cls, dataset_path, scale, tile_size, pattern=None): assert tuple(ds.properties.scale) == tuple( scale ), f"Cannot get_or_create TiledTiffDataset: The dataset {dataset_path} already exists, but the scales do not match ({ds.properties.scale} != {scale})" + assert ds.properties.tile_size is not None assert tuple(ds.properties.tile_size) == tuple( tile_size ), f"Cannot get_or_create TiledTiffDataset: The dataset {dataset_path} already exists, but the tile sizes do not match ({ds.properties.tile_size} != {tile_size})" @@ -424,25 +480,27 @@ def get_or_create(cls, dataset_path, scale, tile_size, pattern=None): else: return cls.create(dataset_path, scale, tile_size, pattern) - def to_wk_dataset(self, new_dataset_path): + def to_wk_dataset(self, new_dataset_path: str) -> WKDataset: raise NotImplementedError # TODO; implement - def __init__(self, dataset_path): + def __init__(self, dataset_path: Union[str, Path]) -> None: super().__init__(dataset_path) self.data_format = "tiled_tiff" assert isinstance(self.properties, TiffProperties) - def _create_layer(self, layer_name, dtype_per_channel, num_channels) -> Layer: + def _create_layer( + self, layer_name: str, dtype_per_channel: np.dtype, num_channels: int + ) -> Layer: return TiledTiffLayer(layer_name, self, dtype_per_channel, num_channels) - def _get_properties_type(self): + def _get_properties_type(self) -> Type[TiffProperties]: return TiffProperties - def _get_type(self): + def _get_type(self) -> Type["TiledTiffDataset"]: return TiledTiffDataset -def validate_pattern(pattern): +def validate_pattern(pattern: str) -> None: assert pattern.count("{") > 0 and pattern.count("}") > 0, ( f"The provided pattern {pattern} is invalid." + " It needs to contain at least one '{' and one '}'." diff --git a/wkcuber/api/Layer.py b/wkcuber/api/Layer.py index e5a8f501d..bd9449cf4 100644 --- a/wkcuber/api/Layer.py +++ b/wkcuber/api/Layer.py @@ -1,12 +1,14 @@ from shutil import rmtree from os.path import join from os import makedirs -from typing import Tuple +from typing import Tuple, Type, Union, Dict, Any, TYPE_CHECKING import numpy as np from wkw import wkw +if TYPE_CHECKING: + from wkcuber.api.Dataset import AbstractDataset, TiffDataset from wkcuber.api.MagDataset import ( MagDataset, WKMagDataset, @@ -23,23 +25,32 @@ class Layer: COLOR_TYPE = "color" SEGMENTATION_TYPE = "segmentation" - def __init__(self, name, dataset, dtype_per_channel, num_channels): + def __init__( + self, + name: str, + dataset: "AbstractDataset", + dtype_per_channel: np.dtype, + num_channels: int, + ) -> None: self.name = name self.dataset = dataset self.dtype_per_channel = dtype_per_channel self.num_channels = num_channels - self.mags = {} + self.mags: Dict[str, Any] = {} full_path = join(dataset.path, name) makedirs(full_path, exist_ok=True) - def get_mag(self, mag) -> MagDataset: + def get_mag(self, mag: Union[str, Mag]) -> MagDataset: mag = Mag(mag).to_layer_name() if mag not in self.mags.keys(): raise IndexError("The mag {} is not a mag of this layer".format(mag)) return self.mags[mag] - def delete_mag(self, mag): + def get_or_add_mag(self, mag: Union[str, Mag], **kwargs: Any) -> MagDataset: + pass + + def delete_mag(self, mag: Union[str, Mag]) -> None: mag = Mag(mag).to_layer_name() if mag not in self.mags.keys(): raise IndexError( @@ -52,12 +63,12 @@ def delete_mag(self, mag): full_path = find_mag_path_on_disk(self.dataset.path, self.name, mag) rmtree(full_path) - def _create_dir_for_mag(self, mag): + def _create_dir_for_mag(self, mag: Union[str, Mag]) -> None: mag = Mag(mag).to_layer_name() full_path = join(self.dataset.path, self.name, mag) makedirs(full_path, exist_ok=True) - def _assert_mag_does_not_exist_yet(self, mag): + def _assert_mag_does_not_exist_yet(self, mag: Union[str, Mag]) -> None: mag = Mag(mag).to_layer_name() if mag in self.mags.keys(): raise IndexError( @@ -68,31 +79,40 @@ def _assert_mag_does_not_exist_yet(self, mag): def set_bounding_box( self, offset: Tuple[int, int, int], size: Tuple[int, int, int] - ): + ) -> None: self.set_bounding_box_offset(offset) self.set_bounding_box_size(size) - def set_bounding_box_offset(self, offset: Tuple[int, int, int]): - size = self.dataset.properties.data_layers["color"].get_bounding_box_size() - self.dataset.properties._set_bounding_box_of_layer( - self.name, tuple(offset), tuple(size) - ) + def set_bounding_box_offset(self, offset: Tuple[int, int, int]) -> None: + size: Tuple[int, int, int] = self.dataset.properties.data_layers[ + "color" + ].get_bounding_box_size() + self.dataset.properties._set_bounding_box_of_layer(self.name, offset, size) for _, mag in self.mags.items(): mag.view.global_offset = offset - def set_bounding_box_size(self, size: Tuple[int, int, int]): - offset = self.dataset.properties.data_layers["color"].get_bounding_box_offset() - self.dataset.properties._set_bounding_box_of_layer( - self.name, tuple(offset), tuple(size) - ) + def set_bounding_box_size(self, size: Tuple[int, int, int]) -> None: + offset: Tuple[int, int, int] = self.dataset.properties.data_layers[ + "color" + ].get_bounding_box_offset() + self.dataset.properties._set_bounding_box_of_layer(self.name, offset, size) for _, mag in self.mags.items(): mag.view.size = size + def setup_mag(self, mag: str) -> None: + pass + class WKLayer(Layer): + mags: Dict[str, WKMagDataset] + def add_mag( - self, mag, block_len=None, file_len=None, block_type=None - ) -> WKMagDataset: + self, + mag: Union[str, Mag], + block_len: int = None, + file_len: int = None, + block_type: int = None, + ) -> MagDataset: if block_len is None: block_len = 32 if file_len is None: @@ -107,15 +127,18 @@ def add_mag( self._create_dir_for_mag(mag) self.mags[mag] = WKMagDataset.create(self, mag, block_len, file_len, block_type) - self.dataset.properties._add_mag(self.name, mag, block_len * file_len) + self.dataset.properties._add_mag( + self.name, mag, cube_length=block_len * file_len + ) return self.mags[mag] - def get_or_add_mag( - self, mag, block_len=None, file_len=None, block_type=None - ) -> WKMagDataset: + def get_or_add_mag(self, mag: Union[str, Mag], **kwargs: Any) -> MagDataset: # normalize the name of the mag mag = Mag(mag).to_layer_name() + block_len: int = kwargs.get("block_len", None) + file_len: int = kwargs.get("file_len", None) + block_type: int = kwargs.get("block_type", None) if mag in self.mags.keys(): assert ( @@ -131,7 +154,7 @@ def get_or_add_mag( else: return self.add_mag(mag, block_len, file_len, block_type) - def setup_mag(self, mag): + def setup_mag(self, mag: str) -> None: # This method is used to initialize the mag when opening the Dataset. This does not create e.g. the wk_header. # normalize the name of the mag @@ -148,12 +171,14 @@ def setup_mag(self, mag): self, mag, wk_header.block_len, wk_header.file_len, wk_header.block_type ) self.dataset.properties._add_mag( - self.name, mag, wk_header.block_len * wk_header.file_len + self.name, mag, cube_length=wk_header.block_len * wk_header.file_len ) class TiffLayer(Layer): - def add_mag(self, mag) -> MagDataset: + dataset: "TiffDataset" + + def add_mag(self, mag: Union[str, Mag]) -> MagDataset: # normalize the name of the mag mag = Mag(mag).to_layer_name() @@ -167,7 +192,7 @@ def add_mag(self, mag) -> MagDataset: return self.mags[mag] - def get_or_add_mag(self, mag) -> MagDataset: + def get_or_add_mag(self, mag: Union[str, Mag], **kwargs: Any) -> MagDataset: # normalize the name of the mag mag = Mag(mag).to_layer_name() @@ -176,7 +201,7 @@ def get_or_add_mag(self, mag) -> MagDataset: else: return self.add_mag(mag) - def setup_mag(self, mag): + def setup_mag(self, mag: str) -> None: # This method is used to initialize the mag when opening the Dataset. This does not create e.g. folders. # normalize the name of the mag @@ -189,10 +214,10 @@ def setup_mag(self, mag): ) self.dataset.properties._add_mag(self.name, mag) - def _get_mag_dataset_class(self): + def _get_mag_dataset_class(self) -> Type[TiffMagDataset]: return TiffMagDataset class TiledTiffLayer(TiffLayer): - def _get_mag_dataset_class(self): + def _get_mag_dataset_class(self) -> Type[TiledTiffMagDataset]: return TiledTiffMagDataset diff --git a/wkcuber/api/MagDataset.py b/wkcuber/api/MagDataset.py index e3be95fe6..71e766894 100644 --- a/wkcuber/api/MagDataset.py +++ b/wkcuber/api/MagDataset.py @@ -1,16 +1,23 @@ import os from os.path import join +from pathlib import Path +from typing import Type, Tuple, Union, cast, TYPE_CHECKING from wkw import wkw import numpy as np import wkcuber.api as api -from wkcuber.api.View import WKView, TiffView + +if TYPE_CHECKING: + from wkcuber.api.Layer import TiffLayer, WKLayer, Layer +from wkcuber.api.View import WKView, TiffView, View from wkcuber.api.TiffData.TiffMag import TiffMagHeader from wkcuber.mag import Mag -def find_mag_path_on_disk(dataset_path: str, layer_name: str, mag_name: str): +def find_mag_path_on_disk( + dataset_path: Union[str, Path], layer_name: str, mag_name: str +) -> str: mag = Mag(mag_name) short_mag_file_path = join(dataset_path, layer_name, mag.to_layer_name()) long_mag_file_path = join(dataset_path, layer_name, mag.to_long_layer_name()) @@ -21,23 +28,32 @@ def find_mag_path_on_disk(dataset_path: str, layer_name: str, mag_name: str): class MagDataset: - def __init__(self, layer, name): + def __init__(self, layer: "Layer", name: str) -> None: self.layer = layer self.name = name self.header = self.get_header() self.view = self.get_view(offset=(0, 0, 0), is_bounded=False) - def open(self): + def open(self) -> None: self.view.open() - def close(self): + def close(self) -> None: self.view.close() - def read(self, offset=(0, 0, 0), size=None) -> np.array: + def read( + self, + offset: Tuple[int, int, int] = (0, 0, 0), + size: Tuple[int, int, int] = None, + ) -> np.array: return self.view.read(offset, size) - def write(self, data, offset=(0, 0, 0), allow_compressed_write=False): + def write( + self, + data: np.ndarray, + offset: Tuple[int, int, int] = (0, 0, 0), + allow_compressed_write: bool = False, + ) -> None: self._assert_valid_num_channels(data.shape) self.view.write(data, offset, allow_compressed_write) layer_properties = self.layer.dataset.properties.data_layers[self.layer.name] @@ -65,16 +81,23 @@ def write(self, data, offset=(0, 0, 0), allow_compressed_write=False): total_size_in_mag1 = max_end_offset_in_mag1 - np.array(new_offset_in_mag1) total_size = total_size_in_mag1 / mag_np - self.view.size = tuple(total_size) + self.view.size = cast(Tuple[int, int, int], tuple(total_size)) self.layer.dataset.properties._set_bounding_box_of_layer( - self.layer.name, tuple(new_offset_in_mag1), tuple(total_size_in_mag1) + self.layer.name, + cast(Tuple[int, int, int], tuple(new_offset_in_mag1)), + cast(Tuple[int, int, int], tuple(total_size_in_mag1)), ) - def get_header(self): + def get_header(self) -> Union[TiffMagHeader, wkw.Header]: raise NotImplementedError - def get_view(self, size=None, offset=None, is_bounded=True): + def get_view( + self, + size: Tuple[int, int, int] = None, + offset: Tuple[int, int, int] = None, + is_bounded: bool = True, + ) -> View: size_in_properties = self.layer.dataset.properties.data_layers[ self.layer.name ].get_bounding_box_size() @@ -117,10 +140,10 @@ def get_view(self, size=None, offset=None, is_bounded=True): mag_file_path, self.header, size, offset, is_bounded ) - def _get_view_type(self): + def _get_view_type(self) -> Type[View]: raise NotImplementedError - def _assert_valid_num_channels(self, write_data_shape): + def _assert_valid_num_channels(self, write_data_shape: Tuple[int, ...]) -> None: num_channels = self.layer.num_channels if len(write_data_shape) == 3: assert ( @@ -133,7 +156,16 @@ def _assert_valid_num_channels(self, write_data_shape): class WKMagDataset(MagDataset): - def __init__(self, layer, name, block_len, file_len, block_type): + header: wkw.Header + + def __init__( + self, + layer: "WKLayer", + name: str, + block_len: int, + file_len: int, + block_type: int, + ) -> None: self.block_len = block_len self.file_len = file_len self.block_type = block_type @@ -150,7 +182,9 @@ def get_header(self) -> wkw.Header: ) @classmethod - def create(cls, layer, name, block_len, file_len, block_type): + def create( + cls, layer: "WKLayer", name: str, block_len: int, file_len: int, block_type: int + ) -> "WKMagDataset": mag_dataset = cls(layer, name, block_len, file_len, block_type) wkw.Dataset.create( join(layer.dataset.path, layer.name, mag_dataset.name), mag_dataset.header @@ -158,12 +192,14 @@ def create(cls, layer, name, block_len, file_len, block_type): return mag_dataset - def _get_view_type(self): + def _get_view_type(self) -> Type[WKView]: return WKView class TiffMagDataset(MagDataset): - def __init__(self, layer, name, pattern): + layer: "TiffLayer" + + def __init__(self, layer: "TiffLayer", name: str, pattern: str) -> None: self.pattern = pattern super().__init__(layer, name) @@ -176,18 +212,19 @@ def get_header(self) -> TiffMagHeader: ) @classmethod - def create(cls, layer, name, pattern): + def create(cls, layer: "TiffLayer", name: str, pattern: str) -> "TiffMagDataset": mag_dataset = cls(layer, name, pattern) return mag_dataset - def _get_view_type(self): + def _get_view_type(self) -> Type[TiffView]: return TiffView class TiledTiffMagDataset(TiffMagDataset): - def get_tile(self, x_index, y_index, z_index) -> np.array: + def get_tile(self, x_index: int, y_index: int, z_index: int) -> np.array: tile_size = self.layer.dataset.properties.tile_size - size = tuple(tile_size) + tuple((1,)) + assert tile_size is not None + size = (tile_size[0], tile_size[1], 1) offset = np.array((0, 0, 0)) + np.array(size) * np.array( (x_index, y_index, z_index) ) diff --git a/wkcuber/api/Properties/DatasetProperties.py b/wkcuber/api/Properties/DatasetProperties.py index f16c737fa..df6e4bcea 100644 --- a/wkcuber/api/Properties/DatasetProperties.py +++ b/wkcuber/api/Properties/DatasetProperties.py @@ -1,4 +1,6 @@ import json +from pathlib import Path +from typing import Union, Tuple, Optional, Dict, Any, cast from wkcuber.api.Layer import Layer from wkcuber.api.Properties.LayerProperties import ( @@ -13,8 +15,15 @@ class Properties: FILE_NAME = "datasource-properties.json" - def __init__(self, path, name, scale, team="", data_layers=None): - self._path = path + def __init__( + self, + path: Union[str, Path], + name: str, + scale: Tuple[float, float, float], + team: str = "", + data_layers: Dict[str, LayerProperties] = None, + ) -> None: + self._path = str(path) self._name = name self._team = team self._scale = scale @@ -24,15 +33,21 @@ def __init__(self, path, name, scale, team="", data_layers=None): self._data_layers = data_layers @classmethod - def _from_json(cls, path): + def _from_json(cls, path: Union[str, Path]) -> "Properties": pass - def _export_as_json(self): + def _export_as_json(self) -> None: pass def _add_layer( - self, layer_name, category, element_class, data_format, num_channels=1, **kwargs - ): + self, + layer_name: str, + category: str, + element_class: str, + data_format: str, + num_channels: int = 1, + **kwargs: Dict[str, Any] + ) -> None: # this layer is already in data_layers in case we reconstruct the dataset from a datasource-properties.json if layer_name not in self.data_layers: if category == Layer.SEGMENTATION_TYPE: @@ -40,30 +55,34 @@ def _add_layer( "largest_segment_id" in kwargs ), "When adding a segmentation layer, largest_segment_id has to be supplied." - new_layer = SegmentationLayerProperties( + self.data_layers[layer_name] = SegmentationLayerProperties( layer_name, category, element_class, data_format, num_channels, - largest_segment_id=kwargs["largest_segment_id"], + largest_segment_id=cast(int, kwargs["largest_segment_id"]), ) else: - new_layer = LayerProperties( + self.data_layers[layer_name] = LayerProperties( layer_name, category, element_class, data_format, num_channels ) - self.data_layers[layer_name] = new_layer self._export_as_json() - def _delete_layer(self, layer_name): + def _add_mag(self, layer_name: str, mag: str, **kwargs: int) -> None: + pass + + def _delete_layer(self, layer_name: str) -> None: del self.data_layers[layer_name] self._export_as_json() - def _delete_mag(self, layer_name, mag): + def _delete_mag(self, layer_name: str, mag: str) -> None: self._data_layers[layer_name]._delete_resolution(mag) self._export_as_json() - def _set_bounding_box_of_layer(self, layer_name, offset, size): + def _set_bounding_box_of_layer( + self, layer_name: str, offset: Tuple[int, int, int], size: Tuple[int, int, int] + ) -> None: self._data_layers[layer_name]._set_bounding_box_size(size) self._data_layers[layer_name]._set_bounding_box_offset(offset) self._export_as_json() @@ -91,12 +110,12 @@ def data_layers(self) -> dict: class WKProperties(Properties): @classmethod - def _from_json(cls, path) -> Properties: + def _from_json(cls, path: Union[str, Path]) -> "WKProperties": with open(path) as datasource_properties: data = json.load(datasource_properties) # reconstruct data_layers - data_layers = {} + data_layers: Dict[str, LayerProperties] = {} for layer in data["dataLayers"]: if layer["category"] == Layer.SEGMENTATION_TYPE: data_layers[layer["name"]] = SegmentationLayerProperties._from_json( @@ -111,7 +130,7 @@ def _from_json(cls, path) -> Properties: path, data["id"]["name"], data["scale"], data["id"]["team"], data_layers ) - def _export_as_json(self): + def _export_as_json(self) -> None: data = { "id": {"name": self.name, "team": self.team}, "scale": self.scale, @@ -123,7 +142,8 @@ def _export_as_json(self): with open(self.path, "w") as outfile: json.dump(data, outfile, indent=4, separators=(",", ": ")) - def _add_mag(self, layer_name, mag, cube_length): + def _add_mag(self, layer_name: str, mag: str, **kwargs: int) -> None: + assert "cube_length" in kwargs # this mag is already in wkw_magnifications in case we reconstruct the dataset from a datasource-properties.json if not any( [ @@ -132,26 +152,33 @@ def _add_mag(self, layer_name, mag, cube_length): ] ): self._data_layers[layer_name]._add_resolution( - WkResolution(mag, cube_length) + WkResolution(mag, kwargs["cube_length"]) ) self._export_as_json() class TiffProperties(Properties): def __init__( - self, path, name, scale, pattern, team="", data_layers=None, tile_size=(32, 32) - ): + self, + path: Union[str, Path], + name: str, + scale: Tuple[float, float, float], + pattern: str, + team: str = "", + data_layers: Dict[str, LayerProperties] = None, + tile_size: Optional[Tuple[int, int]] = (32, 32), + ) -> None: super().__init__(path, name, scale, team, data_layers) self.pattern = pattern self.tile_size = tile_size @classmethod - def _from_json(cls, path) -> Properties: + def _from_json(cls, path: Union[str, Path]) -> Properties: with open(path) as datasource_properties: data = json.load(datasource_properties) # reconstruct data_layers - data_layers = {} + data_layers: Dict[str, LayerProperties] = {} for layer in data["dataLayers"]: if layer["category"] == Layer.SEGMENTATION_TYPE: data_layers[layer["name"]] = SegmentationLayerProperties._from_json( @@ -172,7 +199,7 @@ def _from_json(cls, path) -> Properties: tile_size=data.get("tile_size"), ) - def _export_as_json(self): + def _export_as_json(self) -> None: data = { "id": {"name": self.name, "team": self.team}, "scale": self.scale, @@ -188,7 +215,7 @@ def _export_as_json(self): with open(self.path, "w") as outfile: json.dump(data, outfile, indent=4, separators=(",", ": ")) - def _add_mag(self, layer_name, mag): + def _add_mag(self, layer_name: str, mag: str, **kwargs: int) -> None: # this mag is already in wkw_magnifications in case we reconstruct the dataset from a datasource-properties.json if not any( [ diff --git a/wkcuber/api/Properties/LayerProperties.py b/wkcuber/api/Properties/LayerProperties.py index 7e8ee1163..c0c81dae6 100644 --- a/wkcuber/api/Properties/LayerProperties.py +++ b/wkcuber/api/Properties/LayerProperties.py @@ -1,12 +1,20 @@ from os.path import join, dirname, isfile +from pathlib import Path +from typing import Tuple, Type, Union, Any, Dict, List, Optional, cast from wkw import wkw +from wkcuber.api.Properties.ResolutionProperties import Resolution from wkcuber.mag import Mag from wkcuber.api.bounding_box import BoundingBox -def extract_num_channels(num_channels_in_properties, path, layer, mag): +def extract_num_channels( + num_channels_in_properties: Optional[int], + path: Union[str, Path], + layer: str, + mag: Optional[Dict[str, int]], +) -> int: # if a wk dataset is not created with this API, then it most likely doesn't have the attribute 'num_channels' in the # datasource-properties.json. In this case we need to extract the 'num_channels' from the 'header.wkw'. if num_channels_in_properties is not None: @@ -38,13 +46,13 @@ def extract_num_channels(num_channels_in_properties, path, layer, mag): class LayerProperties: def __init__( self, - name, - category, - element_class, - data_format, - num_channels, - bounding_box=None, - resolutions=None, + name: str, + category: str, + element_class: str, + data_format: str, + num_channels: int, + bounding_box: Dict[str, Union[int, Tuple[int, int, int]]] = None, + resolutions: List[Resolution] = None, ): self._name = name self._category = category @@ -57,9 +65,9 @@ def __init__( "height": 0, "depth": 0, } - self._wkw_magnifications = resolutions or [] + self._wkw_magnifications: List[Resolution] = resolutions or [] - def _to_json(self): + def _to_json(self) -> Dict[str, Any]: return { "name": self.name, "category": self.category, @@ -78,7 +86,12 @@ def _to_json(self): } @classmethod - def _from_json(cls, json_data, resolution_type, dataset_path): + def _from_json( + cls, + json_data: Dict[str, Any], + resolution_type: Type[Resolution], + dataset_path: Union[str, Path], + ) -> "LayerProperties": # create LayerProperties without resolutions layer_properties = cls( json_data["name"], @@ -102,10 +115,10 @@ def _from_json(cls, json_data, resolution_type, dataset_path): return layer_properties - def _add_resolution(self, resolution): + def _add_resolution(self, resolution: Resolution) -> None: self._wkw_magnifications.append(resolution) - def _delete_resolution(self, resolution): + def _delete_resolution(self, resolution: str) -> None: resolutions_to_delete = [ res for res in self._wkw_magnifications if res.mag == Mag(resolution) ] @@ -116,25 +129,27 @@ def get_bounding_box(self) -> BoundingBox: return BoundingBox(self.get_bounding_box_offset(), self.get_bounding_box_size()) - def get_bounding_box_size(self) -> tuple: + def get_bounding_box_size(self) -> Tuple[int, int, int]: return ( self.bounding_box["width"], self.bounding_box["height"], self.bounding_box["depth"], ) - def get_bounding_box_offset(self) -> tuple: - return tuple(self.bounding_box["topLeft"]) + def get_bounding_box_offset(self) -> Tuple[int, int, int]: + return cast(Tuple[int, int, int], tuple(self.bounding_box["topLeft"])) - def _set_bounding_box_size(self, size): + def _set_bounding_box_size(self, size: Tuple[int, int, int]) -> None: # Cast to int in case the provided parameter contains numpy integer self._bounding_box["width"] = int(size[0]) self._bounding_box["height"] = int(size[1]) self._bounding_box["depth"] = int(size[2]) - def _set_bounding_box_offset(self, offset): + def _set_bounding_box_offset(self, offset: Tuple[int, int, int]) -> None: # Cast to int in case the provided parameter contains numpy integer - self._bounding_box["topLeft"] = tuple(map(int, offset)) + self._bounding_box["topLeft"] = cast( + Tuple[int, int, int], tuple(map(int, offset)) + ) @property def name(self) -> str: @@ -145,11 +160,11 @@ def category(self) -> str: return self._category @property - def element_class(self): + def element_class(self) -> str: return self._element_class @property - def data_format(self): + def data_format(self) -> str: return self._data_format @property @@ -161,23 +176,23 @@ def bounding_box(self) -> dict: return self._bounding_box @property - def wkw_magnifications(self) -> dict: + def wkw_magnifications(self) -> List[Resolution]: return self._wkw_magnifications class SegmentationLayerProperties(LayerProperties): def __init__( self, - name, - category, - element_class, - data_format, - num_channels, - bounding_box=None, - resolutions=None, - largest_segment_id=None, - mappings=None, - ): + name: str, + category: str, + element_class: str, + data_format: str, + num_channels: int, + bounding_box: Dict[str, Union[int, Tuple[int, int, int]]] = None, + resolutions: List[Resolution] = None, + largest_segment_id: int = None, + mappings: List[str] = None, + ) -> None: super().__init__( name, category, @@ -187,10 +202,13 @@ def __init__( bounding_box, resolutions, ) + # The parameter largest_segment_id is in fact not optional. + # However, specifying the parameter as not optional, would require to change the parameter order + assert largest_segment_id is not None self._largest_segment_id = largest_segment_id self._mappings = mappings - def _to_json(self): + def _to_json(self) -> Dict[str, Any]: json_properties = super()._to_json() json_properties["largestSegmentId"] = self._largest_segment_id if self._mappings is not None: @@ -198,7 +216,12 @@ def _to_json(self): return json_properties @classmethod - def _from_json(cls, json_data, resolution_type, dataset_path): + def _from_json( + cls, + json_data: Dict[str, Any], + resolution_type: Type[Resolution], + dataset_path: Union[str, Path], + ) -> "SegmentationLayerProperties": # create LayerProperties without resolutions layer_properties = cls( json_data["name"], @@ -225,8 +248,9 @@ def _from_json(cls, json_data, resolution_type, dataset_path): @property def largest_segment_id(self) -> int: + assert self._largest_segment_id is not None return self._largest_segment_id @property - def mappings(self) -> list: + def mappings(self) -> Optional[List[str]]: return self._mappings diff --git a/wkcuber/api/Properties/ResolutionProperties.py b/wkcuber/api/Properties/ResolutionProperties.py index 9f12d479c..e639096c2 100644 --- a/wkcuber/api/Properties/ResolutionProperties.py +++ b/wkcuber/api/Properties/ResolutionProperties.py @@ -1,15 +1,17 @@ +from typing import Any, Dict + from wkcuber.mag import Mag class Resolution: - def __init__(self, mag): + def __init__(self, mag: str): self._mag = Mag(mag) def _to_json(self) -> dict: return {"resolution": self.mag.to_array()} @classmethod - def _from_json(cls, json_data): + def _from_json(cls, json_data: Dict[str, Any]) -> "Resolution": return cls(json_data["resolution"]) @property @@ -18,7 +20,7 @@ def mag(self) -> Mag: class WkResolution(Resolution): - def __init__(self, mag, cube_length): + def __init__(self, mag: str, cube_length: int): super().__init__(mag) self._cube_length = cube_length @@ -26,7 +28,7 @@ def _to_json(self) -> dict: return {"resolution": self.mag.to_array(), "cubeLength": self.cube_length} @classmethod - def _from_json(cls, json_data): + def _from_json(cls, json_data: Dict[str, Any]) -> "WkResolution": return cls(json_data["resolution"], json_data["cubeLength"]) @property diff --git a/wkcuber/api/TiffData/TiffMag.py b/wkcuber/api/TiffData/TiffMag.py index 21b433006..ffdd46d7e 100644 --- a/wkcuber/api/TiffData/TiffMag.py +++ b/wkcuber/api/TiffData/TiffMag.py @@ -1,6 +1,7 @@ import itertools import re -from typing import Optional, List, Tuple, Set +from types import TracebackType +from typing import Optional, List, Tuple, Set, Type, Iterator, Sequence, cast, Union from skimage import io import numpy as np @@ -49,7 +50,7 @@ def detect_tile_ranges( logger.info( f"Auto-detected tile ranges from tif directory structure: z {detected_z_range} x {detected_x_range} y {detected_y_range}" ) - return (detected_z_range, detected_x_range, detected_y_range) + return detected_z_range, detected_x_range, detected_y_range raise Exception("Couldn't auto-detect tile ranges from wkw or tile path pattern") @@ -60,7 +61,7 @@ def detect_tile_ranges_from_pattern_recursively( z_values: Set[int], x_values: Set[int], y_values: Set[int], -) -> Tuple[Optional[range], Optional[range], Optional[range]]: +) -> Tuple[range, range, range]: ( current_pattern_element, prefix, @@ -122,7 +123,7 @@ def advance_to_next_relevant_pattern_element( return current_pattern_element, prefix, remaining_pattern_elements -def values_to_range(values: Set[int]): +def values_to_range(values: Set[int]) -> range: if len(values) > 0: return range(min(values), max(values) + 1) return range(0, 0) @@ -152,7 +153,9 @@ def detect_value( return [] -def to_file_name(pattern, x, y, z) -> str: +def to_file_name( + pattern: str, x: Optional[int], y: Optional[int], z: Optional[int] +) -> str: file_name = pattern if x is not None: file_name = replace_coordinate(file_name, "x", x) @@ -163,17 +166,39 @@ def to_file_name(pattern, x, y, z) -> str: return file_name +class TiffMagHeader: + def __init__( + self, + pattern: str = "{zzzzz}.tif", + dtype_per_channel: np.dtype = np.dtype("uint8"), + num_channels: int = 1, + tile_size: Optional[Tuple[int, int]] = (32, 32), + ) -> None: + self.pattern = pattern + self.dtype_per_channel = np.dtype(dtype_per_channel) + self.num_channels = num_channels + self.tile_size = tile_size + + class TiffMag: - def __init__(self, root, header): + def __init__(self, root: str, header: TiffMagHeader) -> None: self.root = root self.tiffs = dict() self.header = header - z_range, x_range, y_range = detect_tile_ranges(self.root, self.header.pattern) - z_range = [None] if z_range == range(0, 0) else z_range - y_range = [None] if y_range == range(0, 0) else y_range - x_range = [None] if x_range == range(0, 0) else x_range + detected_z_range, detected_x_range, detected_y_range = detect_tile_ranges( + self.root, self.header.pattern + ) + z_range: List[Optional[int]] = ( + [None] if detected_z_range == range(0, 0) else list(detected_z_range) + ) + y_range: List[Optional[int]] = ( + [None] if detected_y_range == range(0, 0) else list(detected_y_range) + ) + x_range: List[Optional[int]] = ( + [None] if detected_x_range == range(0, 0) else list(detected_x_range) + ) available_tiffs = list(itertools.product(x_range, y_range, z_range)) for xyz in available_tiffs: @@ -182,18 +207,19 @@ def __init__(self, root, header): self.get_file_name_for_layer(xyz) ) # open is lazy - def read(self, off, shape) -> np.array: + def read(self, off: Tuple[int, int, int], shape: Tuple[int, int, int]) -> np.array: # modify the shape to also include the num_channels - shape = tuple(shape) + tuple([self.header.num_channels]) - - data = np.zeros(shape=shape, dtype=self.header.dtype_per_channel) + shape_with_num_channels = shape + (self.header.num_channels,) + data = np.zeros( + shape=shape_with_num_channels, dtype=self.header.dtype_per_channel + ) for ( xyz, _, offset_in_output_data, offset_in_input_data, - ) in self.calculate_relevant_slices(off, shape): - x, y, z = xyz + ) in self.calculate_relevant_slices(off, shape_with_num_channels): + _, _, z = xyz z_index_in_data = z - off[2] if xyz in self.tiffs: @@ -201,8 +227,10 @@ def read(self, off, shape) -> np.array: loaded_data = np.array( self.tiffs[xyz].read(), self.header.dtype_per_channel )[ - offset_in_output_data[0] : offset_in_output_data[0] + shape[0], - offset_in_output_data[1] : offset_in_output_data[1] + shape[1], + offset_in_output_data[0] : offset_in_output_data[0] + + shape_with_num_channels[0], + offset_in_output_data[1] : offset_in_output_data[1] + + shape_with_num_channels[1], ] index_slice = [ @@ -226,7 +254,7 @@ def read(self, off, shape) -> np.array: data = np.moveaxis(data, -1, 0) return data - def write(self, off, data): + def write(self, off: Tuple[int, int, int], data: np.ndarray) -> None: if not len(data.shape) == 3: # reformat array to have the channels as the first index (similar to wkw) # this is only necessary if the data has a dedicated dimensions for the num_channels @@ -265,20 +293,28 @@ def write(self, off, data): self.tiffs[xyz].merge_with_image(pixel_data, offset_in_output_data) - def compress(self, dst_path: str, compress_files: bool = False): - raise NotImplementedError - - def list_files(self): + def list_files(self) -> Iterator[str]: _, file_extension = os.path.splitext(self.header.pattern) file_paths = list(iglob(os.path.join(self.root, "*" + file_extension))) for file_path in file_paths: yield os.path.relpath(os.path.normpath(file_path), self.root) - def close(self): + def close(self) -> None: return - def calculate_relevant_slices(self, offset, shape): + def calculate_relevant_slices( + self, + offset: Tuple[int, int, int], + shape: Union[Tuple[int, int, int, int], Tuple[int, int, int]], + ) -> Iterator[ + Tuple[ + Tuple[Optional[int], Optional[int], int], + Tuple[int, ...], + Tuple[int, int], + Tuple[int, int], + ] + ]: """ The purpose of this method is to find out which tiles need to be touched. Each tile is specified by its (x, y, z)-dimensions. @@ -297,17 +333,17 @@ def calculate_relevant_slices(self, offset, shape): if tile_size is None: x_first_index = None - x_indices = [None] + x_indices: List[Union[int, None]] = [None] y_first_index = None - y_indices = [None] + y_indices: List[Union[int, None]] = [None] else: x_first_index = offset[0] // tile_size[0] # floor division x_last_index = np.math.ceil(max_indices[0] / tile_size[0]) - x_indices = range(x_first_index, x_last_index) + x_indices = list(range(x_first_index, x_last_index)) y_first_index = offset[1] // tile_size[1] # floor division y_last_index = np.math.ceil(max_indices[1] / tile_size[1]) - y_indices = range(y_first_index, y_last_index) + y_indices = list(range(y_first_index, y_last_index)) for x in x_indices: for y in y_indices: @@ -331,26 +367,25 @@ def calculate_relevant_slices(self, offset, shape): offset_in_input_data = shape_top_left_corner - offset[0:2] offset_in_output_data = tuple( - (offset[0:2] - shape_top_left_corner) + (np.array(offset[0:2]) - shape_top_left_corner) * np.equal((x, y), (x_first_index, y_first_index)) ) - tile_shape = tuple( - shape_bottom_right - shape_top_left_corner - ) + tuple(shape[3:4]) - - yield tuple( - ( - (x, y, z), - tile_shape, - offset_in_output_data, - offset_in_input_data, + tile_shape = ( + tuple(shape_bottom_right - shape_top_left_corner) + + shape[3:4] ) + + yield ( + (x, y, z), + tile_shape, + cast(Tuple[int, int], offset_in_output_data), + offset_in_input_data, ) def has_only_one_channel(self) -> bool: return self.header.num_channels == 1 - def assert_correct_data_format(self, data): + def assert_correct_data_format(self, data: np.ndarray) -> None: if self.has_only_one_channel(): if not len(data.shape) == 3: raise AttributeError( @@ -370,38 +405,31 @@ def assert_correct_data_format(self, data): f"The type of the provided data does not match the expected type. (Expected np.array of type {self.header.dtype_per_channel.name})" ) - def get_file_name_for_layer(self, xyz) -> str: + def get_file_name_for_layer( + self, xyz: Tuple[Optional[int], Optional[int], Optional[int]] + ) -> str: x, y, z = xyz return os.path.join(self.root, to_file_name(self.header.pattern, x, y, z)) @staticmethod - def open(root: str, header=None): + def open(root: str, header: TiffMagHeader = None) -> "TiffMag": if header is None: header = TiffMagHeader() return TiffMag(root, header) - def __enter__(self): + def __enter__(self) -> "TiffMag": return self - def __exit__(self, type, value, tb): - self.close() - - -class TiffMagHeader: - def __init__( + def __exit__( self, - pattern="{zzzzz}.tif", - dtype_per_channel=np.dtype("uint8"), - num_channels=1, - tile_size=(32, 32), - ): - self.pattern = pattern - self.dtype_per_channel = np.dtype(dtype_per_channel) - self.num_channels = num_channels - self.tile_size = tile_size + _type: Optional[Type[BaseException]], + _value: Optional[BaseException], + _tb: Optional[TracebackType], + ) -> None: + self.close() -def transpose_for_skimage(data): +def transpose_for_skimage(data: np.ndarray) -> np.ndarray: if len(data.shape) == 2: return data.transpose() elif len(data.shape) == 3: @@ -411,28 +439,30 @@ def transpose_for_skimage(data): class TiffReader: - def __init__(self, file_name): + def __init__(self, file_name: str): self.file_name = file_name @classmethod - def init_tiff(cls, pixels, file_name): + def init_tiff(cls, pixels: np.ndarray, file_name: str) -> "TiffReader": tr = TiffReader(file_name) tr.write(pixels) return tr @classmethod - def open(cls, file_name): + def open(cls, file_name: str) -> "TiffReader": return cls(file_name) def read(self) -> np.array: data = io.imread(self.file_name) return transpose_for_skimage(data) - def write(self, pixels): + def write(self, pixels: np.ndarray) -> None: os.makedirs(os.path.dirname(self.file_name), exist_ok=True) io.imsave(self.file_name, transpose_for_skimage(pixels), check_contrast=False) - def merge_with_image(self, foreground_pixels, offset): + def merge_with_image( + self, foreground_pixels: np.ndarray, offset: Tuple[int, int] + ) -> None: background_pixels = self.read() bg_shape = background_pixels.shape fg_shape = foreground_pixels.shape diff --git a/wkcuber/api/View.py b/wkcuber/api/View.py index 9b52a120e..3deaf6622 100644 --- a/wkcuber/api/View.py +++ b/wkcuber/api/View.py @@ -1,9 +1,15 @@ import math +from concurrent.futures._base import Executor +from pathlib import Path +from types import TracebackType +from typing import Tuple, Optional, Type, Callable, Any, Union, List, cast +import cluster_tools import numpy as np +from cluster_tools.schedulers.cluster_executor import ClusterExecutor from wkw import Dataset, wkw -from wkcuber.api.TiffData.TiffMag import TiffMag +from wkcuber.api.TiffData.TiffMag import TiffMag, TiffMagHeader from wkcuber.api.bounding_box import BoundingBox from wkcuber.utils import wait_and_ensure_success @@ -11,13 +17,13 @@ class View: def __init__( self, - path_to_mag_dataset, - header, - size, - global_offset=(0, 0, 0), - is_bounded=True, + path_to_mag_dataset: str, + header: Union[TiffMagHeader, wkw.Header], + size: Tuple[int, int, int], + global_offset: Tuple[int, int, int] = (0, 0, 0), + is_bounded: bool = True, ): - self.dataset = None + self.dataset: Optional[Dataset] = None self.path = path_to_mag_dataset self.header = header self.size = size @@ -25,26 +31,33 @@ def __init__( self.is_bounded = is_bounded self._is_opened = False - def open(self): - raise NotImplemented() + def open(self) -> "View": + pass - def close(self): + def close(self) -> None: if not self._is_opened: raise Exception("Cannot close View: the view is not opened") else: + assert self.dataset is not None # because the View was opened self.dataset.close() self.dataset = None self._is_opened = False - def write(self, data, relative_offset=(0, 0, 0), allow_compressed_write=False): + def write( + self, + data: np.ndarray, + relative_offset: Tuple[int, int, int] = (0, 0, 0), + allow_compressed_write: bool = False, + ) -> None: was_opened = self._is_opened # assert the size of the parameter data is not in conflict with the attribute self.size assert_non_negative_offset(relative_offset) self.assert_bounds(relative_offset, data.shape[-3:]) # calculate the absolute offset - absolute_offset = tuple( - sum(x) for x in zip(self.global_offset, relative_offset) + absolute_offset = cast( + Tuple[int, int, int], + tuple(sum(x) for x in zip(self.global_offset, relative_offset)), ) if self._is_compressed() and allow_compressed_write: @@ -52,13 +65,18 @@ def write(self, data, relative_offset=(0, 0, 0), allow_compressed_write=False): if not was_opened: self.open() + assert self.dataset is not None # because the View was opened self.dataset.write(absolute_offset, data) if not was_opened: self.close() - def read(self, offset=(0, 0, 0), size=None) -> np.array: + def read( + self, + offset: Tuple[int, int, int] = (0, 0, 0), + size: Tuple[int, int, int] = None, + ) -> np.array: was_opened = self._is_opened size = self.size if size is None else size @@ -70,6 +88,7 @@ def read(self, offset=(0, 0, 0), size=None) -> np.array: if not was_opened: self.open() + assert self.dataset is not None # because the View was opened data = self.dataset.read(absolute_offset, size) @@ -78,63 +97,94 @@ def read(self, offset=(0, 0, 0), size=None) -> np.array: return data - def get_view(self, size, relative_offset=(0, 0, 0)): + def get_view( + self, + size: Tuple[int, int, int], + relative_offset: Tuple[int, int, int] = (0, 0, 0), + ) -> "View": self.assert_bounds(relative_offset, size) - view_offset = self.global_offset + np.array(relative_offset) + view_offset = cast( + Tuple[int, int, int], tuple(self.global_offset + np.array(relative_offset)) + ) return type(self)( self.path, self.header, size=size, - global_offset=tuple(view_offset), + global_offset=view_offset, is_bounded=self.is_bounded, ) - def check_bounds(self, offset, size) -> bool: + def check_bounds( + self, offset: Tuple[int, int, int], size: Tuple[int, int, int] + ) -> bool: for s1, s2, off in zip(self.size, size, offset): if s2 + off > s1 and self.is_bounded: return False return True - def assert_bounds(self, offset, size): + def assert_bounds( + self, offset: Tuple[int, int, int], size: Tuple[int, int, int] + ) -> None: if not self.check_bounds(offset, size): raise AssertionError( f"Accessing data out of bounds: The passed parameter 'size' {size} exceeds the size of the current view ({self.size})" ) - def for_each_chunk(self, work_on_chunk, job_args_per_chunk, chunk_size, executor): + def for_each_chunk( + self, + work_on_chunk: Callable[[List[Any]], None], + job_args_per_chunk: Any, + chunk_size: Tuple[int, int, int], + executor: Union[ClusterExecutor, cluster_tools.WrappedProcessPoolExecutor], + ) -> None: self._check_chunk_size(chunk_size) job_args = [] for chunk in BoundingBox(self.global_offset, self.size).chunk( - chunk_size, chunk_size + chunk_size, list(chunk_size) ): - relative_offset = np.array(chunk.topleft) - np.array(self.global_offset) - view = self.get_view(size=chunk.size, relative_offset=relative_offset) + relative_offset = cast( + Tuple[int, int, int], + tuple(np.array(chunk.topleft) - np.array(self.global_offset)), + ) + view = self.get_view( + size=cast(Tuple[int, int, int], tuple(chunk.size)), + relative_offset=relative_offset, + ) view.is_bounded = True job_args.append((view, job_args_per_chunk)) # execute the work for each chunk wait_and_ensure_success(executor.map_to_futures(work_on_chunk, job_args)) - def _check_chunk_size(self, chunk_size): + def _check_chunk_size(self, chunk_size: Tuple[int, int, int]) -> None: raise NotImplementedError - def _is_compressed(self): + def _is_compressed(self) -> bool: return False - def _handle_compressed_write(self, absolute_offset, data): + def _handle_compressed_write( + self, absolute_offset: Tuple[int, int, int], data: np.ndarray + ) -> Tuple[Tuple[int, int, int], np.ndarray]: return absolute_offset, data - def __enter__(self): + def __enter__(self) -> "View": return self - def __exit__(self, type, value, tb): + def __exit__( + self, + _type: Optional[Type[BaseException]], + _value: Optional[BaseException], + _tb: Optional[TracebackType], + ) -> None: self.close() class WKView(View): - def open(self): + header: wkw.Header + + def open(self) -> "WKView": if self._is_opened: raise Exception("Cannot open view: the view is already opened") else: @@ -144,7 +194,7 @@ def open(self): self._is_opened = True return self - def _check_chunk_size(self, chunk_size): + def _check_chunk_size(self, chunk_size: Tuple[int, int, int]) -> None: assert chunk_size is not None if 0 in chunk_size: @@ -162,13 +212,15 @@ def _check_chunk_size(self, chunk_size): f"The passed parameter 'chunk_size' {chunk_size} must be a multiple of (32, 32, 32)." ) - def _is_compressed(self): + def _is_compressed(self) -> bool: return ( self.header.block_type == wkw.Header.BLOCK_TYPE_LZ4 or self.header.block_type == wkw.Header.BLOCK_TYPE_LZ4HC ) - def _handle_compressed_write(self, absolute_offset, data): + def _handle_compressed_write( + self, absolute_offset: Tuple[int, int, int], data: np.ndarray + ) -> Tuple[Tuple[int, int, int], np.ndarray]: # calculate aligned bounding box file_bb = np.full(3, self.header.file_len * self.header.block_len) absolute_offset_np = np.array(absolute_offset) @@ -203,13 +255,13 @@ def _handle_compressed_write(self, absolute_offset, data): ) # overwrite the specified data aligned_data[tuple(index_slice)] = data - return tuple(aligned_offset), aligned_data + return cast(Tuple[int, int, int], tuple(aligned_offset)), aligned_data else: return absolute_offset, data class TiffView(View): - def open(self): + def open(self) -> "TiffView": if self._is_opened: raise Exception("Cannot open view: the view is already opened") else: @@ -217,7 +269,7 @@ def open(self): self._is_opened = True return self - def _check_chunk_size(self, chunk_size): + def _check_chunk_size(self, chunk_size: Tuple[int, int, int]) -> None: assert chunk_size is not None if 0 in chunk_size: @@ -242,7 +294,7 @@ def _check_chunk_size(self, chunk_size): ) -def assert_non_negative_offset(offset): +def assert_non_negative_offset(offset: Tuple[int, int, int]) -> None: all_positive = all(i >= 0 for i in offset) if not all_positive: raise Exception( diff --git a/wkcuber/api/bounding_box.py b/wkcuber/api/bounding_box.py index 69bcc40e1..e66a6a11f 100644 --- a/wkcuber/api/bounding_box.py +++ b/wkcuber/api/bounding_box.py @@ -1,7 +1,17 @@ # mypy: allow-untyped-defs import json import re -from typing import Dict, Generator, Iterable, List, Optional, Tuple, Union, NamedTuple +from typing import ( + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + Union, + NamedTuple, + cast, +) import numpy as np @@ -70,14 +80,16 @@ def from_checkpoint_name(checkpoint_name: str) -> "BoundingBox": match is not None ), f"Could not extract bounding box from {checkpoint_name}" bbox_tuple = tuple(int(value) for value in match.group().split("_")) - topleft = bbox_tuple[:3] - size = bbox_tuple[3:6] + topleft = cast(Tuple[int, int, int], bbox_tuple[:3]) + size = cast(Tuple[int, int, int], bbox_tuple[3:6]) return BoundingBox.from_tuple2((topleft, size)) @staticmethod def from_csv(csv_bbox: str) -> "BoundingBox": bbox_tuple = tuple(int(x) for x in csv_bbox.split(",")) - return BoundingBox.from_tuple6(bbox_tuple) + return BoundingBox.from_tuple6( + cast(Tuple[int, int, int, int, int, int], bbox_tuple) + ) @staticmethod def from_auto(obj) -> "BoundingBox": diff --git a/wkcuber/check_equality.py b/wkcuber/check_equality.py index d5652c00e..8eb9118f3 100644 --- a/wkcuber/check_equality.py +++ b/wkcuber/check_equality.py @@ -1,9 +1,12 @@ import logging -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace +from typing import Any, Callable + from wkcuber.api.Dataset import WKDataset -from wkcuber.api.bounding_box import BoundingBox +from wkcuber.api.bounding_box import BoundingBox, BoundingBoxNamedTuple import numpy as np +from wkcuber.mag import Mag from .utils import ( add_verbose_flag, open_wkw, @@ -20,7 +23,7 @@ CHUNK_SIZE = 1024 -def named_partial(func, *args, **kwargs): +def named_partial(func: Callable, *args: Any, **kwargs: Any) -> Callable: # Propagate __name__ and __doc__ attributes to partial function partial_func = functools.partial(func, *args, **kwargs) functools.update_wrapper(partial_func, func) @@ -30,7 +33,7 @@ def named_partial(func, *args, **kwargs): return partial_func -def create_parser(): +def create_parser() -> ArgumentParser: parser = ArgumentParser() parser.add_argument("source_path", help="Path to input WKW dataset") @@ -53,8 +56,12 @@ def create_parser(): def assert_equality_for_chunk( - source_path: str, target_path: str, layer_name: str, mag, sub_box -): + source_path: str, + target_path: str, + layer_name: str, + mag: Mag, + sub_box: BoundingBoxNamedTuple, +) -> None: wk_dataset = WKDataset(source_path) layer = wk_dataset.layers[layer_name] backup_wkw_info = WkwDatasetInfo(target_path, layer_name, mag, header=None) @@ -69,12 +76,12 @@ def assert_equality_for_chunk( ), f"Data differs in bounding box {sub_box} for layer {layer_name} with mag {mag}" -def check_equality(source_path: str, target_path: str, args=None): +def check_equality(source_path: str, target_path: str, args: Namespace = None) -> None: logging.info(f"Comparing {source_path} with {target_path}") wk_src_dataset = WKDataset(source_path) - src_layer_names = wk_src_dataset.layers.keys() + src_layer_names = list(wk_src_dataset.layers.keys()) target_layer_names = [ layer["name"] for layer in detect_layers(target_path, 0, False) ] @@ -84,7 +91,7 @@ def check_equality(source_path: str, target_path: str, args=None): existing_layer_names = src_layer_names - if args.layer_name is not None: + if args is not None and args.layer_name is not None: assert ( args.layer_name in existing_layer_names ), f"Provided layer {args.layer_name} does not exist in input dataset." diff --git a/wkcuber/compress.py b/wkcuber/compress.py index 0c1c48e3b..c9c0a7868 100644 --- a/wkcuber/compress.py +++ b/wkcuber/compress.py @@ -2,7 +2,7 @@ import wkw import shutil import logging -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from os import path, makedirs from uuid import uuid4 from .mag import Mag @@ -17,12 +17,12 @@ setup_logging, ) from .metadata import detect_resolutions, convert_element_class_to_dtype -from typing import List +from typing import List, Tuple BACKUP_EXT = ".bak" -def create_parser(): +def create_parser() -> ArgumentParser: parser = ArgumentParser() parser.add_argument( @@ -53,7 +53,7 @@ def create_parser(): return parser -def compress_file_job(args): +def compress_file_job(args: Tuple[str, str]) -> None: source_path, target_path = args try: logging.debug("Compressing '{}' to '{}'".format(source_path, target_path)) @@ -75,7 +75,13 @@ def compress_file_job(args): raise exc -def compress_mag(source_path, layer_name, target_path, mag: Mag, args=None): +def compress_mag( + source_path: str, + layer_name: str, + target_path: str, + mag: Mag, + args: Namespace = None, +) -> None: if path.exists(path.join(target_path, layer_name, str(mag))): logging.error("Target path '{}' already exists".format(target_path)) exit(1) @@ -103,7 +109,9 @@ def compress_mag(source_path, layer_name, target_path, mag: Mag, args=None): logging.info("Mag {0} successfully compressed".format(str(mag))) -def compress_mag_inplace(target_path, layer_name, mag: Mag, args=None): +def compress_mag_inplace( + target_path: str, layer_name: str, mag: Mag, args: Namespace = None +) -> None: compress_target_path = "{}.compress-{}".format(target_path, uuid4()) compress_mag(target_path, layer_name, compress_target_path, mag, args) @@ -116,19 +124,25 @@ def compress_mag_inplace(target_path, layer_name, mag: Mag, args=None): def compress_mags( - source_path, layer_name, target_path=None, mags: List[Mag] = None, args=None -): - with_tmp_dir = target_path is None - target_path = source_path + ".tmp" if with_tmp_dir else target_path + source_path: str, + layer_name: str, + target_path: str = None, + mags: List[Mag] = None, + args: Namespace = None, +) -> None: + if target_path is None: + target = source_path + ".tmp" + else: + target = target_path if mags is None: mags = list(detect_resolutions(source_path, layer_name)) mags.sort() for mag in mags: - compress_mag(source_path, layer_name, target_path, mag, args) + compress_mag(source_path, layer_name, target, mag, args) - if with_tmp_dir: + if target_path is None: makedirs(path.join(source_path + BACKUP_EXT, layer_name), exist_ok=True) for mag in mags: shutil.move( @@ -136,10 +150,10 @@ def compress_mags( path.join(source_path + BACKUP_EXT, layer_name, str(mag)), ) shutil.move( - path.join(target_path, layer_name, str(mag)), + path.join(target, layer_name, str(mag)), path.join(source_path, layer_name, str(mag)), ) - shutil.rmtree(target_path) + shutil.rmtree(target) logging.info( "Old files are still present in '{0}.bak'. Please remove them when not required anymore.".format( source_path diff --git a/wkcuber/convert_knossos.py b/wkcuber/convert_knossos.py index 6831e3297..b2a3b6781 100644 --- a/wkcuber/convert_knossos.py +++ b/wkcuber/convert_knossos.py @@ -1,7 +1,9 @@ import time import logging +from typing import Tuple, cast + import wkw -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from .utils import ( add_verbose_flag, @@ -19,7 +21,7 @@ from .metadata import convert_element_class_to_dtype -def create_parser(): +def create_parser() -> ArgumentParser: parser = ArgumentParser() parser.add_argument( @@ -52,12 +54,14 @@ def create_parser(): return parser -def convert_cube_job(args): +def convert_cube_job( + args: Tuple[Tuple[int, int, int], KnossosDatasetInfo, WkwDatasetInfo] +) -> None: cube_xyz, source_knossos_info, target_wkw_info = args logging.info("Converting {},{},{}".format(cube_xyz[0], cube_xyz[1], cube_xyz[2])) ref_time = time.time() - offset = tuple(x * CUBE_EDGE_LEN for x in cube_xyz) - size = (CUBE_EDGE_LEN,) * 3 + offset = cast(Tuple[int, int, int], tuple(x * CUBE_EDGE_LEN for x in cube_xyz)) + size = cast(Tuple[int, int, int], (CUBE_EDGE_LEN,) * 3) with open_knossos(source_knossos_info) as source_knossos, open_wkw( target_wkw_info @@ -71,7 +75,14 @@ def convert_cube_job(args): ) -def convert_knossos(source_path, target_path, layer_name, dtype, mag=1, args=None): +def convert_knossos( + source_path: str, + target_path: str, + layer_name: str, + dtype: str, + mag: int = 1, + args: Namespace = None, +) -> None: source_knossos_info = KnossosDatasetInfo(source_path, dtype) target_wkw_info = WkwDatasetInfo( target_path, layer_name, mag, wkw.Header(convert_element_class_to_dtype(dtype)) diff --git a/wkcuber/convert_nifti.py b/wkcuber/convert_nifti.py index a8d3af392..e1a9a1ec9 100644 --- a/wkcuber/convert_nifti.py +++ b/wkcuber/convert_nifti.py @@ -2,11 +2,13 @@ import time from argparse import ArgumentParser from pathlib import Path +from typing import Tuple, Optional, Union, cast import nibabel as nib import numpy as np from wkcuber.api.Dataset import TiffDataset, WKDataset +from wkcuber.api.bounding_box import BoundingBox from wkcuber.utils import ( DEFAULT_WKW_FILE_LEN, DEFAULT_WKW_VOXELS_PER_BLOCK, @@ -18,7 +20,7 @@ ) -def create_parser(): +def create_parser() -> ArgumentParser: parser = ArgumentParser() parser.add_argument( @@ -88,7 +90,7 @@ def create_parser(): def to_target_datatype( - data: np.ndarray, target_dtype, is_probably_binary: bool + data: np.ndarray, target_dtype: str, is_probably_binary: bool ) -> np.ndarray: if is_probably_binary: logging.info( @@ -111,18 +113,17 @@ def to_target_datatype( def convert_nifti( - source_nifti_path, - target_path, - layer_name, - dtype, - scale, - mag=1, - file_len=DEFAULT_WKW_FILE_LEN, - bbox_to_enforce=None, - write_tiff=False, - use_orientation_header=False, - flip_axes=None, -): + source_nifti_path: Path, + target_path: Path, + layer_name: str, + dtype: str, + scale: Tuple[float, ...], + file_len: int = DEFAULT_WKW_FILE_LEN, + bbox_to_enforce: BoundingBox = None, + write_tiff: bool = False, + use_orientation_header: bool = False, + flip_axes: Optional[Union[int, Tuple[int, ...]]] = None, +) -> None: voxels_per_cube = file_len * DEFAULT_WKW_VOXELS_PER_BLOCK ref_time = time.time() @@ -197,8 +198,10 @@ def convert_nifti( ) if write_tiff: - ds = TiffDataset.get_or_create(target_path, scale=scale or (1, 1, 1)) - layer = ds.get_or_add_layer( + tiff_ds = TiffDataset.get_or_create( + target_path, scale=cast(Tuple[float, float, float], scale or (1, 1, 1)) + ) + layer = tiff_ds.get_or_add_layer( layer_name, category_type, dtype_per_layer=np.dtype(dtype), @@ -208,8 +211,10 @@ def convert_nifti( mag.write(cube_data.squeeze()) else: - ds = WKDataset.get_or_create(target_path, scale=scale or (1, 1, 1)) - layer = ds.get_or_add_layer( + wk_ds = WKDataset.get_or_create( + target_path, scale=cast(Tuple[float, float, float], scale or (1, 1, 1)) + ) + layer = wk_ds.get_or_add_layer( layer_name, category_type, dtype_per_layer=np.dtype(dtype), @@ -226,16 +231,16 @@ def convert_nifti( def convert_folder_nifti( - source_folder_path, - target_path, - color_subpath, - segmentation_subpath, - scale, - use_orientation_header=False, - bbox_to_enforce=None, - write_tiff=False, - flip_axes=None, -): + source_folder_path: Path, + target_path: Path, + color_subpath: str, + segmentation_subpath: str, + scale: Tuple[float, ...], + use_orientation_header: bool = False, + bbox_to_enforce: BoundingBox = None, + write_tiff: bool = False, + flip_axes: Optional[Union[int, Tuple[int, ...]]] = None, +) -> None: paths = list(source_folder_path.rglob("**/*.nii")) color_path = None @@ -260,23 +265,46 @@ def convert_folder_nifti( logging.info("Segmentation file will also use uint8 as a datatype.") - conversion_args = { - "scale": scale, - "write_tiff": write_tiff, - "bbox_to_enforce": bbox_to_enforce, - "use_orientation_header": use_orientation_header, - "flip_axes": flip_axes, - } for path in paths: if path == color_path: - convert_nifti(path, target_path, "color", "uint8", **conversion_args) + convert_nifti( + path, + target_path, + "color", + "uint8", + scale, + write_tiff, + bbox_to_enforce, + use_orientation_header, + flip_axes=flip_axes, + ) elif path == segmentation_path: - convert_nifti(path, target_path, "segmentation", "uint8", **conversion_args) + convert_nifti( + path, + target_path, + "segmentation", + "uint8", + scale, + write_tiff, + bbox_to_enforce, + use_orientation_header, + flip_axes=flip_axes, + ) else: - convert_nifti(path, target_path, path.stem, "uint8", **conversion_args) + convert_nifti( + path, + target_path, + path.stem, + "uint8", + scale, + write_tiff, + bbox_to_enforce, + use_orientation_header, + flip_axes=flip_axes, + ) -def main(): +def main() -> None: args = create_parser().parse_args() setup_logging(args) diff --git a/wkcuber/cubing.py b/wkcuber/cubing.py index 25b35c7d7..7371b6151 100644 --- a/wkcuber/cubing.py +++ b/wkcuber/cubing.py @@ -1,13 +1,19 @@ import time import logging +from typing import List, Tuple + import numpy as np import wkw -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from os import path from natsort import natsorted from .mag import Mag -from .downsampling import parse_interpolation_mode, downsample_unpadded_data +from .downsampling import ( + parse_interpolation_mode, + downsample_unpadded_data, + InterpolationModes, +) from .utils import ( get_chunks, find_files, @@ -28,7 +34,7 @@ BLOCK_LEN = 32 -def create_parser(): +def create_parser() -> ArgumentParser: parser = ArgumentParser() parser.add_argument("source_path", help="Directory containing the input images.") @@ -88,7 +94,7 @@ def create_parser(): return parser -def find_source_filenames(source_path): +def find_source_filenames(source_path: str) -> List[str]: # Find all files in a folder that have a matching file extension source_files = list( find_files(path.join(source_path, "*"), image_reader.readers.keys()) @@ -103,7 +109,7 @@ def find_source_filenames(source_path): return natsorted(source_files) -def read_image_file(file_name, dtype): +def read_image_file(file_name: str, dtype: type) -> np.ndarray: try: return image_reader.read_array(file_name, dtype) except Exception as exc: @@ -111,7 +117,9 @@ def read_image_file(file_name, dtype): raise exc -def prepare_slices_for_wkw(slices, num_channels=None): +def prepare_slices_for_wkw( + slices: List[np.ndarray], num_channels: int = None +) -> np.ndarray: # Write batch buffer which will have shape (x, y, channel_count, z) # since we concat along the last axis (z) buffer = np.concatenate(slices, axis=-1) @@ -125,7 +133,18 @@ def prepare_slices_for_wkw(slices, num_channels=None): return buffer -def cubing_job(args): +def cubing_job( + args: Tuple[ + WkwDatasetInfo, + List[int], + Mag, + InterpolationModes, + List[str], + int, + Tuple[int, int], + bool, + ] +) -> None: ( target_wkw_info, z_batches, @@ -159,6 +178,7 @@ def cubing_job(args): image = read_image_file( file_name, target_wkw_info.header.voxel_type ) + if not pad: assert ( image.shape[0:2] == image_size @@ -209,7 +229,14 @@ def cubing_job(args): raise exc -def cubing(source_path, target_path, layer_name, dtype, batch_size, args) -> dict: +def cubing( + source_path: str, + target_path: str, + layer_name: str, + dtype: str, + batch_size: int, + args: Namespace, +) -> dict: source_files = find_source_filenames(source_path) diff --git a/wkcuber/downsampling.py b/wkcuber/downsampling.py index 769f5e63b..c7d3d0651 100644 --- a/wkcuber/downsampling.py +++ b/wkcuber/downsampling.py @@ -1,9 +1,10 @@ import logging import math +from typing import Any, Tuple, Callable, List, cast import wkw import numpy as np -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace import os from scipy.ndimage.interpolation import zoom from itertools import product @@ -30,23 +31,25 @@ DEFAULT_EDGE_LEN = 256 -def determine_buffer_edge_len(dataset): +def determine_buffer_edge_len(dataset: wkw.Dataset) -> int: return min(DEFAULT_EDGE_LEN, dataset.header.file_len * dataset.header.block_len) -def extend_wkw_dataset_info_header(wkw_info, **kwargs): +def extend_wkw_dataset_info_header(wkw_info: WkwDatasetInfo, **kwargs: Any) -> None: for key, value in kwargs.items(): setattr(wkw_info.header, key, value) -def calculate_virtual_scale_for_target_mag(target_mag): +def calculate_virtual_scale_for_target_mag( + target_mag: Mag, +) -> Tuple[float, float, float]: """ This scale is not the actual scale of the dataset The virtual scale is used for downsample_mags_anisotropic. """ max_target_value = max(list(target_mag.to_array())) scale_array = max_target_value / np.array(target_mag.to_array()) - return tuple(scale_array) + return cast(Tuple[float, float, float], tuple(scale_array)) class InterpolationModes(Enum): @@ -59,7 +62,7 @@ class InterpolationModes(Enum): MIN = 6 -def create_parser(): +def create_parser() -> ArgumentParser: parser = ArgumentParser() parser.add_argument("path", help="Directory containing the dataset.") @@ -126,15 +129,15 @@ def create_parser(): def downsample( - source_wkw_info, - target_wkw_info, + source_wkw_info: WkwDatasetInfo, + target_wkw_info: WkwDatasetInfo, source_mag: Mag, target_mag: Mag, - interpolation_mode, - compress, - buffer_edge_len=None, - args=None, -): + interpolation_mode: InterpolationModes, + compress: bool, + buffer_edge_len: int = None, + args: Namespace = None, +) -> None: assert source_mag < target_mag logging.info("Downsampling mag {} from mag {}".format(target_mag, source_mag)) @@ -215,7 +218,18 @@ def downsample( logging.info("Mag {0} successfully cubed".format(target_mag)) -def downsample_cube_job(args): +def downsample_cube_job( + args: Tuple[ + WkwDatasetInfo, + WkwDatasetInfo, + List[int], + InterpolationModes, + Tuple[int, int, int], + int, + bool, + bool, + ] +) -> None: ( source_wkw_info, target_wkw_info, @@ -308,7 +322,9 @@ def downsample_cube_job(args): raise exc -def non_linear_filter_3d(data, factors, func): +def non_linear_filter_3d( + data: np.ndarray, factors: List[int], func: Callable[[np.ndarray], np.ndarray] +) -> np.ndarray: ds = data.shape assert not any((d % factor > 0 for (d, factor) in zip(ds, factors))) data = data.reshape((ds[0], factors[1], ds[1] // factors[1], ds[2]), order="F") @@ -337,10 +353,10 @@ def non_linear_filter_3d(data, factors, func): return data -def linear_filter_3d(data, factors, order): - factors = np.array(factors) +def linear_filter_3d(data: np.ndarray, factors: List[int], order: int) -> np.ndarray: + factors_np = np.array(factors) - if not np.all(factors == factors[0]): + if not np.all(factors_np == factors[0]): logging.debug( "the selected filtering strategy does not support anisotropic downsampling. Selecting {} as uniform downsampling factor".format( factors[0] @@ -365,19 +381,19 @@ def linear_filter_3d(data, factors, order): ) -def _max(x): +def _max(x: np.ndarray) -> np.ndarray: return np.max(x, axis=0) -def _min(x): +def _min(x: np.ndarray) -> np.ndarray: return np.min(x, axis=0) -def _median(x): +def _median(x: np.ndarray) -> np.ndarray: return np.median(x, axis=0).astype(x.dtype) -def _mode(x): +def _mode(x: np.ndarray) -> np.ndarray: """ Fast mode implementation from: https://stackoverflow.com/a/35674754 """ @@ -424,7 +440,9 @@ def _mode(x): return sort[tuple(index)] -def downsample_cube(cube_buffer, factors, interpolation_mode): +def downsample_cube( + cube_buffer: np.ndarray, factors: List[int], interpolation_mode: InterpolationModes +) -> np.ndarray: if interpolation_mode == InterpolationModes.MODE: return non_linear_filter_3d(cube_buffer, factors, _mode) elif interpolation_mode == InterpolationModes.MEDIAN: @@ -443,7 +461,9 @@ def downsample_cube(cube_buffer, factors, interpolation_mode): raise Exception("Invalid interpolation mode: {}".format(interpolation_mode)) -def downsample_unpadded_data(buffer, target_mag, interpolation_mode): +def downsample_unpadded_data( + buffer: np.ndarray, target_mag: Mag, interpolation_mode: InterpolationModes +) -> np.ndarray: logging.info( f"Downsampling buffer of size {buffer.shape} to mag {target_mag.to_layer_name()}" ) @@ -467,16 +487,16 @@ def downsample_unpadded_data(buffer, target_mag, interpolation_mode): def downsample_mag( - path, - layer_name, + path: str, + layer_name: str, source_mag: Mag, target_mag: Mag, - interpolation_mode="default", - compress=False, - buffer_edge_len=None, - args=None, -): - interpolation_mode = parse_interpolation_mode(interpolation_mode, layer_name) + interpolation_mode: str = "default", + compress: bool = False, + buffer_edge_len: int = None, + args: Namespace = None, +) -> None: + parsed_interpolation_mode = parse_interpolation_mode(interpolation_mode, layer_name) source_wkw_info = WkwDatasetInfo(path, layer_name, source_mag.to_layer_name(), None) with open_wkw(source_wkw_info) as source: @@ -492,14 +512,16 @@ def downsample_mag( target_wkw_info, source_mag, target_mag, - interpolation_mode, + parsed_interpolation_mode, compress, buffer_edge_len, args, ) -def parse_interpolation_mode(interpolation_mode, layer_name): +def parse_interpolation_mode( + interpolation_mode: str, layer_name: str +) -> InterpolationModes: if interpolation_mode.upper() == "DEFAULT": return ( InterpolationModes.MEDIAN @@ -518,9 +540,9 @@ def downsample_mags( interpolation_mode: str = "default", buffer_edge_len: int = None, compress: bool = True, - args=None, + args: Namespace = None, anisotropic: bool = True, -): +) -> None: assert layer_name and from_mag or not layer_name and not from_mag, ( "You provided only one of the following " "parameters: layer_name, from_mag but both " @@ -573,15 +595,15 @@ def downsample_mags( def downsample_mags_isotropic( - path, - layer_name, + path: str, + layer_name: str, from_mag: Mag, max_mag: Mag, - interpolation_mode, - compress, - buffer_edge_len=None, - args=None, -): + interpolation_mode: str, + compress: bool, + buffer_edge_len: int = None, + args: Namespace = None, +) -> None: target_mag = from_mag.scaled_by(2) while target_mag <= max_mag: @@ -600,16 +622,16 @@ def downsample_mags_isotropic( def downsample_mags_anisotropic( - path, - layer_name, + path: str, + layer_name: str, from_mag: Mag, max_mag: Mag, - scale, - interpolation_mode, - compress, - buffer_edge_len=None, - args=None, -): + scale: Tuple[float, float, float], + interpolation_mode: str, + compress: bool, + buffer_edge_len: int = None, + args: Namespace = None, +) -> None: prev_mag = from_mag target_mag = get_next_anisotropic_mag(from_mag, scale) @@ -629,7 +651,7 @@ def downsample_mags_anisotropic( target_mag = get_next_anisotropic_mag(target_mag, scale) -def get_next_anisotropic_mag(mag, scale): +def get_next_anisotropic_mag(mag: Mag, scale: Tuple[float, float, float]) -> Mag: max_index, min_index = detect_larger_and_smaller_dimension(scale) mag_array = mag.to_array() scale_increase = [1, 1, 1] @@ -651,7 +673,9 @@ def get_next_anisotropic_mag(mag, scale): ) -def detect_larger_and_smaller_dimension(scale): +def detect_larger_and_smaller_dimension( + scale: Tuple[float, float, float] +) -> Tuple[int, int]: scale_np = np.array(scale) return np.argmax(scale_np), np.argmin(scale_np) diff --git a/wkcuber/export_wkw_as_tiff.py b/wkcuber/export_wkw_as_tiff.py index 3d00c547a..00aea420a 100644 --- a/wkcuber/export_wkw_as_tiff.py +++ b/wkcuber/export_wkw_as_tiff.py @@ -1,4 +1,4 @@ -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from functools import partial import logging import wkw @@ -21,7 +21,7 @@ from wkcuber.utils import wait_and_ensure_success -def create_parser(): +def create_parser() -> ArgumentParser: parser = ArgumentParser() parser.add_argument( @@ -127,22 +127,22 @@ def export_tiff_slice( downsample: int, mag: Mag, batch_number: int, -): +) -> None: tiff_bbox = tiff_bbox.copy() number_of_slices = ( min(tiff_bbox["size"][2] - batch_number * batch_size, batch_size) // mag.mag[2] ) - tiff_bbox["size"] = [ + tiff_bbox["size"] = ( tiff_bbox["size"][0] // mag.mag[0], tiff_bbox["size"][1] // mag.mag[1], number_of_slices, - ] - tiff_bbox["topleft"] = [ + ) + tiff_bbox["topleft"] = ( tiff_bbox["topleft"][0] // mag.mag[0], tiff_bbox["topleft"][1] // mag.mag[1], (tiff_bbox["topleft"][2] + batch_number * batch_size) // mag.mag[2], - ] + ) with wkw.Dataset.open(dataset_path) as dataset: if tiling_size is None: @@ -198,17 +198,17 @@ def export_tiff_slice( def export_tiff_stack( - wkw_file_path, - wkw_layer, - bbox, - mag, - destination_path, - name, - tiling_slice_size, - batch_size, - downsample, - args, -): + wkw_file_path: str, + wkw_layer: str, + bbox: Dict[str, List[int]], + mag: Mag, + destination_path: str, + name: str, + tiling_slice_size: Union[None, Tuple[int, int]], + batch_size: int, + downsample: int, + args: Namespace, +) -> None: os.makedirs(destination_path, exist_ok=True) dataset_path = os.path.join(wkw_file_path, wkw_layer, mag.to_layer_name()) @@ -234,13 +234,15 @@ def export_tiff_stack( wait_and_ensure_success(futures) -def export_wkw_as_tiff(args): +def export_wkw_as_tiff(args: Namespace) -> None: if args.verbose: logging.basicConfig(level=logging.DEBUG) if args.bbox is None: - _, _, bbox, origin = read_metadata_for_layer(args.source_path, args.layer_name) - bbox = {"topleft": origin, "size": bbox} + _, _, bbox_dim, origin = read_metadata_for_layer( + args.source_path, args.layer_name + ) + bbox = {"topleft": origin, "size": bbox_dim} else: bbox = {"topleft": list(args.bbox.topleft), "size": list(args.bbox.size)} @@ -277,7 +279,7 @@ def export_wkw_as_tiff(args): ) -def run(args_list: List): +def run(args_list: List) -> None: arguments = create_parser().parse_args(args_list) export_wkw_as_tiff(arguments) diff --git a/wkcuber/image_readers.py b/wkcuber/image_readers.py index a718bd499..4a28b11a2 100644 --- a/wkcuber/image_readers.py +++ b/wkcuber/image_readers.py @@ -1,27 +1,29 @@ +from typing import Tuple, Dict, Union + import numpy as np import logging from os import path from PIL import Image from .vendor.dm3 import DM3 -from .vendor.dm4 import DM4File +from .vendor.dm4 import DM4File, DM4TagHeader # Disable PIL's maximum image limit. Image.MAX_IMAGE_PIXELS = None class PillowImageReader: - def read_array(self, file_name, dtype): + def read_array(self, file_name: str, dtype: np.dtype) -> np.ndarray: this_layer = np.array(Image.open(file_name), dtype) this_layer = this_layer.swapaxes(0, 1) this_layer = this_layer.reshape(this_layer.shape + (1,)) return this_layer - def read_dimensions(self, file_name): + def read_dimensions(self, file_name: str) -> Tuple[int, int]: with Image.open(file_name) as test_img: - return (test_img.width, test_img.height) + return test_img.width, test_img.height - def read_channel_count(self, file_name): + def read_channel_count(self, file_name: str) -> int: with Image.open(file_name) as test_img: this_layer = np.array(test_img) if this_layer.ndim == 2: @@ -31,32 +33,30 @@ def read_channel_count(self, file_name): return this_layer.shape[-1] # pylint: disable=unsubscriptable-object -def to_target_datatype(data: np.ndarray, target_dtype) -> np.ndarray: - +def to_target_datatype(data: np.ndarray, target_dtype: np.dtype) -> np.ndarray: factor = (1 + np.iinfo(data.dtype).max) / (1 + np.iinfo(target_dtype).max) return (data / factor).astype(target_dtype) class Dm3ImageReader: - def read_array(self, file_name, dtype): + def read_array(self, file_name: str, dtype: np.dtype) -> np.ndarray: dm3_file = DM3(file_name) this_layer = to_target_datatype(dm3_file.imagedata, dtype) this_layer = this_layer.swapaxes(0, 1) this_layer = this_layer.reshape(this_layer.shape + (1,)) return this_layer - def read_dimensions(self, file_name): + def read_dimensions(self, file_name: str) -> Tuple[int, int]: test_img = DM3(file_name) - return (test_img.width, test_img.height) + return test_img.width, test_img.height - def read_channel_count(self, _file_name): + def read_channel_count(self, _file_name: str) -> int: logging.info("Assuming single channel for DM3 data") return 1 class Dm4ImageReader: - def _read_tags(self, dm4file): - + def _read_tags(self, dm4file: DM4File) -> Tuple[DM4File.DM4TagDir, DM4TagHeader]: tags = dm4file.read_directory() image_data_tag = ( tags.named_subdirs["ImageList"] @@ -67,8 +67,9 @@ def _read_tags(self, dm4file): return image_data_tag, image_tag - def _read_dimensions(self, dm4file, image_data_tag): - + def _read_dimensions( + self, dm4file: DM4File, image_data_tag: DM4File.DM4TagDir + ) -> Tuple[int, int]: width = dm4file.read_tag_data( image_data_tag.named_subdirs["Dimensions"].unnamed_tags[0] ) @@ -77,8 +78,7 @@ def _read_dimensions(self, dm4file, image_data_tag): ) return width, height - def read_array(self, file_name, dtype): - + def read_array(self, file_name: str, dtype: np.dtype) -> np.ndarray: dm4file = DM4File.open(file_name) image_data_tag, image_tag = self._read_tags(dm4file) width, height = self._read_dimensions(dm4file, image_data_tag) @@ -93,7 +93,7 @@ def read_array(self, file_name, dtype): return data - def read_dimensions(self, file_name): + def read_dimensions(self, file_name: str) -> Tuple[int, int]: dm4file = DM4File.open(file_name) image_data_tag, _ = self._read_tags(dm4file) @@ -102,14 +102,16 @@ def read_dimensions(self, file_name): return dimensions - def read_channel_count(self, _file_name): + def read_channel_count(self, _file_name: str) -> int: logging.info("Assuming single channel for DM4 data") return 1 class ImageReader: - def __init__(self): - self.readers = { + def __init__(self) -> None: + self.readers: Dict[ + str, Union[PillowImageReader, Dm3ImageReader, Dm4ImageReader] + ] = { ".tif": PillowImageReader(), ".tiff": PillowImageReader(), ".jpg": PillowImageReader(), @@ -119,7 +121,7 @@ def __init__(self): ".dm4": Dm4ImageReader(), } - def read_array(self, file_name, dtype): + def read_array(self, file_name: str, dtype: np.dtype) -> np.ndarray: _, ext = path.splitext(file_name) # Image shape will be (x, y, channel_count, z=1) or (x, y, z=1) @@ -130,11 +132,11 @@ def read_array(self, file_name, dtype): return image - def read_dimensions(self, file_name): + def read_dimensions(self, file_name: str) -> Tuple[int, int]: _, ext = path.splitext(file_name) return self.readers[ext].read_dimensions(file_name) - def read_channel_count(self, file_name): + def read_channel_count(self, file_name: str) -> int: _, ext = path.splitext(file_name) return self.readers[ext].read_channel_count(file_name) diff --git a/wkcuber/knossos.py b/wkcuber/knossos.py index 4f67b8776..48104f720 100644 --- a/wkcuber/knossos.py +++ b/wkcuber/knossos.py @@ -1,3 +1,6 @@ +from types import TracebackType +from typing import Tuple, Any, Generator, Iterator, Optional, Type + import numpy as np import os import re @@ -11,25 +14,27 @@ class KnossosDataset: - def __init__(self, root, dtype=np.uint8): + def __init__(self, root: str, dtype: np.dtype = np.uint8): self.root = root self.dtype = dtype - def read(self, offset, shape): + def read( + self, offset: Tuple[int, int, int], shape: Tuple[int, int, int] + ) -> np.ndarray: assert offset[0] % CUBE_EDGE_LEN == 0 assert offset[1] % CUBE_EDGE_LEN == 0 assert offset[2] % CUBE_EDGE_LEN == 0 assert shape == CUBE_SHAPE return self.read_cube(tuple(x // CUBE_EDGE_LEN for x in offset)) - def write(self, offset, data): + def write(self, offset: Tuple[int, int, int], data: np.ndarray) -> None: assert offset[0] % CUBE_EDGE_LEN == 0 assert offset[1] % CUBE_EDGE_LEN == 0 assert offset[2] % CUBE_EDGE_LEN == 0 assert data.shape == CUBE_SHAPE self.write_cube(tuple(x // CUBE_EDGE_LEN for x in offset), data) - def read_cube(self, cube_xyz): + def read_cube(self, cube_xyz: Tuple[int, ...]) -> np.ndarray: filename = self.__get_only_raw_file_path(cube_xyz) if filename is None: return np.zeros(CUBE_SHAPE, dtype=self.dtype) @@ -44,7 +49,7 @@ def read_cube(self, cube_xyz): cube_data = cube_data.reshape(CUBE_SHAPE, order="F") return cube_data - def write_cube(self, cube_xyz, cube_data): + def write_cube(self, cube_xyz: Tuple[int, ...], cube_data: np.ndarray) -> None: filename = self.__get_only_raw_file_path(cube_xyz) if filename is None: filename = path.join( @@ -54,17 +59,17 @@ def write_cube(self, cube_xyz, cube_data): with open(filename, "wb") as cube_file: cube_data.ravel(order="F").tofile(cube_file) - def __get_cube_folder(self, cube_xyz): + def __get_cube_folder(self, cube_xyz: Tuple[int, ...]) -> str: x, y, z = cube_xyz return path.join( self.root, "x{:04d}".format(x), "y{:04d}".format(y), "z{:04d}".format(z) ) - def __get_cube_file_name(self, cube_xyz): + def __get_cube_file_name(self, cube_xyz: Tuple[int, ...]) -> str: x, y, z = cube_xyz return "cube_x{:04d}_y{:04d}_z{:04d}.raw".format(x, y, z) - def __get_only_raw_file_path(self, cube_xyz): + def __get_only_raw_file_path(self, cube_xyz: Tuple[int, ...]) -> Optional[str]: cube_folder = self.__get_cube_folder(cube_xyz) raw_files = glob(path.join(cube_folder, "*.raw")) assert len(raw_files) <= 1, "Found %d .raw files in %s" % ( @@ -73,31 +78,36 @@ def __get_only_raw_file_path(self, cube_xyz): ) return raw_files[0] if len(raw_files) > 0 else None - def list_files(self): + def list_files(self) -> Iterator[str]: return iglob(path.join(self.root, "**", "*.raw"), recursive=True) - def __parse_cube_file_name(self, filename): + def __parse_cube_file_name(self, filename: str) -> Optional[Tuple[int, int, int]]: m = CUBE_REGEX.search(filename) if m is None: return None - return (int(m.group(1)), int(m.group(2)), int(m.group(3))) + return int(m.group(1)), int(m.group(2)), int(m.group(3)) - def list_cubes(self): + def list_cubes(self) -> Generator[Tuple[int, int, int], Any, None]: return ( f for f in (self.__parse_cube_file_name(f) for f in self.list_files()) if f is not None ) - def close(self): + def close(self) -> None: pass @staticmethod - def open(root: str, dtype): + def open(root: str, dtype: Optional[np.dtype]) -> "KnossosDataset": return KnossosDataset(root, dtype) - def __enter__(self): + def __enter__(self) -> "KnossosDataset": return self - def __exit__(self, _type, _value, _tb): + def __exit__( + self, + _type: Optional[Type[BaseException]], + _value: Optional[BaseException], + _tb: Optional[TracebackType], + ) -> None: self.close() diff --git a/wkcuber/mag.py b/wkcuber/mag.py index 1ed520591..483aaa450 100644 --- a/wkcuber/mag.py +++ b/wkcuber/mag.py @@ -3,13 +3,13 @@ from functools import total_ordering import numpy as np -from typing import List +from typing import List, Any @total_ordering -class Mag: - def __init__(self, mag): - self.mag = None +class Mag(object): + def __init__(self, mag: Any): + self.mag: List[int] = [] if isinstance(mag, int): self.mag = [mag] * 3 @@ -36,48 +36,48 @@ def __init__(self, mag): for m in self.mag: assert log2(m) % 1 == 0, "magnification needs to be power of 2." - def __lt__(self, other): - return max(self.mag) < (max(other.to_array())) + def __lt__(self, other: Any) -> bool: + return max(self.mag) < (max(Mag(other).to_array())) - def __le__(self, other): - return max(self.mag) <= (max(other.to_array())) + def __le__(self, other: Any) -> bool: + return max(self.mag) <= (max(Mag(other).to_array())) - def __eq__(self, other): - return all(m1 == m2 for m1, m2 in zip(self.mag, other.mag)) + def __eq__(self, other: Any) -> bool: + return all(m1 == m2 for m1, m2 in zip(self.mag, Mag(other).mag)) - def __str__(self): + def __str__(self) -> str: return self.to_layer_name() - def __expr__(self): + def __expr__(self) -> str: return f"Mag({self.to_layer_name()})" - def to_layer_name(self): + def to_layer_name(self) -> str: x, y, z = self.mag if x == y and y == z: return str(x) else: return self.to_long_layer_name() - def to_long_layer_name(self): + def to_long_layer_name(self) -> str: x, y, z = self.mag return "{}-{}-{}".format(x, y, z) - def to_array(self): + def to_array(self) -> List[int]: return self.mag - def scaled_by(self, factor: int): + def scaled_by(self, factor: int) -> "Mag": return Mag([mag * factor for mag in self.mag]) - def scale_by(self, factor: int): + def scale_by(self, factor: int) -> None: self.mag = [mag * factor for mag in self.mag] - def divided(self, coord: List[int]): + def divided(self, coord: List[int]) -> List[int]: return [c // m for c, m in zip(coord, self.mag)] - def divide_by(self, d: int): + def divide_by(self, d: int) -> None: self.mag = [mag // d for mag in self.mag] - def divided_by(self, d: int): + def divided_by(self, d: int) -> "Mag": return Mag([mag // d for mag in self.mag]) def as_np(self) -> np.ndarray: diff --git a/wkcuber/metadata.py b/wkcuber/metadata.py index e253bec91..c34f6cae0 100644 --- a/wkcuber/metadata.py +++ b/wkcuber/metadata.py @@ -8,7 +8,7 @@ from argparse import ArgumentParser from glob import iglob from os import path, listdir -from typing import Optional +from typing import Optional, Tuple, Iterable, Generator from .mag import Mag from typing import List from .utils import add_verbose_flag, setup_logging, add_scale_flag @@ -16,11 +16,11 @@ from os.path import basename, normpath -def get_datasource_path(dataset_path): +def get_datasource_path(dataset_path: str) -> str: return path.join(dataset_path, "datasource-properties.json") -def create_parser(): +def create_parser() -> ArgumentParser: parser = ArgumentParser() parser.add_argument("path", help="Directory containing the dataset.") @@ -45,13 +45,13 @@ def create_parser(): return parser -def write_datasource_properties(dataset_path, datasource_properties): +def write_datasource_properties(dataset_path: str, datasource_properties: dict) -> None: datasource_properties_path = get_datasource_path(dataset_path) with open(datasource_properties_path, "wt") as datasource_properties_file: json.dump(datasource_properties, datasource_properties_file, indent=2) -def read_datasource_properties(dataset_path): +def read_datasource_properties(dataset_path: str) -> dict: with open(get_datasource_path(dataset_path), "r") as datasource_properties_file: return json.load(datasource_properties_file) @@ -59,11 +59,11 @@ def read_datasource_properties(dataset_path): def write_webknossos_metadata( dataset_path: str, name: str, - scale, - max_id=0, - compute_max_id=False, + scale: Tuple[float, float, float], + max_id: int = 0, + compute_max_id: bool = False, exact_bounding_box: Optional[dict] = None, -): +) -> None: """ Creates a datasource-properties.json file with the specified properties for the given dataset path. Common layers are detected automatically. @@ -87,8 +87,11 @@ def write_webknossos_metadata( def refresh_metadata( - wkw_path, max_id=0, compute_max_id=False, exact_bounding_box: Optional[dict] = None -): + wkw_path: str, + max_id: int = 0, + compute_max_id: bool = False, + exact_bounding_box: Optional[dict] = None, +) -> None: """ Updates the datasource-properties.json file for a given dataset. Use this method if you added (or removed) layers and/or changed magnifications for @@ -130,7 +133,7 @@ def refresh_metadata( write_datasource_properties(wkw_path, datasource_properties) -def convert_element_class_to_dtype(elementClass): +def convert_element_class_to_dtype(elementClass: str) -> np.dtype: default_dtype = np.uint8 if "uint" in elementClass else np.dtype(elementClass) conversion_map = { "float": np.float32, @@ -143,7 +146,9 @@ def convert_element_class_to_dtype(elementClass): return conversion_map.get(elementClass, default_dtype) -def read_metadata_for_layer(wkw_path, layer_name): +def read_metadata_for_layer( + wkw_path: str, layer_name: str +) -> Tuple[dict, np.dtype, List[int], List[int]]: datasource_properties = read_datasource_properties(wkw_path) layers = datasource_properties["dataLayers"] @@ -161,7 +166,7 @@ def read_metadata_for_layer(wkw_path, layer_name): return layer_info, dtype, bounding_box, origin -def convert_dtype_to_element_class(dtype): +def convert_dtype_to_element_class(dtype: np.dtype) -> str: element_class_to_dtype_map = { "float": np.float32, "double": np.float64, @@ -174,7 +179,7 @@ def convert_dtype_to_element_class(dtype): return conversion_map.get(dtype, str(dtype)) -def detect_mag_path(dataset_path, layer, mag: Mag = Mag(1)): +def detect_mag_path(dataset_path: str, layer: str, mag: Mag = Mag(1)) -> Optional[str]: layer_path = path.join(dataset_path, layer, str(mag)) if path.exists(layer_path): return layer_path @@ -184,7 +189,7 @@ def detect_mag_path(dataset_path, layer, mag: Mag = Mag(1)): return None -def detect_dtype(dataset_path, layer, mag: Mag = Mag(1)): +def detect_dtype(dataset_path: str, layer: str, mag: Mag = Mag(1)) -> str: layer_path = detect_mag_path(dataset_path, layer, mag) if layer_path is not None: with wkw.Dataset.open(layer_path) as dataset: @@ -194,31 +199,39 @@ def detect_dtype(dataset_path, layer, mag: Mag = Mag(1)): return "uint" + str(8 * num_channels) else: return convert_dtype_to_element_class(voxel_size) + raise RuntimeError( + f"Failed to detect dtype (for {dataset_path}, {layer}, {mag}) because the layer_path is None" + ) -def detect_cubeLength(dataset_path, layer, mag: Mag = Mag(1)): +def detect_cubeLength(dataset_path: str, layer: str, mag: Mag = Mag(1)) -> int: layer_path = detect_mag_path(dataset_path, layer, mag) if layer_path is not None: with wkw.Dataset.open(layer_path) as dataset: return dataset.header.block_len * dataset.header.file_len + raise RuntimeError( + f"Failed to detect the cube length (for {dataset_path}, {layer}, {mag}) because the layer_path is None" + ) -def detect_bbox(dataset_path, layer, mag: Mag = Mag(1)): +def detect_bbox(dataset_path: str, layer: str, mag: Mag = Mag(1)) -> Optional[dict]: # Detect the coarse bounding box of a dataset by iterating # over the WKW cubes layer_path = detect_mag_path(dataset_path, layer, mag) if layer_path is None: return None - def list_files(layer_path): + def list_files(layer_path: str) -> Iterable[str]: return iglob(path.join(layer_path, "*", "*", "*.wkw"), recursive=True) - def parse_cube_file_name(filename): + def parse_cube_file_name(filename: str) -> Tuple[int, int, int]: CUBE_REGEX = re.compile(r"z(\d+)/y(\d+)/x(\d+)(\.wkw)$") m = CUBE_REGEX.search(filename) - return (int(m.group(3)), int(m.group(2)), int(m.group(1))) + if m is not None: + return int(m.group(3)), int(m.group(2)), int(m.group(1)) + raise RuntimeError(f"Failed to parse cube file name from {filename}") - def list_cubes(layer_path): + def list_cubes(layer_path: str) -> Iterable[Tuple[int, int, int]]: return (parse_cube_file_name(f) for f in list_files(layer_path)) xs, ys, zs = list(zip(*list_cubes(layer_path))) @@ -238,7 +251,7 @@ def list_cubes(layer_path): } -def detect_resolutions(dataset_path, layer) -> List[Mag]: +def detect_resolutions(dataset_path: str, layer: str) -> Generator: for mag in listdir(path.join(dataset_path, layer)): try: yield Mag(mag) @@ -247,8 +260,11 @@ def detect_resolutions(dataset_path, layer) -> List[Mag]: def detect_standard_layer( - dataset_path, layer_name, exact_bounding_box=None, category="color" -): + dataset_path: str, + layer_name: str, + exact_bounding_box: Optional[dict] = None, + category: str = "color", +) -> dict: # Perform metadata detection for well-known layers mags = list(detect_resolutions(dataset_path, layer_name)) @@ -297,7 +313,7 @@ def detect_standard_layer( } -def detect_mappings(dataset_path, layer_name): +def detect_mappings(dataset_path: str, layer_name: str) -> List[str]: pattern = path.join(dataset_path, layer_name, "mappings", "*.json") mapping_files = glob.glob(pattern) mapping_file_names = [path.basename(mapping_file) for mapping_file in mapping_files] @@ -305,8 +321,12 @@ def detect_mappings(dataset_path, layer_name): def detect_segmentation_layer( - dataset_path, layer_name, max_id, compute_max_id=False, exact_bounding_box=None -): + dataset_path: str, + layer_name: str, + max_id: int, + compute_max_id: bool = False, + exact_bounding_box: dict = None, +) -> dict: layer_info = detect_standard_layer( dataset_path, layer_name, exact_bounding_box, category="segmentation" ) @@ -336,7 +356,12 @@ def detect_segmentation_layer( return layer_info -def detect_layers(dataset_path: str, max_id, compute_max_id, exact_bounding_box=None): +def detect_layers( + dataset_path: str, + max_id: int, + compute_max_id: bool, + exact_bounding_box: Optional[dict] = None, +) -> Iterable[dict]: # Detect metadata for well-known layers (i.e., color, prediction and segmentation) if path.exists(path.join(dataset_path, "color")): yield detect_standard_layer(dataset_path, "color", exact_bounding_box) @@ -353,6 +378,7 @@ def detect_layers(dataset_path: str, max_id, compute_max_id, exact_bounding_box= for layer_name in available_layer_names: # color and segmentation are already checked explicitly to ensure downwards compatibility (some older datasets don't have the header.wkw file) if layer_name not in ["color", "segmentation"]: + layer_info = None try: layer_info = detect_standard_layer( dataset_path, layer_name, exact_bounding_box diff --git a/wkcuber/recubing.py b/wkcuber/recubing.py index b8cc64982..1eb061c20 100644 --- a/wkcuber/recubing.py +++ b/wkcuber/recubing.py @@ -1,4 +1,6 @@ import logging +from typing import List, Tuple + import wkw import numpy as np from argparse import ArgumentParser @@ -18,7 +20,7 @@ ) -def create_parser(): +def create_parser() -> ArgumentParser: parser = ArgumentParser() parser.add_argument( @@ -60,19 +62,24 @@ def create_parser(): return parser -def next_lower_divisible_by(number, divisor) -> int: +def next_lower_divisible_by(number: int, divisor: int) -> int: remainder = number % divisor return number - remainder -def next_higher_divisible_by(number, divisor) -> int: +def next_higher_divisible_by(number: int, divisor: int) -> int: remainder = number % divisor return number - remainder + divisor def recube( - source_path, target_path, layer_name, dtype, wkw_file_len=32, compression=True -): + source_path: str, + target_path: str, + layer_name: str, + dtype: str, + wkw_file_len: int = 32, + compression: bool = True, +) -> None: if compression: block_type = wkw.Header.BLOCK_TYPE_LZ4 else: @@ -88,6 +95,9 @@ def recube( ensure_wkw(target_wkw_info) bounding_box_dict = detect_bbox(source_wkw_info.dataset_path, layer_name) + if bounding_box_dict is None: + raise ValueError("Failed to detect bounding box.") + bounding_box = ( bounding_box_dict["topLeft"], [ @@ -138,7 +148,11 @@ def recube( logging.info(f"{layer_name} successfully resampled!") -def recubing_cube_job(args): +def recubing_cube_job( + args: Tuple[ + WkwDatasetInfo, WkwDatasetInfo, List[int], List[int], int, Tuple[int, int, int] + ] +) -> None: ( source_wkw_info, target_wkw_info, diff --git a/wkcuber/tile_cubing.py b/wkcuber/tile_cubing.py index 54470cce0..165699daf 100644 --- a/wkcuber/tile_cubing.py +++ b/wkcuber/tile_cubing.py @@ -1,11 +1,11 @@ import time import logging import numpy as np -from typing import Dict, Tuple, Union +from typing import Dict, Tuple, Union, List, Optional import os from glob import glob import re -from argparse import ArgumentTypeError +from argparse import ArgumentTypeError, ArgumentParser, Namespace import wkw from .utils import ( @@ -82,7 +82,7 @@ def replace_coordinates_with_glob_regex(pattern: str, coord_ids: Dict[str, int]) return pattern -def get_digit_counts_for_dimensions(pattern): +def get_digit_counts_for_dimensions(pattern: str) -> Dict[str, int]: """ Counts how many digits the dimensions x, y and z occupy in the given pattern. """ occurrences = re.findall("({x+}|{y+}|{z+})", pattern) decimal_lengths = {"x": 0, "y": 0, "z": 0} @@ -98,14 +98,14 @@ def get_digit_counts_for_dimensions(pattern): def detect_interval_for_dimensions( file_path_pattern: str, decimal_lengths: Dict[str, int] -) -> Tuple[Dict[str, int], Dict[str, int], str, int]: +) -> Tuple[Dict[str, int], Dict[str, int], Optional[str], int]: arbitrary_file = None file_count = 0 # dictionary that maps the dimension string to the current dimension length # used to avoid distinction of dimensions with if statements current_decimal_length = {"x": 0, "y": 0, "z": 0} max_dimensions = {"x": 0, "y": 0, "z": 0} - min_dimensions = {"x": None, "y": None, "z": None} + min_dimensions: Dict[str, int] = {} # find all files by trying all combinations of dimension lengths for x in range(decimal_lengths["x"] + 1): @@ -131,13 +131,13 @@ def detect_interval_for_dimensions( # Use that index to look up the actual value within the file name for current_dimension in ["x", "y", "z"]: idx = applied_fpp.index(current_dimension) - coordinate_value = file_name[ + coordinate_value_str = file_name[ idx : idx + current_decimal_length[current_dimension] ] - coordinate_value = int(coordinate_value) + coordinate_value = int(coordinate_value_str) assert coordinate_value min_dimensions[current_dimension] = min( - min_dimensions[current_dimension] or coordinate_value, + min_dimensions.get(current_dimension, coordinate_value), coordinate_value, ) max_dimensions[current_dimension] = max( @@ -178,7 +178,18 @@ def find_file_with_dimensions( return None -def tile_cubing_job(args): +def tile_cubing_job( + args: Tuple[ + WkwDatasetInfo, + List[int], + str, + int, + Tuple[int, int, int], + Dict[str, int], + Dict[str, int], + Dict[str, int], + ] +) -> None: ( target_wkw_info, z_batches, @@ -251,8 +262,13 @@ def tile_cubing_job(args): def tile_cubing( - target_path, layer_name, dtype, batch_size, input_path_pattern, args=None -): + target_path: str, + layer_name: str, + dtype: str, + batch_size: int, + input_path_pattern: str, + args: Namespace = None, +) -> None: decimal_lengths = get_digit_counts_for_dimensions(input_path_pattern) ( min_dimensions, @@ -270,7 +286,6 @@ def tile_cubing( # Determine tile size from first matching file tile_size = image_reader.read_dimensions(arbitrary_file) num_channels = image_reader.read_channel_count(arbitrary_file) - tile_size = (tile_size[0], tile_size[1], num_channels) logging.info( "Found source files: count={} with tile_size={}x{}".format( file_count, tile_size[0], tile_size[1] @@ -296,7 +311,7 @@ def tile_cubing( list(z_batch), input_path_pattern, batch_size, - tile_size, + (tile_size[0], tile_size[1], num_channels), min_dimensions, max_dimensions, decimal_lengths, @@ -305,7 +320,7 @@ def tile_cubing( wait_and_ensure_success(executor.map_to_futures(tile_cubing_job, job_args)) -def create_parser(): +def create_parser() -> ArgumentParser: parser = create_cubing_parser() parser.add_argument( diff --git a/wkcuber/utils.py b/wkcuber/utils.py index d8eedfd60..eb9d7884a 100644 --- a/wkcuber/utils.py +++ b/wkcuber/utils.py @@ -1,5 +1,8 @@ import re import time +from concurrent.futures._base import Future +from types import TracebackType + import logging import argparse import wkw @@ -9,12 +12,13 @@ import os import psutil import traceback -import concurrent +from cluster_tools.schedulers.cluster_executor import ClusterExecutor -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Iterable, Generator, Any, Optional, Type from glob import iglob from collections import namedtuple from multiprocessing import cpu_count +from concurrent.futures import as_completed from os import path, getpid from math import floor, ceil from logging import getLogger @@ -39,21 +43,21 @@ logger = getLogger(__name__) -def open_wkw(info): +def open_wkw(info: WkwDatasetInfo) -> wkw.Dataset: ds = wkw.Dataset.open( path.join(info.dataset_path, info.layer_name, str(info.mag)), info.header ) return ds -def ensure_wkw(target_wkw_info): +def ensure_wkw(target_wkw_info: WkwDatasetInfo) -> None: assert target_wkw_info.header is not None # Open will create the dataset if it doesn't exist yet target_wkw = open_wkw(target_wkw_info) target_wkw.close() -def cube_addresses(source_wkw_info): +def cube_addresses(source_wkw_info: WkwDatasetInfo) -> List[Tuple[int, int, int]]: # Gathers all WKW cubes in the dataset with open_wkw(source_wkw_info) as source_wkw: wkw_addresses = list(parse_cube_file_name(f) for f in source_wkw.list_files()) @@ -61,31 +65,32 @@ def cube_addresses(source_wkw_info): return wkw_addresses -def parse_cube_file_name(filename): +def parse_cube_file_name(filename: str) -> Tuple[int, int, int]: m = CUBE_REGEX.search(filename) + if m is None: + raise ValueError(f"Failed to parse cube file name {filename}") return int(m.group(3)), int(m.group(2)), int(m.group(1)) -def parse_scale(scale): +def parse_scale(scale: str) -> Tuple[float, ...]: try: - scale = tuple(float(x) for x in scale.split(",")) - return scale + return tuple(float(x) for x in scale.split(",")) except Exception as e: raise argparse.ArgumentTypeError("The scale could not be parsed") from e -def parse_bounding_box(bbox_str): +def parse_bounding_box(bbox_str: str) -> BoundingBox: try: return BoundingBox.from_csv(bbox_str) except Exception as e: raise argparse.ArgumentTypeError("The bounding box could not be parsed.") from e -def open_knossos(info): +def open_knossos(info: KnossosDatasetInfo) -> KnossosDataset: return KnossosDataset.open(info.dataset_path, np.dtype(info.dtype)) -def add_verbose_flag(parser): +def add_verbose_flag(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--silent", help="Silent output", dest="verbose", action="store_false" ) @@ -93,7 +98,7 @@ def add_verbose_flag(parser): parser.set_defaults(verbose=True) -def add_scale_flag(parser, required=True): +def add_scale_flag(parser: argparse.ArgumentParser, required: bool = True) -> None: parser.add_argument( "--scale", "-s", @@ -103,7 +108,7 @@ def add_scale_flag(parser, required=True): ) -def add_isotropic_flag(parser): +def add_isotropic_flag(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--isotropic", help="Activates isotropic downsampling. The default is anisotropic downsampling. " @@ -114,7 +119,7 @@ def add_isotropic_flag(parser): ) -def add_interpolation_flag(parser): +def add_interpolation_flag(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--interpolation_mode", "-i", @@ -123,15 +128,16 @@ def add_interpolation_flag(parser): ) -def setup_logging(args): - +def setup_logging(args: argparse.Namespace) -> None: logging.basicConfig( level=(logging.DEBUG if args.verbose else logging.INFO), format="%(asctime)s %(levelname)s %(message)s", ) -def find_files(source_path, extensions): +def find_files( + source_path: str, extensions: Iterable[str] +) -> Generator[str, Any, None]: # Find all files with a matching file extension return ( f @@ -140,20 +146,22 @@ def find_files(source_path, extensions): ) -def get_chunks(arr, chunk_size): +def get_chunks(arr: List[Any], chunk_size: int) -> Iterable[List[Any]]: for i in range(0, len(arr), chunk_size): yield arr[i : i + chunk_size] # min_z and max_z are both inclusive -def get_regular_chunks(min_z, max_z, chunk_size): +def get_regular_chunks( + min_z: int, max_z: int, chunk_size: int +) -> Iterable[Iterable[int]]: i = floor(min_z / chunk_size) * chunk_size while i < ceil((max_z + 1) / chunk_size) * chunk_size: yield range(i, i + chunk_size) i += chunk_size -def add_distribution_flags(parser): +def add_distribution_flags(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--jobs", "-j", @@ -176,7 +184,7 @@ def add_distribution_flags(parser): ) -def add_batch_size_flag(parser): +def add_batch_size_flag(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--batch_size", "-b", @@ -186,16 +194,18 @@ def add_batch_size_flag(parser): ) -def get_executor_for_args(args): +def get_executor_for_args( + args: Optional[argparse.Namespace], +) -> Union[ClusterExecutor, cluster_tools.WrappedProcessPoolExecutor]: + executor = None if args is None: # For backwards compatibility with code from other packages # we allow args to be None. In this case we are defaulting # to these values: - args = FallbackArgs("multiprocessing", cpu_count()) - - executor = None - - if args.distribution_strategy == "multiprocessing": + jobs = cpu_count() + executor = cluster_tools.get_executor("multiprocessing", max_workers=jobs) + logging.info("Using pool of {} workers.".format(jobs)) + elif args.distribution_strategy == "multiprocessing": # Also accept "processes" instead of job to be compatible with segmentation-tools. # In the long run, the args should be unified and provided by the clustertools. if "jobs" in args: @@ -231,19 +241,19 @@ def get_executor_for_args(args): times = {} -def time_start(identifier): +def time_start(identifier: str) -> None: times[identifier] = time.time() -def time_stop(identifier): +def time_stop(identifier: str) -> None: _time = times.pop(identifier) logging.debug("{} took {:.8f}s".format(identifier, time.time() - _time)) # Waits for all futures to complete and raises an exception # as soon as a future resolves with an error. -def wait_and_ensure_success(futures): - for fut in concurrent.futures.as_completed(futures): +def wait_and_ensure_success(futures: List[Future]) -> None: + for fut in as_completed(futures): fut.result() @@ -252,7 +262,7 @@ def __init__( self, dataset_path: str, layer_name: str, - dtype, + dtype: np.dtype, origin: Union[Tuple[int, int, int], List[int]], # buffer_size specifies, how many slices should be aggregated until they are flushed. buffer_size: int = 32, @@ -274,11 +284,11 @@ def __init__( ) self.origin = origin - self.buffer = [] - self.current_z = None - self.buffer_start_z = None + self.buffer: List[np.ndarray] = [] + self.current_z: Optional[int] = None + self.buffer_start_z: Optional[int] = None - def write_slice(self, z: int, data: np.ndarray): + def write_slice(self, z: int, data: np.ndarray) -> None: """Takes in a slice in [y, x] shape, writes to WKW file.""" if len(self.buffer) == 0: @@ -295,7 +305,7 @@ def write_slice(self, z: int, data: np.ndarray): if self.current_z % self.buffer_size == 0: self._write_buffer() - def _write_buffer(self): + def _write_buffer(self) -> None: if len(self.buffer) == 0: return @@ -321,6 +331,9 @@ def _write_buffer(self): log_memory_consumption() try: + assert ( + self.buffer_start_z is not None + ), "Failed to write buffer: The buffer_start_z is not set." origin_with_offset = list(self.origin) origin_with_offset[2] = self.buffer_start_z x_max = max(slice.shape[0] for slice in self.buffer) @@ -361,19 +374,24 @@ def _write_buffer(self): finally: self.buffer = [] - def close(self): + def close(self) -> None: self._write_buffer() self.dataset.close() - def __enter__(self): + def __enter__(self) -> "BufferedSliceWriter": return self - def __exit__(self, _type, _value, _tb): + def __exit__( + self, + _type: Optional[Type[BaseException]], + _value: Optional[BaseException], + _tb: Optional[TracebackType], + ) -> None: self.close() -def log_memory_consumption(additional_output=""): +def log_memory_consumption(additional_output: str = "") -> None: pid = os.getpid() process = psutil.Process(pid) logging.info( @@ -387,7 +405,9 @@ def log_memory_consumption(additional_output=""): ) -def pad_or_crop_to_size_and_topleft(cube_data, target_size, target_topleft): +def pad_or_crop_to_size_and_topleft( + cube_data: np.ndarray, target_size: np.ndarray, target_topleft: np.ndarray +) -> np.ndarray: """ Given an numpy array and a target_size/target_topleft, the array will be padded so that it is within the bounding box descriped by topleft and size.