From 21fb6bf2abc6c4acdf956b3944b8074d071c9ba0 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 16 May 2023 13:25:44 +0200 Subject: [PATCH 001/244] add Sha256NodeChecker --- bioimageio/core/resource_io/utils.py | 47 +++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/bioimageio/core/resource_io/utils.py b/bioimageio/core/resource_io/utils.py index 575f049b..f39d9818 100644 --- a/bioimageio/core/resource_io/utils.py +++ b/bioimageio/core/resource_io/utils.py @@ -1,4 +1,5 @@ import dataclasses +import hashlib import importlib.util import os import pathlib @@ -6,7 +7,7 @@ import typing from types import ModuleType -from bioimageio.spec.shared import raw_nodes, resolve_source, source_available +from bioimageio.spec.shared import get_resolved_source_path, raw_nodes, resolve_source, source_available from bioimageio.spec.shared.node_transformer import ( GenericRawNode, GenericResolvedNode, @@ -54,6 +55,50 @@ def generic_visit(self, node): super().generic_visit(node) +def get_sha256(path: os.PathLike) -> str: + h = hashlib.sha256() + with open(path, "rb") as f: + while True: + block = f.read(h.block_size) + if not block: + break + h.update(block) + + return h.hexdigest() + + +class Sha256NodeChecker(NodeVisitor): + """Check integrity of the source-like field for every sha256-like field encountered""" + + def __init__(self, *, root_path: os.PathLike): + self.root_path = root_path if isinstance(root_path, raw_nodes.URI) else pathlib.Path(root_path).resolve() + + def generic_visit(self, node): + if isinstance(node, raw_nodes.RawNode): + for field, expected_sha256 in iter_fields(node): + if field == "sha256": + source_name = "source" + elif field.endswith("_sha256"): + source_name = field[: -len("_sha256")] + elif "sha256" in field: + raise NotImplementedError(f"Don't know how to check integrity with {field}") + else: + continue + + if not hasattr(node, source_name): + raise ValueError( + f"Node {node} expected to have '{source_name}' field associated with '{expected_sha256}'" + ) + + source_node = getattr(node, source_name) + source = get_resolved_source_path(source_node, root_path=self.root_path) + actual_sha256 = get_sha256(source) + + if actual_sha256 != expected_sha256: + raise ValueError(f"SHA256 of {source_name} ") + super().generic_visit(node) + + class SourceNodeTransformer(NodeTransformer): """ Imports all source callables From c3bd9664057797c6d3cc2d95658af785f286fd70 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 16 May 2023 13:26:39 +0200 Subject: [PATCH 002/244] test_sha256_checker --- tests/resource_io/test_utils.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/resource_io/test_utils.py b/tests/resource_io/test_utils.py index 30889a1d..28cd351b 100644 --- a/tests/resource_io/test_utils.py +++ b/tests/resource_io/test_utils.py @@ -1,7 +1,13 @@ +import dataclasses from pathlib import Path from bioimageio.core.resource_io import nodes, utils +from bioimageio.core.resource_io.utils import Sha256NodeChecker from bioimageio.spec.shared import raw_nodes +from bioimageio.spec.shared.raw_nodes import RawNode + + +import pytest def test_resolve_import_path(tmpdir): @@ -55,3 +61,24 @@ def test_uri_node_transformer_is_ok_with_abs_path(): assert tree["rel_path"] == Path("/root/something/relative").absolute() assert tree["abs_path"].is_absolute() assert tree["abs_path"] == Path("/something/absolute").absolute() + + +def test_sha256_checker(tmpdir): + root = Path(tmpdir) + src1 = root / "meh.txt" + src2 = root / "meh.txt" + src1.write_text("meh", encoding="utf-8") + src2.write_text("muh", encoding="utf-8") + + @dataclasses.dataclass + class TestNode(RawNode): + src: Path = src1 + sha256: str = "f65255094d7773ed8dd417badc9fc045c1f80fdc5b2d25172b031ce6933e039a" + my_src: Path = src2 + my_src_sha256: str = "8cf5844c38045aa19aae00d689002549d308de07a777c2ea34355d65283255ac" + + checker = Sha256NodeChecker(root_path=root) + checker.visit(TestNode()) + + with pytest.raises(ValueError): + checker.visit(TestNode(my_src_sha256="nope")) From 93400a7ced07116581f81a80f02862333a0fabb0 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 16 May 2023 13:36:20 +0200 Subject: [PATCH 003/244] add _test_resource_integrity and improve error message --- bioimageio/core/resource_io/utils.py | 12 ++++++++- bioimageio/core/resource_tests.py | 39 ++++++++++++++++++++++++---- tests/resource_io/test_utils.py | 14 +++++----- 3 files changed, 52 insertions(+), 13 deletions(-) diff --git a/bioimageio/core/resource_io/utils.py b/bioimageio/core/resource_io/utils.py index f39d9818..4e5a1858 100644 --- a/bioimageio/core/resource_io/utils.py +++ b/bioimageio/core/resource_io/utils.py @@ -94,8 +94,18 @@ def generic_visit(self, node): source = get_resolved_source_path(source_node, root_path=self.root_path) actual_sha256 = get_sha256(source) + if not isinstance(expected_sha256, str): + raise TypeError(f"Expected '{field}' to hold string, not {type(expected_sha256)}") + if actual_sha256 != expected_sha256: - raise ValueError(f"SHA256 of {source_name} ") + if actual_sha256[:6] != expected_sha256[:6]: + actual_sha256 = actual_sha256[:6] + "..." + expected_sha256 = expected_sha256[:6] + "..." + + raise ValueError( + f"Determined {actual_sha256} for {source_name}={source}, but expected {field}={expected_sha256}" + ) + super().generic_visit(node) diff --git a/bioimageio/core/resource_tests.py b/bioimageio/core/resource_tests.py index d807bed0..47b1a663 100644 --- a/bioimageio/core/resource_tests.py +++ b/bioimageio/core/resource_tests.py @@ -22,7 +22,7 @@ ResourceDescription, URI, ) -from bioimageio.core.resource_io.utils import SourceNodeChecker +from bioimageio.core.resource_io.utils import Sha256NodeChecker, SourceNodeChecker from bioimageio.spec import __version__ as bioimageio_spec_version from bioimageio.spec.model.raw_nodes import WeightsFormat from bioimageio.spec.shared import resolve_source @@ -105,6 +105,31 @@ def _test_resource_urls(rd: ResourceDescription) -> TestSummary: ) +def _test_resource_integrity(rd: ResourceDescription) -> TestSummary: + assert isinstance(rd, ResourceDescription) + with warnings.catch_warnings(record=True) as all_warnings: + try: + Sha256NodeChecker(root_path=rd.root_path).visit(rd) + except FileNotFoundError as e: + error = str(e) + tb = traceback.format_tb(e.__traceback__) + else: + error = None + tb = None + + return dict( + name="Integrity of source files", + status="passed" if error is None else "failed", + error=error, + traceback=tb, + bioimageio_spec_version=bioimageio_spec_version, + bioimageio_core_version=bioimageio_core_version, + nested_errors=None, + source_name=rd.id or rd.id or rd.name if hasattr(rd, "id") else rd.name, + warnings={"Sha256NodeChecker": [str(w.message) for w in all_warnings]} if all_warnings else {}, + ) + + def _test_model_documentation(rd: ResourceDescription) -> TestSummary: assert isinstance(rd, Model) with warnings.catch_warnings(): @@ -173,7 +198,7 @@ def _test_model_inference(model: Model, weight_format: str, devices: Optional[Li tb = traceback.format_tb(e.__traceback__) return dict( - name=f"reproduce test outputs from test inputs (bioimageio.core {bioimageio_core_version})", + name=f"Reproduce test outputs from test inputs (bioimageio.core {bioimageio_core_version})", status="passed" if error is None else "failed", error=error, traceback=tb, @@ -212,7 +237,7 @@ def _test_load_resource( tb = None load_summary = TestSummary( - name="load resource description", + name="Load resource description", status="passed" if error is None else "failed", error=error, nested_errors=None, @@ -229,7 +254,7 @@ def _test_load_resource( def _test_expected_resource_type(rd: ResourceDescription, expected_type: str) -> TestSummary: has_expected_type = rd.type == expected_type return dict( - name="has expected resource type", + name="Has expected resource type", status="passed" if has_expected_type else "failed", error=None if has_expected_type else f"expected type {expected_type}, found {rd.type}", traceback=None, @@ -256,10 +281,14 @@ def test_resource( tests.append(_test_expected_resource_type(rd, expected_type)) tests.append(_test_resource_urls(rd)) + if tests[-1]["status"] == "passed": + tests.append(_test_resource_integrity(rd)) if isinstance(rd, Model): + if tests[-1]["status"] == "passed": # only run inference when source file hashes match + tests.append(_test_model_inference(rd, weight_format, devices, decimal)) + tests.append(_test_model_documentation(rd)) - tests.append(_test_model_inference(rd, weight_format, devices, decimal)) return tests diff --git a/tests/resource_io/test_utils.py b/tests/resource_io/test_utils.py index 28cd351b..198df5e5 100644 --- a/tests/resource_io/test_utils.py +++ b/tests/resource_io/test_utils.py @@ -19,10 +19,10 @@ def test_resolve_import_path(tmpdir): node = raw_nodes.ImportableSourceFile(source_file=source_file, callable_name="Foo") uri_transformed = utils.UriNodeTransformer(root_path=tmpdir).transform(node) source_transformed = utils.SourceNodeTransformer().transform(uri_transformed) - assert isinstance(source_transformed, nodes.ImportedSource) + assert isinstance(source_transformed, nodes.ImportedSource), type(source_transformed) Foo = source_transformed.factory - assert Foo.__name__ == "Foo" - assert isinstance(Foo, type) + assert Foo.__name__ == "Foo", Foo.__name__ + assert isinstance(Foo, type), type(Foo) def test_resolve_directory_uri(tmpdir): @@ -66,13 +66,13 @@ def test_uri_node_transformer_is_ok_with_abs_path(): def test_sha256_checker(tmpdir): root = Path(tmpdir) src1 = root / "meh.txt" - src2 = root / "meh.txt" - src1.write_text("meh", encoding="utf-8") - src2.write_text("muh", encoding="utf-8") + src2 = root / "muh.txt" + src1.write_text(src1.stem, encoding="utf-8") + src2.write_text(src2.stem, encoding="utf-8") @dataclasses.dataclass class TestNode(RawNode): - src: Path = src1 + source: Path = src1 sha256: str = "f65255094d7773ed8dd417badc9fc045c1f80fdc5b2d25172b031ce6933e039a" my_src: Path = src2 my_src_sha256: str = "8cf5844c38045aa19aae00d689002549d308de07a777c2ea34355d65283255ac" From 2c8acfeb92590024b32ba17f4a92da281d358ebf Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 16 May 2023 14:30:14 +0200 Subject: [PATCH 004/244] add file integrity to resource tests --- bioimageio/core/resource_tests.py | 82 +++++++++++++++++++++++-------- 1 file changed, 62 insertions(+), 20 deletions(-) diff --git a/bioimageio/core/resource_tests.py b/bioimageio/core/resource_tests.py index 47b1a663..09acd8fc 100644 --- a/bioimageio/core/resource_tests.py +++ b/bioimageio/core/resource_tests.py @@ -11,7 +11,11 @@ import xarray as xr from marshmallow import ValidationError -from bioimageio.core import __version__ as bioimageio_core_version, load_resource_description +from bioimageio.core import ( + __version__ as bioimageio_core_version, + load_raw_resource_description, + load_resource_description, +) from bioimageio.core.common import TestSummary from bioimageio.core.prediction import predict from bioimageio.core.prediction_pipeline import create_prediction_pipeline @@ -80,8 +84,8 @@ def check_output_shape(shape: Tuple[int, ...], shape_spec, input_shapes) -> bool raise TypeError(f"Encountered unexpected shape description of type {type(shape_spec)}") -def _test_resource_urls(rd: ResourceDescription) -> TestSummary: - assert isinstance(rd, ResourceDescription) +def _test_resource_urls(rd: RawResourceDescription) -> TestSummary: + assert isinstance(rd, RawResourceDescription), type(rd) with warnings.catch_warnings(record=True) as all_warnings: try: SourceNodeChecker(root_path=rd.root_path).visit(rd) @@ -105,8 +109,8 @@ def _test_resource_urls(rd: ResourceDescription) -> TestSummary: ) -def _test_resource_integrity(rd: ResourceDescription) -> TestSummary: - assert isinstance(rd, ResourceDescription) +def _test_resource_integrity(rd: RawResourceDescription) -> TestSummary: + assert isinstance(rd, RawResourceDescription) with warnings.catch_warnings(record=True) as all_warnings: try: Sha256NodeChecker(root_path=rd.root_path).visit(rd) @@ -209,9 +213,8 @@ def _test_model_inference(model: Model, weight_format: str, devices: Optional[Li ) -def _test_load_resource( - rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str], - weight_format: Optional[WeightsFormat] = None, +def _test_load_raw_resource( + rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str] ) -> Tuple[Optional[ResourceDescription], TestSummary]: if isinstance(rdf, (URI, os.PathLike)): source_name = str(rdf) @@ -220,11 +223,46 @@ def _test_load_resource( else: source_name = rdf.id if hasattr(rdf, "id") else rdf.name + main_test_warnings = [] + try: + with warnings.catch_warnings(record=True) as all_warnings: + rd: Optional[ResourceDescription] = load_raw_resource_description(rdf) + + main_test_warnings += list(all_warnings) + except Exception as e: + rd = None + error: Optional[str] = str(e) + tb: Optional = traceback.format_tb(e.__traceback__) + else: + error = None + tb = None + + load_summary = TestSummary( + name="Load raw resource description", + status="passed" if error is None else "failed", + error=error, + nested_errors=None, + traceback=tb, + bioimageio_spec_version=bioimageio_spec_version, + bioimageio_core_version=bioimageio_core_version, + warnings={}, + source_name=source_name, + ) + + return rd, load_summary + + +def _test_load_resource( + raw_rd: RawResourceDescription, + weight_format: Optional[WeightsFormat] = None, +) -> Tuple[Optional[ResourceDescription], TestSummary]: + source_name = getattr(raw_rd, "rdf_source", getattr(raw_rd, "id", raw_rd.name)) + main_test_warnings = [] try: with warnings.catch_warnings(record=True) as all_warnings: rd: Optional[ResourceDescription] = load_resource_description( - rdf, weights_priority_order=None if weight_format is None else [weight_format] + raw_rd, weights_priority_order=None if weight_format is None else [weight_format] ) main_test_warnings += list(all_warnings) @@ -251,7 +289,7 @@ def _test_load_resource( return rd, load_summary -def _test_expected_resource_type(rd: ResourceDescription, expected_type: str) -> TestSummary: +def _test_expected_resource_type(rd: RawResourceDescription, expected_type: str) -> TestSummary: has_expected_type = rd.type == expected_type return dict( name="Has expected resource type", @@ -274,21 +312,25 @@ def test_resource( Returns: summary dict with keys: name, status, error, traceback, bioimageio_spec_version, bioimageio_core_version """ - rd, load_test = _test_load_resource(rdf, weight_format) + raw_rd, load_test = _test_load_raw_resource(rdf) tests: List[TestSummary] = [load_test] - if rd is not None: - if expected_type is not None: - tests.append(_test_expected_resource_type(rd, expected_type)) + if raw_rd is None: + return tests - tests.append(_test_resource_urls(rd)) - if tests[-1]["status"] == "passed": - tests.append(_test_resource_integrity(rd)) + if expected_type is not None: + tests.append(_test_expected_resource_type(raw_rd, expected_type)) - if isinstance(rd, Model): - if tests[-1]["status"] == "passed": # only run inference when source file hashes match - tests.append(_test_model_inference(rd, weight_format, devices, decimal)) + tests.append(_test_resource_urls(raw_rd)) + if tests[-1]["status"] == "passed": + tests.append(_test_resource_integrity(raw_rd)) + if tests[-1]["status"] != "passed": + return tests # stop testing if resource availability/integrity is an issue + + rd = _test_load_resource(raw_rd, weight_format) + if isinstance(rd, Model): tests.append(_test_model_documentation(rd)) + tests.append(_test_model_inference(rd, weight_format, devices, decimal)) return tests From 3b1d8a5b808d5e2a47c48884e683cb3611d4666f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 16 May 2023 14:35:14 +0200 Subject: [PATCH 005/244] account for missing sha value and sha for uri field --- bioimageio/core/resource_io/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bioimageio/core/resource_io/utils.py b/bioimageio/core/resource_io/utils.py index 4e5a1858..7b72602d 100644 --- a/bioimageio/core/resource_io/utils.py +++ b/bioimageio/core/resource_io/utils.py @@ -7,6 +7,8 @@ import typing from types import ModuleType +from marshmallow import missing + from bioimageio.spec.shared import get_resolved_source_path, raw_nodes, resolve_source, source_available from bioimageio.spec.shared.node_transformer import ( GenericRawNode, @@ -76,8 +78,16 @@ def __init__(self, *, root_path: os.PathLike): def generic_visit(self, node): if isinstance(node, raw_nodes.RawNode): for field, expected_sha256 in iter_fields(node): + if expected_sha256 is missing: + continue + if field == "sha256": source_name = "source" + for sn in ["source", "uri"]: + if hasattr(node, sn): + source_name = sn + break + elif field.endswith("_sha256"): source_name = field[: -len("_sha256")] elif "sha256" in field: From 1a61dad553705f215c4a9d2b2652324e80ea1d71 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 17 May 2023 21:02:23 +0200 Subject: [PATCH 006/244] improve get_sha256 implementation --- bioimageio/core/resource_io/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bioimageio/core/resource_io/utils.py b/bioimageio/core/resource_io/utils.py index 7b72602d..b4aac1da 100644 --- a/bioimageio/core/resource_io/utils.py +++ b/bioimageio/core/resource_io/utils.py @@ -58,13 +58,13 @@ def generic_visit(self, node): def get_sha256(path: os.PathLike) -> str: + """from https://stackoverflow.com/a/44873382""" h = hashlib.sha256() - with open(path, "rb") as f: - while True: - block = f.read(h.block_size) - if not block: - break - h.update(block) + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(path, "rb", buffering=0) as f: + for n in iter(lambda: f.readinto(mv), 0): + h.update(mv[:n]) return h.hexdigest() From 0e48f70d86142e58fa78438fdd0fbe6f2c910715 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 17 May 2023 21:19:59 +0200 Subject: [PATCH 007/244] improve readibility of sha field iteration --- bioimageio/core/resource_io/utils.py | 37 +++++++++++----------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/bioimageio/core/resource_io/utils.py b/bioimageio/core/resource_io/utils.py index b4aac1da..4f48d0b4 100644 --- a/bioimageio/core/resource_io/utils.py +++ b/bioimageio/core/resource_io/utils.py @@ -77,43 +77,36 @@ def __init__(self, *, root_path: os.PathLike): def generic_visit(self, node): if isinstance(node, raw_nodes.RawNode): - for field, expected_sha256 in iter_fields(node): - if expected_sha256 is missing: - continue - - if field == "sha256": + for sha_field, expected in ((k, v) for (k, v) in iter_fields(node) if "sha256" in k and v is not missing): + if sha_field == "sha256": source_name = "source" - for sn in ["source", "uri"]: - if hasattr(node, sn): - source_name = sn - break - - elif field.endswith("_sha256"): - source_name = field[: -len("_sha256")] - elif "sha256" in field: - raise NotImplementedError(f"Don't know how to check integrity with {field}") + if not hasattr(node, "source") and hasattr(node, "uri"): + source_name = "uri" + + elif sha_field.endswith("_sha256"): + source_name = sha_field[: -len("_sha256")] else: - continue + raise NotImplementedError(f"Don't know how to check integrity with {sha_field}") if not hasattr(node, source_name): raise ValueError( - f"Node {node} expected to have '{source_name}' field associated with '{expected_sha256}'" + f"Node {node} expected to have '{source_name}' field associated with '{sha_field}'" ) source_node = getattr(node, source_name) source = get_resolved_source_path(source_node, root_path=self.root_path) actual_sha256 = get_sha256(source) - if not isinstance(expected_sha256, str): - raise TypeError(f"Expected '{field}' to hold string, not {type(expected_sha256)}") + if not isinstance(expected, str): + raise TypeError(f"Expected '{sha_field}' to hold string, not {type(expected)}") - if actual_sha256 != expected_sha256: - if actual_sha256[:6] != expected_sha256[:6]: + if actual_sha256 != expected: + if actual_sha256[:6] != expected[:6]: actual_sha256 = actual_sha256[:6] + "..." - expected_sha256 = expected_sha256[:6] + "..." + expected = expected[:6] + "..." raise ValueError( - f"Determined {actual_sha256} for {source_name}={source}, but expected {field}={expected_sha256}" + f"Determined {actual_sha256} for {source_name}={source}, but expected {sha_field}={expected}" ) super().generic_visit(node) From 286e16a2e834c3767bf58876752c6fe453c9ae6b Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 17 May 2023 21:28:55 +0200 Subject: [PATCH 008/244] shorten actual_sha256 to actual --- bioimageio/core/resource_io/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bioimageio/core/resource_io/utils.py b/bioimageio/core/resource_io/utils.py index 4f48d0b4..f5843083 100644 --- a/bioimageio/core/resource_io/utils.py +++ b/bioimageio/core/resource_io/utils.py @@ -95,18 +95,18 @@ def generic_visit(self, node): source_node = getattr(node, source_name) source = get_resolved_source_path(source_node, root_path=self.root_path) - actual_sha256 = get_sha256(source) + actual = get_sha256(source) if not isinstance(expected, str): raise TypeError(f"Expected '{sha_field}' to hold string, not {type(expected)}") - if actual_sha256 != expected: - if actual_sha256[:6] != expected[:6]: - actual_sha256 = actual_sha256[:6] + "..." + if actual != expected: + if actual[:6] != expected[:6]: + actual = actual[:6] + "..." expected = expected[:6] + "..." raise ValueError( - f"Determined {actual_sha256} for {source_name}={source}, but expected {sha_field}={expected}" + f"Determined {actual} for {source_name}={source}, but expected {sha_field}={expected}" ) super().generic_visit(node) From 57008479d8f761120c4e1b2cf34ebc329cdd64d7 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 23 May 2023 10:22:07 +0200 Subject: [PATCH 009/244] account for testing already loaded resources --- bioimageio/core/resource_io/utils.py | 5 +++++ bioimageio/core/resource_tests.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/bioimageio/core/resource_io/utils.py b/bioimageio/core/resource_io/utils.py index f5843083..338356e9 100644 --- a/bioimageio/core/resource_io/utils.py +++ b/bioimageio/core/resource_io/utils.py @@ -5,6 +5,7 @@ import pathlib import sys import typing +import warnings from types import ModuleType from marshmallow import missing @@ -18,6 +19,7 @@ UriNodeTransformer, ) from . import nodes +from .nodes import ImportedSource GenericNode = typing.Union[GenericRawNode, GenericResolvedNode] @@ -94,6 +96,9 @@ def generic_visit(self, node): ) source_node = getattr(node, source_name) + if isinstance(source_node, ImportedSource): + continue # test is run after loading. Warning issued in resource_tests._test_resource_integrity + source = get_resolved_source_path(source_node, root_path=self.root_path) actual = get_sha256(source) diff --git a/bioimageio/core/resource_tests.py b/bioimageio/core/resource_tests.py index 09acd8fc..a35566aa 100644 --- a/bioimageio/core/resource_tests.py +++ b/bioimageio/core/resource_tests.py @@ -112,6 +112,9 @@ def _test_resource_urls(rd: RawResourceDescription) -> TestSummary: def _test_resource_integrity(rd: RawResourceDescription) -> TestSummary: assert isinstance(rd, RawResourceDescription) with warnings.catch_warnings(record=True) as all_warnings: + if isinstance(rd, ResourceDescription): + warnings.warn("Testing source file integrity of an already loaded resource!") + try: Sha256NodeChecker(root_path=rd.root_path).visit(rd) except FileNotFoundError as e: From 338f201fa2c946e54840369beb58dbc8ad2b1f30 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 16 Aug 2023 10:07:48 +0200 Subject: [PATCH 010/244] bioimage.io name refactoring --- bioimageio/core/commands.py | 6 +++--- bioimageio/core/resource_io/io_.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bioimageio/core/commands.py b/bioimageio/core/commands.py index 640493a6..b75343e2 100644 --- a/bioimageio/core/commands.py +++ b/bioimageio/core/commands.py @@ -15,11 +15,11 @@ def package( weights_priority_order: Optional[List[str]] = None, verbose: bool = False, ) -> int: - """Package a BioImage.IO resource described by a BioImage.IO Resource Description File (RDF).""" + """Package a bioimage.io resource described by a bioimage.io Resource Description File (RDF).""" code = validate(rdf_source, update_format=True, update_format_inner=True) source_name = rdf_source.get("name") if isinstance(rdf_source, dict) else rdf_source - if code["error"]: - print(f"Cannot package invalid BioImage.IO RDF {source_name}") + if code["status"] != "passed": + print(f"Cannot package invalid bioimage.io RDF {source_name}") return 1 try: diff --git a/bioimageio/core/resource_io/io_.py b/bioimageio/core/resource_io/io_.py index 3bb98ead..f6af4ca0 100644 --- a/bioimageio/core/resource_io/io_.py +++ b/bioimageio/core/resource_io/io_.py @@ -30,15 +30,15 @@ def load_resource_description( *, weights_priority_order: Optional[Sequence[str]] = None, # model only ) -> ResourceDescription: - """load a BioImage.IO resource description file (RDF). + """load a bioimage.io resource description file (RDF). This includes some transformations for convenience, e.g. importing `source`. Use `load_raw_resource_description` to obtain a raw representation instead. Args: - source: resource description file (RDF) or raw BioImage.IO resource + source: resource description file (RDF) or raw bioimage.io resource weights_priority_order: If given only the first weights format present in the model resource is included Returns: - BioImage.IO resource + bioimage.io resource """ source = deepcopy(source) if isinstance(source, ResourceDescription): @@ -101,7 +101,7 @@ def export_resource_package( update_to_format: Optional[str] = None, weights_priority_order: Optional[Sequence[Union[str]]] = None, ) -> pathlib.Path: - """Package a BioImage.IO resource as a zip file. + """Package a bioimage.io resource as a zip file. Args: source: raw resource description, path, URI or raw data as dict @@ -114,7 +114,7 @@ def export_resource_package( If none of the prioritized weights formats is found all are included. Returns: - path to zipped BioImage.IO package in BIOIMAGEIO_CACHE_PATH or 'output_path' + path to zipped bioimage.io package in BIOIMAGEIO_CACHE_PATH or 'output_path' """ raw_rd = load_raw_resource_description(source, update_to_format=update_to_format) package_content = get_local_resource_package_content( From 70669fbb633faac5bdecd049aba0ca8459924716 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 16 Aug 2023 14:10:06 +0200 Subject: [PATCH 011/244] update mamba-org/setup-micromamba action --- .github/workflows/build.yml | 41 ++++++++++++++++++++----------------- .pre-commit-config.yaml | 4 ++-- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 76a06b5d..c0b54a62 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -34,12 +34,12 @@ jobs: steps: - uses: actions/checkout@v3 - name: Install Conda environment with Micromamba - uses: mamba-org/provision-with-micromamba@main + uses: mamba-org/setup-micromamba@v1 with: cache-downloads: true - cache-env: true + cache-environment: true environment-file: dev/environment-torch.yaml - extra-specs: | + create-args: >- python=${{ matrix.python-version }} - name: additional setup run: pip install --no-deps -e . @@ -54,12 +54,12 @@ jobs: steps: - uses: actions/checkout@v3 - name: Install Conda environment with Micromamba - uses: mamba-org/provision-with-micromamba@main + uses: mamba-org/setup-micromamba@v1 with: cache-downloads: true - cache-env: true + cache-environment: true environment-file: dev/environment-torch.yaml - extra-specs: | + create-args: >- python=${{ matrix.python-version }} - name: additional setup run: | @@ -77,13 +77,14 @@ jobs: steps: - uses: actions/checkout@v3 - name: Install Conda environment with Micromamba - uses: mamba-org/provision-with-micromamba@main + uses: mamba-org/setup-micromamba@v1 with: cache-downloads: true - cache-env: true + cache-environment: true environment-file: dev/environment-tf.yaml - channel-priority: flexible - extra-specs: | + condarc: | + channel-priority: flexible + create-args: >- python=${{ matrix.python-version }} - name: additional setup run: | @@ -101,13 +102,14 @@ jobs: steps: - uses: actions/checkout@v3 - name: Install Conda environment with Micromamba - uses: mamba-org/provision-with-micromamba@main + uses: mamba-org/setup-micromamba@v1 with: cache-downloads: true - cache-env: true + cache-environment: true environment-file: dev/environment-tf-legacy.yaml - channel-priority: flexible - extra-specs: | + condarc: | + channel_priority: flexible + create-args: | python=${{ matrix.python-version }} - name: additional setup run: | @@ -126,14 +128,15 @@ jobs: with: fetch-depth: 0 - name: Install Conda environment with Micromamba - uses: mamba-org/provision-with-micromamba@main + uses: mamba-org/setup-micromamba@v1 with: cache-downloads: true - cache-env: true - environment-file: false + cache-environment: true environment-name: build-env - channels: conda-forge - extra-specs: | + condarc: | + channels: + - conda-forge + create-args: | boa - name: linux conda build run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4256fb2..e8f64f50 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ repos: - repo: https://github.com/ambv/black - rev: 23.1.0 + rev: 23.7.0 hooks: - - id: black + - id: black From e2b4b713e679acbfb20d25c65bc6b9be19c1e8f8 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 16 Aug 2023 14:10:54 +0200 Subject: [PATCH 012/244] specify __all__ --- bioimageio/core/__init__.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index b47ac27d..80bed641 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -3,6 +3,13 @@ __version__ = json.loads((pathlib.Path(__file__).parent / "VERSION").read_text())["version"] +from .prediction import ( + predict_image, + predict_images, + predict_with_padding, + predict_with_tiling, +) +from .prediction_pipeline import create_prediction_pipeline from .resource_io import ( export_resource_package, load_raw_resource_description, @@ -10,6 +17,21 @@ save_raw_resource_description, serialize_raw_resource_description, ) -from .prediction_pipeline import create_prediction_pipeline -from .prediction import predict_image, predict_images, predict_with_padding, predict_with_tiling from .resource_tests import check_input_shape, check_output_shape, test_resource + +__all__ = [ + "__version__", + "check_input_shape", + "check_output_shape", + "create_prediction_pipeline", + "export_resource_package", + "load_raw_resource_description", + "load_resource_description", + "predict_image", + "predict_images", + "predict_with_padding", + "predict_with_tiling", + "save_raw_resource_description", + "serialize_raw_resource_description", + "test_resource", +] From 57179fd628eeb8fde28cc37ecfdd20830c9d533d Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 16 Aug 2023 16:18:57 +0200 Subject: [PATCH 013/244] WIP --- bioimageio/core/commands.py | 4 +- bioimageio/core/resource_io/io_.py | 63 +++------ bioimageio/core/resource_io/nodes.py | 187 --------------------------- pyproject.toml | 22 +++- setup.py | 13 +- 5 files changed, 49 insertions(+), 240 deletions(-) delete mode 100644 bioimageio/core/resource_io/nodes.py diff --git a/bioimageio/core/commands.py b/bioimageio/core/commands.py index b75343e2..c5d80fda 100644 --- a/bioimageio/core/commands.py +++ b/bioimageio/core/commands.py @@ -5,7 +5,7 @@ from bioimageio.core import export_resource_package from bioimageio.core.resource_io.utils import resolve_source -from bioimageio.spec.commands import validate +from bioimageio.spec import validate from bioimageio.spec.shared.raw_nodes import URI @@ -16,7 +16,7 @@ def package( verbose: bool = False, ) -> int: """Package a bioimage.io resource described by a bioimage.io Resource Description File (RDF).""" - code = validate(rdf_source, update_format=True, update_format_inner=True) + rd, summary = load_description(rdf_source, update_format=True, update_format_inner=True) source_name = rdf_source.get("name") if isinstance(rdf_source, dict) else rdf_source if code["status"] != "passed": print(f"Cannot package invalid bioimage.io RDF {source_name}") diff --git a/bioimageio/core/resource_io/io_.py b/bioimageio/core/resource_io/io_.py index f6af4ca0..cb7bc7d9 100644 --- a/bioimageio/core/resource_io/io_.py +++ b/bioimageio/core/resource_io/io_.py @@ -2,13 +2,10 @@ import pathlib from copy import deepcopy from tempfile import TemporaryDirectory -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Literal, Optional, Sequence, Union from zipfile import ZIP_DEFLATED, ZipFile -from marshmallow import missing - -from bioimageio import spec -from bioimageio.core.resource_io.nodes import ResourceDescription +from bioimageio.spec._internal._constants import DISCOVER from bioimageio.spec import load_raw_resource_description from bioimageio.spec.shared import raw_nodes from bioimageio.spec.shared.common import ( @@ -17,7 +14,7 @@ get_class_name_from_type, no_cache_tmp_list, ) -from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription +from bioimageio.spec import ResourceDescription from . import nodes from .utils import resolve_raw_node, resolve_source @@ -25,45 +22,21 @@ save_raw_resource_description = spec.io_.save_raw_resource_description -def load_resource_description( - source: Union[RawResourceDescription, ResourceDescription, os.PathLike, str, dict, raw_nodes.URI], - *, - weights_priority_order: Optional[Sequence[str]] = None, # model only -) -> ResourceDescription: - """load a bioimage.io resource description file (RDF). - This includes some transformations for convenience, e.g. importing `source`. - Use `load_raw_resource_description` to obtain a raw representation instead. - - Args: - source: resource description file (RDF) or raw bioimage.io resource - weights_priority_order: If given only the first weights format present in the model resource is included - Returns: - bioimage.io resource - """ - source = deepcopy(source) - if isinstance(source, ResourceDescription): - return source - - raw_rd = load_raw_resource_description(source, update_to_format="latest") - - if raw_rd.type == "model" and weights_priority_order is not None: - for wf in weights_priority_order: - if wf in raw_rd.weights: - raw_rd.weights = {wf: raw_rd.weights[wf]} - break - else: - raise ValueError(f"Not found any of the specified weights formats {weights_priority_order}") - - rd: ResourceDescription = resolve_raw_node(raw_rd=raw_rd, nodes_module=nodes) - assert isinstance(rd, getattr(nodes, get_class_name_from_type(raw_rd.type))) - - return rd - - def get_local_resource_package_content( - source: RawResourceDescription, - weights_priority_order: Optional[Sequence[Union[str]]], - update_to_format: Optional[str] = None, + source: ResourceDescription, + weights_priority_order: Optional[ + Sequence[ + Literal[ + "keras_hdf5", + "onnx", + "pytorch_state_dict", + "tensorflow_js", + "tensorflow_saved_model_bundle", + "torchscript", + ] + ] + ], + format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, ) -> Dict[str, Union[pathlib.Path, str]]: """ @@ -77,7 +50,7 @@ def get_local_resource_package_content( Package content of local file paths or text content keyed by file names. """ - raw_rd = load_raw_resource_description(source, update_to_format=update_to_format) + rd = load_resource_description(source, update_to_format=update_to_format) package_content = spec.get_resource_package_content(raw_rd, weights_priority_order=weights_priority_order) local_package_content = {} diff --git a/bioimageio/core/resource_io/nodes.py b/bioimageio/core/resource_io/nodes.py deleted file mode 100644 index 47e2035f..00000000 --- a/bioimageio/core/resource_io/nodes.py +++ /dev/null @@ -1,187 +0,0 @@ -import pathlib -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple, Union - -from marshmallow import missing -from marshmallow.utils import _Missing - -from bioimageio.spec.model import raw_nodes as model_raw_nodes -from bioimageio.spec.rdf import raw_nodes as rdf_raw_nodes -from bioimageio.spec.collection import raw_nodes as collection_raw_nodes -from bioimageio.spec.shared import raw_nodes - - -@dataclass -class Node(raw_nodes.RawNode): - pass - - -@dataclass -class ResourceDescription(Node, raw_nodes.ResourceDescription): - pass - - -@dataclass -class URI(Node, raw_nodes.URI): - pass - - -@dataclass -class ParametrizedInputShape(Node, raw_nodes.ParametrizedInputShape): - pass - - -@dataclass -class ImplicitOutputShape(Node, raw_nodes.ImplicitOutputShape): - pass - - -@dataclass -class Dependencies(Node, raw_nodes.Dependencies): - file: pathlib.Path = missing - - -@dataclass -class CiteEntry(Node, rdf_raw_nodes.CiteEntry): - pass - - -@dataclass -class Author(Node, model_raw_nodes.Author): - pass - - -@dataclass -class Maintainer(Node, model_raw_nodes.Maintainer): - pass - - -@dataclass -class Badge(Node, rdf_raw_nodes.Badge): - pass - - -@dataclass -class RDF(rdf_raw_nodes.RDF, ResourceDescription): - badges: Union[_Missing, List[Badge]] = missing - covers: Union[_Missing, List[Path]] = missing - - -@dataclass -class CollectionEntry(Node, collection_raw_nodes.CollectionEntry): - source: URI = missing - - -@dataclass -class LinkedDataset(Node, model_raw_nodes.LinkedDataset): - pass - - -@dataclass -class ModelParent(Node, model_raw_nodes.ModelParent): - pass - - -@dataclass -class Collection(collection_raw_nodes.Collection, RDF): - pass - - -@dataclass -class RunMode(Node, model_raw_nodes.RunMode): - pass - - -@dataclass -class Preprocessing(Node, model_raw_nodes.Preprocessing): - pass - - -@dataclass -class Postprocessing(Node, model_raw_nodes.Postprocessing): - pass - - -@dataclass -class InputTensor(Node, model_raw_nodes.InputTensor): - axes: Tuple[str, ...] = missing - - def __post_init__(self): - super().__post_init__() - # raw node has string with single letter axes names. Here we use a tuple of string here (like xr.DataArray). - self.axes = tuple(self.axes) - - -@dataclass -class OutputTensor(Node, model_raw_nodes.OutputTensor): - axes: Tuple[str, ...] = missing - - def __post_init__(self): - super().__post_init__() - # raw node has string with single letter axes names. Here we use a tuple of string here (like xr.DataArray). - self.axes = tuple(self.axes) - - -@dataclass -class ImportedSource(Node): - factory: Callable - - def __call__(self, *args, **kwargs): - return self.factory(*args, **kwargs) - - -@dataclass -class KerasHdf5WeightsEntry(Node, model_raw_nodes.KerasHdf5WeightsEntry): - source: Path = missing - - -@dataclass -class OnnxWeightsEntry(Node, model_raw_nodes.OnnxWeightsEntry): - source: Path = missing - - -@dataclass -class PytorchStateDictWeightsEntry(Node, model_raw_nodes.PytorchStateDictWeightsEntry): - source: Path = missing - architecture: Union[_Missing, ImportedSource] = missing - - -@dataclass -class TorchscriptWeightsEntry(Node, model_raw_nodes.TorchscriptWeightsEntry): - source: Path = missing - - -@dataclass -class TensorflowJsWeightsEntry(Node, model_raw_nodes.TensorflowJsWeightsEntry): - source: Path = missing - - -@dataclass -class TensorflowSavedModelBundleWeightsEntry(Node, model_raw_nodes.TensorflowSavedModelBundleWeightsEntry): - source: Path = missing - - -@dataclass -class Attachments(Node, rdf_raw_nodes.Attachments): - files: Union[_Missing, List[Path]] = missing - unknown: Union[_Missing, Dict[str, Any]] = missing - - -WeightsEntry = Union[ - KerasHdf5WeightsEntry, - OnnxWeightsEntry, - PytorchStateDictWeightsEntry, - TensorflowJsWeightsEntry, - TensorflowSavedModelBundleWeightsEntry, - TorchscriptWeightsEntry, -] - - -@dataclass -class Model(model_raw_nodes.Model, RDF): - authors: List[Author] = missing - maintainers: Union[_Missing, List[Maintainer]] = missing - test_inputs: List[Path] = missing - test_outputs: List[Path] = missing - weights: Dict[model_raw_nodes.WeightsFormat, WeightsEntry] = missing diff --git a/pyproject.toml b/pyproject.toml index 4e42573b..6cd771d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,23 @@ [tool.black] line-length = 120 -target-version = ['py38'] +target-version = ['py38', 'py39', 'py310', 'py311'] + +[tool.ruff] +line-length = 120 + +[tool.pyright] +include = ["bioimageio", "scripts", "tests"] +exclude = ["**/node_modules", "**/__pycache__", "tests/old_*"] +typeCheckingMode = "strict" +reportMissingSuperCall = "error" +reportUnnecessaryTypeIgnoreComment = "error" +reportUninitializedInstanceVariable = "error" +reportUnknownMemberType = false +reportIncompatibleMethodOverride = true +reportMissingTypeArgument = true +reportMissingTypeStubs = "warning" +useLibraryCodeForTypes = true +reportUnusedCallResult = "error" +reportUnusedVariable = "error" +pythonVersion = "3.9" +pythonPlatform = "All" diff --git a/setup.py b/setup.py index 3fc27a76..4de8a245 100644 --- a/setup.py +++ b/setup.py @@ -18,21 +18,24 @@ long_description_content_type="text/markdown", url="https://github.com/bioimage-io/core-bioimage-io-python", author="Bioimage Team", - classifiers=[ # Optional + classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], - packages=find_namespace_packages(exclude=["tests"]), # Required + packages=find_namespace_packages(exclude=["tests"]), install_requires=[ "bioimageio.spec==0.4.9.*", "imageio>=2.5", "numpy", "ruamel.yaml", + "tifffile", "tqdm", + "typer", "xarray", - "tifffile", ], include_package_data=True, extras_require={ @@ -42,7 +45,7 @@ "tensorflow": ["tensorflow"], "onnx": ["onnxruntime"], }, - project_urls={ # Optional + project_urls={ "Bug Reports": "https://github.com/bioimage-io/core-bioimage-io-python/issues", "Source": "https://github.com/bioimage-io/core-bioimage-io-python", }, From 6a6456a703148994283f4a9dd38bc8ef1b107f40 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 22 Aug 2023 13:58:40 +0200 Subject: [PATCH 014/244] ignore pooch stubs for now --- .gitignore | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index c75e8e7a..4edd992c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,9 @@ -build/ -dist/ .idea/ -*.egg-info/ -cache -**/tmp .tox/ +*.egg-info/ *.pyc +**/tmp +build/ +cache +dist/ +typings/pooch/ From c6c8f5c1adbb52d66c1be4159ae5978c3d3885e9 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 23 Aug 2023 16:11:10 +0200 Subject: [PATCH 015/244] update project setup --- .markdownlint.json | 8 +++++ .vscode/settings.json | 15 ++++++++ MANIFEST.in | 2 ++ README.md | 83 +++++++++++++++++++++++++++++-------------- pyproject.toml | 10 ++++-- pytest.ini | 5 --- setup.cfg | 10 ------ setup.py | 2 +- tox.ini | 8 ----- 9 files changed, 91 insertions(+), 52 deletions(-) create mode 100644 .markdownlint.json create mode 100644 .vscode/settings.json delete mode 100644 pytest.ini delete mode 100644 setup.cfg delete mode 100644 tox.ini diff --git a/.markdownlint.json b/.markdownlint.json new file mode 100644 index 00000000..8111539b --- /dev/null +++ b/.markdownlint.json @@ -0,0 +1,8 @@ +{ + "default": true, + "MD013": { + "line_length": 120 + }, + "MD033": false, + "MD041": false +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..9520c20f --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,15 @@ +{ + "python.languageServer": "Pylance", + "python.analysis.typeCheckingMode": "strict", + "python.linting.pylintEnabled": true, + "python.linting.enabled": false, + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test_*.py" + ], + "python.testing.pytestEnabled": true, + "python.testing.unittestEnabled": false, +} \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index 031d8dc7..e1d35f13 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,3 @@ include bioimageio/core/VERSION +include README.md +include LICENSE diff --git a/README.md b/README.md index 49a1831a..c78067e4 100644 --- a/README.md +++ b/README.md @@ -7,46 +7,60 @@ Python specific core utilities for running models in the [BioImage Model Zoo](ht ### Via Conda The `bioimageio.core` package can be installed from conda-forge via -``` + +```console conda install -c conda-forge bioimageio.core ``` -if you don't install any additional deep learning libraries, you will only be able to use general convenience functionality, but not any functionality for model prediction. + +If you do not install any additional deep learning libraries, you will only be able to use general convenience +functionality, but not any functionality for model prediction. To install additional deep learning libraries use: * Pytorch/Torchscript: - ```bash - # cpu installation (if you don't have an nvidia graphics card) + + CPU installation (if you don't have an nvidia graphics card): + + ```console conda install -c pytorch -c conda-forge bioimageio.core pytorch torchvision cpuonly + ``` + + GPU installation (for cuda 11.6, please choose the appropriate cuda version for your system): - # gpu installation (for cuda 11.6, please choose the appropriate cuda version for your system) - conda install -c pytorch -c nvidia -c conda-forge bioimageio.core pytorch torchvision pytorch-cuda=11.6 + ```console + conda install -c pytorch -c nvidia -c conda-forge bioimageio.core pytorch torchvision pytorch-cuda=11.6 ``` - Note that the pytorch installation instructions may change in the future. For the latest instructions please refer to [pytorch.org](https://pytorch.org/). + Note that the pytorch installation instructions may change in the future. For the latest instructions please refer to [pytorch.org](https://pytorch.org/). * Tensorflow - ```bash - # currently only cpu version supported + + Currently only CPU version supported + + ```console conda install -c conda-forge bioimageio.core tensorflow ``` * ONNXRuntime - ```bash - # currently only cpu version supported + + Currently only cpu version supported + + ```console conda install -c conda-forge bioimageio.core onnxruntime ``` - + ### Via pip The package is also available via pip: -``` + +```console pip install bioimageio.core ``` ### Set up Development Environment To set up a development conda environment run the following commands: -``` + +```console conda env create -f dev/environment-base.yaml conda activate bio-core-dev pip install -e . --no-deps @@ -54,43 +68,60 @@ pip install -e . --no-deps There are different environment files that only install tensorflow or pytorch as dependencies available. -## Command Line +## 🏞 Environment variables -`bioimageio.core` installs a command line interface for testing models and other functionality. You can list all the available commands via: -``` +| Name | Default | Description | +|---|---|---| +| BIOIMAGEIO_USE_CACHE | "true" | Enables simple URL to file cache. possible, case-insensitive, positive values are: +"true", "yes", "1". Any other value is interpreted as "false" | +| BIOIMAGEIO_CACHE_PATH | generated tmp folder | File path for simple URL to file cache; +changes of URL source are not detected. | +| BIOIMAGEIO_CACHE_WARNINGS_LIMIT | "3" | Maximum number of warnings generated for simple cache hits. | + +## 💻 Command Line + +`bioimageio.core` installs a command line interface (CLI) for testing models and other functionality. +You can list all the available commands via: + +```console bioimageio ``` Check that a model adheres to the model spec: -``` + +```console bioimageio validate ``` Test a model (including prediction for the test input): -``` + +```console bioimageio test-model ``` Run prediction for an image stored on disc: -``` + +```console bioimageio predict-image -m -i -o ``` Run prediction for multiple images stored on disc: -``` + +```console bioimagei predict-images -m -i - o ``` -`` is a `glob` pattern to select the desired images, e.g. `/path/to/my/images/*.tif`. +`` is a `glob` pattern to select the desired images, e.g. `/path/to/my/images/*.tif`. ## From python `bioimageio.core` is a python library that implements loading models, running prediction with them and more. To get an overview of this functionality, check out the example notebooks: -- [example/model_usage](https://github.com/bioimage-io/core-bioimage-io-python/blob/main/example/model_usage.ipynb) for how to load models and run prediction with them -- [example/model_creation](https://github.com/bioimage-io/core-bioimage-io-python/blob/main/example/model_creation.ipynb) for how to create bioimage.io compatible model packages -- [example/dataset_statistics_demo](https://github.com/bioimage-io/core-bioimage-io-python/blob/main/example/dataset_statistics_demo.ipynb) for how to use the dataset statistics for advanced pre-and-postprocessing + +* [example/model_usage](https://github.com/bioimage-io/core-bioimage-io-python/blob/main/example/model_usage.ipynb) for how to load models and run prediction with them +* [example/model_creation](https://github.com/bioimage-io/core-bioimage-io-python/blob/main/example/model_creation.ipynb) for how to create bioimage.io compatible model packages +* [example/dataset_statistics_demo](https://github.com/bioimage-io/core-bioimage-io-python/blob/main/example/dataset_statistics_demo.ipynb) for how to use the dataset statistics for advanced pre-and-postprocessing ## Model Specification -The model specification and its validation tools can be found at https://github.com/bioimage-io/spec-bioimage-io. +The model specification and its validation tools can be found at . diff --git a/pyproject.toml b/pyproject.toml index 6cd771d7..4263089c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,8 +2,10 @@ line-length = 120 target-version = ['py38', 'py39', 'py310', 'py311'] -[tool.ruff] -line-length = 120 +[tool.isort] +line_length = 120 +multi_line_output = 3 +include_trailing_comma = true [tool.pyright] include = ["bioimageio", "scripts", "tests"] @@ -21,3 +23,7 @@ reportUnusedCallResult = "error" reportUnusedVariable = "error" pythonVersion = "3.9" pythonPlatform = "All" + +[tool.pytest.ini_options] +addopts = "-s --doctest-modules" +# testpaths = ["bioimageio", "scripts", "example", "tests"] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index c2c366b2..00000000 --- a/pytest.ini +++ /dev/null @@ -1,5 +0,0 @@ -[pytest] -add_opts = -s --doctest-modules -testpaths = tests -#log_format = %(asctime)s.%(msecs)03d %(levelname)s %(message)s -#log_date_format = %M:%S.%f diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index d8020326..00000000 --- a/setup.cfg +++ /dev/null @@ -1,10 +0,0 @@ -[tool:isort] -line_length = 120 -multi_line_output = 3 -include_trailing_comma = true - -[flake8] -max-line-length = 120 - -[pylint] -max-line-length = 120 diff --git a/setup.py b/setup.py index 4de8a245..dcca6bab 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ ], include_package_data=True, extras_require={ - "test": ["pytest", "black", "mypy"], + "test": ["pytest", "black"], "dev": ["pre-commit"], "pytorch": ["torch>=1.6", "torchvision"], "tensorflow": ["tensorflow"], diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 7fea938f..00000000 --- a/tox.ini +++ /dev/null @@ -1,8 +0,0 @@ -[tox] -envlist = py38 - -[testenv] -deps = - pytest -commands = - pytest From 7c8adfd105f51550259f0f668fbb8741e5a82002 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 24 Aug 2023 12:59:30 +0200 Subject: [PATCH 016/244] add py.typed to indicate that the package is fully typed --- bioimageio/core/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 bioimageio/core/py.typed diff --git a/bioimageio/core/py.typed b/bioimageio/core/py.typed new file mode 100644 index 00000000..e69de29b From 1e92a4f6728978e46cb15d5f3f24771c0cc0ca25 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 24 Aug 2023 13:01:07 +0200 Subject: [PATCH 017/244] WIP --- bioimageio/core/__init__.py | 72 ++--- .../{utils.py => _internal/pytest_utils.py} | 11 +- .../core/_internal/validation_visitors.py | 211 +++++++++++++++ bioimageio/core/_io.py | 246 ++++++++++++++++++ bioimageio/core/commands.py | 5 +- bioimageio/core/common.py | 17 +- bioimageio/core/prediction.py | 12 +- .../_prediction_pipeline.py | 5 +- bioimageio/core/resource_io/__init__.py | 9 - bioimageio/core/resource_io/io_.py | 158 ----------- bioimageio/core/resource_io/utils.py | 187 ------------- bioimageio/core/resource_tests.py | 21 +- tests/conftest.py | 4 +- .../test_device_management.py | 2 +- .../test_internal/test_validation_visitors.py | 40 +++ 15 files changed, 588 insertions(+), 412 deletions(-) rename bioimageio/core/{utils.py => _internal/pytest_utils.py} (74%) create mode 100644 bioimageio/core/_internal/validation_visitors.py create mode 100644 bioimageio/core/_io.py delete mode 100644 bioimageio/core/resource_io/__init__.py delete mode 100644 bioimageio/core/resource_io/io_.py delete mode 100644 bioimageio/core/resource_io/utils.py create mode 100644 tests/test_internal/test_validation_visitors.py diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 80bed641..9e69bbb8 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -1,37 +1,51 @@ import json -import pathlib -__version__ = json.loads((pathlib.Path(__file__).parent / "VERSION").read_text())["version"] +from bioimageio.spec._internal.utils import files -from .prediction import ( - predict_image, - predict_images, - predict_with_padding, - predict_with_tiling, +from bioimageio.core._io import ( + load_description_and_validate, + read_rdf, + resolve_source, + validate, + write_rdf, + write_zipped_resource_package, ) -from .prediction_pipeline import create_prediction_pipeline -from .resource_io import ( - export_resource_package, - load_raw_resource_description, - load_resource_description, - save_raw_resource_description, - serialize_raw_resource_description, -) -from .resource_tests import check_input_shape, check_output_shape, test_resource + +with files("bioimageio.core").joinpath("VERSION").open("r", encoding="utf-8") as f: + __version__: str = json.load(f)["version"] + assert isinstance(__version__, str) + +# __version__ = json.loads((pathlib.Path(__file__).parent / "VERSION").read_text())["version"] +# from .prediction import predict_image, predict_images, predict_with_padding, predict_with_tiling +# from .prediction_pipeline import create_prediction_pipeline +# from .resource_io import ( +# export_resource_package, +# load_raw_resource_description, +# load_resource_description, +# save_raw_resource_description, +# serialize_raw_resource_description, +# ) +# from .resource_tests import check_input_shape, check_output_shape, test_resource __all__ = [ "__version__", - "check_input_shape", - "check_output_shape", - "create_prediction_pipeline", - "export_resource_package", - "load_raw_resource_description", - "load_resource_description", - "predict_image", - "predict_images", - "predict_with_padding", - "predict_with_tiling", - "save_raw_resource_description", - "serialize_raw_resource_description", - "test_resource", + "load_description_and_validate", + "read_rdf", + "resolve_source", + "validate", + "write_rdf", + "write_zipped_resource_package", + # "check_input_shape", + # "check_output_shape", + # "create_prediction_pipeline", + # "export_resource_package", + # "load_raw_resource_description", + # "load_resource_description", + # "predict_image", + # "predict_images", + # "predict_with_padding", + # "predict_with_tiling", + # "save_raw_resource_description", + # "serialize_raw_resource_description", + # "test_resource", ] diff --git a/bioimageio/core/utils.py b/bioimageio/core/_internal/pytest_utils.py similarity index 74% rename from bioimageio/core/utils.py rename to bioimageio/core/_internal/pytest_utils.py index 770f7d21..4ae0b8e4 100644 --- a/bioimageio/core/utils.py +++ b/bioimageio/core/_internal/pytest_utils.py @@ -1,5 +1,10 @@ from functools import wraps -from typing import Type +from typing import Any, Protocol, Type + + +class test_func(Protocol): + def __call__(*args: Any, **kwargs: Any): + ... def skip_on(exception: Type[Exception], reason: str): @@ -7,9 +12,9 @@ def skip_on(exception: Type[Exception], reason: str): import pytest # Func below is the real decorator and will receive the test function as param - def decorator_func(f): + def decorator_func(f: test_func): @wraps(f) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any): try: # Try to run the test return f(*args, **kwargs) diff --git a/bioimageio/core/_internal/validation_visitors.py b/bioimageio/core/_internal/validation_visitors.py new file mode 100644 index 00000000..6cb0b503 --- /dev/null +++ b/bioimageio/core/_internal/validation_visitors.py @@ -0,0 +1,211 @@ +import hashlib +import importlib.util +import os +import sys +from dataclasses import dataclass, replace +from functools import singledispatchmethod +from pathlib import Path, PosixPath, PurePath +from types import ModuleType +from typing import Any, Hashable, List, Optional, Tuple, TypedDict, Union + +from annotated_types import SLOTS +from bioimageio.spec._internal.base_nodes import Node +from bioimageio.spec._internal.constants import IN_PACKAGE_MESSAGE, KW_ONLY, SLOTS +from bioimageio.spec.description import ResourceDescription +from bioimageio.spec.summary import ErrorOutcome, WarningOutcome +from bioimageio.spec.types import Loc +from pydantic import AnyUrl, DirectoryPath +from pydantic.fields import FieldInfo +from typing_extensions import NotRequired, Unpack + + +class VisitorKwargs(TypedDict): + info: NotRequired[FieldInfo] + + +@dataclass(frozen=True, **SLOTS, **KW_ONLY) +class Note: + loc: Loc = () + info: Optional[FieldInfo] = None + + +class ValidationVisitor: + def __init__(self) -> None: + super().__init__() + self.errors: List[ErrorOutcome] = [] + self.warnings: List[WarningOutcome] = [] + + @singledispatchmethod + def visit(self, obj: type, /, note: Note = Note()): + pass + + @visit.register + def _visit_node(self, node: Node, note: Note = Note()): + for k, v in node: + self.visit(v, replace(note, loc=note.loc + (k,), info=node.model_fields[k])) + + @visit.register + def _visit_list(self, lst: list, note: Note = Note()): # type: ignore + for i, e in enumerate(lst): # type: ignore + self.visit(e, replace(note, loc=note.loc + (i,))) + + @visit.register + def _visit_tuple(self, tup: tuple, note: Note = Note()): # type: ignore + for i, e in enumerate(tup): # type: ignore + self.visit(e, replace(note, loc=note.loc + (i,))) + + @visit.register + def _visit_dict(self, dict_: dict, note: Note = Note()): # type: ignore + for k, v in dict_.items(): # type: ignore + self.visit(v, replace(note, loc=note.loc + (k,))) + + +class SourceValidator(ValidationVisitor): + def __init__(self, root: Union[DirectoryPath, AnyUrl]) -> None: + super().__init__() + self.root = root + + # def _visit_path(self, path: PurePath, info: FieldInfo): + # if not Path(path).exists(): + + +# # info.description.startswith(IN_PACKAGE_MESSAGE) +# if not source_available(leaf, self.root_path): +# raise FileNotFoundError(leaf) + +# def visit_URI(self, node: raw_nodes.URI): +# self._visit_source(node) + +# def visit_PosixPath(self, leaf: PosixPath): +# self._visit_source(leaf) + +# def visit_WindowsPath(self, leaf: pathlib.WindowsPath): +# self._visit_source(leaf) + +# def generic_visit(self, node): +# """Called if no explicit visitor function exists for a node.""" + +# if isinstance(node, raw_nodes.RawNode): +# for field, value in iter_fields(node): +# if field != "root_path": # do not visit root_path, as it might be an incomplete (non-available) URL +# self.visit(value) +# else: +# super().generic_visit(node) + + +# def get_sha256(path: os.PathLike) -> str: +# """from https://stackoverflow.com/a/44873382""" +# h = hashlib.sha256() +# b = bytearray(128 * 1024) +# mv = memoryview(b) +# with open(path, "rb", buffering=0) as f: +# for n in iter(lambda: f.readinto(mv), 0): +# h.update(mv[:n]) + +# return h.hexdigest() + + +# class Sha256NodeChecker(NodeVisitor): +# """Check integrity of the source-like field for every sha256-like field encountered""" + +# def __init__(self, *, root_path: os.PathLike): +# self.root_path = root_path if isinstance(root_path, raw_nodes.URI) else pathlib.Path(root_path).resolve() + +# def generic_visit(self, node): +# if isinstance(node, raw_nodes.RawNode): +# for sha_field, expected in ((k, v) for (k, v) in iter_fields(node) if "sha256" in k and v is not missing): +# if sha_field == "sha256": +# source_name = "source" +# if not hasattr(node, "source") and hasattr(node, "uri"): +# source_name = "uri" + +# elif sha_field.endswith("_sha256"): +# source_name = sha_field[: -len("_sha256")] +# else: +# raise NotImplementedError(f"Don't know how to check integrity with {sha_field}") + +# if not hasattr(node, source_name): +# raise ValueError( +# f"Node {node} expected to have '{source_name}' field associated with '{sha_field}'" +# ) + +# source_node = getattr(node, source_name) +# if isinstance(source_node, ImportedSource): +# continue # test is run after loading. Warning issued in resource_tests._test_resource_integrity + +# source = get_resolved_source_path(source_node, root_path=self.root_path) +# actual = get_sha256(source) + +# if not isinstance(expected, str): +# raise TypeError(f"Expected '{sha_field}' to hold string, not {type(expected)}") + +# if actual != expected: +# if actual[:6] != expected[:6]: +# actual = actual[:6] + "..." +# expected = expected[:6] + "..." + +# raise ValueError( +# f"Determined {actual} for {source_name}={source}, but expected {sha_field}={expected}" +# ) + +# super().generic_visit(node) + + +# class SourceNodeTransformer(NodeTransformer): +# """ +# Imports all source callables +# note: Requires previous transformation by UriNodeTransformer +# """ + +# class TemporaryInsertionIntoPythonPath: +# def __init__(self, path: str): +# self.path = path + +# def __enter__(self): +# sys.path.insert(0, self.path) + +# def __exit__(self, exc_type, exc_value, traceback): +# sys.path.remove(self.path) + +# def transform_LocalImportableModule(self, node: raw_nodes.LocalImportableModule) -> nodes.ImportedSource: +# with self.TemporaryInsertionIntoPythonPath(str(node.root_path)): +# module = importlib.import_module(node.module_name) + +# return nodes.ImportedSource(factory=getattr(module, node.callable_name)) + +# @staticmethod +# def transform_ResolvedImportableSourceFile(node: raw_nodes.ResolvedImportableSourceFile) -> nodes.ImportedSource: +# module_path = resolve_source(node.source_file) +# module_name = f"module_from_source.{module_path.stem}" +# importlib_spec = importlib.util.spec_from_file_location(module_name, module_path) +# assert importlib_spec is not None +# dep = importlib.util.module_from_spec(importlib_spec) +# importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? +# return nodes.ImportedSource(factory=getattr(dep, node.callable_name)) + + +# class RawNodeTypeTransformer(NodeTransformer): +# def __init__(self, nodes_module: ModuleType): +# super().__init__() +# self.nodes = nodes_module + +# def generic_transformer(self, node: GenericRawNode) -> GenericResolvedNode: +# if isinstance(node, raw_nodes.RawNode): +# resolved_data = { +# field.name: self.transform(getattr(node, field.name)) for field in dataclasses.fields(node) +# } +# resolved_node_type: typing.Type[GenericResolvedNode] = getattr(self.nodes, node.__class__.__name__) +# return resolved_node_type(**resolved_data) # type: ignore +# else: +# return super().generic_transformer(node) + + +# def all_sources_available( +# node: typing.Union[GenericNode, list, tuple, dict], root_path: os.PathLike = pathlib.Path() +# ) -> bool: +# try: +# SourceNodeChecker(root_path=root_path).visit(node) +# except FileNotFoundError: +# return False +# else: +# return True diff --git a/bioimageio/core/_io.py b/bioimageio/core/_io.py new file mode 100644 index 00000000..5dbc8b63 --- /dev/null +++ b/bioimageio/core/_io.py @@ -0,0 +1,246 @@ +import collections.abc +import os +import shutil +from pathlib import Path +from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import Dict, Literal, Mapping, NamedTuple, Optional, Sequence, Tuple, Union, cast +from urllib.parse import urlsplit, urlunsplit +from zipfile import ZIP_DEFLATED, ZipFile + +import pooch +from bioimageio.spec import ResourceDescription, load_description +from bioimageio.spec._internal.base_nodes import ResourceDescriptionBase +from bioimageio.spec._internal.constants import LATEST +from bioimageio.spec.description import dump_description +from bioimageio.spec.model.v0_4 import WeightsFormat +from bioimageio.spec.package import extract_file_name, get_resource_package_content +from bioimageio.spec.summary import ValidationSummary +from bioimageio.spec.types import FileName, RawStringMapping, RawValue, RelativeFilePath, ValidationContext +from pydantic import AnyUrl, DirectoryPath, FilePath, HttpUrl, TypeAdapter +from ruamel.yaml import YAML + +yaml = YAML(typ="safe") + +# def _resolve_resource_description_source(rd: Union[ResourceDescriptionSource, ResourceDescription]) -> : +# HttpUrl, FilePath + + +def read_description( + rdf_source: Union[HttpUrl, FilePath, str], + format_version: Union[Literal["discover"], Literal["latest"], str] = LATEST, +) -> Tuple[Optional[ResourceDescription], ValidationSummary]: + rdf_content, root, file_name = read_rdf(rdf_source) + return load_description(rdf_content, context=dict(root=root, file_name=file_name), format_version=format_version) + + +def resolve_source( + source: Union[HttpUrl, FilePath, RelativeFilePath, str], + *, + known_hash: Optional[str] = None, + root: Union[DirectoryPath, AnyUrl, None] = None, +) -> FilePath: + if isinstance(source, str): + source = TypeAdapter(Union[HttpUrl, FilePath, RelativeFilePath]).validate_python(source) + + if isinstance(source, RelativeFilePath): + if root is None: + raise ValueError(f"Cannot resolve relative file path '{source}' without root.") + + source = source.get_absolute(root) + + if isinstance(source, AnyUrl): + source = Path(pooch.retrieve(source, known_hash=known_hash)) # type: ignore + + return source + + +def _get_parent_url(url: HttpUrl) -> HttpUrl: + parsed = urlsplit(str(url)) + return AnyUrl( + urlunsplit((parsed.scheme, parsed.netloc, "/".join(parsed.path.split("/")[:-1]), parsed.query, parsed.fragment)) + ) + + +def write_rdf(rd: Union[RawStringMapping, ResourceDescription], path: Path): + if isinstance(rd, ResourceDescriptionBase): + rdf_content = dump_description(rd) + else: + rdf_content = rd + + with path.open("w", encoding="utf-8") as f: + yaml.dump(rdf_content, f) + + +FileSource = Union[HttpUrl, FilePath] + + +class Rdf(NamedTuple): + content: RawStringMapping + root: Union[HttpUrl, DirectoryPath] + file_name: str + + +def read_rdf(source: Union[FileSource, str], known_hash: Optional[str] = None, encoding: Optional[str] = None) -> Rdf: + if isinstance(source, str): + source = TypeAdapter(FileSource).validate_python(source) + + src_msg = str(source) + if isinstance(source, AnyUrl): + cached_source: FilePath = Path(pooch.retrieve(url=str(source), known_hash=known_hash)) # type: ignore + src_msg += f" cached at {cached_source}" + local_source = cached_source + root: Union[HttpUrl, DirectoryPath] = _get_parent_url(source) + else: + local_source = source + root = source.parent + + with local_source.open(encoding=encoding) as f: + content: RawValue = yaml.load(f) + + if not isinstance(content, collections.abc.Mapping): + raise TypeError(f"Expected RDF content to be a mapping, but got '{type(content)}'.") + + if non_string_keys := [k for k in content if not isinstance(k, str)]: + raise TypeError(f"Got non-string keys {non_string_keys} in {src_msg}") + + return Rdf( + content=cast(RawStringMapping, content), + root=root, + file_name=extract_file_name(source), + ) + + +def load_description_and_validate( + rdf_content: RawStringMapping, + *, + context: Optional[ValidationContext] = None, +) -> Tuple[Optional[ResourceDescription], ValidationSummary]: + """load and validate a BioImage.IO description from the content of a resource description file (RDF)""" + rd, summary = load_description(rdf_content, context=context, format_version="latest") + # todo: add validation + return rd, summary + + +def validate( + rdf_content: RawStringMapping, + *, + context: Optional[ValidationContext] = None, +) -> ValidationSummary: + _rd, summary = load_description_and_validate(rdf_content, context=context) + return summary + + +def prepare_resource_package( + rd: ResourceDescription, + *, + root: Union[AnyUrl, DirectoryPath], + output_folder: DirectoryPath, + weights_priority_order: Optional[Sequence[WeightsFormat]] = None, +) -> Dict[FileName, FilePath]: + """Prepare to package a resource description; downloads all required files. + + Args: + rd: bioimage.io resource description + root: URL or path to resolve relative file paths in `rd` + weights_priority_order: If given only the first weights format present in the model is included. + If none of the prioritized weights formats is found all are included. + """ + package_content = get_resource_package_content(rd, weights_priority_order=weights_priority_order) + + output_folder.mkdir(parents=True, exist_ok=True) + local_package_content: Dict[FileName, FilePath] = {} + for k, v in package_content.items(): + in_package_path = output_folder / k + if isinstance(v, RelativeFilePath): + v = v.get_absolute(root) + + if isinstance(v, AnyUrl): + v = resolve_source(v, root=root) + + if isinstance(v, Path): + shutil.copy(str(v), str(in_package_path)) + else: + assert isinstance(v, collections.abc.Mapping) + write_rdf(v, in_package_path) + + local_package_content[k] = in_package_path + + return local_package_content + + +def write_zipped_resource_package( + rd: ResourceDescription, + *, + root: Union[AnyUrl, DirectoryPath], + compression: int = ZIP_DEFLATED, + compression_level: int = 1, + output_path: Optional[os.PathLike[str]] = None, + weights_priority_order: Optional[ # model only + Sequence[ + Literal[ + "keras_hdf5", + "onnx", + "pytorch_state_dict", + "tensorflow_js", + "tensorflow_saved_model_bundle", + "torchscript", + ] + ] + ] = None, +) -> FilePath: + """Package a bioimage.io resource as a zip file. + + Args: + rd: bioimage.io resource description + root: reference for any relative file paths in the bioimage.io resource description + compression: The numeric constant of compression method. + compression_level: Compression level to use when writing files to the archive. + See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile + output_path: file path to write package to + weights_priority_order: If given only the first weights format present in the model is included. + If none of the prioritized weights formats is found all are included. + + Returns: + path to zipped bioimage.io package in BIOIMAGEIO_CACHE_PATH or 'output_path' + """ + + with TemporaryDirectory() as tmp_dir: + package_content = prepare_resource_package( + rd, + root=root, + output_folder=Path(tmp_dir), + weights_priority_order=weights_priority_order, + ) + + if output_path is None: + output_path = Path(NamedTemporaryFile(suffix=".bioimageio.zip", delete=False).name) + else: + output_path = Path(output_path) + + _write_zip(output_path, package_content, compression=compression, compression_level=compression_level) + return output_path + + +def _write_zip( + path: os.PathLike[str], + content: Mapping[FileName, Union[str, FilePath]], + *, + compression: int, + compression_level: int, +) -> None: + """Write a zip archive. + + Args: + path: output path to write to. + content: dict with archive names and local file paths or strings for text files. + compression: The numeric constant of compression method. + compression_level: Compression level to use when writing files to the archive. + See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile + + """ + with ZipFile(path, "w", compression=compression, compresslevel=compression_level) as myzip: + for arc_name, file_or_str_content in content.items(): + if isinstance(file_or_str_content, str): + myzip.writestr(arc_name, file_or_str_content) + else: + myzip.write(file_or_str_content, arcname=arc_name) diff --git a/bioimageio/core/commands.py b/bioimageio/core/commands.py index c5d80fda..d1ebc2b5 100644 --- a/bioimageio/core/commands.py +++ b/bioimageio/core/commands.py @@ -3,11 +3,12 @@ from pathlib import Path from typing import List, Optional, Union -from bioimageio.core import export_resource_package -from bioimageio.core.resource_io.utils import resolve_source from bioimageio.spec import validate from bioimageio.spec.shared.raw_nodes import URI +from bioimageio.core import export_resource_package +from bioimageio.core._internal.validation_visitors import resolve_source + def package( rdf_source: Union[Path, str, URI, dict], diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 2a660d0c..846bb668 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,5 +1,20 @@ -from bioimageio.spec.shared.common import ValidationSummary +import getpass +import os +import tempfile +import warnings +from pathlib import Path + +from bioimageio.spec.types import ValidationSummary class TestSummary(ValidationSummary): bioimageio_core_version: str + + +# BIOIMAGEIO_CACHE_PATH = Path( +# os.getenv("BIOIMAGEIO_CACHE_PATH", Path(tempfile.gettempdir()) / getpass.getuser() / "bioimageio_cache") +# ) + +# BIOIMAGEIO_USE_CACHE = os.getenv("BIOIMAGEIO_USE_CACHE", "true").lower() in ("1", "yes", "true") +# if (env_val := os.getenv("BIOIMAGEIO_USE_CACHE", "true").lower()) not in ("0", "1", "no", "yes", "false", "true"): +# warnings.warn(f"Unrecognized BIOIMAGEIO_USE_CACHE environment value '{env_val}'") diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index b7e195e1..30026b0a 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -3,17 +3,17 @@ from fractions import Fraction from itertools import product from pathlib import Path -from typing import Dict, Iterator, List, NamedTuple, Optional, OrderedDict, Sequence, Tuple, Union +from typing import Any, Dict, Iterator, List, NamedTuple, Optional, OrderedDict, Sequence, Tuple, Union import numpy as np +from pydantic import HttpUrl import xarray as xr from bioimageio.core import image_helper, load_resource_description from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline from bioimageio.core.resource_io.nodes import ImplicitOutputShape, Model, ResourceDescription -from bioimageio.spec.shared import raw_nodes -from bioimageio.spec.shared.common import tqdm -from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription +from tqdm import tqdm +from bioimageio.spec import ResourceDescription def _apply_crop(data, crop): @@ -428,7 +428,7 @@ def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling): def predict_image( - model_rdf: Union[RawResourceDescription, ResourceDescription, os.PathLike, str, dict, raw_nodes.URI], + model_rdf: RdfSource, inputs: Union[Tuple[Path, ...], List[Path], Path], outputs: Union[Tuple[Path, ...], List[Path], Path], padding: Optional[Union[bool, Dict[str, int]]] = None, @@ -469,7 +469,7 @@ def predict_image( def predict_images( - model_rdf: Union[RawResourceDescription, ResourceDescription, os.PathLike, str, dict, raw_nodes.URI], + model_rdf: RdfSource, inputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]], outputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]], padding: Optional[Union[bool, Dict[str, int]]] = None, diff --git a/bioimageio/core/prediction_pipeline/_prediction_pipeline.py b/bioimageio/core/prediction_pipeline/_prediction_pipeline.py index dc98a373..8d92ee62 100644 --- a/bioimageio/core/prediction_pipeline/_prediction_pipeline.py +++ b/bioimageio/core/prediction_pipeline/_prediction_pipeline.py @@ -4,11 +4,12 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union import xarray as xr +from bioimageio.spec.model import raw_nodes from marshmallow import missing +from bioimageio.core._internal.validation_visitors import resolve_raw_node from bioimageio.core.resource_io import nodes -from bioimageio.core.resource_io.utils import resolve_raw_node -from bioimageio.spec.model import raw_nodes + from ._combined_processing import CombinedProcessing from ._model_adapters import ModelAdapter, create_model_adapter from ._stat_state import StatsState diff --git a/bioimageio/core/resource_io/__init__.py b/bioimageio/core/resource_io/__init__.py deleted file mode 100644 index bdfb805f..00000000 --- a/bioimageio/core/resource_io/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -import bioimageio.spec -from .io_ import ( - export_resource_package, - load_resource_description, - save_raw_resource_description, - serialize_raw_resource_description, -) - -load_raw_resource_description = bioimageio.spec.load_raw_resource_description diff --git a/bioimageio/core/resource_io/io_.py b/bioimageio/core/resource_io/io_.py deleted file mode 100644 index cb7bc7d9..00000000 --- a/bioimageio/core/resource_io/io_.py +++ /dev/null @@ -1,158 +0,0 @@ -import os -import pathlib -from copy import deepcopy -from tempfile import TemporaryDirectory -from typing import Dict, Literal, Optional, Sequence, Union -from zipfile import ZIP_DEFLATED, ZipFile - -from bioimageio.spec._internal._constants import DISCOVER -from bioimageio.spec import load_raw_resource_description -from bioimageio.spec.shared import raw_nodes -from bioimageio.spec.shared.common import ( - BIOIMAGEIO_CACHE_PATH, - BIOIMAGEIO_USE_CACHE, - get_class_name_from_type, - no_cache_tmp_list, -) -from bioimageio.spec import ResourceDescription -from . import nodes -from .utils import resolve_raw_node, resolve_source - -serialize_raw_resource_description = spec.io_.serialize_raw_resource_description -save_raw_resource_description = spec.io_.save_raw_resource_description - - -def get_local_resource_package_content( - source: ResourceDescription, - weights_priority_order: Optional[ - Sequence[ - Literal[ - "keras_hdf5", - "onnx", - "pytorch_state_dict", - "tensorflow_js", - "tensorflow_saved_model_bundle", - "torchscript", - ] - ] - ], - format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, -) -> Dict[str, Union[pathlib.Path, str]]: - """ - - Args: - source: raw resource description - weights_priority_order: If given only the first weights format present in the model is included. - If none of the prioritized weights formats is found all are included. - update_to_format: update resource to specific major.minor format version; ignoring patch version. - - Returns: - Package content of local file paths or text content keyed by file names. - - """ - rd = load_resource_description(source, update_to_format=update_to_format) - package_content = spec.get_resource_package_content(raw_rd, weights_priority_order=weights_priority_order) - - local_package_content = {} - for k, v in package_content.items(): - if isinstance(v, raw_nodes.URI): - v = resolve_source(v, raw_rd.root_path) - elif isinstance(v, pathlib.Path): - v = raw_rd.root_path / v - - local_package_content[k] = v - - return local_package_content - - -def export_resource_package( - source: Union[RawResourceDescription, os.PathLike, str, dict, raw_nodes.URI], - *, - compression: int = ZIP_DEFLATED, - compression_level: int = 1, - output_path: Optional[os.PathLike] = None, - update_to_format: Optional[str] = None, - weights_priority_order: Optional[Sequence[Union[str]]] = None, -) -> pathlib.Path: - """Package a bioimage.io resource as a zip file. - - Args: - source: raw resource description, path, URI or raw data as dict - compression: The numeric constant of compression method. - compression_level: Compression level to use when writing files to the archive. - See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile - output_path: file path to write package to - update_to_format: update resource to specific "major.minor" or "latest" format version; ignoring patch version. - weights_priority_order: If given only the first weights format present in the model is included. - If none of the prioritized weights formats is found all are included. - - Returns: - path to zipped bioimage.io package in BIOIMAGEIO_CACHE_PATH or 'output_path' - """ - raw_rd = load_raw_resource_description(source, update_to_format=update_to_format) - package_content = get_local_resource_package_content( - raw_rd, weights_priority_order, update_to_format=update_to_format - ) - if output_path is None: - package_path = _get_tmp_package_path(raw_rd, weights_priority_order) - else: - package_path = output_path - - make_zip(package_path, package_content, compression=compression, compression_level=compression_level) - return package_path - - -def _get_package_base_name(raw_rd: RawResourceDescription, weights_priority_order: Optional[Sequence[str]]) -> str: - package_file_name = raw_rd.name - if raw_rd.version is not missing: - package_file_name += f"_{raw_rd.version}" - - package_file_name = package_file_name.replace(" ", "_").replace(".", "_") - - return package_file_name - - -def _get_tmp_package_path(raw_rd: RawResourceDescription, weights_priority_order: Optional[Sequence[str]]): - if BIOIMAGEIO_USE_CACHE: - package_file_name = _get_package_base_name(raw_rd, weights_priority_order) - cache_folder = BIOIMAGEIO_CACHE_PATH / "packages" - cache_folder.mkdir(exist_ok=True, parents=True) - - package_path = (cache_folder / package_file_name).with_suffix(".zip") - max_cached_packages_with_same_name = 100 - for p in range(max_cached_packages_with_same_name): - if package_path.exists(): - package_path = (cache_folder / f"{package_file_name}p{p}").with_suffix(".zip") - else: - break - else: - raise FileExistsError( - f"Already caching {max_cached_packages_with_same_name} versions of {cache_folder / package_file_name}!" - ) - else: - tmp_dir = TemporaryDirectory() - no_cache_tmp_list.append(tmp_dir) - package_path = pathlib.Path(tmp_dir.name) / "file" - - return package_path - - -def make_zip( - path: os.PathLike, content: Dict[str, Union[str, pathlib.Path]], *, compression: int, compression_level: int -) -> None: - """Write a zip archive. - - Args: - path: output path to write to. - content: dict with archive names and local file paths or strings for text files. - compression: The numeric constant of compression method. - compression_level: Compression level to use when writing files to the archive. - See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile - - """ - with ZipFile(path, "w", compression=compression, compresslevel=compression_level) as myzip: - for arc_name, file_or_str_content in content.items(): - if isinstance(file_or_str_content, str): - myzip.writestr(arc_name, file_or_str_content) - else: - myzip.write(file_or_str_content, arcname=arc_name) diff --git a/bioimageio/core/resource_io/utils.py b/bioimageio/core/resource_io/utils.py deleted file mode 100644 index 338356e9..00000000 --- a/bioimageio/core/resource_io/utils.py +++ /dev/null @@ -1,187 +0,0 @@ -import dataclasses -import hashlib -import importlib.util -import os -import pathlib -import sys -import typing -import warnings -from types import ModuleType - -from marshmallow import missing - -from bioimageio.spec.shared import get_resolved_source_path, raw_nodes, resolve_source, source_available -from bioimageio.spec.shared.node_transformer import ( - GenericRawNode, - GenericResolvedNode, - NodeTransformer, - NodeVisitor, - UriNodeTransformer, -) -from . import nodes -from .nodes import ImportedSource - -GenericNode = typing.Union[GenericRawNode, GenericResolvedNode] - - -def iter_fields(node: GenericNode): - for field in dataclasses.fields(node): - yield field.name, getattr(node, field.name) - - -class SourceNodeChecker(NodeVisitor): - """raises FileNotFoundError for unavailable URIs and paths""" - - def __init__(self, *, root_path: os.PathLike): - self.root_path = root_path if isinstance(root_path, raw_nodes.URI) else pathlib.Path(root_path).resolve() - - def _visit_source(self, leaf: typing.Union[pathlib.Path, raw_nodes.URI]): - if not source_available(leaf, self.root_path): - raise FileNotFoundError(leaf) - - def visit_URI(self, node: raw_nodes.URI): - self._visit_source(node) - - def visit_PosixPath(self, leaf: pathlib.PosixPath): - self._visit_source(leaf) - - def visit_WindowsPath(self, leaf: pathlib.WindowsPath): - self._visit_source(leaf) - - def generic_visit(self, node): - """Called if no explicit visitor function exists for a node.""" - - if isinstance(node, raw_nodes.RawNode): - for field, value in iter_fields(node): - if field != "root_path": # do not visit root_path, as it might be an incomplete (non-available) URL - self.visit(value) - else: - super().generic_visit(node) - - -def get_sha256(path: os.PathLike) -> str: - """from https://stackoverflow.com/a/44873382""" - h = hashlib.sha256() - b = bytearray(128 * 1024) - mv = memoryview(b) - with open(path, "rb", buffering=0) as f: - for n in iter(lambda: f.readinto(mv), 0): - h.update(mv[:n]) - - return h.hexdigest() - - -class Sha256NodeChecker(NodeVisitor): - """Check integrity of the source-like field for every sha256-like field encountered""" - - def __init__(self, *, root_path: os.PathLike): - self.root_path = root_path if isinstance(root_path, raw_nodes.URI) else pathlib.Path(root_path).resolve() - - def generic_visit(self, node): - if isinstance(node, raw_nodes.RawNode): - for sha_field, expected in ((k, v) for (k, v) in iter_fields(node) if "sha256" in k and v is not missing): - if sha_field == "sha256": - source_name = "source" - if not hasattr(node, "source") and hasattr(node, "uri"): - source_name = "uri" - - elif sha_field.endswith("_sha256"): - source_name = sha_field[: -len("_sha256")] - else: - raise NotImplementedError(f"Don't know how to check integrity with {sha_field}") - - if not hasattr(node, source_name): - raise ValueError( - f"Node {node} expected to have '{source_name}' field associated with '{sha_field}'" - ) - - source_node = getattr(node, source_name) - if isinstance(source_node, ImportedSource): - continue # test is run after loading. Warning issued in resource_tests._test_resource_integrity - - source = get_resolved_source_path(source_node, root_path=self.root_path) - actual = get_sha256(source) - - if not isinstance(expected, str): - raise TypeError(f"Expected '{sha_field}' to hold string, not {type(expected)}") - - if actual != expected: - if actual[:6] != expected[:6]: - actual = actual[:6] + "..." - expected = expected[:6] + "..." - - raise ValueError( - f"Determined {actual} for {source_name}={source}, but expected {sha_field}={expected}" - ) - - super().generic_visit(node) - - -class SourceNodeTransformer(NodeTransformer): - """ - Imports all source callables - note: Requires previous transformation by UriNodeTransformer - """ - - class TemporaryInsertionIntoPythonPath: - def __init__(self, path: str): - self.path = path - - def __enter__(self): - sys.path.insert(0, self.path) - - def __exit__(self, exc_type, exc_value, traceback): - sys.path.remove(self.path) - - def transform_LocalImportableModule(self, node: raw_nodes.LocalImportableModule) -> nodes.ImportedSource: - with self.TemporaryInsertionIntoPythonPath(str(node.root_path)): - module = importlib.import_module(node.module_name) - - return nodes.ImportedSource(factory=getattr(module, node.callable_name)) - - @staticmethod - def transform_ResolvedImportableSourceFile(node: raw_nodes.ResolvedImportableSourceFile) -> nodes.ImportedSource: - module_path = resolve_source(node.source_file) - module_name = f"module_from_source.{module_path.stem}" - importlib_spec = importlib.util.spec_from_file_location(module_name, module_path) - assert importlib_spec is not None - dep = importlib.util.module_from_spec(importlib_spec) - importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? - return nodes.ImportedSource(factory=getattr(dep, node.callable_name)) - - -class RawNodeTypeTransformer(NodeTransformer): - def __init__(self, nodes_module: ModuleType): - super().__init__() - self.nodes = nodes_module - - def generic_transformer(self, node: GenericRawNode) -> GenericResolvedNode: - if isinstance(node, raw_nodes.RawNode): - resolved_data = { - field.name: self.transform(getattr(node, field.name)) for field in dataclasses.fields(node) - } - resolved_node_type: typing.Type[GenericResolvedNode] = getattr(self.nodes, node.__class__.__name__) - return resolved_node_type(**resolved_data) # type: ignore - else: - return super().generic_transformer(node) - - -def all_sources_available( - node: typing.Union[GenericNode, list, tuple, dict], root_path: os.PathLike = pathlib.Path() -) -> bool: - try: - SourceNodeChecker(root_path=root_path).visit(node) - except FileNotFoundError: - return False - else: - return True - - -def resolve_raw_node( - raw_rd: GenericRawNode, nodes_module: typing.Any, uri_only_if_in_package: bool = True -) -> GenericResolvedNode: - """resolve all uris and paths (that are included when packaging)""" - rd = UriNodeTransformer(root_path=raw_rd.root_path, uri_only_if_in_package=uri_only_if_in_package).transform(raw_rd) - rd = SourceNodeTransformer().transform(rd) - rd = RawNodeTypeTransformer(nodes_module).transform(rd) - return rd diff --git a/bioimageio/core/resource_tests.py b/bioimageio/core/resource_tests.py index a35566aa..32556671 100644 --- a/bioimageio/core/resource_tests.py +++ b/bioimageio/core/resource_tests.py @@ -9,29 +9,26 @@ import numpy import numpy as np import xarray as xr +from bioimageio.spec import __version__ as bioimageio_spec_version +from bioimageio.spec.model.raw_nodes import WeightsFormat +from bioimageio.spec.shared import resolve_source +from bioimageio.spec.shared.common import ValidationWarning +from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription from marshmallow import ValidationError -from bioimageio.core import ( - __version__ as bioimageio_core_version, - load_raw_resource_description, - load_resource_description, -) +from bioimageio.core import __version__ as bioimageio_core_version +from bioimageio.core import load_raw_resource_description, load_resource_description +from bioimageio.core._internal.validation_visitors import Sha256NodeChecker, SourceNodeChecker from bioimageio.core.common import TestSummary from bioimageio.core.prediction import predict from bioimageio.core.prediction_pipeline import create_prediction_pipeline from bioimageio.core.resource_io.nodes import ( + URI, ImplicitOutputShape, Model, ParametrizedInputShape, ResourceDescription, - URI, ) -from bioimageio.core.resource_io.utils import Sha256NodeChecker, SourceNodeChecker -from bioimageio.spec import __version__ as bioimageio_spec_version -from bioimageio.spec.model.raw_nodes import WeightsFormat -from bioimageio.spec.shared import resolve_source -from bioimageio.spec.shared.common import ValidationWarning -from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription def test_model( diff --git a/tests/conftest.py b/tests/conftest.py index 10dfe53a..788899ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,9 +6,9 @@ import pytest os.environ["BIOIMAGEIO_COUNT_RDF_DOWNLOADS"] = "false" # disable tracking before bioimageio imports -from bioimageio.core import export_resource_package from bioimageio.spec import __version__ as bioimageio_spec_version +from bioimageio.core import write_zipped_resource_package logger = logging.getLogger(__name__) warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}") @@ -132,7 +132,7 @@ def pytest_configure(): pytest.skip_onnx = skip_onnx # load all model packages used in tests - pytest.model_packages = {name: export_resource_package(model_sources[name]) for name in load_model_packages} + pytest.model_packages = {name: write_zipped_resource_package(model_sources[name]) for name in load_model_packages} pytest.mamba_cmd = "micromamba" try: diff --git a/tests/prediction_pipeline/test_device_management.py b/tests/prediction_pipeline/test_device_management.py index 98ab13c2..bbe907ad 100644 --- a/tests/prediction_pipeline/test_device_management.py +++ b/tests/prediction_pipeline/test_device_management.py @@ -4,8 +4,8 @@ from numpy.testing import assert_array_almost_equal from bioimageio.core import load_resource_description +from bioimageio.core._internal.pytest_utils import skip_on from bioimageio.core.resource_io.nodes import Model -from bioimageio.core.utils import skip_on class TooFewDevicesException(Exception): diff --git a/tests/test_internal/test_validation_visitors.py b/tests/test_internal/test_validation_visitors.py new file mode 100644 index 00000000..0aaa882c --- /dev/null +++ b/tests/test_internal/test_validation_visitors.py @@ -0,0 +1,40 @@ +from functools import singledispatchmethod + +from bioimageio.spec._internal.base_nodes import Node +from bioimageio.spec.summary import ErrorOutcome + +from bioimageio.core._internal.validation_visitors import Note, ValidationVisitor + + +def test_traversing_nodes(): + class MyVisitor(ValidationVisitor): + @singledispatchmethod + def visit(self, obj: type, note: Note = Note()): + super().visit(obj, note) + + @visit.register + def _visit_int(self, nr: int, note: Note = Note()): + super().visit(nr, note) + self.errors.append(ErrorOutcome(loc=note.loc, msg=f"nr: {nr}", type="got-int")) + + class NestedNode(Node): + leaf: int + + class MyNode(Node): + nested: NestedNode + + tree = { + "a": MyNode(nested=NestedNode(leaf=1)), + "b": [NestedNode(leaf=2), NestedNode(leaf=3)], + "c": (NestedNode(leaf=4),), + "d": {"deep": MyNode(nested=NestedNode(leaf=5))}, + } + visitor = MyVisitor() + visitor.visit(tree) + assert len(visitor.errors) == [ + ErrorOutcome(loc=("a", "nested", "leaf"), msg="nr: 1", type="got-int"), + ErrorOutcome(loc=("b", 0, "leaf"), msg="nr: 2", type="got-int"), + ErrorOutcome(loc=("b", 1, "leaf"), msg="nr: 3", type="got-int"), + ErrorOutcome(loc=("c", 0, "leaf"), msg="nr: 4", type="got-int"), + ErrorOutcome(loc=("d", "deep", "nested", "leaf"), msg="nr: 5", type="got-int"), + ] From 3083204bc290523dc1f043ee50c27c6f450c2b2c Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 24 Aug 2023 13:24:14 +0200 Subject: [PATCH 018/244] add install script draft --- scripts/setup_dev_env.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 scripts/setup_dev_env.py diff --git a/scripts/setup_dev_env.py b/scripts/setup_dev_env.py new file mode 100644 index 00000000..fc107c33 --- /dev/null +++ b/scripts/setup_dev_env.py @@ -0,0 +1,20 @@ +# untested draft! +import subprocess +from os import chdir +from pathlib import Path + + +def run(prompt: str): + _ = subprocess.run(prompt, check=True, capture_output=True) + + +repo_dir = Path(__file__).parent.parent.parent +cur_dir = Path().resolve() +chdir(str(repo_dir)) +try: + run("mamba env create --file core-bioimage-io/dev/env.yaml") + run("pip install --no-deps --config-settings editable_mode=compat -e spec-bioimage-io") + run("pip install --no-deps --config-settings editable_mode=compat -e core-bioimage-io") +except Exception: + chdir(cur_dir) + raise From 64910d9d92d0cd839a7232bf8b2a650fd86b4e05 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 30 Aug 2023 09:16:05 +0200 Subject: [PATCH 019/244] update io funcs --- bioimageio/core/__init__.py | 11 +- bioimageio/core/_internal/utils.py | 52 +++++ .../core/_internal/validation_visitors.py | 16 +- bioimageio/core/_io.py | 198 ++++++++++-------- bioimageio/core/commands.py | 52 ----- bioimageio/core/common.py | 20 -- bioimageio/core/image_helper.py | 16 +- 7 files changed, 187 insertions(+), 178 deletions(-) create mode 100644 bioimageio/core/_internal/utils.py delete mode 100644 bioimageio/core/commands.py delete mode 100644 bioimageio/core/common.py diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 9e69bbb8..1cf498b8 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -1,14 +1,13 @@ import json -from bioimageio.spec._internal.utils import files - +from bioimageio.core._internal.utils import files from bioimageio.core._io import ( + dump_description_to_file, load_description_and_validate, read_rdf, resolve_source, validate, - write_rdf, - write_zipped_resource_package, + write_package, ) with files("bioimageio.core").joinpath("VERSION").open("r", encoding="utf-8") as f: @@ -33,8 +32,8 @@ "read_rdf", "resolve_source", "validate", - "write_rdf", - "write_zipped_resource_package", + "dump_description_to_file", + "write_package", # "check_input_shape", # "check_output_shape", # "create_prediction_pipeline", diff --git a/bioimageio/core/_internal/utils.py b/bioimageio/core/_internal/utils.py new file mode 100644 index 00000000..13b6ac05 --- /dev/null +++ b/bioimageio/core/_internal/utils.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import os +import sys +from pathlib import Path +from typing import Mapping, Union +from urllib.parse import urlsplit, urlunsplit +from zipfile import ZipFile + +from bioimageio.spec.types import FileName +from pydantic import AnyUrl, FilePath, HttpUrl + +if sys.version_info < (3, 9): + + def files(package_name: str): + assert package_name == "bioimageio.core" + return Path(__file__).parent.parent + +else: + from importlib.resources import files as files + + +def get_parent_url(url: HttpUrl) -> HttpUrl: + parsed = urlsplit(str(url)) + return AnyUrl( + urlunsplit((parsed.scheme, parsed.netloc, "/".join(parsed.path.split("/")[:-1]), parsed.query, parsed.fragment)) + ) + + +def write_zip( + path: os.PathLike[str], + content: Mapping[FileName, Union[str, FilePath]], + *, + compression: int, + compression_level: int, +) -> None: + """Write a zip archive. + + Args: + path: output path to write to. + content: dict with archive names and local file paths or strings for text files. + compression: The numeric constant of compression method. + compression_level: Compression level to use when writing files to the archive. + See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile + + """ + with ZipFile(path, "w", compression=compression, compresslevel=compression_level) as myzip: + for arc_name, file_or_str_content in content.items(): + if isinstance(file_or_str_content, str): + myzip.writestr(arc_name, file_or_str_content) + else: + myzip.write(file_or_str_content, arcname=arc_name) diff --git a/bioimageio/core/_internal/validation_visitors.py b/bioimageio/core/_internal/validation_visitors.py index 6cb0b503..f27972d5 100644 --- a/bioimageio/core/_internal/validation_visitors.py +++ b/bioimageio/core/_internal/validation_visitors.py @@ -10,9 +10,9 @@ from annotated_types import SLOTS from bioimageio.spec._internal.base_nodes import Node -from bioimageio.spec._internal.constants import IN_PACKAGE_MESSAGE, KW_ONLY, SLOTS +from bioimageio.spec._internal.constants import ALERT_TYPE, IN_PACKAGE_MESSAGE, KW_ONLY, SLOTS from bioimageio.spec.description import ResourceDescription -from bioimageio.spec.summary import ErrorOutcome, WarningOutcome +from bioimageio.spec.summary import ErrorEntry, WarningEntry from bioimageio.spec.types import Loc from pydantic import AnyUrl, DirectoryPath from pydantic.fields import FieldInfo @@ -32,8 +32,8 @@ class Note: class ValidationVisitor: def __init__(self) -> None: super().__init__() - self.errors: List[ErrorOutcome] = [] - self.warnings: List[WarningOutcome] = [] + self.errors: List[ErrorEntry] = [] + self.warnings: List[WarningEntry] = [] @singledispatchmethod def visit(self, obj: type, /, note: Note = Note()): @@ -65,8 +65,12 @@ def __init__(self, root: Union[DirectoryPath, AnyUrl]) -> None: super().__init__() self.root = root - # def _visit_path(self, path: PurePath, info: FieldInfo): - # if not Path(path).exists(): + def _visit_path(self, path: PurePath, note: Note): + if not Path(path).exists(): + if note.info and note.info.description and note.info.description.startswith(IN_PACKAGE_MESSAGE): + self.errors.append(ErrorEntry(loc=note.loc, msg=msg, type="file-not-found")) + else: + self.warnings.append(WarningEntry(loc=note.loc, msg=msg, type=ALERT_TYPE)) # # info.description.startswith(IN_PACKAGE_MESSAGE) diff --git a/bioimageio/core/_io.py b/bioimageio/core/_io.py index 5dbc8b63..31240971 100644 --- a/bioimageio/core/_io.py +++ b/bioimageio/core/_io.py @@ -1,86 +1,63 @@ +from __future__ import annotations + import collections.abc import os import shutil from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import Dict, Literal, Mapping, NamedTuple, Optional, Sequence, Tuple, Union, cast -from urllib.parse import urlsplit, urlunsplit -from zipfile import ZIP_DEFLATED, ZipFile +from typing import Dict, Literal, NamedTuple, Optional, Sequence, Tuple, Union, cast +from zipfile import ZIP_DEFLATED import pooch from bioimageio.spec import ResourceDescription, load_description from bioimageio.spec._internal.base_nodes import ResourceDescriptionBase -from bioimageio.spec._internal.constants import LATEST +from bioimageio.spec._internal.constants import DISCOVER, ERROR, LATEST from bioimageio.spec.description import dump_description from bioimageio.spec.model.v0_4 import WeightsFormat from bioimageio.spec.package import extract_file_name, get_resource_package_content from bioimageio.spec.summary import ValidationSummary -from bioimageio.spec.types import FileName, RawStringMapping, RawValue, RelativeFilePath, ValidationContext +from bioimageio.spec.types import ( + FileName, + RawStringMapping, + RawValue, + RelativeFilePath, + ValidationContext, + WarningLevel, +) from pydantic import AnyUrl, DirectoryPath, FilePath, HttpUrl, TypeAdapter from ruamel.yaml import YAML -yaml = YAML(typ="safe") - -# def _resolve_resource_description_source(rd: Union[ResourceDescriptionSource, ResourceDescription]) -> : -# HttpUrl, FilePath - - -def read_description( - rdf_source: Union[HttpUrl, FilePath, str], - format_version: Union[Literal["discover"], Literal["latest"], str] = LATEST, -) -> Tuple[Optional[ResourceDescription], ValidationSummary]: - rdf_content, root, file_name = read_rdf(rdf_source) - return load_description(rdf_content, context=dict(root=root, file_name=file_name), format_version=format_version) - - -def resolve_source( - source: Union[HttpUrl, FilePath, RelativeFilePath, str], - *, - known_hash: Optional[str] = None, - root: Union[DirectoryPath, AnyUrl, None] = None, -) -> FilePath: - if isinstance(source, str): - source = TypeAdapter(Union[HttpUrl, FilePath, RelativeFilePath]).validate_python(source) - - if isinstance(source, RelativeFilePath): - if root is None: - raise ValueError(f"Cannot resolve relative file path '{source}' without root.") - - source = source.get_absolute(root) - - if isinstance(source, AnyUrl): - source = Path(pooch.retrieve(source, known_hash=known_hash)) # type: ignore - - return source - - -def _get_parent_url(url: HttpUrl) -> HttpUrl: - parsed = urlsplit(str(url)) - return AnyUrl( - urlunsplit((parsed.scheme, parsed.netloc, "/".join(parsed.path.split("/")[:-1]), parsed.query, parsed.fragment)) - ) - - -def write_rdf(rd: Union[RawStringMapping, ResourceDescription], path: Path): - if isinstance(rd, ResourceDescriptionBase): - rdf_content = dump_description(rd) - else: - rdf_content = rd - - with path.open("w", encoding="utf-8") as f: - yaml.dump(rdf_content, f) +from bioimageio.core._internal.utils import get_parent_url, write_zip +yaml = YAML(typ="safe") FileSource = Union[HttpUrl, FilePath] -class Rdf(NamedTuple): +class ReadRdf(NamedTuple): content: RawStringMapping root: Union[HttpUrl, DirectoryPath] file_name: str -def read_rdf(source: Union[FileSource, str], known_hash: Optional[str] = None, encoding: Optional[str] = None) -> Rdf: +def load_description_from_file( + source: Union[FileSource, str], + /, + *, + warning_level: WarningLevel = ERROR, + format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, +) -> Tuple[Optional[ResourceDescription], ValidationSummary]: + rdf = read_rdf(source) + return load_description( + rdf.content, + context=ValidationContext(root=rdf.root, file_name=rdf.file_name, warning_level=warning_level), + format_version=format_version, + ) + + +def read_rdf( + source: Union[FileSource, str], /, *, known_hash: Optional[str] = None, encoding: Optional[str] = None +) -> ReadRdf: if isinstance(source, str): source = TypeAdapter(FileSource).validate_python(source) @@ -89,7 +66,7 @@ def read_rdf(source: Union[FileSource, str], known_hash: Optional[str] = None, e cached_source: FilePath = Path(pooch.retrieve(url=str(source), known_hash=known_hash)) # type: ignore src_msg += f" cached at {cached_source}" local_source = cached_source - root: Union[HttpUrl, DirectoryPath] = _get_parent_url(source) + root: Union[HttpUrl, DirectoryPath] = get_parent_url(source) else: local_source = source root = source.parent @@ -103,26 +80,72 @@ def read_rdf(source: Union[FileSource, str], known_hash: Optional[str] = None, e if non_string_keys := [k for k in content if not isinstance(k, str)]: raise TypeError(f"Got non-string keys {non_string_keys} in {src_msg}") - return Rdf( + return ReadRdf( content=cast(RawStringMapping, content), root=root, file_name=extract_file_name(source), ) +def resolve_source( + source: Union[HttpUrl, FilePath, RelativeFilePath, str], + /, + *, + known_hash: Optional[str] = None, + root: Union[DirectoryPath, AnyUrl, None] = None, +) -> FilePath: + if isinstance(source, str): + source = TypeAdapter(Union[HttpUrl, FilePath, RelativeFilePath]).validate_python(source) + + if isinstance(source, RelativeFilePath): + if root is None: + raise ValueError(f"Cannot resolve relative file path '{source}' without root.") + + source = source.get_absolute(root) + + if isinstance(source, AnyUrl): + source = Path(pooch.retrieve(source, known_hash=known_hash)) # type: ignore + + return source + + +def dump_description_to_file(rd: Union[ResourceDescription, RawStringMapping], /, file_path: Path): + if isinstance(rd, ResourceDescriptionBase): + content = dump_description(rd) + else: + content = rd + + with file_path.open("w", encoding="utf-8") as f: + yaml.dump(content, f) + + +def load_description_from_file_and_validate( + rdf_source: Union[FileSource, str], + /, + *, + warning_level: WarningLevel = ERROR, +) -> Tuple[Optional[ResourceDescription], ValidationSummary]: + rdf = read_rdf(rdf_source) + return load_description_and_validate( + rdf.content, context=ValidationContext(root=rdf.root, file_name=rdf.file_name, warning_level=warning_level) + ) + + def load_description_and_validate( rdf_content: RawStringMapping, + /, *, context: Optional[ValidationContext] = None, ) -> Tuple[Optional[ResourceDescription], ValidationSummary]: """load and validate a BioImage.IO description from the content of a resource description file (RDF)""" - rd, summary = load_description(rdf_content, context=context, format_version="latest") + rd, summary = load_description(rdf_content, context=context, format_version=LATEST) # todo: add validation return rd, summary def validate( rdf_content: RawStringMapping, + /, *, context: Optional[ValidationContext] = None, ) -> ValidationSummary: @@ -130,8 +153,21 @@ def validate( return summary +def validate_rdf(rdf_source: Union[FileSource, str], /, *, warning_level: WarningLevel = ERROR) -> ValidationSummary: + _rd, summary = load_description_from_file_and_validate(rdf_source, warning_level=warning_level) + return summary + + +def validate_rdf_format( + rdf_source: Union[FileSource, str], /, *, warning_level: WarningLevel = ERROR +) -> ValidationSummary: + _rd, summary = load_description_from_file(rdf_source, warning_level=warning_level) + return summary + + def prepare_resource_package( rd: ResourceDescription, + /, *, root: Union[AnyUrl, DirectoryPath], output_folder: DirectoryPath, @@ -161,17 +197,18 @@ def prepare_resource_package( shutil.copy(str(v), str(in_package_path)) else: assert isinstance(v, collections.abc.Mapping) - write_rdf(v, in_package_path) + dump_description_to_file(v, in_package_path) local_package_content[k] = in_package_path return local_package_content -def write_zipped_resource_package( - rd: ResourceDescription, +def write_package( + rd: Union[ResourceDescription, FileSource, str], + /, *, - root: Union[AnyUrl, DirectoryPath], + root: Union[AnyUrl, DirectoryPath] = Path(), compression: int = ZIP_DEFLATED, compression_level: int = 1, output_path: Optional[os.PathLike[str]] = None, @@ -203,6 +240,12 @@ def write_zipped_resource_package( Returns: path to zipped bioimage.io package in BIOIMAGEIO_CACHE_PATH or 'output_path' """ + if isinstance(rd, (AnyUrl, os.PathLike, str)): + rd_, summary = load_description_from_file(rd) + if rd_ is None: + raise ValueError(summary.format()) + else: + rd = rd_ with TemporaryDirectory() as tmp_dir: package_content = prepare_resource_package( @@ -217,30 +260,5 @@ def write_zipped_resource_package( else: output_path = Path(output_path) - _write_zip(output_path, package_content, compression=compression, compression_level=compression_level) + write_zip(output_path, package_content, compression=compression, compression_level=compression_level) return output_path - - -def _write_zip( - path: os.PathLike[str], - content: Mapping[FileName, Union[str, FilePath]], - *, - compression: int, - compression_level: int, -) -> None: - """Write a zip archive. - - Args: - path: output path to write to. - content: dict with archive names and local file paths or strings for text files. - compression: The numeric constant of compression method. - compression_level: Compression level to use when writing files to the archive. - See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile - - """ - with ZipFile(path, "w", compression=compression, compresslevel=compression_level) as myzip: - for arc_name, file_or_str_content in content.items(): - if isinstance(file_or_str_content, str): - myzip.writestr(arc_name, file_or_str_content) - else: - myzip.write(file_or_str_content, arcname=arc_name) diff --git a/bioimageio/core/commands.py b/bioimageio/core/commands.py deleted file mode 100644 index d1ebc2b5..00000000 --- a/bioimageio/core/commands.py +++ /dev/null @@ -1,52 +0,0 @@ -import shutil -import traceback -from pathlib import Path -from typing import List, Optional, Union - -from bioimageio.spec import validate -from bioimageio.spec.shared.raw_nodes import URI - -from bioimageio.core import export_resource_package -from bioimageio.core._internal.validation_visitors import resolve_source - - -def package( - rdf_source: Union[Path, str, URI, dict], - path: Path = Path() / "{src_name}-package.zip", - weights_priority_order: Optional[List[str]] = None, - verbose: bool = False, -) -> int: - """Package a bioimage.io resource described by a bioimage.io Resource Description File (RDF).""" - rd, summary = load_description(rdf_source, update_format=True, update_format_inner=True) - source_name = rdf_source.get("name") if isinstance(rdf_source, dict) else rdf_source - if code["status"] != "passed": - print(f"Cannot package invalid bioimage.io RDF {source_name}") - return 1 - - try: - tmp_package_path = export_resource_package(rdf_source, weights_priority_order=weights_priority_order) - except Exception as e: - print(f"Failed to package {source_name} due to: {e}") - if verbose: - traceback.print_exc() - return 1 - - try: - rdf_local_source = resolve_source(rdf_source) - except Exception as e: - print(f"Failed to resolve RDF source {rdf_source}: {e}") - if verbose: - traceback.print_exc() - return 1 - - try: - path = path.with_name(path.name.format(src_name=rdf_local_source.stem)) - shutil.move(tmp_package_path, path) - except Exception as e: - print(f"Failed to move package from {tmp_package_path} to {path} due to: {e}") - if verbose: - traceback.print_exc() - return 1 - - print(f"exported bioimageio package from {source_name} to {path}") - return 0 diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py deleted file mode 100644 index 846bb668..00000000 --- a/bioimageio/core/common.py +++ /dev/null @@ -1,20 +0,0 @@ -import getpass -import os -import tempfile -import warnings -from pathlib import Path - -from bioimageio.spec.types import ValidationSummary - - -class TestSummary(ValidationSummary): - bioimageio_core_version: str - - -# BIOIMAGEIO_CACHE_PATH = Path( -# os.getenv("BIOIMAGEIO_CACHE_PATH", Path(tempfile.gettempdir()) / getpass.getuser() / "bioimageio_cache") -# ) - -# BIOIMAGEIO_USE_CACHE = os.getenv("BIOIMAGEIO_USE_CACHE", "true").lower() in ("1", "yes", "true") -# if (env_val := os.getenv("BIOIMAGEIO_USE_CACHE", "true").lower()) not in ("0", "1", "no", "yes", "false", "true"): -# warnings.warn(f"Unrecognized BIOIMAGEIO_USE_CACHE environment value '{env_val}'") diff --git a/bioimageio/core/image_helper.py b/bioimageio/core/image_helper.py index 0468b61f..e26d5e4c 100644 --- a/bioimageio/core/image_helper.py +++ b/bioimageio/core/image_helper.py @@ -1,11 +1,20 @@ +from __future__ import annotations + import os from copy import deepcopy from typing import Dict, List, Optional, Sequence, Tuple, Union import imageio import numpy as np +from bioimageio.spec.model.v0_4 import InputTensor as InputTensor04 +from bioimageio.spec.model.v0_4 import OutputTensor as OutputTensor04 +from bioimageio.spec.model.v0_5 import InputTensor as InputTensor05 +from bioimageio.spec.model.v0_5 import OutputTensor as OutputTensor05 +from numpy.typing import NDArray from xarray import DataArray -from bioimageio.core.resource_io.nodes import InputTensor, OutputTensor + +InputTensor = Union[InputTensor04, InputTensor05] +OutputTensor = Union[OutputTensor04, OutputTensor05] # @@ -13,7 +22,7 @@ # -def transform_input_image(image: np.ndarray, tensor_axes: str, image_axes: Optional[str] = None): +def transform_input_image(image: NDArray, tensor_axes: str, image_axes: Optional[str] = None): """Transform input image into output tensor with desired axes. Args: @@ -51,7 +60,7 @@ def _drop_axis_default(axis_name, axis_len): return axis_len // 2 if axis_name in "zyx" else 0 -def transform_output_tensor(tensor: np.ndarray, tensor_axes: str, output_axes: str, drop_function=_drop_axis_default): +def transform_output_tensor(tensor: NDArray, tensor_axes: str, output_axes: str, drop_function=_drop_axis_default): """Transform output tensor into image with desired axes. Args: @@ -157,7 +166,6 @@ def pad(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray pad_width = [] crop = {} for ax, dlen, pr in zip(axes, image.shape, pad_right): - if ax in "zyx": pad_to = padding_[ax] From e463a74e1699c0be94aeff27fa8e27f8dd999a22 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 30 Aug 2023 09:16:34 +0200 Subject: [PATCH 020/244] improve tests with type annotations --- pyproject.toml | 2 +- tests/build_spec/test_build_spec.py | 4 +- tests/conftest.py | 181 +++++++++++++++------------- tests/resource_io/test_load_rdf.py | 7 +- tests/resource_io/test_utils.py | 18 +-- 5 files changed, 110 insertions(+), 102 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4263089c..a8002b85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,5 +25,5 @@ pythonVersion = "3.9" pythonPlatform = "All" [tool.pytest.ini_options] -addopts = "-s --doctest-modules" +addopts = "--capture=no --doctest-modules --failed-first" # testpaths = ["bioimageio", "scripts", "example", "tests"] diff --git a/tests/build_spec/test_build_spec.py b/tests/build_spec/test_build_spec.py index 8edd9436..7d842509 100644 --- a/tests/build_spec/test_build_spec.py +++ b/tests/build_spec/test_build_spec.py @@ -1,11 +1,11 @@ from typing import Optional +import bioimageio.spec as spec from marshmallow import missing -import bioimageio.spec as spec from bioimageio.core import load_raw_resource_description, load_resource_description +from bioimageio.core._internal.validation_visitors import resolve_source from bioimageio.core.resource_io import nodes -from bioimageio.core.resource_io.utils import resolve_source from bioimageio.core.resource_tests import test_model as _test_model try: diff --git a/tests/conftest.py b/tests/conftest.py index 788899ca..db626a10 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,36 +1,41 @@ +from __future__ import annotations + import logging import os import subprocess import warnings +from types import MappingProxyType +from typing import Set -import pytest +from pydantic import FilePath +from pytest import FixtureRequest, fixture os.environ["BIOIMAGEIO_COUNT_RDF_DOWNLOADS"] = "false" # disable tracking before bioimageio imports from bioimageio.spec import __version__ as bioimageio_spec_version -from bioimageio.core import write_zipped_resource_package +from bioimageio.core import write_package logger = logging.getLogger(__name__) warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}") # test models for various frameworks -torch_models = [ +TORCH_MODELS = [ "unet2d_fixed_shape", "unet2d_multi_tensor", "unet2d_nuclei_broad_model", "unet2d_diff_output_shape", "shape_change", ] -torchscript_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"] -onnx_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"] -tensorflow1_models = ["stardist"] -tensorflow2_models = ["unet2d_keras_tf2"] -keras_tf1_models = ["unet2d_keras"] -keras_tf2_models = ["unet2d_keras_tf2"] -tensorflow_js_models = [] +TORCHSCRIPT_MODELS = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"] +ONNX_MODELS = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"] +TENSORFLOW1_MODELS = ["stardist"] +TENSORFLOW2_MODELS = ["unet2d_keras_tf2"] +KERAS_TF1_MODELS = ["unet2d_keras"] +KERAS_TF2_MODELS = ["unet2d_keras_tf2"] +TENSORFLOW_JS_MODELS = [] -model_sources = { +MODEL_SOURCES = { "unet2d_keras": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" "unet2d_keras_tf/rdf.yaml" @@ -91,58 +96,60 @@ skip_torch = torch is None try: - import onnxruntime + import onnxruntime # type: ignore except ImportError: onnxruntime = None skip_onnx = onnxruntime is None try: - import tensorflow + import tensorflow # type: ignore - tf_major_version = int(tensorflow.__version__.split(".")[0]) + tf_major_version = int(tensorflow.__version__.split(".")[0]) # type: ignore except ImportError: tensorflow = None tf_major_version = None + skip_tensorflow = tensorflow is None skip_tensorflow_js = True # TODO: add a tensorflow_js example model # load all model packages we need for testing -load_model_packages = set() +load_model_packages: Set[str] = set() if not skip_torch: - load_model_packages |= set(torch_models + torchscript_models) + load_model_packages |= set(TORCH_MODELS + TORCHSCRIPT_MODELS) if not skip_onnx: - load_model_packages |= set(onnx_models) + load_model_packages |= set(ONNX_MODELS) if not skip_tensorflow: - load_model_packages |= set(tensorflow_js_models) + load_model_packages |= set(TENSORFLOW_JS_MODELS) if tf_major_version == 1: - load_model_packages |= set(keras_tf1_models) - load_model_packages |= set(tensorflow1_models) + load_model_packages |= set(KERAS_TF1_MODELS) + load_model_packages |= set(TENSORFLOW1_MODELS) load_model_packages.add("stardist_wrong_shape") load_model_packages.add("stardist_wrong_shape2") elif tf_major_version == 2: - load_model_packages |= set(keras_tf2_models) - load_model_packages |= set(tensorflow2_models) + load_model_packages |= set(KERAS_TF2_MODELS) + load_model_packages |= set(TENSORFLOW2_MODELS) -def pytest_configure(): - # explicit skip flags needed for some tests - pytest.skip_torch = skip_torch - pytest.skip_onnx = skip_onnx +@fixture(scope="session") +def model_packages(): + return MappingProxyType({name: write_package(MODEL_SOURCES[name]) for name in load_model_packages}) - # load all model packages used in tests - pytest.model_packages = {name: write_zipped_resource_package(model_sources[name]) for name in load_model_packages} - pytest.mamba_cmd = "micromamba" +@fixture(scope="session") +def mamba_cmd(): + mamba_cmd = "micromamba" try: - subprocess.run(["which", pytest.mamba_cmd], check=True) + _ = subprocess.run(["which", mamba_cmd], check=True) except (subprocess.CalledProcessError, FileNotFoundError): - pytest.mamba_cmd = "mamba" + mamba_cmd = "mamba" try: - subprocess.run(["which", pytest.mamba_cmd], check=True) + _ = subprocess.run(["which", mamba_cmd], check=True) except (subprocess.CalledProcessError, FileNotFoundError): - pytest.mamba_cmd = None + mamba_cmd = None + + return mamba_cmd # @@ -150,42 +157,42 @@ def pytest_configure(): # -@pytest.fixture(params=[] if skip_torch else torch_models) -def any_torch_model(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else TORCH_MODELS) +def any_torch_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] -@pytest.fixture(params=[] if skip_torch else torchscript_models) -def any_torchscript_model(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else TORCHSCRIPT_MODELS) +def any_torchscript_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] -@pytest.fixture(params=[] if skip_onnx else onnx_models) -def any_onnx_model(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_onnx else ONNX_MODELS) +def any_onnx_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] -@pytest.fixture(params=[] if skip_tensorflow else tensorflow1_models if tf_major_version == 1 else tensorflow2_models) -def any_tensorflow_model(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_tensorflow else TENSORFLOW1_MODELS if tf_major_version == 1 else TENSORFLOW2_MODELS) +def any_tensorflow_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] -@pytest.fixture(params=[] if skip_tensorflow else keras_tf1_models if tf_major_version == 1 else keras_tf2_models) -def any_keras_model(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_tensorflow else KERAS_TF1_MODELS if tf_major_version == 1 else KERAS_TF2_MODELS) +def any_keras_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] -@pytest.fixture(params=[] if skip_tensorflow_js else tensorflow_js_models) -def any_tensorflow_js_model(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_tensorflow_js else TENSORFLOW_JS_MODELS) +def any_tensorflow_js_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] # fixture to test with all models that should run in the current environment # we exclude stardist_wrong_shape here because it is not a valid model # and included only to test that validation for this model fails -@pytest.fixture(params=load_model_packages - {"stardist_wrong_shape", "stardist_wrong_shape2"}) -def any_model(request): - return pytest.model_packages[request.param] +@fixture(params=load_model_packages - {"stardist_wrong_shape", "stardist_wrong_shape2"}) +def any_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] # TODO it would be nice to just generate fixtures for all the individual models dynamically @@ -195,64 +202,64 @@ def any_model(request): # -@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"]) -def unet2d_fixed_shape_or_not(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"]) +def unet2d_fixed_shape_or_not(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] -@pytest.fixture(params=[] if skip_onnx or skip_torch else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]) -def convert_to_onnx(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_onnx or skip_torch else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]) +def convert_to_onnx(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] -@pytest.fixture(params=[] if skip_tensorflow else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"]) -def unet2d_keras(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_tensorflow else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"]) +def unet2d_keras(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] # written as model group to automatically skip on missing torch -@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"]) -def unet2d_nuclei_broad_model(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"]) +def unet2d_nuclei_broad_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] # written as model group to automatically skip on missing torch -@pytest.fixture(params=[] if skip_torch else ["unet2d_diff_output_shape"]) -def unet2d_diff_output_shape(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else ["unet2d_diff_output_shape"]) +def unet2d_diff_output_shape(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] # written as model group to automatically skip on missing torch -@pytest.fixture(params=[] if skip_torch else ["unet2d_expand_output_shape"]) -def unet2d_expand_output_shape(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else ["unet2d_expand_output_shape"]) +def unet2d_expand_output_shape(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] # written as model group to automatically skip on missing torch -@pytest.fixture(params=[] if skip_torch else ["unet2d_fixed_shape"]) -def unet2d_fixed_shape(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else ["unet2d_fixed_shape"]) +def unet2d_fixed_shape(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] # written as model group to automatically skip on missing torch -@pytest.fixture(params=[] if skip_torch else ["shape_change"]) -def shape_change_model(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else ["shape_change"]) +def shape_change_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] # written as model group to automatically skip on missing tensorflow 1 -@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"]) -def stardist_wrong_shape(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"]) +def stardist_wrong_shape(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] # written as model group to automatically skip on missing tensorflow 1 -@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape2"]) -def stardist_wrong_shape2(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape2"]) +def stardist_wrong_shape2(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] # written as model group to automatically skip on missing tensorflow 1 -@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"]) -def stardist(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"]) +def stardist(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): + return model_packages[request.param] diff --git a/tests/resource_io/test_load_rdf.py b/tests/resource_io/test_load_rdf.py index ca86750a..a9ea2441 100644 --- a/tests/resource_io/test_load_rdf.py +++ b/tests/resource_io/test_load_rdf.py @@ -4,7 +4,7 @@ import pytest -from bioimageio.core.resource_io.utils import resolve_source +from bioimageio.core._internal.validation_visitors import resolve_source def test_load_non_existing_rdf(): @@ -82,10 +82,11 @@ def test_load_remote_rdf(): @pytest.mark.skipif(True, reason="No suitable test model available yet") def test_load_remote_rdf_with_folders(): - from bioimageio.core import load_resource_description, load_raw_resource_description - from bioimageio.core.resource_io import nodes from bioimageio.spec.model import raw_nodes + from bioimageio.core import load_raw_resource_description, load_resource_description + from bioimageio.core.resource_io import nodes + rdf_doi = "" raw_model = load_raw_resource_description(rdf_doi, update_to_format="latest") assert isinstance(raw_model, raw_nodes.Model) diff --git a/tests/resource_io/test_utils.py b/tests/resource_io/test_utils.py index 198df5e5..d1e570cc 100644 --- a/tests/resource_io/test_utils.py +++ b/tests/resource_io/test_utils.py @@ -1,13 +1,13 @@ import dataclasses from pathlib import Path -from bioimageio.core.resource_io import nodes, utils -from bioimageio.core.resource_io.utils import Sha256NodeChecker +import pytest from bioimageio.spec.shared import raw_nodes from bioimageio.spec.shared.raw_nodes import RawNode - -import pytest +from bioimageio.core._internal import validation_visitors +from bioimageio.core._internal.validation_visitors import Sha256NodeChecker +from bioimageio.core.resource_io import nodes def test_resolve_import_path(tmpdir): @@ -17,8 +17,8 @@ def test_resolve_import_path(tmpdir): source_file = Path("my_mod.py") (tmpdir / str(source_file)).write_text("class Foo: pass", encoding="utf8") node = raw_nodes.ImportableSourceFile(source_file=source_file, callable_name="Foo") - uri_transformed = utils.UriNodeTransformer(root_path=tmpdir).transform(node) - source_transformed = utils.SourceNodeTransformer().transform(uri_transformed) + uri_transformed = validation_visitors.UriNodeTransformer(root_path=tmpdir).transform(node) + source_transformed = validation_visitors.SourceNodeTransformer().transform(uri_transformed) assert isinstance(source_transformed, nodes.ImportedSource), type(source_transformed) Foo = source_transformed.factory assert Foo.__name__ == "Foo", Foo.__name__ @@ -27,7 +27,7 @@ def test_resolve_import_path(tmpdir): def test_resolve_directory_uri(tmpdir): node = raw_nodes.URI(Path(tmpdir).as_uri()) - uri_transformed = utils.UriNodeTransformer(root_path=Path(tmpdir)).transform(node) + uri_transformed = validation_visitors.UriNodeTransformer(root_path=Path(tmpdir)).transform(node) assert uri_transformed == Path(tmpdir) @@ -36,7 +36,7 @@ def test_uri_available(): def test_all_uris_available(): - from bioimageio.core.resource_io.utils import all_sources_available + from bioimageio.core._internal.validation_visitors import all_sources_available not_available = { "uri": raw_nodes.URI(scheme="file", path="non_existing_file_in/non_existing_dir/ftw"), @@ -46,7 +46,7 @@ def test_all_uris_available(): def test_uri_node_transformer_is_ok_with_abs_path(): - from bioimageio.core.resource_io.utils import UriNodeTransformer + from bioimageio.core._internal.validation_visitors import UriNodeTransformer # note: the call of .absolute() is required to add the drive letter for windows paths, which are relative otherwise tree = {"rel_path": Path("something/relative"), "abs_path": Path("/something/absolute").absolute()} From d4e382b48d0ea558560c61a5192a8137104cc0db Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 30 Aug 2023 11:06:21 +0200 Subject: [PATCH 021/244] improve type names --- bioimageio/core/_io.py | 45 +++++++++++++++++------------------------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/bioimageio/core/_io.py b/bioimageio/core/_io.py index 31240971..62a21489 100644 --- a/bioimageio/core/_io.py +++ b/bioimageio/core/_io.py @@ -5,7 +5,7 @@ import shutil from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import Dict, Literal, NamedTuple, Optional, Sequence, Tuple, Union, cast +from typing import Dict, Literal, NamedTuple, Optional, Sequence, Tuple, Union from zipfile import ZIP_DEFLATED import pooch @@ -16,14 +16,7 @@ from bioimageio.spec.model.v0_4 import WeightsFormat from bioimageio.spec.package import extract_file_name, get_resource_package_content from bioimageio.spec.summary import ValidationSummary -from bioimageio.spec.types import ( - FileName, - RawStringMapping, - RawValue, - RelativeFilePath, - ValidationContext, - WarningLevel, -) +from bioimageio.spec.types import FileName, RelativeFilePath, ValidationContext, WarningLevel, YamlMapping, YamlValue from pydantic import AnyUrl, DirectoryPath, FilePath, HttpUrl, TypeAdapter from ruamel.yaml import YAML @@ -31,17 +24,19 @@ yaml = YAML(typ="safe") -FileSource = Union[HttpUrl, FilePath] +StrictFileSource = Union[HttpUrl, FilePath] +FileSource = Union[StrictFileSource, str] +DescriptionSource = Union[ResourceDescription, YamlMapping, FileSource] class ReadRdf(NamedTuple): - content: RawStringMapping + content: YamlMapping root: Union[HttpUrl, DirectoryPath] file_name: str def load_description_from_file( - source: Union[FileSource, str], + source: FileSource, /, *, warning_level: WarningLevel = ERROR, @@ -55,11 +50,9 @@ def load_description_from_file( ) -def read_rdf( - source: Union[FileSource, str], /, *, known_hash: Optional[str] = None, encoding: Optional[str] = None -) -> ReadRdf: +def read_rdf(source: FileSource, /, *, known_hash: Optional[str] = None, encoding: Optional[str] = None) -> ReadRdf: if isinstance(source, str): - source = TypeAdapter(FileSource).validate_python(source) + source = TypeAdapter(StrictFileSource).validate_python(source) src_msg = str(source) if isinstance(source, AnyUrl): @@ -72,7 +65,7 @@ def read_rdf( root = source.parent with local_source.open(encoding=encoding) as f: - content: RawValue = yaml.load(f) + content: YamlValue = yaml.load(f) if not isinstance(content, collections.abc.Mapping): raise TypeError(f"Expected RDF content to be a mapping, but got '{type(content)}'.") @@ -81,7 +74,7 @@ def read_rdf( raise TypeError(f"Got non-string keys {non_string_keys} in {src_msg}") return ReadRdf( - content=cast(RawStringMapping, content), + content=content, root=root, file_name=extract_file_name(source), ) @@ -109,7 +102,7 @@ def resolve_source( return source -def dump_description_to_file(rd: Union[ResourceDescription, RawStringMapping], /, file_path: Path): +def dump_description_to_file(rd: Union[ResourceDescription, YamlMapping], /, file_path: Path): if isinstance(rd, ResourceDescriptionBase): content = dump_description(rd) else: @@ -120,7 +113,7 @@ def dump_description_to_file(rd: Union[ResourceDescription, RawStringMapping], / def load_description_from_file_and_validate( - rdf_source: Union[FileSource, str], + rdf_source: FileSource, /, *, warning_level: WarningLevel = ERROR, @@ -132,7 +125,7 @@ def load_description_from_file_and_validate( def load_description_and_validate( - rdf_content: RawStringMapping, + rdf_content: YamlMapping, /, *, context: Optional[ValidationContext] = None, @@ -144,7 +137,7 @@ def load_description_and_validate( def validate( - rdf_content: RawStringMapping, + rdf_content: YamlMapping, /, *, context: Optional[ValidationContext] = None, @@ -153,14 +146,12 @@ def validate( return summary -def validate_rdf(rdf_source: Union[FileSource, str], /, *, warning_level: WarningLevel = ERROR) -> ValidationSummary: +def validate_rdf(rdf_source: FileSource, /, *, warning_level: WarningLevel = ERROR) -> ValidationSummary: _rd, summary = load_description_from_file_and_validate(rdf_source, warning_level=warning_level) return summary -def validate_rdf_format( - rdf_source: Union[FileSource, str], /, *, warning_level: WarningLevel = ERROR -) -> ValidationSummary: +def validate_rdf_format(rdf_source: FileSource, /, *, warning_level: WarningLevel = ERROR) -> ValidationSummary: _rd, summary = load_description_from_file(rdf_source, warning_level=warning_level) return summary @@ -205,7 +196,7 @@ def prepare_resource_package( def write_package( - rd: Union[ResourceDescription, FileSource, str], + rd: Union[ResourceDescription, FileSource], /, *, root: Union[AnyUrl, DirectoryPath] = Path(), From d6c63e2b7f150334aac26c526698e32e5cae92ab Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 7 Sep 2023 22:24:50 +0200 Subject: [PATCH 022/244] black also for notebooks --- .github/workflows/build.yml | 14 +++++--------- .pre-commit-config.yaml | 7 ++++++- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c0b54a62..9dbb37b1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,16 +15,12 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Check files using the black formatter - uses: rickstaa/action-black@v1 - id: action_black + - uses: psf/black@stable with: - black_args: "." - - name: Annotate diff changes using reviewdog - if: steps.action_black.outputs.is_formatted == 'true' - uses: reviewdog/action-suggester@v1 - with: - tool_name: blackfmt + options: "--check --verbose" + src: "." + jupyter: true + version: "23.7" test-spec-conda: runs-on: ubuntu-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e8f64f50..faeee4a4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,4 +2,9 @@ repos: - repo: https://github.com/ambv/black rev: 23.7.0 hooks: - - id: black + - id: black-jupyter + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort From 3483087f3f438183e61c0da77a9447423f12ac6e Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 7 Sep 2023 22:26:27 +0200 Subject: [PATCH 023/244] update io funcs --- bioimageio/core/__init__.py | 10 +- bioimageio/core/_internal/utils.py | 112 ++++---- .../core/_internal/validation_visitors.py | 2 +- bioimageio/core/_io.py | 240 ++++++++++-------- bioimageio/core/prediction.py | 72 +++--- 5 files changed, 233 insertions(+), 203 deletions(-) diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 1cf498b8..ead66cb0 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -1,14 +1,7 @@ import json from bioimageio.core._internal.utils import files -from bioimageio.core._io import ( - dump_description_to_file, - load_description_and_validate, - read_rdf, - resolve_source, - validate, - write_package, -) +from bioimageio.core._io import load_description_and_validate, resolve_source, validate, write_package with files("bioimageio.core").joinpath("VERSION").open("r", encoding="utf-8") as f: __version__: str = json.load(f)["version"] @@ -32,7 +25,6 @@ "read_rdf", "resolve_source", "validate", - "dump_description_to_file", "write_package", # "check_input_shape", # "check_output_shape", diff --git a/bioimageio/core/_internal/utils.py b/bioimageio/core/_internal/utils.py index 13b6ac05..f8424d6c 100644 --- a/bioimageio/core/_internal/utils.py +++ b/bioimageio/core/_internal/utils.py @@ -1,52 +1,60 @@ -from __future__ import annotations - -import os -import sys -from pathlib import Path -from typing import Mapping, Union -from urllib.parse import urlsplit, urlunsplit -from zipfile import ZipFile - -from bioimageio.spec.types import FileName -from pydantic import AnyUrl, FilePath, HttpUrl - -if sys.version_info < (3, 9): - - def files(package_name: str): - assert package_name == "bioimageio.core" - return Path(__file__).parent.parent - -else: - from importlib.resources import files as files - - -def get_parent_url(url: HttpUrl) -> HttpUrl: - parsed = urlsplit(str(url)) - return AnyUrl( - urlunsplit((parsed.scheme, parsed.netloc, "/".join(parsed.path.split("/")[:-1]), parsed.query, parsed.fragment)) - ) - - -def write_zip( - path: os.PathLike[str], - content: Mapping[FileName, Union[str, FilePath]], - *, - compression: int, - compression_level: int, -) -> None: - """Write a zip archive. - - Args: - path: output path to write to. - content: dict with archive names and local file paths or strings for text files. - compression: The numeric constant of compression method. - compression_level: Compression level to use when writing files to the archive. - See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile - - """ - with ZipFile(path, "w", compression=compression, compresslevel=compression_level) as myzip: - for arc_name, file_or_str_content in content.items(): - if isinstance(file_or_str_content, str): - myzip.writestr(arc_name, file_or_str_content) - else: - myzip.write(file_or_str_content, arcname=arc_name) +from __future__ import annotations + +import io +import os +import sys +from pathlib import Path +from typing import Any, Dict, Mapping, Union +from urllib.parse import urlsplit, urlunsplit +from zipfile import ZipFile + +from bioimageio.spec._internal.types import FileName +from pydantic import AnyUrl, FilePath, HttpUrl +from ruamel.yaml import YAML + +yaml = YAML(typ="safe") +if sys.version_info < (3, 9): + + def files(package_name: str): + assert package_name == "bioimageio.core" + return Path(__file__).parent.parent + +else: + from importlib.resources import files as files + + +def get_parent_url(url: HttpUrl) -> HttpUrl: + parsed = urlsplit(str(url)) + return AnyUrl( + urlunsplit((parsed.scheme, parsed.netloc, "/".join(parsed.path.split("/")[:-1]), parsed.query, parsed.fragment)) + ) + + +def write_zip( + path: os.PathLike[str], + content: Mapping[FileName, Union[str, FilePath, Dict[Any, Any]]], + *, + compression: int, + compression_level: int, +) -> None: + """Write a zip archive. + + Args: + path: output path to write to. + content: dict mapping archive names to local file paths, strings (for text files), or dict (for yaml files). + compression: The numeric constant of compression method. + compression_level: Compression level to use when writing files to the archive. + See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile + + """ + with ZipFile(path, "w", compression=compression, compresslevel=compression_level) as myzip: + for arc_name, file in content.items(): + if isinstance(file, dict): + buf = io.StringIO() + YAML.dump(file, buf) + file = buf.getvalue() + + if isinstance(file, str): + myzip.writestr(arc_name, file.encode("utf-8")) + else: + myzip.write(file, arcname=arc_name) diff --git a/bioimageio/core/_internal/validation_visitors.py b/bioimageio/core/_internal/validation_visitors.py index f27972d5..bb7de16a 100644 --- a/bioimageio/core/_internal/validation_visitors.py +++ b/bioimageio/core/_internal/validation_visitors.py @@ -11,9 +11,9 @@ from annotated_types import SLOTS from bioimageio.spec._internal.base_nodes import Node from bioimageio.spec._internal.constants import ALERT_TYPE, IN_PACKAGE_MESSAGE, KW_ONLY, SLOTS +from bioimageio.spec._internal.types import Loc from bioimageio.spec.description import ResourceDescription from bioimageio.spec.summary import ErrorEntry, WarningEntry -from bioimageio.spec.types import Loc from pydantic import AnyUrl, DirectoryPath from pydantic.fields import FieldInfo from typing_extensions import NotRequired, Unpack diff --git a/bioimageio/core/_io.py b/bioimageio/core/_io.py index 62a21489..8e85078b 100644 --- a/bioimageio/core/_io.py +++ b/bioimageio/core/_io.py @@ -2,21 +2,21 @@ import collections.abc import os -import shutil from pathlib import Path -from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import Dict, Literal, NamedTuple, Optional, Sequence, Tuple, Union -from zipfile import ZIP_DEFLATED +from tempfile import NamedTemporaryFile +from typing import Any, Dict, Literal, NamedTuple, Optional, Sequence, Tuple, Union, cast +from zipfile import ZIP_DEFLATED, ZipFile, is_zipfile import pooch -from bioimageio.spec import ResourceDescription, load_description +from bioimageio.spec import ResourceDescription +from bioimageio.spec import load_description as load_description_from_content from bioimageio.spec._internal.base_nodes import ResourceDescriptionBase -from bioimageio.spec._internal.constants import DISCOVER, ERROR, LATEST +from bioimageio.spec._internal.constants import DISCOVER, LATEST +from bioimageio.spec._internal.types import FileName, RdfContent, RelativeFilePath, ValidationContext, YamlValue from bioimageio.spec.description import dump_description from bioimageio.spec.model.v0_4 import WeightsFormat from bioimageio.spec.package import extract_file_name, get_resource_package_content from bioimageio.spec.summary import ValidationSummary -from bioimageio.spec.types import FileName, RelativeFilePath, ValidationContext, WarningLevel, YamlMapping, YamlValue from pydantic import AnyUrl, DirectoryPath, FilePath, HttpUrl, TypeAdapter from ruamel.yaml import YAML @@ -26,57 +26,83 @@ StrictFileSource = Union[HttpUrl, FilePath] FileSource = Union[StrictFileSource, str] -DescriptionSource = Union[ResourceDescription, YamlMapping, FileSource] +StrictRdfSource = Union[StrictFileSource, RdfContent, ResourceDescription] +RdfSource = Union[StrictRdfSource, str] -class ReadRdf(NamedTuple): - content: YamlMapping +class RawRdf(NamedTuple): + content: RdfContent root: Union[HttpUrl, DirectoryPath] file_name: str -def load_description_from_file( - source: FileSource, +def load_description( + rdf_source: RdfSource, /, *, - warning_level: WarningLevel = ERROR, + context: Optional[ValidationContext] = None, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, ) -> Tuple[Optional[ResourceDescription], ValidationSummary]: - rdf = read_rdf(source) - return load_description( - rdf.content, - context=ValidationContext(root=rdf.root, file_name=rdf.file_name, warning_level=warning_level), + context = context or ValidationContext() + rdf_content = _get_rdf_content_and_update_context(rdf_source, context) + return load_description_from_content( + rdf_content, + context=context, format_version=format_version, ) -def read_rdf(source: FileSource, /, *, known_hash: Optional[str] = None, encoding: Optional[str] = None) -> ReadRdf: - if isinstance(source, str): - source = TypeAdapter(StrictFileSource).validate_python(source) +LEGACY_RDF_NAME = "rdf.yaml" - src_msg = str(source) - if isinstance(source, AnyUrl): - cached_source: FilePath = Path(pooch.retrieve(url=str(source), known_hash=known_hash)) # type: ignore - src_msg += f" cached at {cached_source}" - local_source = cached_source - root: Union[HttpUrl, DirectoryPath] = get_parent_url(source) - else: - local_source = source - root = source.parent - with local_source.open(encoding=encoding) as f: +def read_rdf_content( + rdf_source: FileSource, + /, + *, + known_hash: Optional[str] = None, + rdf_encoding: str = "utf-8", +) -> RawRdf: + class FileSourceInterpreter(BaseModel): + source: StrictFileSource + + rdf_source = FileSourceInterpreter(source=rdf_source).source + + if isinstance(rdf_source, AnyUrl): + _ls: Any = pooch.retrieve(url=str(rdf_source), known_hash=known_hash) + local_source = Path(_ls) + root: Union[HttpUrl, DirectoryPath] = get_parent_url(rdf_source) + else: + local_source = rdf_source + root = rdf_source.parent + + if is_zipfile(local_source): + out_path = local_source.with_suffix(local_source.suffix + ".unzip") + with ZipFile(local_source, "r") as f: + rdfs = [fname for fname in f.namelist() if fname.endswith(".bioimageio.yaml")] + if len(rdfs) > 1: + raise ValueError(f"Multiple RDFs in one package not yet supported (found {rdfs}).") + elif len(rdfs) == 1: + rdf_file_name = rdfs[0] + elif LEGACY_RDF_NAME in f.namelist(): + rdf_file_name = LEGACY_RDF_NAME + else: + raise ValueError( + f"No RDF found in {local_source}. (Looking for any '*.bioimageio.yaml' file or an 'rdf.yaml' file)." + ) + + f.extractall(out_path) + local_source = out_path / rdf_file_name + + with local_source.open(encoding=rdf_encoding) as f: content: YamlValue = yaml.load(f) if not isinstance(content, collections.abc.Mapping): raise TypeError(f"Expected RDF content to be a mapping, but got '{type(content)}'.") - if non_string_keys := [k for k in content if not isinstance(k, str)]: - raise TypeError(f"Got non-string keys {non_string_keys} in {src_msg}") - - return ReadRdf( - content=content, + return RawRdf( + content=cast(RdfContent, content), root=root, - file_name=extract_file_name(source), + file_name=extract_file_name(rdf_source), ) @@ -97,12 +123,13 @@ def resolve_source( source = source.get_absolute(root) if isinstance(source, AnyUrl): - source = Path(pooch.retrieve(source, known_hash=known_hash)) # type: ignore + _s: Any = pooch.retrieve(str(source), known_hash=known_hash) + source = Path(_s) return source -def dump_description_to_file(rd: Union[ResourceDescription, YamlMapping], /, file_path: Path): +def write_description(rd: Union[ResourceDescription, RdfContent], /, file_path: FilePath): if isinstance(rd, ResourceDescriptionBase): content = dump_description(rd) else: @@ -112,94 +139,115 @@ def dump_description_to_file(rd: Union[ResourceDescription, YamlMapping], /, fil yaml.dump(content, f) -def load_description_from_file_and_validate( - rdf_source: FileSource, - /, - *, - warning_level: WarningLevel = ERROR, -) -> Tuple[Optional[ResourceDescription], ValidationSummary]: - rdf = read_rdf(rdf_source) - return load_description_and_validate( - rdf.content, context=ValidationContext(root=rdf.root, file_name=rdf.file_name, warning_level=warning_level) - ) - - def load_description_and_validate( - rdf_content: YamlMapping, + rdf_source: RdfSource, /, *, context: Optional[ValidationContext] = None, ) -> Tuple[Optional[ResourceDescription], ValidationSummary]: """load and validate a BioImage.IO description from the content of a resource description file (RDF)""" - rd, summary = load_description(rdf_content, context=context, format_version=LATEST) - # todo: add validation + context = context or ValidationContext() + rdf_content = _get_rdf_content_and_update_context(rdf_source, context) + rd, summary = load_description_from_content(rdf_content, context=context, format_version=LATEST) + # todo: add dynamic validation return rd, summary +# def _get_default_io_context(context: Union[ValidationContext, CompleteValidationContext, None]) -> Union[ValidationContext, CompleteValidationContext]: +# if context is None: +# context = ValidationContext() + +# if "warning_level" not in context: +# context["warning_level"] = INFO + +# return context + + +def _get_rdf_content_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> RdfContent: + class RdfSourceInterpreter(BaseModel): + source: RdfSource + + rdf_source = RdfSourceInterpreter(source=rdf_source).source + + if isinstance(rdf_source, (AnyUrl, Path, str)): + rdf = read_rdf_content(rdf_source) + rdf_source = rdf.content + context.root = rdf.root + context.file_name = rdf.file_name + elif isinstance(rdf_source, ResourceDescriptionBase): + rdf_source = dump_description(rdf_source, exclude_unset=False) + + return rdf_source + + +def _get_description_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> ResourceDescription: + if not isinstance(rdf_source, ResourceDescriptionBase): + descr, summary = load_description(rdf_source, context=context) + if descr is None: + rdf_source_msg = ( + f"{{name={rdf_source.get('name', 'missing'), ...}}})" + if isinstance(rdf_source, collections.abc.Mapping) + else rdf_source + ) + raise ValueError(f"Failed to load {rdf_source_msg}:\n{summary.format()}") + rdf_source = descr + + return rdf_source + + def validate( - rdf_content: YamlMapping, + rdf_source: RdfSource, /, *, context: Optional[ValidationContext] = None, ) -> ValidationSummary: - _rd, summary = load_description_and_validate(rdf_content, context=context) + _rd, summary = load_description_and_validate(rdf_source, context=context) return summary -def validate_rdf(rdf_source: FileSource, /, *, warning_level: WarningLevel = ERROR) -> ValidationSummary: - _rd, summary = load_description_from_file_and_validate(rdf_source, warning_level=warning_level) - return summary - - -def validate_rdf_format(rdf_source: FileSource, /, *, warning_level: WarningLevel = ERROR) -> ValidationSummary: - _rd, summary = load_description_from_file(rdf_source, warning_level=warning_level) +def validate_format_only( + rdf_source: Union[ResourceDescription, RdfContent, FileSource], context: Optional[ValidationContext] = None +) -> ValidationSummary: + _rd, summary = load_description(rdf_source, context=context) return summary def prepare_resource_package( - rd: ResourceDescription, + rdf_source: RdfSource, /, *, - root: Union[AnyUrl, DirectoryPath], - output_folder: DirectoryPath, + context: Optional[ValidationContext] = None, weights_priority_order: Optional[Sequence[WeightsFormat]] = None, -) -> Dict[FileName, FilePath]: +) -> Dict[FileName, Union[FilePath, RdfContent]]: """Prepare to package a resource description; downloads all required files. Args: - rd: bioimage.io resource description - root: URL or path to resolve relative file paths in `rd` + rdf_source: A bioimage.io resource description (as file, raw YAML content or description class) + context: validation context weights_priority_order: If given only the first weights format present in the model is included. If none of the prioritized weights formats is found all are included. """ + context = context or ValidationContext() + rd = _get_description_and_update_context(rdf_source, context) package_content = get_resource_package_content(rd, weights_priority_order=weights_priority_order) - output_folder.mkdir(parents=True, exist_ok=True) - local_package_content: Dict[FileName, FilePath] = {} + local_package_content: Dict[FileName, Union[FilePath, RdfContent]] = {} for k, v in package_content.items(): - in_package_path = output_folder / k - if isinstance(v, RelativeFilePath): - v = v.get_absolute(root) - - if isinstance(v, AnyUrl): - v = resolve_source(v, root=root) + if not isinstance(v, collections.abc.Mapping): + v = resolve_source(v, root=context.root) - if isinstance(v, Path): - shutil.copy(str(v), str(in_package_path)) - else: - assert isinstance(v, collections.abc.Mapping) - dump_description_to_file(v, in_package_path) - - local_package_content[k] = in_package_path + local_package_content[k] = v return local_package_content + # output_folder.mkdir(parents=True, exist_ok=True) + def write_package( - rd: Union[ResourceDescription, FileSource], + rdf_source: RdfSource, /, *, - root: Union[AnyUrl, DirectoryPath] = Path(), + context: Optional[ValidationContext] = None, compression: int = ZIP_DEFLATED, compression_level: int = 1, output_path: Optional[os.PathLike[str]] = None, @@ -220,7 +268,7 @@ def write_package( Args: rd: bioimage.io resource description - root: reference for any relative file paths in the bioimage.io resource description + context: compression: The numeric constant of compression method. compression_level: Compression level to use when writing files to the archive. See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile @@ -231,21 +279,11 @@ def write_package( Returns: path to zipped bioimage.io package in BIOIMAGEIO_CACHE_PATH or 'output_path' """ - if isinstance(rd, (AnyUrl, os.PathLike, str)): - rd_, summary = load_description_from_file(rd) - if rd_ is None: - raise ValueError(summary.format()) - else: - rd = rd_ - - with TemporaryDirectory() as tmp_dir: - package_content = prepare_resource_package( - rd, - root=root, - output_folder=Path(tmp_dir), - weights_priority_order=weights_priority_order, - ) - + package_content = prepare_resource_package( + rdf_source, + context=context, + weights_priority_order=weights_priority_order, + ) if output_path is None: output_path = Path(NamedTemporaryFile(suffix=".bioimageio.zip", delete=False).name) else: diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index 30026b0a..9628a356 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -3,46 +3,46 @@ from fractions import Fraction from itertools import product from pathlib import Path -from typing import Any, Dict, Iterator, List, NamedTuple, Optional, OrderedDict, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, Iterator, List, NamedTuple, Optional, OrderedDict, Sequence, Tuple, Union import numpy as np -from pydantic import HttpUrl import xarray as xr +from bioimageio.spec import ResourceDescription +from bioimageio.spec.model.v0_5 import AxisType +from numpy.typing import NDArray +from pydantic import HttpUrl +from tqdm import tqdm from bioimageio.core import image_helper, load_resource_description from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline from bioimageio.core.resource_io.nodes import ImplicitOutputShape, Model, ResourceDescription -from tqdm import tqdm -from bioimageio.spec import ResourceDescription - -def _apply_crop(data, crop): - crop = tuple(crop[ax] for ax in data.dims) - return data[crop] +Axis = Hashable class TileDef(NamedTuple): - outer: Dict[str, slice] - inner: Dict[str, slice] - local: Dict[str, slice] + outer: Dict[Axis, slice] + inner: Dict[Axis, slice] + local: Dict[Axis, slice] def get_tiling( shape: Sequence[int], - tile_shape: Dict[str, int], - halo: Dict[str, int], - input_axes: Sequence[str], - scaling: Dict[str, float], + tile_shape: Dict[Axis, int], + halo: Dict[Axis, int], + input_axes: Sequence[Axis], + axis_types: Dict[Axis, AxisType], + scaling: Dict[Axis, float], ) -> Iterator[TileDef]: # outer_tile is the "input" tile, inner_tile is the "output" tile with the halo removed # tile_shape is the shape of the outer_tile assert len(shape) == len(input_axes) - scaling = {ax: Fraction(sc).limit_denominator() for ax, sc in scaling.items()} + scaling_fractions = {ax: Fraction(sc).limit_denominator() for ax, sc in scaling.items()} - shape_ = [sh for sh, ax in zip(shape, input_axes) if ax in "xyz"] - spatial_axes = [ax for ax in input_axes if ax in "xyz"] + shape_ = [sh for sh, ax in zip(shape, input_axes) if axis_types[ax] == "space"] + spatial_axes = [ax for ax in input_axes if axis_types[ax] == "space"] inner_tile_shape_ = [tile_shape[ax] - 2 * halo[ax] for ax in spatial_axes] - scaling_ = [scaling[ax] for ax in spatial_axes] + scaling_ = [scaling_fractions[ax] for ax in spatial_axes] assert all([sh % fr.denominator == 0 for sh, fr in zip(shape_, scaling_)]) assert all([ish % fr.denominator == 0 for ish, fr in zip(inner_tile_shape_, scaling_)]) halo_ = [halo[ax] for ax in spatial_axes] @@ -58,15 +58,15 @@ def get_tiling( ax: slice(int(pos * fr), int(min(pos + tsh, sh) * fr)) for ax, pos, tsh, sh, fr in zip(spatial_axes, positions, inner_tile_shape_, shape_, scaling_) } - inner_tile["b"] = slice(None) - inner_tile["c"] = slice(None) + # inner_tile["b"] = slice(None) + # inner_tile["c"] = slice(None) outer_tile = { ax: slice(max(pos - ha, 0), min(pos + tsh + ha, sh)) for ax, pos, tsh, sh, ha in zip(spatial_axes, positions, inner_tile_shape_, shape_, halo_) } - outer_tile["b"] = slice(None) - outer_tile["c"] = slice(None) + # outer_tile["b"] = slice(None) + # outer_tile["c"] = slice(None) local_tile = { ax: slice( @@ -77,8 +77,8 @@ def get_tiling( ) for ax in spatial_axes } - local_tile["b"] = slice(None) - local_tile["c"] = slice(None) + # local_tile["b"] = slice(None) + # local_tile["c"] = slice(None) yield TileDef(outer_tile, inner_tile, local_tile) @@ -109,14 +109,11 @@ def _predict_with_tiling_impl( tiles = get_tiling(shape=input_.shape, tile_shape=tile_shape, halo=halo, input_axes=input_.dims, scaling=scaling) - assert all(isinstance(ax, str) for ax in input_.dims) - input_axes: Tuple[str, ...] = input_.dims # noqa - def load_tile(tile): inp = input_[tile] # whether to pad on the right or left of the dim for the spatial dims # + placeholders for batch and axis dimension, where we don't pad - pad_right = [tile[ax].start == 0 if ax in "xyz" else None for ax in input_axes] + pad_right = [tile[ax].start == 0 if ax in "xyz" else None for ax in input_.dims] return inp, pad_right if verbose: @@ -136,15 +133,10 @@ def load_tile(tile): output[inner_tile] = out[local_tile] -# -# prediction functions -# - - def predict( prediction_pipeline: PredictionPipeline, inputs: Union[ - xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray], np.ndarray, List[np.ndarray], Tuple[np.ndarray] + xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray], NDArray[Any], List[NDArray[Any]], Tuple[NDArray[Any]] ], ) -> List[xr.DataArray]: """Run prediction for a single set of input(s) with a bioimage.io model @@ -239,7 +231,7 @@ def predict_with_padding( ) result = predict(prediction_pipeline, inputs) if network_resizes: - crops = tuple( + crops = [ { ax: slice( crp.start if crp.start is None else int(crp.start * scale[ax] + 2 * offset[ax]), @@ -250,8 +242,8 @@ def predict_with_padding( for ax, crp in crop.items() } for crop in crops - ) - return [_apply_crop(res, crop) for res, crop in zip(result, crops)] + ] + return [res[crop] for res, crop in zip(result, crops)] # simple heuristic to determine suitable shape from min and step @@ -428,7 +420,7 @@ def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling): def predict_image( - model_rdf: RdfSource, + model_rdf: DescriptionSource, inputs: Union[Tuple[Path, ...], List[Path], Path], outputs: Union[Tuple[Path, ...], List[Path], Path], padding: Optional[Union[bool, Dict[str, int]]] = None, @@ -469,7 +461,7 @@ def predict_image( def predict_images( - model_rdf: RdfSource, + model_rdf: DescriptionSource, inputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]], outputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]], padding: Optional[Union[bool, Dict[str, int]]] = None, From 07cdcf9d6ce6d53a937139706202cdc13382dab1 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 7 Sep 2023 22:27:24 +0200 Subject: [PATCH 024/244] WIP update Processing --- .../_combined_processing.py | 14 +- .../prediction_pipeline/_measure_groups.py | 86 +++++----- .../core/prediction_pipeline/_processing.py | 148 +++++++++--------- .../core/prediction_pipeline/_stat_state.py | 11 +- bioimageio/core/prediction_pipeline/_utils.py | 90 +++++++++-- bioimageio/core/statistical_measures.py | 15 +- setup.py | 2 +- .../test_internal/test_validation_visitors.py | 80 +++++----- 8 files changed, 251 insertions(+), 195 deletions(-) diff --git a/bioimageio/core/prediction_pipeline/_combined_processing.py b/bioimageio/core/prediction_pipeline/_combined_processing.py index bbd3e354..71c0693d 100644 --- a/bioimageio/core/prediction_pipeline/_combined_processing.py +++ b/bioimageio/core/prediction_pipeline/_combined_processing.py @@ -1,14 +1,10 @@ import dataclasses -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Literal, Optional, Sequence, Union from bioimageio.core.resource_io import nodes -from ._processing import AssertDtype, EnsureDtype, KNOWN_PROCESSING, Processing, TensorName -from ._utils import ComputedMeasures, PER_DATASET, PER_SAMPLE, RequiredMeasures, Sample -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal # type: ignore +from ._processing import KNOWN_PROCESSING, AssertDtype, EnsureDtype, Processing, TensorName +from ._utils import PER_DATASET, PER_SAMPLE, ComputedMeasures, RequiredMeasures, Sample @dataclasses.dataclass @@ -20,9 +16,9 @@ class ProcessingInfoStep: @dataclasses.dataclass class ProcessingInfo: steps: List[ProcessingInfoStep] - assert_dtype_before: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match + # assert_dtype_before: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match ensure_dtype_before: Optional[str] = None # cast data type if needed - assert_dtype_after: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match + # assert_dtype_after: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match ensure_dtype_after: Optional[str] = None # throw AssertionError if data type doesn't match diff --git a/bioimageio/core/prediction_pipeline/_measure_groups.py b/bioimageio/core/prediction_pipeline/_measure_groups.py index 59e0ce74..cc2c0646 100644 --- a/bioimageio/core/prediction_pipeline/_measure_groups.py +++ b/bioimageio/core/prediction_pipeline/_measure_groups.py @@ -3,36 +3,35 @@ import collections import warnings from collections import defaultdict +from dataclasses import field from itertools import product from typing import DefaultDict, Dict, Hashable, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union import numpy import xarray as xr +from attr import dataclass +from bioimageio.spec.model.v0_5 import AxisName from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std, Var -from ._utils import ComputedMeasures, PER_DATASET, PER_SAMPLE, RequiredMeasures, Sample, TensorName -try: - from typing import Literal, TypedDict -except ImportError: - from typing_extensions import Literal, TypedDict # type: ignore +from ._utils import PER_DATASET, PER_SAMPLE, ComputedMeasures, RequiredMeasures, Sample, TensorName try: - import crick + import crick # type: ignore except ImportError: crick = None MeasureValue = xr.DataArray -class SampleMeasureGroup: +class SampleMeasureCalculator: """group of measures for more efficient computation of multiple measures per sample""" def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue]]: raise NotImplementedError -class DatasetMeasureGroup: +class DatasetMeasureCalculator: """group of measures for more efficient computation of multiple measures per dataset""" def reset(self) -> None: @@ -48,19 +47,19 @@ def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: raise NotImplementedError -MeasureGroups = TypedDict( - "MeasureGroups", {PER_SAMPLE: Sequence[SampleMeasureGroup], PER_DATASET: Sequence[DatasetMeasureGroup]} -) +@dataclass +class MeasureGroups: + per_sample: List[SampleMeasureCalculator] = field(default_factory=list) + per_dataset: List[DatasetMeasureCalculator] = field(default_factory=list) -class DatasetMean(DatasetMeasureGroup): - n: int - mean: Optional[xr.DataArray] - - def __init__(self, tensor_name: TensorName, axes: Optional[Tuple[int]]): - self.axes: Optional[Tuple[str]] = axes +class DatasetMean(DatasetMeasureCalculator): + def __init__(self, tensor_name: TensorName, axes: Optional[Sequence[AxisName]]): + super().__init__() + self.axes = None if axes is None else tuple(axes) self.tensor_name = tensor_name - self.reset() + self.n: int = 0 + self.mean: Optional[xr.DataArray] = None def reset(self): self.n = 0 @@ -89,15 +88,18 @@ def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: return {self.tensor_name: {Mean(axes=self.axes): self.mean}} -class MeanVarStd(SampleMeasureGroup, DatasetMeasureGroup): - n: int - mean: Optional[xr.DataArray] - m2: Optional[xr.DataArray] - - def __init__(self, tensor_name: TensorName, axes: Optional[Tuple[int]]): - self.axes: Optional[Tuple[str]] = axes +class MeanVarStd(SampleMeasureCalculator, DatasetMeasureCalculator): + def __init__(self, tensor_name: TensorName, axes: Optional[Sequence[AxisName]]): + self.axes = None if axes is None else tuple(axes) self.tensor_name = tensor_name - self.reset() + self.n: int = 0 + self.mean: Optional[xr.DataArray] = None + self.m2: Optional[xr.DataArray] = None + + def reset(self): + self.n = 0 + self.mean = None + self.m2 = None def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue]]: tensor = sample[self.tensor_name] @@ -108,11 +110,6 @@ def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue std = numpy.sqrt(var) return {self.tensor_name: {Mean(axes=self.axes): mean, Var(axes=self.axes): var, Std(axes=self.axes): std}} - def reset(self): - self.n = 0 - self.mean = None - self.m2 = None - def update_with_sample(self, sample: Sample): tensor = sample[self.tensor_name].astype(numpy.float64, copy=False) mean_b = tensor.mean(dim=self.axes) @@ -134,7 +131,7 @@ def update_with_sample(self, sample: Sample): self.mean = (n_a * mean_a + n_b * mean_b) / n assert self.mean.dtype == numpy.float64 d = mean_b - mean_a - self.m2 = m2_a + m2_b + d ** 2 * n_a * n_b / n + self.m2 = m2_a + m2_b + d**2 * n_a * n_b / n assert self.m2.dtype == numpy.float64 def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: @@ -151,7 +148,7 @@ def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: } -class SamplePercentiles(SampleMeasureGroup): +class SamplePercentiles(SampleMeasureCalculator): def __init__(self, tensor_name: TensorName, axes: Optional[Tuple[str]], ns: Sequence[float]): assert all(0 <= n <= 100 for n in ns) self.ns = ns @@ -165,7 +162,7 @@ def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue return {self.tensor_name: {Percentile(n=n, axes=self.axes): p for n, p in zip(self.ns, ps)}} -class MeanPercentiles(DatasetMeasureGroup): +class MeanPercentiles(DatasetMeasureCalculator): n: int estimates: Optional[xr.DataArray] @@ -203,7 +200,7 @@ def finalize(self) -> Dict[TensorName, Dict[Percentile, MeasureValue]]: return {self.tensor_name: {Percentile(n=n, axes=self.axes): e for n, e in zip(self.ns, self.estimates)}} -class CrickPercentiles(DatasetMeasureGroup): +class CrickPercentiles(DatasetMeasureCalculator): digest: Optional[List["crick.TDigest"]] dims: Optional[Tuple[Hashable, ...]] indices: Optional[Iterator[Tuple[int, ...]]] @@ -259,15 +256,16 @@ def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: if crick is None: - DatasetPercentileGroup: Union[Type[MeanPercentiles], Type[CrickPercentiles]] = MeanPercentiles + DatasetPercentileGroup: Type[Union[MeanPercentiles, CrickPercentiles]] = MeanPercentiles else: DatasetPercentileGroup = CrickPercentiles -class SingleMeasureAsGroup(SampleMeasureGroup): +class SingleMeasureAsGroup(SampleMeasureCalculator): """wrapper for measures to match interface of SampleMeasureGroup""" def __init__(self, tensor_name: TensorName, measure: Measure): + super().__init__() self.tensor_name = tensor_name self.measure = measure @@ -278,7 +276,7 @@ def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue def get_measure_groups(measures: RequiredMeasures) -> MeasureGroups: """find a list of MeasureGroups to compute measures efficiently""" - measure_groups = {PER_SAMPLE: [], PER_DATASET: []} + measure_groups = MeasureGroups() means: Set[Tuple[TensorName, Mean]] = set() mean_var_std_groups: Set[Tuple[TensorName, Optional[Tuple[str, ...]]]] = set() percentile_groups: DefaultDict[Tuple[TensorName, Optional[Tuple[str, ...]]], List[float]] = defaultdict(list) @@ -292,12 +290,12 @@ def get_measure_groups(measures: RequiredMeasures) -> MeasureGroups: elif isinstance(m, Percentile): percentile_groups[(tn, m.axes)].append(m.n) elif mode == PER_SAMPLE: - measure_groups[mode].append(SingleMeasureAsGroup(tensor_name=tn, measure=m)) + measure_groups.per_sample.append(SingleMeasureAsGroup(tensor_name=tn, measure=m)) else: raise NotImplementedError(f"Computing statistics for {m} {mode} not yet implemented") # add all mean measures that are not included in a mean/var/std group - for (tn, m) in means: + for tn, m in means: if (tn, m.axes) not in mean_var_std_groups: # compute only mean if mode == PER_SAMPLE: @@ -307,7 +305,7 @@ def get_measure_groups(measures: RequiredMeasures) -> MeasureGroups: else: raise NotImplementedError(mode) - for (tn, axes) in mean_var_std_groups: + for tn, axes in mean_var_std_groups: measure_groups[mode].append(MeanVarStd(tensor_name=tn, axes=axes)) for (tn, axes), ns in percentile_groups.items(): @@ -328,16 +326,16 @@ def compute_measures( ret = {PER_SAMPLE: {}, PER_DATASET: {}} if sample is not None: for mg in ms_groups[PER_SAMPLE]: - assert isinstance(mg, SampleMeasureGroup) + assert isinstance(mg, SampleMeasureCalculator) ret[PER_SAMPLE].update(mg.compute(sample)) for sample in dataset: for mg in ms_groups[PER_DATASET]: - assert isinstance(mg, DatasetMeasureGroup) + assert isinstance(mg, DatasetMeasureCalculator) mg.update_with_sample(sample) for mg in ms_groups[PER_DATASET]: - assert isinstance(mg, DatasetMeasureGroup) + assert isinstance(mg, DatasetMeasureCalculator) ret[PER_DATASET].update(mg.finalize()) return ret diff --git a/bioimageio/core/prediction_pipeline/_processing.py b/bioimageio/core/prediction_pipeline/_processing.py index 6fbea8c6..ecf40efc 100644 --- a/bioimageio/core/prediction_pipeline/_processing.py +++ b/bioimageio/core/prediction_pipeline/_processing.py @@ -2,23 +2,26 @@ see https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/preprocessing_spec_latest.md and https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/postprocessing_spec_latest.md """ +from abc import ABC, abstractmethod import numbers from dataclasses import InitVar, dataclass, field, fields -from typing import List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Dict, Generic, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing_extensions import Self import numpy import numpy as np +from pydantic import model_validator # type: ignore +from pydantic import field_validator import xarray as xr +from bioimageio.spec._internal.base_nodes import Node +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import Processing as ProcessingSpec, ProcessingKwargs, Binarize, Clip +from bioimageio.spec.model.v0_5 import TensorId +from numpy.typing import DTypeLike +from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std, MeasureValue +from ._utils import FIXED, PER_DATASET, PER_SAMPLE, DatasetMode, Mode, RequiredMeasure, SampleMode, Sample -from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std -from bioimageio.spec.model.raw_nodes import PostprocessingName, PreprocessingName -from ._utils import ComputedMeasures, DatasetMode, FIXED, Mode, PER_DATASET, PER_SAMPLE, RequiredMeasures, SampleMode - -try: - from typing import Literal, get_args, TypedDict -except ImportError: - from typing_extensions import Literal, get_args, TypedDict # type: ignore - +from typing import Literal, TypedDict, get_args def _get_fixed( fixed: Union[float, Sequence[float]], tensor: xr.DataArray, axes: Optional[Sequence[str]] @@ -32,93 +35,81 @@ def _get_fixed( return xr.DataArray(fixed, dims=fixed_dims) -TensorName = str - -MISSING = "MISSING" +PKwargs = TypeVar("PKwargs", bound=ProcessingKwargs) +ProcInput = TypeVar("ProcInput", xr.DataArray, Sample) -@dataclass -class Processing: +class ProcessingBase(Node, Generic[PKwargs], ABC, frozen=True): """base class for all Pre- and Postprocessing transformations.""" - tensor_name: str - # todo: in python>=3.10 we should use dataclasses.KW_ONLY instead of MISSING (see child classes) to make inheritance work properly - computed_measures: ComputedMeasures = field(default_factory=dict) - mode: Mode = FIXED + tensor_id: TensorId + """id of tensor to operate on""" + kwargs: PKwargs + computed_measures: Dict[RequiredMeasure, MeasureValue] = field(default_factory=dict) - def get_required_measures(self) -> RequiredMeasures: - return {} + @model_validator(mode="after") + def check_required_measures_in_computed(self) -> Self: + for req in self.required_measures: + if req not in self.computed_measures: + raise ValueError(f"Missing computed {req}.") - def set_computed_measures(self, computed: ComputedMeasures): - # check if computed contains all required measures - for mode, req_per_mode in self.get_required_measures().items(): - for tn, req_per_tn in req_per_mode.items(): - comp_measures = computed.get(mode, {}).get(tn, {}) - for req_measure in req_per_tn: - if req_measure not in comp_measures: - raise ValueError(f"Missing required {req_measure} for {tn} {mode}.") + return self - self.computed_measures = computed + @classmethod + def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> Tuple[RequiredMeasure, ...]: + return () - def get_computed_measure(self, tensor_name: TensorName, measure: Measure, *, mode: Optional[Mode] = None): - """helper to unpack self.computed_measures""" - ret = self.computed_measures.get(mode or self.mode, {}).get(tensor_name, {}).get(measure) - if ret is None: - raise RuntimeError(f"Missing computed {measure} for {tensor_name} {mode}.") + @property + def required_measures(self) -> Tuple[RequiredMeasure, ...]: + return self.get_required_measures(tensor_id=self.tensor_id, kwargs=self.kwargs) - return ret - - def __call__(self, tensor: xr.DataArray) -> xr.DataArray: - return self.apply(tensor) + def __call__(self, __input: ProcInput, /) -> ProcInput: + if isinstance(__input, xr.DataArray): + return self.apply(__input) + else: + return self.apply_to_sample(__input) + @abstractmethod def apply(self, tensor: xr.DataArray) -> xr.DataArray: """apply processing""" - raise NotImplementedError - - def __post_init__(self): - # validate common kwargs by their annotations - for f in fields(self): - # check MISSING - if getattr(self, f.name) is MISSING: - raise TypeError(f"missing required argument {f.name}") - - if f.name == "mode": - # mode is always annotated as literals (or literals of literals) - valid_modes = get_args(f.type) - for inner in get_args(f.type): - valid_modes += get_args(inner) - - if self.mode not in valid_modes: - raise NotImplementedError(f"Unsupported mode {self.mode} for {self.__class__.__name__}") + ... + def apply_to_sample(self, sample: Sample) -> Sample: + ret = dict(sample) + ret[self.tensor_id] = self.apply(sample[self.tensor_id]) + return ret +class Processing(ProcessingSpec, ProcessingBase[PKwargs], frozen=True): + pass # # Pre- and Postprocessing implementations # +class NonSpecProcessing(ProcessingBase[PKwargs], frozen=True): + """processings operations beyond what is currently defined in bioimageio.spec""" + pass -@dataclass -class AssertDtype(Processing): +class AssertDtype(NonSpecProcessing[ProcessingKwargs], frozen=True): """Helper Processing to assert dtype.""" + id: Literal["assert_dtype"] = "assert_dtype" - dtype: Union[str, Sequence[str]] = MISSING - assert_with: Tuple[Type[numpy.dtype], ...] = field(init=False) + dtype: Union[str, Sequence[str]] + _assert_with: Tuple[Type[DTypeLike], ...] - def __post_init__(self): + def __pydantic_postinit__(self): if isinstance(self.dtype, str): dtype = [self.dtype] else: dtype = self.dtype - self.assert_with = tuple(type(numpy.dtype(dt)) for dt in dtype) + object.__setattr__(self, "_assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype)) def apply(self, tensor: xr.DataArray) -> xr.DataArray: - assert isinstance(tensor.dtype, self.assert_with) + assert isinstance(tensor.dtype, self._assert_with) return tensor -@dataclass -class Binarize(Processing): +class Binarize(Processing[BinarizeKwargs]): """'output = tensor > threshold'.""" threshold: float = MISSING # make dataclass inheritance work for py<3.10 by using an explicit MISSING value. @@ -187,7 +178,7 @@ def get_required_measures(self) -> RequiredMeasures: axes = None if self.axes is None else tuple(self.axes) return { self.mode: { - self.tensor_name: {Mean(axes=axes), Std(axes=axes)}, + self.tensor_id: {Mean(axes=axes), Std(axes=axes)}, self.reference_tensor: {Mean(axes=axes), Std(axes=axes)}, } } @@ -195,8 +186,8 @@ def get_required_measures(self) -> RequiredMeasures: def apply(self, tensor: xr.DataArray) -> xr.DataArray: axes = None if self.axes is None else tuple(self.axes) assert self.mode in (PER_SAMPLE, PER_DATASET) - mean = self.get_computed_measure(self.tensor_name, Mean(axes), mode=self.mode) - std = self.get_computed_measure(self.tensor_name, Std(axes), mode=self.mode) + mean = self.get_computed_measure(self.tensor_id, Mean(axes), mode=self.mode) + std = self.get_computed_measure(self.tensor_id, Std(axes), mode=self.mode) ref_mean = self.get_computed_measure(self.reference_tensor, Mean(axes), mode=self.mode) ref_std = self.get_computed_measure(self.reference_tensor, Std(axes), mode=self.mode) @@ -217,10 +208,10 @@ class ScaleRange(Processing): def get_required_measures(self) -> RequiredMeasures: axes = None if self.axes is None else tuple(self.axes) measures = {Percentile(self.min_percentile, axes=axes), Percentile(self.max_percentile, axes=axes)} - return {self.mode: {self.reference_tensor or self.tensor_name: measures}} + return {self.mode: {self.reference_tensor or self.tensor_id: measures}} def apply(self, tensor: xr.DataArray) -> xr.DataArray: - ref_name = self.reference_tensor or self.tensor_name + ref_name = self.reference_tensor or self.tensor_id axes = None if self.axes is None else tuple(self.axes) v_lower = self.get_computed_measure(ref_name, Percentile(self.min_percentile, axes=axes)) v_upper = self.get_computed_measure(ref_name, Percentile(self.max_percentile, axes=axes)) @@ -255,7 +246,7 @@ def get_required_measures(self) -> RequiredMeasures: return {} else: axes = None if self.axes is None else tuple(self.axes) - return {self.mode: {self.tensor_name: {Mean(axes=axes), Std(axes=axes)}}} + return {self.mode: {self.tensor_id: {Mean(axes=axes), Std(axes=axes)}}} def apply(self, tensor: xr.DataArray) -> xr.DataArray: axes = None if self.axes is None else tuple(self.axes) @@ -265,20 +256,21 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray: std = _get_fixed(self.std, tensor, axes) elif self.mode in (PER_SAMPLE, PER_DATASET): assert self.mean is None and self.std is None - mean = self.get_computed_measure(self.tensor_name, Mean(axes), mode=self.mode) - std = self.get_computed_measure(self.tensor_name, Std(axes), mode=self.mode) + mean = self.get_computed_measure(self.tensor_id, Mean(axes), mode=self.mode) + std = self.get_computed_measure(self.tensor_id, Std(axes), mode=self.mode) else: raise ValueError(self.mode) return (tensor - mean) / (std + self.eps) -_KnownProcessing = TypedDict( - "_KnownProcessing", - dict(pre=Mapping[PreprocessingName, Type[Processing]], post=Mapping[PostprocessingName, Type[Processing]]), -) +class _KNOWN_PREPROCESSING(TypedDict): + +class _KnownProcessing(TypedDict): + pre: Mapping[PreprocessingName, Type[Processing]] + post: Mapping[PostprocessingName, Type[Processing]] -KNOWN_PROCESSING: _KnownProcessing = dict( +KNOWN_PROCESSING = _KnownProcessing( pre={ "binarize": Binarize, "clip": Clip, diff --git a/bioimageio/core/prediction_pipeline/_stat_state.py b/bioimageio/core/prediction_pipeline/_stat_state.py index 6de4d68d..cf5be64c 100644 --- a/bioimageio/core/prediction_pipeline/_stat_state.py +++ b/bioimageio/core/prediction_pipeline/_stat_state.py @@ -1,14 +1,11 @@ from typing import Dict, Iterable, Optional +from tqdm import tqdm + from bioimageio.core.statistical_measures import Measure -from bioimageio.spec.shared.common import tqdm -from ._measure_groups import MeasureGroups, MeasureValue, get_measure_groups -from ._utils import ComputedMeasures, PER_DATASET, PER_SAMPLE, RequiredMeasures, Sample, TensorName -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal # type: ignore +from ._measure_groups import MeasureGroups, MeasureValue, get_measure_groups +from ._utils import PER_DATASET, PER_SAMPLE, ComputedMeasures, RequiredMeasures, Sample, TensorName class StatsState: diff --git a/bioimageio/core/prediction_pipeline/_utils.py b/bioimageio/core/prediction_pipeline/_utils.py index 8b39753d..78d0e478 100644 --- a/bioimageio/core/prediction_pipeline/_utils.py +++ b/bioimageio/core/prediction_pipeline/_utils.py @@ -1,15 +1,14 @@ -from typing import Dict, Set +from __future__ import annotations + +import collections.abc +from dataclasses import dataclass, field +from typing import Any, Dict, Iterator, List, Literal, NamedTuple, Set, Union import xarray as xr +from bioimageio.spec.model.v0_5 import TensorId from bioimageio.core.statistical_measures import Measure, MeasureValue -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal # type: ignore - -TensorName = str FixedMode = Literal["fixed"] SampleMode = Literal["per_sample"] DatasetMode = Literal["per_dataset"] @@ -20,6 +19,77 @@ PER_DATASET: DatasetMode = "per_dataset" MODES: Set[Mode] = {FIXED, PER_SAMPLE, PER_DATASET} -Sample = Dict[TensorName, xr.DataArray] -RequiredMeasures = Dict[Literal[SampleMode, DatasetMode], Dict[TensorName, Set[Measure]]] -ComputedMeasures = Dict[Literal[SampleMode, DatasetMode], Dict[TensorName, Dict[Measure, MeasureValue]]] + +Sample = Dict[TensorId, xr.DataArray] + + +class RequiredMeasure(NamedTuple): + measure: Measure + tensor_id: TensorId + mode: Mode + + # def __repr__(self) -> str: + # return f"{self.measure} of {self.tensor_id} ({self.mode})" + + +# RequiredMeasures = List[ReqMeasure] +# @dataclass +# class RequiredMeasures(collections.abc.Iterator[ReqMeasureEntry]): +# per_sample: Dict[TensorId, Set[Measure]] = field(default_factory=dict) +# per_dataset: Dict[TensorId, Set[Measure]] = field(default_factory=dict) + +# def update(self, *others: RequiredMeasures): +# for other in others: +# for t, ms in other.per_sample.items(): +# self.per_sample.setdefault(t, set()).update(ms) + +# for t, ms in other.per_dataset.items(): +# self.per_dataset.setdefault(t, set()).update(ms) + +# def __iter__(self) -> Iterator[ReqMeasureEntry]: +# for t, ms in self.per_sample.items(): +# for m in ms: +# yield ReqMeasureEntry("per_sample", t, m) + +# for t, ms in self.per_dataset.items(): +# for m in ms: +# yield ReqMeasureEntry("per_dataset", t, m) + + +# class ComputedMeasure(NamedTuple): +# measure: Measure +# tensor_id: TensorId +# mode: Mode +# value: MeasureValue +# def __repr__(self) -> str: +# return f"{self.measure} of {self.tensor_id} ({self.mode}) is {self.value}" + + +# @dataclass +# class ComputedMeasures(collections.abc.Container[CompMeasureEntry]): +# per_sample: Dict[TensorId, Dict[Measure, MeasureValue]] = field(default_factory=dict) +# per_dataset: Dict[TensorId, Dict[Measure, MeasureValue]] = field(default_factory=dict) + +# def update(self, other: ComputedMeasures) -> None: +# for t, ms in other.per_sample.items(): +# self.per_sample.setdefault(t, {}).update(ms) + +# for t, ms in other.per_dataset.items(): +# self.per_dataset.setdefault(t, {}).update(ms) + +# def __iter__(self) -> Iterator[CompMeasureEntry]: +# for t, ms in self.per_sample.items(): +# for m, v in ms.items(): +# yield CompMeasureEntry("per_sample", t, m, v) + +# for t, ms in self.per_dataset.items(): +# for m, v in ms.items(): +# yield CompMeasureEntry("per_dataset", t, m, v) + +# def __contains__(self, __x: Any) -> bool: +# if isinstance(__x, CompMeasureEntry): + +# elif isinstance(__x, ReqMeasureEntry): + +# else: +# return super().__contains__(__x) \ No newline at end of file diff --git a/bioimageio/core/statistical_measures.py b/bioimageio/core/statistical_measures.py index 0a3df99b..0c9c94ec 100644 --- a/bioimageio/core/statistical_measures.py +++ b/bioimageio/core/statistical_measures.py @@ -1,23 +1,26 @@ from __future__ import annotations +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Optional, Tuple import xarray as xr +from bioimageio.spec.model.v0_5 import AxisName MeasureValue = xr.DataArray @dataclass(frozen=True) -class Measure: +class Measure(ABC): + @abstractmethod def compute(self, tensor: xr.DataArray) -> MeasureValue: """compute the measure (and also associated other Measures)""" - raise NotImplementedError(self.__class__.__name__) + ... @dataclass(frozen=True) class Mean(Measure): - axes: Optional[Tuple[str, ...]] = None + axes: Optional[Tuple[AxisName, ...]] = None def compute(self, tensor: xr.DataArray) -> xr.DataArray: return tensor.mean(dim=self.axes) @@ -25,7 +28,7 @@ def compute(self, tensor: xr.DataArray) -> xr.DataArray: @dataclass(frozen=True) class Std(Measure): - axes: Optional[Tuple[str, ...]] = None + axes: Optional[Tuple[AxisName, ...]] = None def compute(self, tensor: xr.DataArray) -> xr.DataArray: return tensor.std(dim=self.axes) @@ -33,7 +36,7 @@ def compute(self, tensor: xr.DataArray) -> xr.DataArray: @dataclass(frozen=True) class Var(Measure): - axes: Optional[Tuple[str, ...]] = None + axes: Optional[Tuple[AxisName, ...]] = None def compute(self, tensor: xr.DataArray) -> xr.DataArray: return tensor.var(dim=self.axes) @@ -42,7 +45,7 @@ def compute(self, tensor: xr.DataArray) -> xr.DataArray: @dataclass(frozen=True) class Percentile(Measure): n: float - axes: Optional[Tuple[str, ...]] = None + axes: Optional[Tuple[AxisName, ...]] = None def __post_init__(self): assert self.n >= 0 diff --git a/setup.py b/setup.py index dcca6bab..e140ffac 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ ], include_package_data=True, extras_require={ - "test": ["pytest", "black"], + "test": ["pytest", "black[jupyter]"], "dev": ["pre-commit"], "pytorch": ["torch>=1.6", "torchvision"], "tensorflow": ["tensorflow"], diff --git a/tests/test_internal/test_validation_visitors.py b/tests/test_internal/test_validation_visitors.py index 0aaa882c..cc702755 100644 --- a/tests/test_internal/test_validation_visitors.py +++ b/tests/test_internal/test_validation_visitors.py @@ -1,40 +1,40 @@ -from functools import singledispatchmethod - -from bioimageio.spec._internal.base_nodes import Node -from bioimageio.spec.summary import ErrorOutcome - -from bioimageio.core._internal.validation_visitors import Note, ValidationVisitor - - -def test_traversing_nodes(): - class MyVisitor(ValidationVisitor): - @singledispatchmethod - def visit(self, obj: type, note: Note = Note()): - super().visit(obj, note) - - @visit.register - def _visit_int(self, nr: int, note: Note = Note()): - super().visit(nr, note) - self.errors.append(ErrorOutcome(loc=note.loc, msg=f"nr: {nr}", type="got-int")) - - class NestedNode(Node): - leaf: int - - class MyNode(Node): - nested: NestedNode - - tree = { - "a": MyNode(nested=NestedNode(leaf=1)), - "b": [NestedNode(leaf=2), NestedNode(leaf=3)], - "c": (NestedNode(leaf=4),), - "d": {"deep": MyNode(nested=NestedNode(leaf=5))}, - } - visitor = MyVisitor() - visitor.visit(tree) - assert len(visitor.errors) == [ - ErrorOutcome(loc=("a", "nested", "leaf"), msg="nr: 1", type="got-int"), - ErrorOutcome(loc=("b", 0, "leaf"), msg="nr: 2", type="got-int"), - ErrorOutcome(loc=("b", 1, "leaf"), msg="nr: 3", type="got-int"), - ErrorOutcome(loc=("c", 0, "leaf"), msg="nr: 4", type="got-int"), - ErrorOutcome(loc=("d", "deep", "nested", "leaf"), msg="nr: 5", type="got-int"), - ] +from functools import singledispatchmethod + +from bioimageio.spec._internal.base_nodes import Node +from bioimageio.spec.summary import ErrorOutcome + +from bioimageio.core._internal.validation_visitors import Note, ValidationVisitor + + +def test_traversing_nodes(): + class MyVisitor(ValidationVisitor): + @singledispatchmethod + def visit(self, obj: type, note: Note = Note()): + super().visit(obj, note) + + @visit.register + def _visit_int(self, nr: int, note: Note = Note()): + super().visit(nr, note) + self.errors.append(ErrorOutcome(loc=note.loc, msg=f"nr: {nr}", type="got-int")) + + class NestedNode(Node, frozen=True): + leaf: int + + class MyNode(Node, frozen=True): + nested: NestedNode + + tree = { + "a": MyNode(nested=NestedNode(leaf=1)), + "b": [NestedNode(leaf=2), NestedNode(leaf=3)], + "c": (NestedNode(leaf=4),), + "d": {"deep": MyNode(nested=NestedNode(leaf=5))}, + } + visitor = MyVisitor() + visitor.visit(tree) + assert len(visitor.errors) == [ + ErrorOutcome(loc=("a", "nested", "leaf"), msg="nr: 1", type="got-int"), + ErrorOutcome(loc=("b", 0, "leaf"), msg="nr: 2", type="got-int"), + ErrorOutcome(loc=("b", 1, "leaf"), msg="nr: 3", type="got-int"), + ErrorOutcome(loc=("c", 0, "leaf"), msg="nr: 4", type="got-int"), + ErrorOutcome(loc=("d", "deep", "nested", "leaf"), msg="nr: 5", type="got-int"), + ] From c1caee17acada15f4fe0540969112a5e66cb5f9a Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 7 Sep 2023 22:36:10 +0200 Subject: [PATCH 025/244] WIP add new example notebooks --- example/dataset_creation.ipynb | 89 ++++++++++++ example/demo.ipynb | 249 +++++++++++++++++++++++++++++++++ 2 files changed, 338 insertions(+) create mode 100644 example/dataset_creation.ipynb create mode 100644 example/demo.ipynb diff --git a/example/dataset_creation.ipynb b/example/dataset_creation.ipynb new file mode 100644 index 00000000..dfefb681 --- /dev/null +++ b/example/dataset_creation.ipynb @@ -0,0 +1,89 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from bioimageio.spec.pretty_validation_errors import enable_pretty_validation_errors_in_ipynb\n", + "\n", + "enable_pretty_validation_errors_in_ipynb()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "\n", + "from bioimageio.spec.dataset.v0_3 import Author, CiteEntry, Dataset\n", + "\n", + "nuclei_broad_data = Dataset(\n", + " name=\"Kaggle 2018 Data Science Bowl\",\n", + " description=\"This image data set contains a large number of segmented nuclei images and was created for the Kaggle \"\n", + " \"2018 Data Science Bowl sponsored by Booz Allen Hamilton with cash prizes. The image set was a testing ground \"\n", + " \"for the application of novel and cutting edge approaches in computer vision and machine learning to the \"\n", + " \"segmentation of the nuclei belonging to cells from a breadth of biological contexts.\",\n", + " documentation=\"README.md\",\n", + " covers=(\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage1.png\",\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage2.png\",\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage3.png\",\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage4.png\",\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage5.png\",\n", + " ),\n", + " authors=(\n", + " Author(name=\"Fynn Beuttenmueller\", affiliation=\"EMBL\", github_user=\"fynnbe\", orcid=\"0000-0002-8567-6389\"),\n", + " ),\n", + " source=\"https://bbbc.lbroadinstitute.org/BBBC038/\",\n", + " cite=(\n", + " CiteEntry(\n", + " text=\"Caicedo, J.C., Goodman, A., Karhohs, K.W. et al. Nucleus segmentation across imaging experiments: \"\n", + " \"the 2018 Data Science Bowl. Nat Methods 16, 1247–1253 (2019).\",\n", + " url=\"10.1038/s41592-019-0612-7\",\n", + " ),\n", + " CiteEntry(\n", + " text=\"Allen Goodman, Anne Carpenter, Elizabeth Park, jlefman-nvidia, Josette_BoozAllen, Kyle, Maggie, \"\n", + " \"Nilofer, Peter Sedivec, Will Cukierski. (2018). 2018 Data Science Bowl . Kaggle.\",\n", + " url=\"https://kaggle.com/competitions/data-science-bowl-2018\",\n", + " ),\n", + " ),\n", + " timestamp=datetime.today(),\n", + " license=\"CC0-1.0\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bio38", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/example/demo.ipynb b/example/demo.ipynb new file mode 100644 index 00000000..c12aff25 --- /dev/null +++ b/example/demo.ipynb @@ -0,0 +1,249 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from bioimageio.spec.pretty_validation_errors import enable_pretty_validation_errors_in_ipynb\n", + "\n", + "enable_pretty_validation_errors_in_ipynb()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from bioimageio.core import load_description_and_validate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "descr, summary = load_description_and_validate(\"10.5281/zenodo.6559929/6559930/rdf.yaml\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from bioimageio.spec.model.v0_5 import Author, CiteEntry, Model\n", + "\n", + "# # id: raw\n", + "# # description: raw input\n", + "# # axes:\n", + "# # - type: batch\n", + "# # - type: channel\n", + "# # channel_names: [raw_intensity]\n", + "# # - type: space # todo: scale/unit\n", + "# # name: y\n", + "# # size: 512\n", + "# # - type: space\n", + "# # name: x\n", + "# # size: 512\n", + "# # test_tensor: test_input.npy\n", + "# # sample_tensor: test_input.npy\n", + "# # preprocessing: # list of preprocessing steps\n", + "# # - id: zero_mean_unit_variance # name of preprocessing step\n", + "# # kwargs:\n", + "# # mode: per_sample\n", + "# # axes: [x, y]\n", + "\n", + "# # outputs:\n", + "# # - id: probability\n", + "# # description: probability in [0,1]\n", + "# # data:\n", + "# # type: float32\n", + "# # range:\n", + "# # - 0.0\n", + "# # - 1.0\n", + "# # axes:\n", + "# # - type: batch\n", + "# # - type: channel\n", + "# # channel_names: [probability]\n", + "# # - type: space\n", + "# # name: y\n", + "# # size: raw.y\n", + "# # halo: 32\n", + "# # - type: space\n", + "# # size: raw.x\n", + "# # name: x\n", + "# # halo: 32\n", + "# # test_tensor: test_output.npy\n", + "# # sample_tensor: test_output.npy\n", + "\n", + "# # weights:\n", + "# # pytorch_state_dict:\n", + "# # authors:\n", + "# # - name: \"Constantin Pape;@bioimage-io\"\n", + "# # affiliation: \"EMBL Heidelberg\"\n", + "# # orcid: \"0000-0001-6562-7187\"\n", + "# # sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2\n", + "# # source: https://zenodo.org/record/3446812/files/unet2d_weights.torch\n", + "# # architecture:\n", + "# # callable: unet2d.py:UNet2d\n", + "# # sha256: cf42a6d86adeb4eb6e8e37b539a20e5413866b183bed88f4e2e26ad1639761ed\n", + "# # kwargs: { input_channels: 1, output_channels: 1 }\n", + "# # dependencies: conda:environment.yaml\n", + "# # pytorch_version: \"1.5.1\"\n", + "# # onnx:\n", + "# # sha256: f1f086d5e340f9d4d7001a1b62a2b835f9b87a2fb5452c4fe7d8cc821bdf539c\n", + "# # source: weights.onnx\n", + "# # opset_version: 12\n", + "# # parent: pytorch_state_dict\n", + "# # torchscript:\n", + "# # sha256: 62fa1c39923bee7d58a192277e0dd58f2da9ee810662addadd0f44a3784d9210\n", + "# # source: weights.pt\n", + "# # parent: pytorch_state_dict\n", + "# # pytorch_version: \"1.5.1\"\n", + "\n", + "\n", + "# my_model = Model(\n", + "# name=\"UNet 2D Nuclei Broad\",\n", + "# version=\"0.2.0\",\n", + "# description=\"A 2d U-Net trained on the nuclei broad dataset.\",\n", + "# documentation=\"README.md\",\n", + "# authors=(\n", + "# Author(\n", + "# name=\"Constantin Pape\",\n", + "# affiliation=\"EMBL Heidelberg\",\n", + "# orcid=\"0000-0001-6562-7187\",\n", + "# ),\n", + "# Author(\n", + "# name=\"Fynn Beuttenmueller\",\n", + "# affiliation=\"EMBL Heidelberg\",\n", + "# orcid=\"0000-0002-8567-6389\",\n", + "# ),\n", + "# ),\n", + "# cite=(CiteEntry(text=\"bioimage.io\", doi=\"10.1101/2022.06.07.495102\"),),\n", + "# inputs=(),\n", + "# outputs=(),\n", + "# timestamp=\"2019-12-11T12:22:32\",\n", + "# training_data={\"id\": \"ilastik/covid_if_training_data\"}, # note: not the real training data\n", + "# license=\"MIT\",\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr\n", + "\n", + "a = xr.DataArray([[1, 2], [3, 4]], dims=(\"x\", \"y\"))\n", + "a[{\"x\": slice(None)}]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bio38", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From bfe28e6129cde8f959895306b9a306ddd134e522 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 7 Sep 2023 22:36:23 +0200 Subject: [PATCH 026/244] update isort settings --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a8002b85..e54676f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ target-version = ['py38', 'py39', 'py310', 'py311'] [tool.isort] line_length = 120 -multi_line_output = 3 -include_trailing_comma = true +profile = "black" +# sort_reexports = true # buggy [tool.pyright] include = ["bioimageio", "scripts", "tests"] From 0f0d6ae966f0badc4dde0e6034b63806d91da16f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 7 Sep 2023 22:43:47 +0200 Subject: [PATCH 027/244] crlf to lf --- .markdownlint.json | 14 +++--- .vscode/settings.json | 28 ++++++------ bioimageio/core/_internal/pytest_utils.py | 56 +++++++++++------------ scripts/setup_dev_env.py | 40 ++++++++-------- 4 files changed, 69 insertions(+), 69 deletions(-) diff --git a/.markdownlint.json b/.markdownlint.json index 8111539b..e3494375 100644 --- a/.markdownlint.json +++ b/.markdownlint.json @@ -1,8 +1,8 @@ -{ - "default": true, - "MD013": { - "line_length": 120 - }, - "MD033": false, - "MD041": false +{ + "default": true, + "MD013": { + "line_length": 120 + }, + "MD033": false, + "MD041": false } \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 9520c20f..771ec3eb 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,15 +1,15 @@ -{ - "python.languageServer": "Pylance", - "python.analysis.typeCheckingMode": "strict", - "python.linting.pylintEnabled": true, - "python.linting.enabled": false, - "python.testing.unittestArgs": [ - "-v", - "-s", - "./tests", - "-p", - "test_*.py" - ], - "python.testing.pytestEnabled": true, - "python.testing.unittestEnabled": false, +{ + "python.languageServer": "Pylance", + "python.analysis.typeCheckingMode": "strict", + "python.linting.pylintEnabled": true, + "python.linting.enabled": false, + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test_*.py" + ], + "python.testing.pytestEnabled": true, + "python.testing.unittestEnabled": false, } \ No newline at end of file diff --git a/bioimageio/core/_internal/pytest_utils.py b/bioimageio/core/_internal/pytest_utils.py index 4ae0b8e4..c61fa62f 100644 --- a/bioimageio/core/_internal/pytest_utils.py +++ b/bioimageio/core/_internal/pytest_utils.py @@ -1,28 +1,28 @@ -from functools import wraps -from typing import Any, Protocol, Type - - -class test_func(Protocol): - def __call__(*args: Any, **kwargs: Any): - ... - - -def skip_on(exception: Type[Exception], reason: str): - """adapted from https://stackoverflow.com/a/63522579""" - import pytest - - # Func below is the real decorator and will receive the test function as param - def decorator_func(f: test_func): - @wraps(f) - def wrapper(*args: Any, **kwargs: Any): - try: - # Try to run the test - return f(*args, **kwargs) - except exception: - # If exception of given type happens - # just swallow it and raise pytest.Skip with given reason - pytest.skip(reason) - - return wrapper - - return decorator_func +from functools import wraps +from typing import Any, Protocol, Type + + +class test_func(Protocol): + def __call__(*args: Any, **kwargs: Any): + ... + + +def skip_on(exception: Type[Exception], reason: str): + """adapted from https://stackoverflow.com/a/63522579""" + import pytest + + # Func below is the real decorator and will receive the test function as param + def decorator_func(f: test_func): + @wraps(f) + def wrapper(*args: Any, **kwargs: Any): + try: + # Try to run the test + return f(*args, **kwargs) + except exception: + # If exception of given type happens + # just swallow it and raise pytest.Skip with given reason + pytest.skip(reason) + + return wrapper + + return decorator_func diff --git a/scripts/setup_dev_env.py b/scripts/setup_dev_env.py index fc107c33..ed4502cc 100644 --- a/scripts/setup_dev_env.py +++ b/scripts/setup_dev_env.py @@ -1,20 +1,20 @@ -# untested draft! -import subprocess -from os import chdir -from pathlib import Path - - -def run(prompt: str): - _ = subprocess.run(prompt, check=True, capture_output=True) - - -repo_dir = Path(__file__).parent.parent.parent -cur_dir = Path().resolve() -chdir(str(repo_dir)) -try: - run("mamba env create --file core-bioimage-io/dev/env.yaml") - run("pip install --no-deps --config-settings editable_mode=compat -e spec-bioimage-io") - run("pip install --no-deps --config-settings editable_mode=compat -e core-bioimage-io") -except Exception: - chdir(cur_dir) - raise +# untested draft! +import subprocess +from os import chdir +from pathlib import Path + + +def run(prompt: str): + _ = subprocess.run(prompt, check=True, capture_output=True) + + +repo_dir = Path(__file__).parent.parent.parent +cur_dir = Path().resolve() +chdir(str(repo_dir)) +try: + run("mamba env create --file core-bioimage-io/dev/env.yaml") + run("pip install --no-deps --config-settings editable_mode=compat -e spec-bioimage-io") + run("pip install --no-deps --config-settings editable_mode=compat -e core-bioimage-io") +except Exception: + chdir(cur_dir) + raise From dda9062a4450372e8cccea3deb7a3e1c18776fcb Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 8 Sep 2023 20:45:15 +0200 Subject: [PATCH 028/244] WIP clean up --- .vscode/settings.json | 4 -- bioimageio/core/__init__.py | 43 +++++++++--------- bioimageio/core/_io.py | 45 +++++++------------ bioimageio/core/build_spec/build_model.py | 10 +---- .../_tensorflow_model_adapter.py | 8 +--- pyproject.toml | 11 +++-- tests/conftest.py | 3 +- 7 files changed, 46 insertions(+), 78 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 771ec3eb..a328cfa1 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,8 +1,4 @@ { - "python.languageServer": "Pylance", - "python.analysis.typeCheckingMode": "strict", - "python.linting.pylintEnabled": true, - "python.linting.enabled": false, "python.testing.unittestArgs": [ "-v", "-s", diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index ead66cb0..36c1b290 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -1,7 +1,13 @@ import json from bioimageio.core._internal.utils import files -from bioimageio.core._io import load_description_and_validate, resolve_source, validate, write_package +from bioimageio.core._io import load_description as load_description +from bioimageio.core._io import load_description_and_validate as load_description_and_validate +from bioimageio.core._io import read_rdf_content as read_rdf_content +from bioimageio.core._io import resolve_source as resolve_source +from bioimageio.core._io import validate as validate +from bioimageio.core._io import write_description as write_description +from bioimageio.core._io import write_package as write_package with files("bioimageio.core").joinpath("VERSION").open("r", encoding="utf-8") as f: __version__: str = json.load(f)["version"] @@ -19,24 +25,17 @@ # ) # from .resource_tests import check_input_shape, check_output_shape, test_resource -__all__ = [ - "__version__", - "load_description_and_validate", - "read_rdf", - "resolve_source", - "validate", - "write_package", - # "check_input_shape", - # "check_output_shape", - # "create_prediction_pipeline", - # "export_resource_package", - # "load_raw_resource_description", - # "load_resource_description", - # "predict_image", - # "predict_images", - # "predict_with_padding", - # "predict_with_tiling", - # "save_raw_resource_description", - # "serialize_raw_resource_description", - # "test_resource", -] +# __all__ = [ +# "check_input_shape", +# "check_output_shape", +# "create_prediction_pipeline", +# "export_resource_package", +# "load_raw_resource_description", +# "load_resource_description", +# "predict_image", +# "predict_images", +# "predict_with_padding", +# "predict_with_tiling", +# "save_raw_resource_description", +# "serialize_raw_resource_description", +# "test_resource", diff --git a/bioimageio/core/_io.py b/bioimageio/core/_io.py index 8e85078b..96724f33 100644 --- a/bioimageio/core/_io.py +++ b/bioimageio/core/_io.py @@ -4,10 +4,14 @@ import os from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Any, Dict, Literal, NamedTuple, Optional, Sequence, Tuple, Union, cast +from typing import Any, Dict, Literal, NamedTuple, Optional, Sequence, TextIO, Tuple, Union, cast from zipfile import ZIP_DEFLATED, ZipFile, is_zipfile import pooch +from pydantic import AnyUrl, DirectoryPath, FilePath, HttpUrl, TypeAdapter, ValidationError +from ruamel.yaml import YAML + +from bioimageio.core._internal.utils import get_parent_url, write_zip from bioimageio.spec import ResourceDescription from bioimageio.spec import load_description as load_description_from_content from bioimageio.spec._internal.base_nodes import ResourceDescriptionBase @@ -17,17 +21,12 @@ from bioimageio.spec.model.v0_4 import WeightsFormat from bioimageio.spec.package import extract_file_name, get_resource_package_content from bioimageio.spec.summary import ValidationSummary -from pydantic import AnyUrl, DirectoryPath, FilePath, HttpUrl, TypeAdapter -from ruamel.yaml import YAML - -from bioimageio.core._internal.utils import get_parent_url, write_zip yaml = YAML(typ="safe") StrictFileSource = Union[HttpUrl, FilePath] FileSource = Union[StrictFileSource, str] -StrictRdfSource = Union[StrictFileSource, RdfContent, ResourceDescription] -RdfSource = Union[StrictRdfSource, str] +RdfSource = Union[FileSource, RdfContent, ResourceDescription, str] class RawRdf(NamedTuple): @@ -62,10 +61,10 @@ def read_rdf_content( known_hash: Optional[str] = None, rdf_encoding: str = "utf-8", ) -> RawRdf: - class FileSourceInterpreter(BaseModel): - source: StrictFileSource - - rdf_source = FileSourceInterpreter(source=rdf_source).source + try: + rdf_source = TypeAdapter(StrictFileSource).validate_python(rdf_source) + except ValidationError as e: + raise e if isinstance(rdf_source, AnyUrl): _ls: Any = pooch.retrieve(url=str(rdf_source), known_hash=known_hash) @@ -129,14 +128,17 @@ def resolve_source( return source -def write_description(rd: Union[ResourceDescription, RdfContent], /, file_path: FilePath): +def write_description(rd: Union[ResourceDescription, RdfContent], /, file: Union[FilePath, TextIO]): if isinstance(rd, ResourceDescriptionBase): content = dump_description(rd) else: content = rd - with file_path.open("w", encoding="utf-8") as f: - yaml.dump(content, f) + if isinstance(file, Path): + with file.open("w", encoding="utf-8") as f: + yaml.dump(content, f) + else: + yaml.dump(content, file) def load_description_and_validate( @@ -153,22 +155,7 @@ def load_description_and_validate( return rd, summary -# def _get_default_io_context(context: Union[ValidationContext, CompleteValidationContext, None]) -> Union[ValidationContext, CompleteValidationContext]: -# if context is None: -# context = ValidationContext() - -# if "warning_level" not in context: -# context["warning_level"] = INFO - -# return context - - def _get_rdf_content_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> RdfContent: - class RdfSourceInterpreter(BaseModel): - source: RdfSource - - rdf_source = RdfSourceInterpreter(source=rdf_source).source - if isinstance(rdf_source, (AnyUrl, Path, str)): rdf = read_rdf_content(rdf_source) rdf_source = rdf.content diff --git a/bioimageio/core/build_spec/build_model.py b/bioimageio/core/build_spec/build_model.py index 59a72994..c1f52229 100644 --- a/bioimageio/core/build_spec/build_model.py +++ b/bioimageio/core/build_spec/build_model.py @@ -2,7 +2,7 @@ import hashlib import os from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, get_args from warnings import warn import imageio @@ -14,14 +14,8 @@ import bioimageio.spec.model as model_spec from bioimageio.core import export_resource_package, load_raw_resource_description from bioimageio.core.resource_io.nodes import URI -from bioimageio.spec.shared.raw_nodes import ImportableModule, ImportableSourceFile from bioimageio.spec.shared import resolve_local_source, resolve_source - -try: - from typing import get_args -except ImportError: - from typing_extensions import get_args # type: ignore - +from bioimageio.spec.shared.raw_nodes import ImportableModule, ImportableSourceFile # # utility functions to build the spec from python diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py index 57f8de41..7b470608 100644 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py @@ -1,6 +1,6 @@ import warnings import zipfile -from typing import List, Optional +from typing import List, Literal, Optional import numpy as np import tensorflow as tf @@ -9,11 +9,6 @@ from ._model_adapter import ModelAdapter -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal # type: ignore - class TensorflowModelAdapterBase(ModelAdapter): weight_format: Literal["keras_hdf5", "tensorflow_saved_model_bundle"] @@ -79,7 +74,6 @@ def _forward_tf(self, *input_tensors): graph = tf.Graph() with graph.as_default(): with tf.Session(graph=graph) as sess: - # load the model and the signature graph_def = tf.saved_model.loader.load(sess, [tag], self._model) signature = graph_def.signature_def diff --git a/pyproject.toml b/pyproject.toml index e54676f7..d98987d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,6 @@ [tool.black] line-length = 120 -target-version = ['py38', 'py39', 'py310', 'py311'] - -[tool.isort] -line_length = 120 -profile = "black" -# sort_reexports = true # buggy +target-version = ["py38", "py39", "py310", "py311"] [tool.pyright] include = ["bioimageio", "scripts", "tests"] @@ -27,3 +22,7 @@ pythonPlatform = "All" [tool.pytest.ini_options] addopts = "--capture=no --doctest-modules --failed-first" # testpaths = ["bioimageio", "scripts", "example", "tests"] + +[tool.ruff] +line-length = 120 +include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"] diff --git a/tests/conftest.py b/tests/conftest.py index db626a10..a4efa28a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,9 +11,8 @@ from pytest import FixtureRequest, fixture os.environ["BIOIMAGEIO_COUNT_RDF_DOWNLOADS"] = "false" # disable tracking before bioimageio imports -from bioimageio.spec import __version__ as bioimageio_spec_version - from bioimageio.core import write_package +from bioimageio.spec import __version__ as bioimageio_spec_version logger = logging.getLogger(__name__) warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}") From 5c6bece8b2e71b5c1e9b0cdc41b8b267580b531c Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 8 Sep 2023 20:45:31 +0200 Subject: [PATCH 029/244] add show_diff.py --- scripts/show_diff.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 scripts/show_diff.py diff --git a/scripts/show_diff.py b/scripts/show_diff.py new file mode 100644 index 00000000..f0fb20d8 --- /dev/null +++ b/scripts/show_diff.py @@ -0,0 +1,24 @@ +import subprocess +from pathlib import Path +from tempfile import TemporaryDirectory + +import pooch + +from bioimageio.core import load_description, write_description + +rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/pydantic_axes/example_specs/models/unet2d_nuclei_broad/rdf_v0_4_9.yaml" + +local_source = Path(pooch.retrieve(rdf_source, None)) # type: ignore +model_as_is, summary_as_is = load_description(rdf_source, format_version="discover") +assert model_as_is is not None, summary_as_is +model_latest, summary_latest = load_description(rdf_source, format_version="latest") +print(summary_latest) +assert model_latest is not None + +with TemporaryDirectory() as tmp: + as_is = Path(tmp) / "as_is.bioimageio.yaml" + write_description(model_as_is, as_is) # write out as is to avoid sorting diff + latest = Path(tmp) / "latest.bioimageio.yaml" + write_description(model_latest, latest) + + _ = subprocess.run(f"git diff --no-index --ignore-all-space {as_is} {latest}") From 2cc2e138deafec5d95c9241fcccfdcde5f82bdf9 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 8 Sep 2023 20:45:46 +0200 Subject: [PATCH 030/244] add dev env --- dev/env.yaml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 dev/env.yaml diff --git a/dev/env.yaml b/dev/env.yaml new file mode 100644 index 00000000..51c69875 --- /dev/null +++ b/dev/env.yaml @@ -0,0 +1,29 @@ +name: bio38 +channels: + - conda-forge + - defaults +dependencies: + - annotated-types + - black + - deepdiff + - email-validator + - imageio[version='>=2.5'] + - lxml + - numpy + - onnxruntime + - packaging[version='>=17.0'] + - pooch + - pre-commit + - pydantic[version='>=2.3.0'] + - pyright + - pytest + - python-dateutil + - python=3.8 + - pytorch + - ruamel.yaml + - ruff + - torchvision + - tqdm + - typer + - typing-extensions + - xarray From 102425c0e94ad0abf13317f792ef5f37228ab56d Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 8 Sep 2023 23:47:47 +0200 Subject: [PATCH 031/244] improve io funcs --- bioimageio/core/_io.py | 312 +++++++++++++++++++++-------------------- pyproject.toml | 2 +- 2 files changed, 163 insertions(+), 151 deletions(-) diff --git a/bioimageio/core/_io.py b/bioimageio/core/_io.py index 96724f33..f92cb102 100644 --- a/bioimageio/core/_io.py +++ b/bioimageio/core/_io.py @@ -13,9 +13,9 @@ from bioimageio.core._internal.utils import get_parent_url, write_zip from bioimageio.spec import ResourceDescription -from bioimageio.spec import load_description as load_description_from_content +from bioimageio.spec import load_description as load_description from bioimageio.spec._internal.base_nodes import ResourceDescriptionBase -from bioimageio.spec._internal.constants import DISCOVER, LATEST +from bioimageio.spec._internal.constants import DISCOVER from bioimageio.spec._internal.types import FileName, RdfContent, RelativeFilePath, ValidationContext, YamlValue from bioimageio.spec.description import dump_description from bioimageio.spec.model.v0_4 import WeightsFormat @@ -26,106 +26,43 @@ StrictFileSource = Union[HttpUrl, FilePath] FileSource = Union[StrictFileSource, str] -RdfSource = Union[FileSource, RdfContent, ResourceDescription, str] +RdfSource = Union[FileSource, ResourceDescription] - -class RawRdf(NamedTuple): - content: RdfContent - root: Union[HttpUrl, DirectoryPath] - file_name: str +LEGACY_RDF_NAME = "rdf.yaml" -def load_description( - rdf_source: RdfSource, +def read_description( + rdf_source: FileSource, /, *, - context: Optional[ValidationContext] = None, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, ) -> Tuple[Optional[ResourceDescription], ValidationSummary]: - context = context or ValidationContext() - rdf_content = _get_rdf_content_and_update_context(rdf_source, context) - return load_description_from_content( - rdf_content, - context=context, + rdf = download_rdf(rdf_source) + return load_description( + rdf.content, + context=ValidationContext(root=rdf.root, file_name=rdf.file_name), format_version=format_version, ) -LEGACY_RDF_NAME = "rdf.yaml" - - -def read_rdf_content( +def read_description_and_validate( rdf_source: FileSource, - /, - *, - known_hash: Optional[str] = None, - rdf_encoding: str = "utf-8", -) -> RawRdf: - try: - rdf_source = TypeAdapter(StrictFileSource).validate_python(rdf_source) - except ValidationError as e: - raise e - - if isinstance(rdf_source, AnyUrl): - _ls: Any = pooch.retrieve(url=str(rdf_source), known_hash=known_hash) - local_source = Path(_ls) - root: Union[HttpUrl, DirectoryPath] = get_parent_url(rdf_source) - else: - local_source = rdf_source - root = rdf_source.parent - - if is_zipfile(local_source): - out_path = local_source.with_suffix(local_source.suffix + ".unzip") - with ZipFile(local_source, "r") as f: - rdfs = [fname for fname in f.namelist() if fname.endswith(".bioimageio.yaml")] - if len(rdfs) > 1: - raise ValueError(f"Multiple RDFs in one package not yet supported (found {rdfs}).") - elif len(rdfs) == 1: - rdf_file_name = rdfs[0] - elif LEGACY_RDF_NAME in f.namelist(): - rdf_file_name = LEGACY_RDF_NAME - else: - raise ValueError( - f"No RDF found in {local_source}. (Looking for any '*.bioimageio.yaml' file or an 'rdf.yaml' file)." - ) - - f.extractall(out_path) - local_source = out_path / rdf_file_name - - with local_source.open(encoding=rdf_encoding) as f: - content: YamlValue = yaml.load(f) - - if not isinstance(content, collections.abc.Mapping): - raise TypeError(f"Expected RDF content to be a mapping, but got '{type(content)}'.") - - return RawRdf( - content=cast(RdfContent, content), - root=root, - file_name=extract_file_name(rdf_source), - ) +) -> Tuple[Optional[ResourceDescription], ValidationSummary]: + rdf = download_rdf(rdf_source) + return load_description_and_validate(rdf.content, context=ValidationContext(root=rdf.root, file_name=rdf.file_name)) -def resolve_source( - source: Union[HttpUrl, FilePath, RelativeFilePath, str], +def load_description_and_validate( + rdf_content: RdfContent, /, *, - known_hash: Optional[str] = None, - root: Union[DirectoryPath, AnyUrl, None] = None, -) -> FilePath: - if isinstance(source, str): - source = TypeAdapter(Union[HttpUrl, FilePath, RelativeFilePath]).validate_python(source) - - if isinstance(source, RelativeFilePath): - if root is None: - raise ValueError(f"Cannot resolve relative file path '{source}' without root.") - - source = source.get_absolute(root) - - if isinstance(source, AnyUrl): - _s: Any = pooch.retrieve(str(source), known_hash=known_hash) - source = Path(_s) - - return source + context: Optional[ValidationContext] = None, + format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, +) -> Tuple[Optional[ResourceDescription], ValidationSummary]: + """load and validate a BioImage.IO description from the content of a resource description file (RDF)""" + rd, summary = load_description(rdf_content, context=context, format_version=format_version) + # todo: add dynamic validation + return rd, summary def write_description(rd: Union[ResourceDescription, RdfContent], /, file: Union[FilePath, TextIO]): @@ -141,69 +78,10 @@ def write_description(rd: Union[ResourceDescription, RdfContent], /, file: Union yaml.dump(content, file) -def load_description_and_validate( - rdf_source: RdfSource, - /, - *, - context: Optional[ValidationContext] = None, -) -> Tuple[Optional[ResourceDescription], ValidationSummary]: - """load and validate a BioImage.IO description from the content of a resource description file (RDF)""" - context = context or ValidationContext() - rdf_content = _get_rdf_content_and_update_context(rdf_source, context) - rd, summary = load_description_from_content(rdf_content, context=context, format_version=LATEST) - # todo: add dynamic validation - return rd, summary - - -def _get_rdf_content_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> RdfContent: - if isinstance(rdf_source, (AnyUrl, Path, str)): - rdf = read_rdf_content(rdf_source) - rdf_source = rdf.content - context.root = rdf.root - context.file_name = rdf.file_name - elif isinstance(rdf_source, ResourceDescriptionBase): - rdf_source = dump_description(rdf_source, exclude_unset=False) - - return rdf_source - - -def _get_description_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> ResourceDescription: - if not isinstance(rdf_source, ResourceDescriptionBase): - descr, summary = load_description(rdf_source, context=context) - if descr is None: - rdf_source_msg = ( - f"{{name={rdf_source.get('name', 'missing'), ...}}})" - if isinstance(rdf_source, collections.abc.Mapping) - else rdf_source - ) - raise ValueError(f"Failed to load {rdf_source_msg}:\n{summary.format()}") - rdf_source = descr - - return rdf_source - - -def validate( - rdf_source: RdfSource, - /, - *, - context: Optional[ValidationContext] = None, -) -> ValidationSummary: - _rd, summary = load_description_and_validate(rdf_source, context=context) - return summary - - -def validate_format_only( - rdf_source: Union[ResourceDescription, RdfContent, FileSource], context: Optional[ValidationContext] = None -) -> ValidationSummary: - _rd, summary = load_description(rdf_source, context=context) - return summary - - def prepare_resource_package( rdf_source: RdfSource, /, *, - context: Optional[ValidationContext] = None, weights_priority_order: Optional[Sequence[WeightsFormat]] = None, ) -> Dict[FileName, Union[FilePath, RdfContent]]: """Prepare to package a resource description; downloads all required files. @@ -214,8 +92,17 @@ def prepare_resource_package( weights_priority_order: If given only the first weights format present in the model is included. If none of the prioritized weights formats is found all are included. """ - context = context or ValidationContext() - rd = _get_description_and_update_context(rdf_source, context) + if isinstance(rdf_source, ResourceDescriptionBase): + rd = rdf_source + _ctxt = rd._internal_validation_context # pyright: ignore[reportPrivateUsage] + context = ValidationContext(root=_ctxt["root"], file_name=_ctxt["file_name"]) + else: + rdf = download_rdf(rdf_source) + context = ValidationContext(root=rdf.root, file_name=rdf.file_name) + rd = load_description( + rdf.content, + context=context, + ) package_content = get_resource_package_content(rd, weights_priority_order=weights_priority_order) local_package_content: Dict[FileName, Union[FilePath, RdfContent]] = {} @@ -227,14 +114,11 @@ def prepare_resource_package( return local_package_content - # output_folder.mkdir(parents=True, exist_ok=True) - def write_package( rdf_source: RdfSource, /, *, - context: Optional[ValidationContext] = None, compression: int = ZIP_DEFLATED, compression_level: int = 1, output_path: Optional[os.PathLike[str]] = None, @@ -278,3 +162,131 @@ def write_package( write_zip(output_path, package_content, compression=compression, compression_level=compression_level) return output_path + + +class _LocalFile(NamedTuple): + path: FilePath + original_root: Union[AnyUrl, DirectoryPath] + original_file_name: str + + +class _LocalRdf(NamedTuple): + content: RdfContent + root: Union[AnyUrl, DirectoryPath] + file_name: str + + +def download( + source: FileSource, + /, + *, + known_hash: Optional[str] = None, +) -> _LocalFile: + source = _interprete_file_source(source) + if isinstance(source, AnyUrl): + _ls: Any = pooch.retrieve(url=str(source), known_hash=known_hash) + local_source = Path(_ls) + root: Union[HttpUrl, DirectoryPath] = get_parent_url(source) + else: + local_source = source + root = source.parent + + return _LocalFile( + local_source, + root, + extract_file_name(source), + ) + + +def download_rdf(source: FileSource, /, *, known_hash: Optional[str] = None, rdf_encoding: str = "utf-8"): + local_source, root, file_name = download(source, known_hash=known_hash) + if is_zipfile(local_source): + out_path = local_source.with_suffix(local_source.suffix + ".unzip") + with ZipFile(local_source, "r") as f: + rdfs = [fname for fname in f.namelist() if fname.endswith(".bioimageio.yaml")] + if len(rdfs) > 1: + raise ValueError(f"Multiple RDFs in one package not yet supported (found {rdfs}).") + elif len(rdfs) == 1: + rdf_file_name = rdfs[0] + elif LEGACY_RDF_NAME in f.namelist(): + rdf_file_name = LEGACY_RDF_NAME + else: + raise ValueError( + f"No RDF found in {local_source}. (Looking for any '*.bioimageio.yaml' file or an 'rdf.yaml' file)." + ) + + f.extractall(out_path) + local_source = out_path / rdf_file_name + + with local_source.open(encoding=rdf_encoding) as f: + content: YamlValue = yaml.load(f) + + if not isinstance(content, collections.abc.Mapping): + raise TypeError(f"Expected RDF content to be a mapping, but got '{type(content)}'.") + + return _LocalRdf(cast(RdfContent, content), root, file_name) + + +def resolve_source( + source: Union[FileSource, RelativeFilePath], + /, + *, + known_hash: Optional[str] = None, + root: Union[DirectoryPath, AnyUrl, None] = None, +) -> FilePath: + if isinstance(source, RelativeFilePath): + if root is None: + raise ValueError(f"Cannot resolve relative file path '{source}' without root.") + + source = source.get_absolute(root) + + return download(source, known_hash=known_hash).path + + +# def _get_rdf_content(rdf_source: RdfSource) -> Tuple[RdfContent, ValidationContext]: +# if isinstance(rdf_source, (AnyUrl, Path, str)): +# rdf = read_rdf_content(rdf_source) +# rdf_content = rdf.content +# context = ValidationContext(root=rdf.root, file_name=rdf.file_name) +# elif isinstance(rdf_source, ResourceDescriptionBase): +# rdf_content = dump_description(rdf_source, exclude_unset=False) +# ctxt = rdf_source._internal_validation_context # pyright: ignore[reportPrivateUsage] +# context = ValidationContext(root=ctxt["root"], file_name=ctxt["file_name"]) +# else: +# rdf_content = rdf_source +# context = ValidationContext() + +# return rdf_content, context + + +# def _get_rdf_content_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> RdfContent: +# if isinstance(rdf_source, (AnyUrl, Path, str)): +# rdf = read_rdf_content(rdf_source) +# rdf_source = rdf.content +# context.root = rdf.root +# context.file_name = rdf.file_name +# elif isinstance(rdf_source, ResourceDescriptionBase): +# rdf_source = dump_description(rdf_source, exclude_unset=False) + +# return rdf_source + + +def _get_description_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> ResourceDescription: + if isinstance(rdf_source, dict): + descr, summary = load_description(rdf_source, context=context) + if descr is None: + rdf_source_msg = ( + f"{{name={rdf_source.get('name', 'missing'), ...}}})" + if isinstance(rdf_source, collections.abc.Mapping) + else rdf_source + ) + raise ValueError(f"Failed to load {rdf_source_msg}:\n{summary.format()}") + + return descr + + +def _interprete_file_source(file_source: FileSource) -> StrictFileSource: + return TypeAdapter(StrictFileSource).validate_python(file_source) + # todo: prettier file source validation error + # try: + # except ValidationError as e: diff --git a/pyproject.toml b/pyproject.toml index d98987d6..4857d8c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,8 +21,8 @@ pythonPlatform = "All" [tool.pytest.ini_options] addopts = "--capture=no --doctest-modules --failed-first" -# testpaths = ["bioimageio", "scripts", "example", "tests"] [tool.ruff] line-length = 120 include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"] +target-version = "py38" From f4d84650d32aa0fe83e30d5c803381632e937b3b Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 19 Sep 2023 11:25:24 +0200 Subject: [PATCH 032/244] add validate --- bioimageio/core/__init__.py | 4 +- bioimageio/core/_io.py | 92 ++++++++++++++++--------------------- 2 files changed, 41 insertions(+), 55 deletions(-) diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 36c1b290..408636e8 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -3,9 +3,9 @@ from bioimageio.core._internal.utils import files from bioimageio.core._io import load_description as load_description from bioimageio.core._io import load_description_and_validate as load_description_and_validate -from bioimageio.core._io import read_rdf_content as read_rdf_content +from bioimageio.core._io import read_description as read_description +from bioimageio.core._io import read_description_and_validate as read_description_and_validate from bioimageio.core._io import resolve_source as resolve_source -from bioimageio.core._io import validate as validate from bioimageio.core._io import write_description as write_description from bioimageio.core._io import write_package as write_package diff --git a/bioimageio/core/_io.py b/bioimageio/core/_io.py index f92cb102..95099a6c 100644 --- a/bioimageio/core/_io.py +++ b/bioimageio/core/_io.py @@ -4,11 +4,11 @@ import os from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Any, Dict, Literal, NamedTuple, Optional, Sequence, TextIO, Tuple, Union, cast +from typing import Any, Dict, List, Literal, NamedTuple, Optional, Sequence, TextIO, Union, cast from zipfile import ZIP_DEFLATED, ZipFile, is_zipfile import pooch -from pydantic import AnyUrl, DirectoryPath, FilePath, HttpUrl, TypeAdapter, ValidationError +from pydantic import AnyUrl, DirectoryPath, FilePath, HttpUrl, TypeAdapter from ruamel.yaml import YAML from bioimageio.core._internal.utils import get_parent_url, write_zip @@ -17,7 +17,7 @@ from bioimageio.spec._internal.base_nodes import ResourceDescriptionBase from bioimageio.spec._internal.constants import DISCOVER from bioimageio.spec._internal.types import FileName, RdfContent, RelativeFilePath, ValidationContext, YamlValue -from bioimageio.spec.description import dump_description +from bioimageio.spec.description import InvalidDescription, dump_description from bioimageio.spec.model.v0_4 import WeightsFormat from bioimageio.spec.package import extract_file_name, get_resource_package_content from bioimageio.spec.summary import ValidationSummary @@ -36,7 +36,7 @@ def read_description( /, *, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, -) -> Tuple[Optional[ResourceDescription], ValidationSummary]: +) -> Union[ResourceDescription, InvalidDescription]: rdf = download_rdf(rdf_source) return load_description( rdf.content, @@ -47,9 +47,14 @@ def read_description( def read_description_and_validate( rdf_source: FileSource, -) -> Tuple[Optional[ResourceDescription], ValidationSummary]: + /, + *, + format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, +) -> Union[ResourceDescription, InvalidDescription]: rdf = download_rdf(rdf_source) - return load_description_and_validate(rdf.content, context=ValidationContext(root=rdf.root, file_name=rdf.file_name)) + return load_description_and_validate( + rdf.content, context=ValidationContext(root=rdf.root, file_name=rdf.file_name), format_version=format_version + ) def load_description_and_validate( @@ -58,11 +63,26 @@ def load_description_and_validate( *, context: Optional[ValidationContext] = None, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, -) -> Tuple[Optional[ResourceDescription], ValidationSummary]: +) -> Union[ResourceDescription, InvalidDescription]: """load and validate a BioImage.IO description from the content of a resource description file (RDF)""" - rd, summary = load_description(rdf_content, context=context, format_version=format_version) + rd = load_description(rdf_content, context=context, format_version=format_version) # todo: add dynamic validation - return rd, summary + return rd + + +def validate( + rdf_source: Union[FileSource, RdfContent], + /, + *, + context: Optional[ValidationContext] = None, + format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, +) -> List[ValidationSummary]: + if isinstance(rdf_source, dict): + rd = load_description_and_validate(rdf_source, context=context, format_version=format_version) + else: + rd = read_description_and_validate(rdf_source, format_version=format_version) + + return rd.validation_summaries def write_description(rd: Union[ResourceDescription, RdfContent], /, file: Union[FilePath, TextIO]): @@ -103,6 +123,10 @@ def prepare_resource_package( rdf.content, context=context, ) + + if isinstance(rd, InvalidDescription): + raise ValueError(f"{rdf_source} is invalid: {rd.validation_summaries[0]}") + package_content = get_resource_package_content(rd, weights_priority_order=weights_priority_order) local_package_content: Dict[FileName, Union[FilePath, RdfContent]] = {} @@ -152,7 +176,6 @@ def write_package( """ package_content = prepare_resource_package( rdf_source, - context=context, weights_priority_order=weights_priority_order, ) if output_path is None: @@ -184,7 +207,12 @@ def download( ) -> _LocalFile: source = _interprete_file_source(source) if isinstance(source, AnyUrl): - _ls: Any = pooch.retrieve(url=str(source), known_hash=known_hash) + if source.scheme in ("http", "https") and os.environ.get("CI", "false").lower() in ("1", "true"): + downloader = pooch.HTTPDownloader(headers={"User-Agent": "ci"}) + else: + downloader = None + + _ls: Any = pooch.retrieve(url=str(source), known_hash=known_hash, downloader=downloader) local_source = Path(_ls) root: Union[HttpUrl, DirectoryPath] = get_parent_url(source) else: @@ -243,48 +271,6 @@ def resolve_source( return download(source, known_hash=known_hash).path -# def _get_rdf_content(rdf_source: RdfSource) -> Tuple[RdfContent, ValidationContext]: -# if isinstance(rdf_source, (AnyUrl, Path, str)): -# rdf = read_rdf_content(rdf_source) -# rdf_content = rdf.content -# context = ValidationContext(root=rdf.root, file_name=rdf.file_name) -# elif isinstance(rdf_source, ResourceDescriptionBase): -# rdf_content = dump_description(rdf_source, exclude_unset=False) -# ctxt = rdf_source._internal_validation_context # pyright: ignore[reportPrivateUsage] -# context = ValidationContext(root=ctxt["root"], file_name=ctxt["file_name"]) -# else: -# rdf_content = rdf_source -# context = ValidationContext() - -# return rdf_content, context - - -# def _get_rdf_content_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> RdfContent: -# if isinstance(rdf_source, (AnyUrl, Path, str)): -# rdf = read_rdf_content(rdf_source) -# rdf_source = rdf.content -# context.root = rdf.root -# context.file_name = rdf.file_name -# elif isinstance(rdf_source, ResourceDescriptionBase): -# rdf_source = dump_description(rdf_source, exclude_unset=False) - -# return rdf_source - - -def _get_description_and_update_context(rdf_source: RdfSource, context: ValidationContext) -> ResourceDescription: - if isinstance(rdf_source, dict): - descr, summary = load_description(rdf_source, context=context) - if descr is None: - rdf_source_msg = ( - f"{{name={rdf_source.get('name', 'missing'), ...}}})" - if isinstance(rdf_source, collections.abc.Mapping) - else rdf_source - ) - raise ValueError(f"Failed to load {rdf_source_msg}:\n{summary.format()}") - - return descr - - def _interprete_file_source(file_source: FileSource) -> StrictFileSource: return TypeAdapter(StrictFileSource).validate_python(file_source) # todo: prettier file source validation error From 9c10e6d4e862b168a60dea02c453b09a34668978 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 26 Sep 2023 14:05:46 +0200 Subject: [PATCH 033/244] WIP update Processings --- .../_combined_processing.py | 22 +- .../core/prediction_pipeline/_processing.py | 429 ++++++++++-------- bioimageio/core/statistical_measures.py | 1 + tests/prediction_pipeline/test_processing.py | 4 +- 4 files changed, 265 insertions(+), 191 deletions(-) diff --git a/bioimageio/core/prediction_pipeline/_combined_processing.py b/bioimageio/core/prediction_pipeline/_combined_processing.py index 71c0693d..1c9d122e 100644 --- a/bioimageio/core/prediction_pipeline/_combined_processing.py +++ b/bioimageio/core/prediction_pipeline/_combined_processing.py @@ -3,19 +3,20 @@ from bioimageio.core.resource_io import nodes -from ._processing import KNOWN_PROCESSING, AssertDtype, EnsureDtype, Processing, TensorName +from ._processing import AssertDtype, EnsureDtype, Processing from ._utils import PER_DATASET, PER_SAMPLE, ComputedMeasures, RequiredMeasures, Sample +import ._processing as proc_impl +from bioimageio.spec.model.v0_5 import TensorId - -@dataclasses.dataclass -class ProcessingInfoStep: - name: str - kwargs: Dict[str, Any] +# @dataclasses.dataclass +# class ProcessingInfoStep: +# name: str +# kwargs: Dict[str, Any] @dataclasses.dataclass class ProcessingInfo: - steps: List[ProcessingInfoStep] + steps: List[Processing] # assert_dtype_before: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match ensure_dtype_before: Optional[str] = None # cast data type if needed # assert_dtype_after: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match @@ -23,10 +24,8 @@ class ProcessingInfo: class CombinedProcessing: - def __init__(self, combine_tensors: Dict[TensorName, ProcessingInfo]): + def __init__(self, combine_tensors: Dict[TensorId, ProcessingInfo]): self._procs = [] - known = dict(KNOWN_PROCESSING["pre"]) - known.update(KNOWN_PROCESSING["post"]) # ensure all tensors have correct data type before any processing for tensor_name, info in combine_tensors.items(): @@ -38,7 +37,8 @@ def __init__(self, combine_tensors: Dict[TensorName, ProcessingInfo]): for tensor_name, info in combine_tensors.items(): for step in info.steps: - self._procs.append(known[step.name](tensor_name=tensor_name, **step.kwargs)) + + self._procs.append((tensor_name=tensor_name, **step.kwargs)) if info.assert_dtype_after is not None: self._procs.append(AssertDtype(tensor_name=tensor_name, dtype=info.assert_dtype_after)) diff --git a/bioimageio/core/prediction_pipeline/_processing.py b/bioimageio/core/prediction_pipeline/_processing.py index ecf40efc..f72dda2e 100644 --- a/bioimageio/core/prediction_pipeline/_processing.py +++ b/bioimageio/core/prediction_pipeline/_processing.py @@ -1,27 +1,56 @@ -"""Here pre- and postprocessing operations are implemented according to their definitions in bioimageio.spec: -see https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/preprocessing_spec_latest.md -and https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/postprocessing_spec_latest.md -""" from abc import ABC, abstractmethod -import numbers from dataclasses import InitVar, dataclass, field, fields -from typing import Dict, Generic, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union -from typing_extensions import Self +from types import MappingProxyType +from typing import ( + Any, + ClassVar, + Dict, + FrozenSet, + Generic, + Literal, + Mapping, + NamedTuple, + NotRequired, + Optional, + Sequence, + Set, + Tuple, + Type, + TypedDict, + TypeVar, + Union, + get_args, +) import numpy import numpy as np -from pydantic import model_validator # type: ignore -from pydantic import field_validator import xarray as xr -from bioimageio.spec._internal.base_nodes import Node +from numpy.typing import DTypeLike +from typing_extensions import LiteralString, Self + +from bioimageio.core.statistical_measures import Mean, Measure, MeasureValue, Percentile, Std +from bioimageio.spec._internal.base_nodes import Node, NodeWithExplicitlySetFields from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import Processing as ProcessingSpec, ProcessingKwargs, Binarize, Clip from bioimageio.spec.model.v0_5 import TensorId -from numpy.typing import DTypeLike -from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std, MeasureValue -from ._utils import FIXED, PER_DATASET, PER_SAMPLE, DatasetMode, Mode, RequiredMeasure, SampleMode, Sample -from typing import Literal, TypedDict, get_args +from ._utils import FIXED, PER_DATASET, PER_SAMPLE, DatasetMode, Mode, RequiredMeasure, Sample, SampleMode + +Binarize = Union[v0_4.Binarize, v0_5.Binarize] +BinarizeKwargs = Union[v0_4.BinarizeKwargs, v0_5.BinarizeKwargs] +Clip = Union[v0_4.Clip, v0_5.Clip] +ClipKwargs = Union[v0_4.ClipKwargs, v0_5.ClipKwargs] +EnsureDtypeKwargs = v0_5.EnsureDtypeKwargs +Processing = Union[v0_4.Processing, v0_5.Processing] +ProcessingKwargs = Union[v0_4.ProcessingKwargs, v0_5.ProcessingKwargs] +ScaleLinear = Union[v0_4.ScaleLinear, v0_5.ScaleLinear] +ScaleLinearKwargs = Union[v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs] +ScaleMeanVariance = Union[v0_4.ScaleMeanVariance, v0_5.ScaleMeanVariance] +ScaleMeanVarianceKwargs = Union[v0_4.ScaleMeanVarianceKwargs, v0_5.ScaleMeanVarianceKwargs] +ScaleRange = Union[v0_4.ScaleRange, v0_5.ScaleRange] +ScaleRangeKwargs = Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs] +ZeroMeanUnitVariance = Union[v0_4.ZeroMeanUnitVariance, v0_5.ZeroMeanUnitVariance] +ZeroMeanUnitVarianceKwargs = Union[v0_4.ZeroMeanUnitVarianceKwargs, v0_5.ZeroMeanUnitVarianceKwargs] + def _get_fixed( fixed: Union[float, Sequence[float]], tensor: xr.DataArray, axes: Optional[Sequence[str]] @@ -35,33 +64,57 @@ def _get_fixed( return xr.DataArray(fixed, dims=fixed_dims) - PKwargs = TypeVar("PKwargs", bound=ProcessingKwargs) ProcInput = TypeVar("ProcInput", xr.DataArray, Sample) -class ProcessingBase(Node, Generic[PKwargs], ABC, frozen=True): - """base class for all Pre- and Postprocessing transformations.""" + +RCV = TypeVar("RCV", RequiredMeasure, MeasureValue) + + +@dataclass +class _NamedMeasures(Generic[RCV]): + def get_set(self) -> Set[RCV]: + return {getattr(self, f.name) for f in fields(self)} + + +_NoRequiredMeasures = _NamedMeasures[RequiredMeasure] +_NoMeasureValues = _NamedMeasures[MeasureValue] + +R = TypeVar("R", bound=_NamedMeasures[RequiredMeasure]) +C = TypeVar("C", bound=_NamedMeasures[MeasureValue]) + + +@dataclass +class ProcessingImplBase(Generic[PKwargs, R, C], ABC): + """Base class for all Pre- and Postprocessing implementations.""" tensor_id: TensorId """id of tensor to operate on""" kwargs: PKwargs - computed_measures: Dict[RequiredMeasure, MeasureValue] = field(default_factory=dict) - - @model_validator(mode="after") - def check_required_measures_in_computed(self) -> Self: - for req in self.required_measures: - if req not in self.computed_measures: - raise ValueError(f"Missing computed {req}.") - - return self + computed_measures: InitVar[Mapping[RequiredMeasure, MeasureValue]] = field( + default=MappingProxyType[RequiredMeasure, MeasureValue]({}) + ) + required: R = field(init=False) + computed: C = field(init=False) + + def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None: + self.required = self.get_required_measures(self.tensor_id, self.kwargs) + selected = {} + for f in fields(self.required): + req = getattr(self.required, f.name) + if req in computed_measures: + selected[f.name] = computed_measures[req] + else: + raise ValueError(f"Missing computed measure: {req} (as '{f.name}').") @classmethod - def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> Tuple[RequiredMeasure, ...]: - return () + @abstractmethod + def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> R: + ... @property - def required_measures(self) -> Tuple[RequiredMeasure, ...]: - return self.get_required_measures(tensor_id=self.tensor_id, kwargs=self.kwargs) + def required_measures(self) -> Set[RequiredMeasure]: + return self.required.get_set() def __call__(self, __input: ProcInput, /) -> ProcInput: if isinstance(__input, xr.DataArray): @@ -79,213 +132,233 @@ def apply_to_sample(self, sample: Sample) -> Sample: ret[self.tensor_id] = self.apply(sample[self.tensor_id]) return ret -class Processing(ProcessingSpec, ProcessingBase[PKwargs], frozen=True): - pass -# -# Pre- and Postprocessing implementations -# -class NonSpecProcessing(ProcessingBase[PKwargs], frozen=True): - """processings operations beyond what is currently defined in bioimageio.spec""" - pass + def get_spec(self) -> v0_5.Processing: + raise NotImplementedError -class AssertDtype(NonSpecProcessing[ProcessingKwargs], frozen=True): - """Helper Processing to assert dtype.""" - id: Literal["assert_dtype"] = "assert_dtype" +class ProcessingImplBaseWoMeasures(ProcessingImplBase[PKwargs, _NoRequiredMeasures, _NoMeasureValues]): + @classmethod + def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> _NoRequiredMeasures: + return _NamedMeasures() + + +class AssertProcessing(NodeWithExplicitlySetFields, ABC, frozen=True): + id: str + kwargs: ProcessingKwargs + fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"}) + +class AssertDtypeKwargs(ProcessingKwargs, frozen=True): dtype: Union[str, Sequence[str]] + + +class AssertDtype(AssertProcessing, frozen=True): + id: Literal["assert_dtype"] = "assert_dtype" + kwargs: AssertDtypeKwargs + + +class AssertDtypeImpl(ProcessingImplBaseWoMeasures[AssertDtypeKwargs]): _assert_with: Tuple[Type[DTypeLike], ...] - def __pydantic_postinit__(self): - if isinstance(self.dtype, str): - dtype = [self.dtype] + def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None: + super().__post_init__(computed_measures) + if isinstance(self.kwargs.dtype, str): + dtype = [self.kwargs.dtype] else: - dtype = self.dtype + dtype = self.kwargs.dtype - object.__setattr__(self, "_assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype)) + assert_w = tuple(type(numpy.dtype(dt)) for dt in dtype) + self._assert_with = assert_w def apply(self, tensor: xr.DataArray) -> xr.DataArray: assert isinstance(tensor.dtype, self._assert_with) return tensor -class Binarize(Processing[BinarizeKwargs]): +class BinarizeImpl(ProcessingImplBaseWoMeasures[BinarizeKwargs]): """'output = tensor > threshold'.""" - threshold: float = MISSING # make dataclass inheritance work for py<3.10 by using an explicit MISSING value. + kwargs: BinarizeKwargs def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor > self.threshold + return tensor > self.kwargs.threshold + def get_spec(self): + return v0_5.Binarize(kwargs=self.kwargs) -@dataclass -class Clip(Processing): - """Limit tensor values to [min, max].""" - - min: float = MISSING - max: float = MISSING +class ClipImpl(ProcessingImplBaseWoMeasures[ClipKwargs]): def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.clip(min=self.min, max=self.max) - + return tensor.clip(min=self.kwargs.min, max=self.kwargs.max) -@dataclass -class EnsureDtype(Processing): - """Helper Processing to cast dtype if needed.""" + def get_spec(self): + return v0_5.Clip(kwargs=self.kwargs) - dtype: str = MISSING +class EnsureDtypeImpl(ProcessingImplBaseWoMeasures[EnsureDtypeKwargs]): def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.astype(self.dtype) - - -@dataclass -class ScaleLinear(Processing): - """Scale the tensor with a fixed multiplicative and additive factor.""" + return tensor.astype(self.kwargs.dtype) - gain: Union[float, Sequence[float]] = MISSING - offset: Union[float, Sequence[float]] = MISSING - axes: Optional[Sequence[str]] = None +class ScaleLinearImpl(ProcessingImplBaseWoMeasures[ScaleLinearKwargs]): def apply(self, tensor: xr.DataArray) -> xr.DataArray: - scale_axes = tuple(ax for ax in tensor.dims if (ax not in self.axes and ax != "b")) - if scale_axes: - gain = xr.DataArray(np.atleast_1d(self.gain), dims=scale_axes) - offset = xr.DataArray(np.atleast_1d(self.offset), dims=scale_axes) + joint_axes = self.kwargs.axes or () + batch_axis_names = ("b", "batch") + scale_along = tuple( + ax for ax in tensor.dims if isinstance(ax, str) and ax not in joint_axes and ax not in batch_axis_names + ) + if scale_along: + gain = xr.DataArray(np.atleast_1d(self.kwargs.gain), dims=scale_along) + offset = xr.DataArray(np.atleast_1d(self.kwargs.offset), dims=scale_along) else: - gain = self.gain - offset = self.offset + assert isinstance(self.kwargs.gain, float) or len(self.kwargs.gain) == 1 + gain = self.kwargs.gain if isinstance(self.kwargs.gain, float) else self.kwargs.gain[0] + assert isinstance(self.kwargs.offset, float) or len(self.kwargs.offset) == 1 + offset = self.kwargs.offset if isinstance(self.kwargs.offset, float) else self.kwargs.offset[0] return tensor * gain + offset - def __post_init__(self): - super().__post_init__() - if self.axes is None: - assert isinstance(self.gain, (int, float)) - assert isinstance(self.offset, (int, float)) + def get_spec(self): + if isinstance(self.kwargs, v0_4.ScaleLinearKwargs): + raise NotImplementedError + + return v0_5.ScaleLinear(kwargs=self.kwargs) + + +@dataclass +class _MeanStd(_NamedMeasures[RCV]): + mean: RCV + std: RCV @dataclass -class ScaleMeanVariance(Processing): - """Scale the tensor s.t. its mean and variance match a reference tensor.""" - - mode: Literal[SampleMode, DatasetMode] = PER_SAMPLE - reference_tensor: TensorName = MISSING - axes: Optional[Sequence[str]] = None - eps: float = 1e-6 - - def get_required_measures(self) -> RequiredMeasures: - axes = None if self.axes is None else tuple(self.axes) - return { - self.mode: { - self.tensor_id: {Mean(axes=axes), Std(axes=axes)}, - self.reference_tensor: {Mean(axes=axes), Std(axes=axes)}, - } - } +class _MeanStdAndRef(_MeanStd[RCV]): + ref_mean: RCV + ref_std: RCV + + +class ScaleMeanVarianceImpl( + ProcessingImplBase[ScaleMeanVarianceKwargs, _MeanStdAndRef[RequiredMeasure], _MeanStdAndRef[MeasureValue]] +): + @classmethod + def get_required_measures(cls, tensor_id: TensorId, kwargs: ScaleMeanVarianceKwargs): + axes = tuple(kwargs.axes) if isinstance(kwargs.axes, str) else kwargs.axes + return _MeanStdAndRef( + mean=RequiredMeasure(Mean(axes), tensor_id, mode=kwargs.mode), + std=RequiredMeasure(Std(axes), tensor_id, mode=kwargs.mode), + ref_mean=RequiredMeasure(Mean(axes), kwargs.reference_tensor, mode=kwargs.mode), + ref_std=RequiredMeasure(Std(axes), kwargs.reference_tensor, mode=kwargs.mode), + ) def apply(self, tensor: xr.DataArray) -> xr.DataArray: - axes = None if self.axes is None else tuple(self.axes) - assert self.mode in (PER_SAMPLE, PER_DATASET) - mean = self.get_computed_measure(self.tensor_id, Mean(axes), mode=self.mode) - std = self.get_computed_measure(self.tensor_id, Std(axes), mode=self.mode) - ref_mean = self.get_computed_measure(self.reference_tensor, Mean(axes), mode=self.mode) - ref_std = self.get_computed_measure(self.reference_tensor, Std(axes), mode=self.mode) + c = self.computed + eps = self.kwargs.eps + return (tensor - c.mean) / (c.std + eps) * (c.ref_std + eps) + c.ref_mean - return (tensor - mean) / (std + self.eps) * (ref_std + self.eps) + ref_mean + def get_spec(self): + if isinstance(self.kwargs, v0_4.ScaleMeanVarianceKwargs): + raise NotImplementedError + + return v0_5.ScaleMeanVariance(kwargs=self.kwargs) @dataclass -class ScaleRange(Processing): - """Scale with percentiles.""" +class _MinMaxPerc(_NamedMeasures[RCV]): + lower: RCV + upper: RCV - mode: Literal[SampleMode, DatasetMode] = PER_SAMPLE - axes: Optional[Sequence[str]] = None - min_percentile: float = 0.0 - max_percentile: float = 100.0 - eps: float = 1e-6 - reference_tensor: Optional[TensorName] = None - def get_required_measures(self) -> RequiredMeasures: - axes = None if self.axes is None else tuple(self.axes) - measures = {Percentile(self.min_percentile, axes=axes), Percentile(self.max_percentile, axes=axes)} - return {self.mode: {self.reference_tensor or self.tensor_id: measures}} +class ScaleRangeImpl(ProcessingImplBase[ScaleRangeKwargs, _MinMaxPerc[RequiredMeasure], _MinMaxPerc[MeasureValue]]): + # def get_required_measures(self): + @classmethod + def get_required_measures(cls, tensor_id: TensorId, kwargs: ScaleRangeKwargs) -> _MinMaxPerc[RequiredMeasure]: + ref_name = kwargs.reference_tensor or tensor_id + axes = None if kwargs.axes is None else tuple(kwargs.axes) + return _MinMaxPerc( + lower=RequiredMeasure(Percentile(kwargs.min_percentile, axes=axes), ref_name, kwargs.mode), + upper=RequiredMeasure(Percentile(kwargs.max_percentile, axes=axes), ref_name, kwargs.mode), + ) def apply(self, tensor: xr.DataArray) -> xr.DataArray: - ref_name = self.reference_tensor or self.tensor_id - axes = None if self.axes is None else tuple(self.axes) - v_lower = self.get_computed_measure(ref_name, Percentile(self.min_percentile, axes=axes)) - v_upper = self.get_computed_measure(ref_name, Percentile(self.max_percentile, axes=axes)) + c = self.computed + return (tensor - c.lower) / (c.upper - c.lower + self.kwargs.eps) - return (tensor - v_lower) / (v_upper - v_lower + self.eps) + def get_spec(self): + if isinstance(self.kwargs, v0_4.ScaleRangeKwargs): + raise NotImplementedError - def __post_init__(self): - super().__post_init__() - self.axes = None if self.axes is None else tuple(self.axes) # make sure axes is Tuple[str] or None + return v0_5.ScaleRange(kwargs=self.kwargs) @dataclass -class Sigmoid(Processing): +class SigmoidImpl(ProcessingImplBaseWoMeasures[ProcessingKwargs]): """1 / (1 + e^(-tensor)).""" def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return 1.0 / (1.0 + np.exp(-tensor)) + return 1.0 / (1.0 + np.exp(-tensor)) # type: ignore @dataclass -class ZeroMeanUnitVariance(Processing): +class ZeroMeanUnitVarianceImpl( + ProcessingImplBase[ + ZeroMeanUnitVarianceKwargs, + Union[_NoRequiredMeasures, _MeanStd[RequiredMeasure]], + Union[_NoMeasureValues, _MeanStd[MeasureValue]], + ] +): """normalize to zero mean, unit variance.""" - mode: Mode = PER_SAMPLE - mean: Optional[Union[float, Sequence[float]]] = None - std: Optional[Union[float, Sequence[float]]] = None - axes: Optional[Sequence[str]] = None - eps: float = 1.0e-6 - - def get_required_measures(self) -> RequiredMeasures: - if self.mode == FIXED: - return {} + @classmethod + def get_required_measures( + cls, tensor_id: TensorId, kwargs: ZeroMeanUnitVarianceKwargs + ) -> Union[_NoRequiredMeasures, _MeanStd[RequiredMeasure]]: + if kwargs.mode == FIXED: + return _NamedMeasures() else: - axes = None if self.axes is None else tuple(self.axes) - return {self.mode: {self.tensor_id: {Mean(axes=axes), Std(axes=axes)}}} + axes = None if kwargs.axes is None else tuple(kwargs.axes) + return _MeanStd( + mean=RequiredMeasure(Mean(axes=axes), tensor_id, kwargs.mode), + std=RequiredMeasure(Std(axes=axes), tensor_id, kwargs.mode), + ) def apply(self, tensor: xr.DataArray) -> xr.DataArray: - axes = None if self.axes is None else tuple(self.axes) - if self.mode == FIXED: - assert self.mean is not None and self.std is not None - mean = _get_fixed(self.mean, tensor, axes) - std = _get_fixed(self.std, tensor, axes) - elif self.mode in (PER_SAMPLE, PER_DATASET): - assert self.mean is None and self.std is None - mean = self.get_computed_measure(self.tensor_id, Mean(axes), mode=self.mode) - std = self.get_computed_measure(self.tensor_id, Std(axes), mode=self.mode) + if self.kwargs.mode == FIXED: + assert self.kwargs.mean is not None + assert self.kwargs.std is not None + assert not isinstance(self.computed, _MeanStd) + axes = None if self.kwargs.axes is None else tuple(self.kwargs.axes) + mean = _get_fixed(self.kwargs.mean, tensor, axes) + std = _get_fixed(self.kwargs.std, tensor, axes) else: - raise ValueError(self.mode) - - return (tensor - mean) / (std + self.eps) - - -class _KNOWN_PREPROCESSING(TypedDict): - -class _KnownProcessing(TypedDict): - pre: Mapping[PreprocessingName, Type[Processing]] - post: Mapping[PostprocessingName, Type[Processing]] - -KNOWN_PROCESSING = _KnownProcessing( - pre={ - "binarize": Binarize, - "clip": Clip, - "scale_linear": ScaleLinear, - "scale_range": ScaleRange, - "sigmoid": Sigmoid, - "zero_mean_unit_variance": ZeroMeanUnitVariance, - }, - post={ - "binarize": Binarize, - "clip": Clip, - "scale_linear": ScaleLinear, - "scale_mean_variance": ScaleMeanVariance, - "scale_range": ScaleRange, - "sigmoid": Sigmoid, - "zero_mean_unit_variance": ZeroMeanUnitVariance, - }, -) + assert self.kwargs.mode in (PER_SAMPLE, PER_DATASET) + assert self.kwargs.mean is None + assert self.kwargs.std is None + assert isinstance(self.computed, _MeanStd) + mean = self.computed.mean + std = self.computed.std + + return (tensor - mean) / (std + self.kwargs.eps) + + +IMPLEMENTED_PREPROCESSING = { + v0_5.Binarize.model_fields["id"].default + # binarize = Binarize + # clip = Clip + # scale_linear = ScaleLinear + # scale_range = ScaleRange + # sigmoid = Sigmoid + # zero_mean_unit_variance = ZeroMeanUnitVariance +} + +class IMPLEMENTED_POSTPROCESSING: + binarize = Binarize + clip = Clip + scale_linear = ScaleLinear + scale_mean_variance = ScaleMeanVariance + scale_range = ScaleRange + sigmoid = Sigmoid + zero_mean_unit_variance = ZeroMeanUnitVariance + + +class IMPLEMENTED_PROCESSING(IMPLEMENTED_PREPROCESSING, IMPLEMENTED_POSTPROCESSING): + pass diff --git a/bioimageio/core/statistical_measures.py b/bioimageio/core/statistical_measures.py index 0c9c94ec..c554ec33 100644 --- a/bioimageio/core/statistical_measures.py +++ b/bioimageio/core/statistical_measures.py @@ -5,6 +5,7 @@ from typing import Optional, Tuple import xarray as xr + from bioimageio.spec.model.v0_5 import AxisName MeasureValue = xr.DataArray diff --git a/tests/prediction_pipeline/test_processing.py b/tests/prediction_pipeline/test_processing.py index b693bcfc..819982a2 100644 --- a/tests/prediction_pipeline/test_processing.py +++ b/tests/prediction_pipeline/test_processing.py @@ -4,7 +4,7 @@ import pytest import xarray as xr -from bioimageio.core.prediction_pipeline._processing import KNOWN_PROCESSING +from bioimageio.core.prediction_pipeline._processing import IMPLEMENTED_PROCESSING from bioimageio.core.prediction_pipeline._utils import FIXED try: @@ -29,7 +29,7 @@ def test_assert_dtype(): @pytest.mark.parametrize( "proc", - list(KNOWN_PROCESSING["pre"].values()) + list(KNOWN_PROCESSING["post"].values()), + list(IMPLEMENTED_PROCESSING["pre"].values()) + list(IMPLEMENTED_PROCESSING["post"].values()), ) def test_no_req_measures_for_mode_fixed(proc): # check if mode=fixed is valid for this proc From d5d6756d64fe00004e11a4c75ff37e321fb9b463 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 26 Sep 2023 14:06:26 +0200 Subject: [PATCH 034/244] show download progressbar if not in CI --- bioimageio/core/_io.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/bioimageio/core/_io.py b/bioimageio/core/_io.py index 95099a6c..d403e7de 100644 --- a/bioimageio/core/_io.py +++ b/bioimageio/core/_io.py @@ -207,11 +207,17 @@ def download( ) -> _LocalFile: source = _interprete_file_source(source) if isinstance(source, AnyUrl): - if source.scheme in ("http", "https") and os.environ.get("CI", "false").lower() in ("1", "true"): - downloader = pooch.HTTPDownloader(headers={"User-Agent": "ci"}) + if source.scheme not in ("http", "https"): + raise NotImplementedError(source.scheme) + + if os.environ.get("CI", "false").lower() in ("1", "t", "true", "yes", "y"): + headers = {"User-Agent": "ci"} + progressbar = False else: - downloader = None + headers = {} + progressbar = True + downloader = pooch.HTTPDownloader(headers=headers, progressbar=progressbar) _ls: Any = pooch.retrieve(url=str(source), known_hash=known_hash, downloader=downloader) local_source = Path(_ls) root: Union[HttpUrl, DirectoryPath] = get_parent_url(source) From 390cd1a5b5caf87b894f5464d096bdcda712fdaf Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 26 Sep 2023 14:09:15 +0200 Subject: [PATCH 035/244] allow to set User-Agent via env var --- bioimageio/core/_io.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bioimageio/core/_io.py b/bioimageio/core/_io.py index d403e7de..84b2d87b 100644 --- a/bioimageio/core/_io.py +++ b/bioimageio/core/_io.py @@ -217,6 +217,9 @@ def download( headers = {} progressbar = True + if (user_agent := os.environ.get("BIOIMAGEIO_USER_AGENT")) is not None: + headers["User-Agent"] = user_agent + downloader = pooch.HTTPDownloader(headers=headers, progressbar=progressbar) _ls: Any = pooch.retrieve(url=str(source), known_hash=known_hash, downloader=downloader) local_source = Path(_ls) From f619b46ab2b1e9954e5cabf3b0d2a76c3c1d05b4 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 4 Oct 2023 14:28:25 +0200 Subject: [PATCH 036/244] WIP update Processing and ModelAdapter --- bioimageio/core/__main__.py | 15 +- .../_combined_processing.py | 16 +- .../_model_adapters/_model_adapter.py | 69 ++- .../core/prediction_pipeline/_processing.py | 364 --------------- .../core/prediction_pipeline/_stat_state.py | 2 +- bioimageio/core/prediction_pipeline/_utils.py | 26 +- .../core/prediction_pipeline/processing.py | 431 ++++++++++++++++++ 7 files changed, 483 insertions(+), 440 deletions(-) delete mode 100644 bioimageio/core/prediction_pipeline/_processing.py create mode 100644 bioimageio/core/prediction_pipeline/processing.py diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index a26f43d1..75da0316 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -4,24 +4,18 @@ import sys import warnings from glob import glob - from pathlib import Path from pprint import pformat, pprint -from typing import List, Optional +from typing import List, Optional, get_args import typer -from bioimageio.core import __version__, prediction, commands, resource_tests, load_raw_resource_description +from bioimageio.core import __version__, commands, load_raw_resource_description, prediction, resource_tests from bioimageio.core.common import TestSummary -from bioimageio.core.prediction_pipeline import get_weight_formats -from bioimageio.spec.__main__ import app, help_version as help_version_spec +from bioimageio.spec.__main__ import app +from bioimageio.spec.__main__ import help_version as help_version_spec from bioimageio.spec.model.raw_nodes import WeightsFormat -try: - from typing import get_args -except ImportError: - from typing_extensions import get_args # type: ignore - try: with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -192,7 +186,6 @@ def predict_image( weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."), devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."), ): - if isinstance(padding, str): padding = json.loads(padding.replace("'", '"')) assert isinstance(padding, dict) diff --git a/bioimageio/core/prediction_pipeline/_combined_processing.py b/bioimageio/core/prediction_pipeline/_combined_processing.py index 1c9d122e..c4a0d532 100644 --- a/bioimageio/core/prediction_pipeline/_combined_processing.py +++ b/bioimageio/core/prediction_pipeline/_combined_processing.py @@ -5,13 +5,19 @@ from ._processing import AssertDtype, EnsureDtype, Processing from ._utils import PER_DATASET, PER_SAMPLE, ComputedMeasures, RequiredMeasures, Sample -import ._processing as proc_impl +from .processing import get_impl, NamedMeasures, ProcSpec, M from bioimageio.spec.model.v0_5 import TensorId -# @dataclasses.dataclass -# class ProcessingInfoStep: -# name: str -# kwargs: Dict[str, Any] + +@dataclass +class CombinedMeasures(NamedMeasures[M]): + step_specs: Sequence[ProcSpec] + steps: ProcessingImplBase[Any, Any, Any] + def get_set(self) -> Set[M]: + ret = set() + for step in self.steps: + for f in fields(step) + return {f"{}getattr(self, f.name) for f in fields(self)} @dataclasses.dataclass diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py b/bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py index 8ab3fa88..adac96a7 100644 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py +++ b/bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py @@ -1,40 +1,41 @@ import abc -from typing import List, Optional, Sequence, Type, Union +from typing import List, Optional, Sequence, Tuple, Type, Union import xarray as xr -from bioimageio.core import load_resource_description -from bioimageio.core.resource_io import nodes +from bioimageio.spec.model import v0_4, v0_5 + +WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat] #: Known weight formats in order of priority #: First match wins -from bioimageio.spec.model import raw_nodes +_WEIGHT_FORMATS: Tuple[WeightsFormat, ...] = ( + "pytorch_state_dict", + "tensorflow_saved_model_bundle", + "torchscript", + "onnx", + "keras_hdf5", +) + -_WEIGHT_FORMATS = ["pytorch_state_dict", "tensorflow_saved_model_bundle", "torchscript", "onnx", "keras_hdf5"] +BioimageioModel = Union[v0_4.Model, v0_5.Model] class ModelAdapter(abc.ABC): """ - Represents model *without* any preprocessing and postprocessing + Represents model *without* any preprocessing or postprocessing """ - def __init__( - self, *, bioimageio_model: Union[nodes.Model, raw_nodes.Model], devices: Optional[Sequence[str]] = None - ): + def __init__(self, *, bioimageio_model: BioimageioModel, devices: Optional[Sequence[str]] = None): + super().__init__() self.bioimageio_model = self._prepare_model(bioimageio_model) self.default_devices = devices self.loaded = False @staticmethod - def _prepare_model(bioimageio_model): - """the (raw) model node is prepared (here: loaded as non-raw model node) for the model adapter to be ready - for operation. - Note: To write a model adapter that uses the raw model node one can overwrite this method. - """ - if isinstance(bioimageio_model, nodes.Model): - return bioimageio_model - else: - return load_resource_description(bioimageio_model) + def _prepare_model(bioimageio_model: BioimageioModel) -> BioimageioModel: + """The model node is prepared for the model adapter to be ready for operation.""" + return bioimageio_model def __enter__(self): """load on entering context""" @@ -42,7 +43,7 @@ def __enter__(self): self.load() # using default_devices return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore """unload on exiting context""" assert self.loaded self.unload() @@ -105,39 +106,31 @@ def get_weight_formats() -> List[str]: """ Return list of supported weight types """ - return _WEIGHT_FORMATS.copy() + return list(_WEIGHT_FORMATS) def create_model_adapter( *, - bioimageio_model: Union[nodes.Model, raw_nodes.Model], - devices=Optional[Sequence[str]], - weight_format: Optional[str] = None, + bioimageio_model: Union[v0_4.Model, v0_5.Model], + devices: Optional[Sequence[str]] = None, + weight_format: Optional[WeightsFormat] = None, ) -> ModelAdapter: """ Creates model adapter based on the passed spec Note: All specific adapters should happen inside this function to prevent different framework initializations interfering with each other """ - weights = bioimageio_model.weights - weight_formats = get_weight_formats() - - if weight_format is not None: - if weight_format not in weight_formats: - raise ValueError(f"Weight format {weight_format} is not in supported formats {weight_formats}") - weight_formats = [weight_format] + if weight_format is not None and weight_format not in _WEIGHT_FORMATS: + raise ValueError(f"Weight format {weight_format} is not in supported formats {_WEIGHT_FORMATS}") - for weight in weight_formats: - if weight in weights: - adapter_cls = _get_model_adapter(weight) - return adapter_cls(bioimageio_model=bioimageio_model, devices=devices) + priority_order = _WEIGHT_FORMATS if weight_format is None else (weight_format,) + weight = bioimageio_model.weights.get(*priority_order) - raise RuntimeError( - f"weight format {weight_format} not among formats listed in model: {list(bioimageio_model.weights.keys())}" - ) + adapter_cls = _get_model_adapter(weight.type) + return adapter_cls(bioimageio_model=bioimageio_model, devices=devices) -def _get_model_adapter(weight_format: str) -> Type[ModelAdapter]: +def _get_model_adapter(weight_format: WeightsFormat) -> Type[ModelAdapter]: """ Return adapter class based on the weight format Note: All specific adapters should happen inside this function to prevent different framework diff --git a/bioimageio/core/prediction_pipeline/_processing.py b/bioimageio/core/prediction_pipeline/_processing.py deleted file mode 100644 index f72dda2e..00000000 --- a/bioimageio/core/prediction_pipeline/_processing.py +++ /dev/null @@ -1,364 +0,0 @@ -from abc import ABC, abstractmethod -from dataclasses import InitVar, dataclass, field, fields -from types import MappingProxyType -from typing import ( - Any, - ClassVar, - Dict, - FrozenSet, - Generic, - Literal, - Mapping, - NamedTuple, - NotRequired, - Optional, - Sequence, - Set, - Tuple, - Type, - TypedDict, - TypeVar, - Union, - get_args, -) - -import numpy -import numpy as np -import xarray as xr -from numpy.typing import DTypeLike -from typing_extensions import LiteralString, Self - -from bioimageio.core.statistical_measures import Mean, Measure, MeasureValue, Percentile, Std -from bioimageio.spec._internal.base_nodes import Node, NodeWithExplicitlySetFields -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import TensorId - -from ._utils import FIXED, PER_DATASET, PER_SAMPLE, DatasetMode, Mode, RequiredMeasure, Sample, SampleMode - -Binarize = Union[v0_4.Binarize, v0_5.Binarize] -BinarizeKwargs = Union[v0_4.BinarizeKwargs, v0_5.BinarizeKwargs] -Clip = Union[v0_4.Clip, v0_5.Clip] -ClipKwargs = Union[v0_4.ClipKwargs, v0_5.ClipKwargs] -EnsureDtypeKwargs = v0_5.EnsureDtypeKwargs -Processing = Union[v0_4.Processing, v0_5.Processing] -ProcessingKwargs = Union[v0_4.ProcessingKwargs, v0_5.ProcessingKwargs] -ScaleLinear = Union[v0_4.ScaleLinear, v0_5.ScaleLinear] -ScaleLinearKwargs = Union[v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs] -ScaleMeanVariance = Union[v0_4.ScaleMeanVariance, v0_5.ScaleMeanVariance] -ScaleMeanVarianceKwargs = Union[v0_4.ScaleMeanVarianceKwargs, v0_5.ScaleMeanVarianceKwargs] -ScaleRange = Union[v0_4.ScaleRange, v0_5.ScaleRange] -ScaleRangeKwargs = Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs] -ZeroMeanUnitVariance = Union[v0_4.ZeroMeanUnitVariance, v0_5.ZeroMeanUnitVariance] -ZeroMeanUnitVarianceKwargs = Union[v0_4.ZeroMeanUnitVarianceKwargs, v0_5.ZeroMeanUnitVarianceKwargs] - - -def _get_fixed( - fixed: Union[float, Sequence[float]], tensor: xr.DataArray, axes: Optional[Sequence[str]] -) -> Union[float, xr.DataArray]: - if axes is None: - return fixed - - fixed_shape = tuple(s for d, s in tensor.sizes.items() if d not in axes) - fixed_dims = tuple(d for d in tensor.dims if d not in axes) - fixed = np.array(fixed).reshape(fixed_shape) - return xr.DataArray(fixed, dims=fixed_dims) - - -PKwargs = TypeVar("PKwargs", bound=ProcessingKwargs) -ProcInput = TypeVar("ProcInput", xr.DataArray, Sample) - - -RCV = TypeVar("RCV", RequiredMeasure, MeasureValue) - - -@dataclass -class _NamedMeasures(Generic[RCV]): - def get_set(self) -> Set[RCV]: - return {getattr(self, f.name) for f in fields(self)} - - -_NoRequiredMeasures = _NamedMeasures[RequiredMeasure] -_NoMeasureValues = _NamedMeasures[MeasureValue] - -R = TypeVar("R", bound=_NamedMeasures[RequiredMeasure]) -C = TypeVar("C", bound=_NamedMeasures[MeasureValue]) - - -@dataclass -class ProcessingImplBase(Generic[PKwargs, R, C], ABC): - """Base class for all Pre- and Postprocessing implementations.""" - - tensor_id: TensorId - """id of tensor to operate on""" - kwargs: PKwargs - computed_measures: InitVar[Mapping[RequiredMeasure, MeasureValue]] = field( - default=MappingProxyType[RequiredMeasure, MeasureValue]({}) - ) - required: R = field(init=False) - computed: C = field(init=False) - - def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None: - self.required = self.get_required_measures(self.tensor_id, self.kwargs) - selected = {} - for f in fields(self.required): - req = getattr(self.required, f.name) - if req in computed_measures: - selected[f.name] = computed_measures[req] - else: - raise ValueError(f"Missing computed measure: {req} (as '{f.name}').") - - @classmethod - @abstractmethod - def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> R: - ... - - @property - def required_measures(self) -> Set[RequiredMeasure]: - return self.required.get_set() - - def __call__(self, __input: ProcInput, /) -> ProcInput: - if isinstance(__input, xr.DataArray): - return self.apply(__input) - else: - return self.apply_to_sample(__input) - - @abstractmethod - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - """apply processing""" - ... - - def apply_to_sample(self, sample: Sample) -> Sample: - ret = dict(sample) - ret[self.tensor_id] = self.apply(sample[self.tensor_id]) - return ret - - def get_spec(self) -> v0_5.Processing: - raise NotImplementedError - - -class ProcessingImplBaseWoMeasures(ProcessingImplBase[PKwargs, _NoRequiredMeasures, _NoMeasureValues]): - @classmethod - def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> _NoRequiredMeasures: - return _NamedMeasures() - - -class AssertProcessing(NodeWithExplicitlySetFields, ABC, frozen=True): - id: str - kwargs: ProcessingKwargs - fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"}) - - -class AssertDtypeKwargs(ProcessingKwargs, frozen=True): - dtype: Union[str, Sequence[str]] - - -class AssertDtype(AssertProcessing, frozen=True): - id: Literal["assert_dtype"] = "assert_dtype" - kwargs: AssertDtypeKwargs - - -class AssertDtypeImpl(ProcessingImplBaseWoMeasures[AssertDtypeKwargs]): - _assert_with: Tuple[Type[DTypeLike], ...] - - def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None: - super().__post_init__(computed_measures) - if isinstance(self.kwargs.dtype, str): - dtype = [self.kwargs.dtype] - else: - dtype = self.kwargs.dtype - - assert_w = tuple(type(numpy.dtype(dt)) for dt in dtype) - self._assert_with = assert_w - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - assert isinstance(tensor.dtype, self._assert_with) - return tensor - - -class BinarizeImpl(ProcessingImplBaseWoMeasures[BinarizeKwargs]): - """'output = tensor > threshold'.""" - - kwargs: BinarizeKwargs - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor > self.kwargs.threshold - - def get_spec(self): - return v0_5.Binarize(kwargs=self.kwargs) - - -class ClipImpl(ProcessingImplBaseWoMeasures[ClipKwargs]): - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.clip(min=self.kwargs.min, max=self.kwargs.max) - - def get_spec(self): - return v0_5.Clip(kwargs=self.kwargs) - - -class EnsureDtypeImpl(ProcessingImplBaseWoMeasures[EnsureDtypeKwargs]): - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.astype(self.kwargs.dtype) - - -class ScaleLinearImpl(ProcessingImplBaseWoMeasures[ScaleLinearKwargs]): - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - joint_axes = self.kwargs.axes or () - batch_axis_names = ("b", "batch") - scale_along = tuple( - ax for ax in tensor.dims if isinstance(ax, str) and ax not in joint_axes and ax not in batch_axis_names - ) - if scale_along: - gain = xr.DataArray(np.atleast_1d(self.kwargs.gain), dims=scale_along) - offset = xr.DataArray(np.atleast_1d(self.kwargs.offset), dims=scale_along) - else: - assert isinstance(self.kwargs.gain, float) or len(self.kwargs.gain) == 1 - gain = self.kwargs.gain if isinstance(self.kwargs.gain, float) else self.kwargs.gain[0] - assert isinstance(self.kwargs.offset, float) or len(self.kwargs.offset) == 1 - offset = self.kwargs.offset if isinstance(self.kwargs.offset, float) else self.kwargs.offset[0] - - return tensor * gain + offset - - def get_spec(self): - if isinstance(self.kwargs, v0_4.ScaleLinearKwargs): - raise NotImplementedError - - return v0_5.ScaleLinear(kwargs=self.kwargs) - - -@dataclass -class _MeanStd(_NamedMeasures[RCV]): - mean: RCV - std: RCV - - -@dataclass -class _MeanStdAndRef(_MeanStd[RCV]): - ref_mean: RCV - ref_std: RCV - - -class ScaleMeanVarianceImpl( - ProcessingImplBase[ScaleMeanVarianceKwargs, _MeanStdAndRef[RequiredMeasure], _MeanStdAndRef[MeasureValue]] -): - @classmethod - def get_required_measures(cls, tensor_id: TensorId, kwargs: ScaleMeanVarianceKwargs): - axes = tuple(kwargs.axes) if isinstance(kwargs.axes, str) else kwargs.axes - return _MeanStdAndRef( - mean=RequiredMeasure(Mean(axes), tensor_id, mode=kwargs.mode), - std=RequiredMeasure(Std(axes), tensor_id, mode=kwargs.mode), - ref_mean=RequiredMeasure(Mean(axes), kwargs.reference_tensor, mode=kwargs.mode), - ref_std=RequiredMeasure(Std(axes), kwargs.reference_tensor, mode=kwargs.mode), - ) - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - c = self.computed - eps = self.kwargs.eps - return (tensor - c.mean) / (c.std + eps) * (c.ref_std + eps) + c.ref_mean - - def get_spec(self): - if isinstance(self.kwargs, v0_4.ScaleMeanVarianceKwargs): - raise NotImplementedError - - return v0_5.ScaleMeanVariance(kwargs=self.kwargs) - - -@dataclass -class _MinMaxPerc(_NamedMeasures[RCV]): - lower: RCV - upper: RCV - - -class ScaleRangeImpl(ProcessingImplBase[ScaleRangeKwargs, _MinMaxPerc[RequiredMeasure], _MinMaxPerc[MeasureValue]]): - # def get_required_measures(self): - @classmethod - def get_required_measures(cls, tensor_id: TensorId, kwargs: ScaleRangeKwargs) -> _MinMaxPerc[RequiredMeasure]: - ref_name = kwargs.reference_tensor or tensor_id - axes = None if kwargs.axes is None else tuple(kwargs.axes) - return _MinMaxPerc( - lower=RequiredMeasure(Percentile(kwargs.min_percentile, axes=axes), ref_name, kwargs.mode), - upper=RequiredMeasure(Percentile(kwargs.max_percentile, axes=axes), ref_name, kwargs.mode), - ) - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - c = self.computed - return (tensor - c.lower) / (c.upper - c.lower + self.kwargs.eps) - - def get_spec(self): - if isinstance(self.kwargs, v0_4.ScaleRangeKwargs): - raise NotImplementedError - - return v0_5.ScaleRange(kwargs=self.kwargs) - - -@dataclass -class SigmoidImpl(ProcessingImplBaseWoMeasures[ProcessingKwargs]): - """1 / (1 + e^(-tensor)).""" - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return 1.0 / (1.0 + np.exp(-tensor)) # type: ignore - - -@dataclass -class ZeroMeanUnitVarianceImpl( - ProcessingImplBase[ - ZeroMeanUnitVarianceKwargs, - Union[_NoRequiredMeasures, _MeanStd[RequiredMeasure]], - Union[_NoMeasureValues, _MeanStd[MeasureValue]], - ] -): - """normalize to zero mean, unit variance.""" - - @classmethod - def get_required_measures( - cls, tensor_id: TensorId, kwargs: ZeroMeanUnitVarianceKwargs - ) -> Union[_NoRequiredMeasures, _MeanStd[RequiredMeasure]]: - if kwargs.mode == FIXED: - return _NamedMeasures() - else: - axes = None if kwargs.axes is None else tuple(kwargs.axes) - return _MeanStd( - mean=RequiredMeasure(Mean(axes=axes), tensor_id, kwargs.mode), - std=RequiredMeasure(Std(axes=axes), tensor_id, kwargs.mode), - ) - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - if self.kwargs.mode == FIXED: - assert self.kwargs.mean is not None - assert self.kwargs.std is not None - assert not isinstance(self.computed, _MeanStd) - axes = None if self.kwargs.axes is None else tuple(self.kwargs.axes) - mean = _get_fixed(self.kwargs.mean, tensor, axes) - std = _get_fixed(self.kwargs.std, tensor, axes) - else: - assert self.kwargs.mode in (PER_SAMPLE, PER_DATASET) - assert self.kwargs.mean is None - assert self.kwargs.std is None - assert isinstance(self.computed, _MeanStd) - mean = self.computed.mean - std = self.computed.std - - return (tensor - mean) / (std + self.kwargs.eps) - - -IMPLEMENTED_PREPROCESSING = { - v0_5.Binarize.model_fields["id"].default - # binarize = Binarize - # clip = Clip - # scale_linear = ScaleLinear - # scale_range = ScaleRange - # sigmoid = Sigmoid - # zero_mean_unit_variance = ZeroMeanUnitVariance -} - -class IMPLEMENTED_POSTPROCESSING: - binarize = Binarize - clip = Clip - scale_linear = ScaleLinear - scale_mean_variance = ScaleMeanVariance - scale_range = ScaleRange - sigmoid = Sigmoid - zero_mean_unit_variance = ZeroMeanUnitVariance - - -class IMPLEMENTED_PROCESSING(IMPLEMENTED_PREPROCESSING, IMPLEMENTED_POSTPROCESSING): - pass diff --git a/bioimageio/core/prediction_pipeline/_stat_state.py b/bioimageio/core/prediction_pipeline/_stat_state.py index cf5be64c..c0e72eb0 100644 --- a/bioimageio/core/prediction_pipeline/_stat_state.py +++ b/bioimageio/core/prediction_pipeline/_stat_state.py @@ -5,7 +5,7 @@ from bioimageio.core.statistical_measures import Measure from ._measure_groups import MeasureGroups, MeasureValue, get_measure_groups -from ._utils import PER_DATASET, PER_SAMPLE, ComputedMeasures, RequiredMeasures, Sample, TensorName +from ._utils import PER_DATASET, PER_SAMPLE, MeasureValue, RequiredMeasure, Sample, TensorName class StatsState: diff --git a/bioimageio/core/prediction_pipeline/_utils.py b/bioimageio/core/prediction_pipeline/_utils.py index 78d0e478..83f181ef 100644 --- a/bioimageio/core/prediction_pipeline/_utils.py +++ b/bioimageio/core/prediction_pipeline/_utils.py @@ -5,31 +5,15 @@ from typing import Any, Dict, Iterator, List, Literal, NamedTuple, Set, Union import xarray as xr -from bioimageio.spec.model.v0_5 import TensorId - -from bioimageio.core.statistical_measures import Measure, MeasureValue - -FixedMode = Literal["fixed"] -SampleMode = Literal["per_sample"] -DatasetMode = Literal["per_dataset"] -Mode = Literal[FixedMode, SampleMode, DatasetMode] - -FIXED: FixedMode = "fixed" -PER_SAMPLE: SampleMode = "per_sample" -PER_DATASET: DatasetMode = "per_dataset" -MODES: Set[Mode] = {FIXED, PER_SAMPLE, PER_DATASET} +from bioimageio.core.statistical_measures import Measure +from bioimageio.spec.model.v0_5 import TensorId Sample = Dict[TensorId, xr.DataArray] -class RequiredMeasure(NamedTuple): - measure: Measure - tensor_id: TensorId - mode: Mode - - # def __repr__(self) -> str: - # return f"{self.measure} of {self.tensor_id} ({self.mode})" +# def __repr__(self) -> str: +# return f"{self.measure} of {self.tensor_id} ({self.mode})" # RequiredMeasures = List[ReqMeasure] @@ -92,4 +76,4 @@ class RequiredMeasure(NamedTuple): # elif isinstance(__x, ReqMeasureEntry): # else: -# return super().__contains__(__x) \ No newline at end of file +# return super().__contains__(__x) diff --git a/bioimageio/core/prediction_pipeline/processing.py b/bioimageio/core/prediction_pipeline/processing.py new file mode 100644 index 00000000..dde182e3 --- /dev/null +++ b/bioimageio/core/prediction_pipeline/processing.py @@ -0,0 +1,431 @@ +from abc import ABC, abstractmethod +from dataclasses import InitVar, dataclass, field, fields +from types import MappingProxyType +from typing import ( + Any, + ClassVar, + FrozenSet, + Generic, + Hashable, + List, + Literal, + Mapping, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +import numpy +import numpy as np +import xarray as xr +from numpy.typing import DTypeLike +from typing_extensions import LiteralString + +from bioimageio.core.statistical_measures import Mean, Measure, MeasureValue, Percentile, Std +from bioimageio.spec._internal.base_nodes import NodeWithExplicitlySetFields +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import TensorId + +from ._utils import Sample + +AssertProcessingId = Literal["assert_dtype"] + + +class AssertProcessingBase(NodeWithExplicitlySetFields, frozen=True): + id: AssertProcessingId + fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"}) + + +class AssertDtypeKwargs(v0_5.ProcessingKwargs, frozen=True): + dtype: Union[str, Sequence[str]] + + +class AssertDtype(AssertProcessingBase, frozen=True): + id: Literal["assert_dtype"] = "assert_dtype" + kwargs: AssertDtypeKwargs + + +class RequiredMeasure(NamedTuple): + measure: Measure + tensor_id: TensorId + mode: Literal["per_sample", "per_dataset"] + + +M = TypeVar("M", RequiredMeasure, MeasureValue) + + +@dataclass +class NamedMeasures(Generic[M]): + """Named Measures that specifies all required/computed measures of a Processing instance""" + + def get_set(self) -> Set[M]: + return {getattr(self, f.name) for f in fields(self)} + + +# The two generics are conceptually a higher kinded generic +R = TypeVar("R", bound=NamedMeasures[RequiredMeasure]) +C = TypeVar("C", bound=NamedMeasures[MeasureValue]) + + +PKwargs = TypeVar("PKwargs", bound=Union[v0_4.ProcessingKwargs, v0_5.ProcessingKwargs]) +ProcInput = TypeVar("ProcInput", xr.DataArray, Sample) +ProcessingBase = Union[v0_4.ProcessingBase, v0_5.ProcessingBase] + + +@dataclass(frozen=True) +class ProcessingImplBase(Generic[PKwargs, R, C], ABC): + """Base class for all Pre- and Postprocessing implementations.""" + + tensor_id: TensorId + """id of tensor to operate on""" + kwargs: PKwargs + computed_measures: InitVar[Mapping[RequiredMeasure, MeasureValue]] = field( + default=MappingProxyType[RequiredMeasure, MeasureValue]({}) + ) + assert type(R) is type(C), "R and C are conceptually a higher kindes generic, their class has to be identical" + required: R = field(init=False) + computed: C = field(init=False) + + def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None: + object.__setattr__(self, "required", self.get_required_measures(self.tensor_id, self.kwargs)) + selected = {} + for f in fields(self.required): + req = getattr(self.required, f.name) + if req in computed_measures: + selected[f.name] = computed_measures[req] + else: + raise ValueError(f"Missing computed measure: {req} (as '{f.name}').") + + object.__setattr__(self, "computed", self.required.__class__(**selected)) + + @abstractmethod + @classmethod + def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMeasures[RequiredMeasure]: + ... + + def __call__(self, __input: ProcInput, /) -> ProcInput: + if isinstance(__input, xr.DataArray): + return self.apply(__input) + else: + return self.apply_to_sample(__input) + + @abstractmethod + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + """apply processing""" + ... + + def apply_to_sample(self, sample: Sample) -> Sample: + ret = dict(sample) + ret[self.tensor_id] = self.apply(sample[self.tensor_id]) + return ret + + @abstractmethod + def get_spec(self) -> Union[ProcessingBase, AssertProcessingBase]: + ... + + +@dataclass(frozen=True) +class ProcessingImplBaseWoMeasures( + ProcessingImplBase[PKwargs, NamedMeasures[RequiredMeasure], NamedMeasures[MeasureValue]] +): + @classmethod + def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMeasures[RequiredMeasure]: + return NamedMeasures() + + +@dataclass(frozen=True) +class AssertDtypeImpl(ProcessingImplBaseWoMeasures[AssertDtypeKwargs]): + _assert_with: Tuple[Type[DTypeLike], ...] = field(init=False) + + def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None: + super().__post_init__(computed_measures) + if isinstance(self.kwargs.dtype, str): + dtype = [self.kwargs.dtype] + else: + dtype = self.kwargs.dtype + + object.__setattr__(self, "assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype)) + + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + assert isinstance(tensor.dtype, self._assert_with) + return tensor + + def get_spec(self): + return AssertDtype(kwargs=self.kwargs) + + +@dataclass(frozen=True) +class BinarizeImpl(ProcessingImplBaseWoMeasures[Union[v0_4.BinarizeKwargs, v0_5.BinarizeKwargs]]): + """'output = tensor > threshold'.""" + + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + return tensor > self.kwargs.threshold + + def get_spec(self): + return v0_5.Binarize(kwargs=self.kwargs) + + +@dataclass(frozen=True) +class ClipImpl(ProcessingImplBaseWoMeasures[Union[v0_4.ClipKwargs, v0_5.ClipKwargs]]): + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + return tensor.clip(min=self.kwargs.min, max=self.kwargs.max) + + def get_spec(self): + return v0_5.Clip(kwargs=self.kwargs) + + +@dataclass(frozen=True) +class EnsureDtypeImpl(ProcessingImplBaseWoMeasures[v0_5.EnsureDtypeKwargs]): + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + return tensor.astype(self.kwargs.dtype) + + def get_spec(self): + return v0_5.EnsureDtype(kwargs=self.kwargs) + + +@dataclass(frozen=True) +class ScaleLinearImpl(ProcessingImplBaseWoMeasures[Union[v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs]]): + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + axis = ( + self.kwargs.axis + if isinstance(self.kwargs, v0_5.ScaleLinearKwargs) + else _get_complement_axis(tensor, self.kwargs.axes) + ) + if axis: + gain = xr.DataArray(np.atleast_1d(self.kwargs.gain), dims=axis) + offset = xr.DataArray(np.atleast_1d(self.kwargs.offset), dims=axis) + else: + assert isinstance(self.kwargs.gain, (float, int)) or len(self.kwargs.gain) == 1 + gain = self.kwargs.gain if isinstance(self.kwargs.gain, (float, int)) else self.kwargs.gain[0] + assert isinstance(self.kwargs.offset, (float, int)) or len(self.kwargs.offset) == 1 + offset = self.kwargs.offset if isinstance(self.kwargs.offset, (float, int)) else self.kwargs.offset[0] + + return tensor * gain + offset + + def get_spec(self): + if isinstance(self.kwargs, v0_4.ScaleLinearKwargs): + raise NotImplementedError + + return v0_5.ScaleLinear(kwargs=self.kwargs) + + +@dataclass +class NamedMeasuresScaleMeanVariance(NamedMeasures[M]): + mean: M + std: M + ref_mean: M + ref_std: M + + +@dataclass(frozen=True) +class ScaleMeanVarianceImpl( + ProcessingImplBase[ + Union[v0_4.ScaleMeanVarianceKwargs, v0_5.ScaleMeanVarianceKwargs], + NamedMeasuresScaleMeanVariance[RequiredMeasure], + NamedMeasuresScaleMeanVariance[MeasureValue], + ] +): + @classmethod + def get_required_measures( + cls, tensor_id: TensorId, kwargs: Union[v0_4.ScaleMeanVarianceKwargs, v0_5.ScaleMeanVarianceKwargs] + ): + axes = tuple(kwargs.axes) if isinstance(kwargs.axes, str) else kwargs.axes + return NamedMeasuresScaleMeanVariance( + mean=RequiredMeasure(Mean(axes), tensor_id, mode=kwargs.mode), + std=RequiredMeasure(Std(axes), tensor_id, mode=kwargs.mode), + ref_mean=RequiredMeasure(Mean(axes), cast(TensorId, kwargs.reference_tensor), mode=kwargs.mode), + ref_std=RequiredMeasure(Std(axes), cast(TensorId, kwargs.reference_tensor), mode=kwargs.mode), + ) + + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + c = self.computed + eps = self.kwargs.eps + return (tensor - c.mean) / (c.std + eps) * (c.ref_std + eps) + c.ref_mean + + def get_spec(self): + if isinstance(self.kwargs, v0_4.ScaleMeanVarianceKwargs): + raise NotImplementedError + + return v0_5.ScaleMeanVariance(kwargs=self.kwargs) + + +@dataclass +class NamedMeasuresScaleRange(NamedMeasures[M]): + lower: M + upper: M + + +@dataclass(frozen=True) +class ScaleRangeImpl( + ProcessingImplBase[ + Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs], + NamedMeasuresScaleRange[RequiredMeasure], + NamedMeasuresScaleRange[MeasureValue], + ] +): + @classmethod + def get_required_measures(cls, tensor_id: TensorId, kwargs: Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs]): + ref_name = kwargs.reference_tensor or tensor_id + axes = None if kwargs.axes is None else tuple(kwargs.axes) + return NamedMeasuresScaleRange( + lower=RequiredMeasure(Percentile(kwargs.min_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode), + upper=RequiredMeasure(Percentile(kwargs.max_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode), + ) + + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + c = self.computed + return (tensor - c.lower) / (c.upper - c.lower + self.kwargs.eps) + + def get_spec(self): + if isinstance(self.kwargs, v0_4.ScaleRangeKwargs): + raise NotImplementedError + + return v0_5.ScaleRange(kwargs=self.kwargs) + + +@dataclass(frozen=True) +class SigmoidImpl(ProcessingImplBaseWoMeasures[v0_5.ProcessingKwargs]): + """1 / (1 + e^(-tensor)).""" + + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + return 1.0 / (1.0 + np.exp(-tensor)) # type: ignore + + def get_spec(self): + return v0_5.Sigmoid() + + +@dataclass +class NamedMeasuresZeroMeanUnitVariance(NamedMeasures[M]): + mean: M + std: M + + +@dataclass(frozen=True) +class ZeroMeanUnitVarianceImpl( + ProcessingImplBase[ + Union[v0_4.ZeroMeanUnitVarianceKwargs, v0_5.ZeroMeanUnitVarianceKwargs], + NamedMeasuresZeroMeanUnitVariance[RequiredMeasure], + NamedMeasuresZeroMeanUnitVariance[MeasureValue], + ] +): + """normalize to zero mean, unit variance.""" + + @classmethod + def get_required_measures( + cls, tensor_id: TensorId, kwargs: Union[v0_4.ZeroMeanUnitVarianceKwargs, v0_5.ZeroMeanUnitVarianceKwargs] + ): + axes = None if kwargs.axes is None else tuple(kwargs.axes) + assert kwargs.mode != "fixed" # should use FixedZeroMeanUnitVarianceImpl + return NamedMeasuresZeroMeanUnitVariance( + mean=RequiredMeasure(Mean(axes=axes), tensor_id, kwargs.mode), + std=RequiredMeasure(Std(axes=axes), tensor_id, kwargs.mode), + ) + + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + mean = self.computed.mean + std = self.computed.std + return (tensor - mean) / (std + self.kwargs.eps) + + def get_spec(self): + if isinstance(self.kwargs, v0_4.ZeroMeanUnitVarianceKwargs): + raise NotImplementedError + + return v0_5.ZeroMeanUnitVariance(kwargs=self.kwargs) + + +@dataclass(frozen=True) +class FixedZeroMeanUnitVarianceImpl( + ProcessingImplBaseWoMeasures[Union[v0_4.ZeroMeanUnitVarianceKwargs, v0_5.FixedZeroMeanUnitVarianceKwargs]] +): + """normalize to zero mean, unit variance with precomputed values.""" + + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + if isinstance(self.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): + axis = self.kwargs.axis + elif isinstance(self.kwargs.mean, float) and isinstance(self.kwargs.std, float): + axis = None + else: + axis = _get_complement_axis(tensor, self.kwargs.axes) + + mean = xr.DataArray(self.kwargs.mean, dims=axis) + std = xr.DataArray(self.kwargs.std, dims=axis) + return (tensor - mean) / std + + def get_spec(self): + if isinstance(self.kwargs, v0_4.ZeroMeanUnitVarianceKwargs): + raise NotImplementedError + + return v0_5.FixedZeroMeanUnitVariance(kwargs=self.kwargs) + + +ProcSpec = Union[AssertDtype, v0_4.Preprocessing, v0_4.Postprocessing, v0_5.Preprocessing, v0_5.Postprocessing] + + +def get_impl(proc_spec: ProcSpec): + if isinstance(proc_spec, AssertDtype): + return AssertDtypeImpl + elif isinstance(proc_spec, (v0_4.Binarize, v0_5.Binarize)): + return BinarizeImpl + elif isinstance(proc_spec, (v0_4.Clip, v0_5.Clip)): + return ClipImpl + elif isinstance(proc_spec, v0_5.EnsureDtype): + return EnsureDtypeImpl + elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVariance): + return FixedZeroMeanUnitVarianceImpl + elif isinstance(proc_spec, (v0_4.ScaleLinear, v0_5.ScaleLinear)): + return ScaleLinearImpl + elif isinstance(proc_spec, (v0_4.ScaleMeanVariance, v0_5.ScaleMeanVariance)): + return ScaleMeanVarianceImpl + elif isinstance(proc_spec, (v0_4.ScaleRange, v0_5.ScaleRange)): + return ScaleRangeImpl + elif isinstance(proc_spec, (v0_4.Sigmoid, v0_5.Sigmoid)): + return SigmoidImpl + elif isinstance(proc_spec, v0_4.ZeroMeanUnitVariance) and proc_spec.kwargs.mode == "fixed": + return FixedZeroMeanUnitVarianceImpl + elif isinstance(proc_spec, (v0_4.ZeroMeanUnitVariance, v0_5.ZeroMeanUnitVariance)): + return ZeroMeanUnitVarianceImpl + else: + raise NotImplementedError(proc_spec) + + +Model = Union[v0_4.Model, v0_5.Model] + + +def get_procs(model: Model): + procs: List[ProcessingImplBase[Any, Any, Any]] = [] + for ipt in model.inputs: + if not ipt.preprocessing: + continue + + for proc_spec in ipt.preprocessing: + impl = get_impl(proc_spec) + + +def _get_complement_axis(tensor: xr.DataArray, axes: Optional[Sequence[Hashable]]) -> Optional[Hashable]: + if axes is None: + return None + + v04_AXIS_TYPE_MAP = { + "b": "batch", + "t": "time", + "i": "index", + "c": "channel", + "x": "space", + "y": "space", + "z": "space", + } + converted_axes = [v04_AXIS_TYPE_MAP.get(a, a) for a in map(str, axes)] + ["batch"] + complement_axes = [a for a in tensor.dims if str(a) not in converted_axes] + if len(complement_axes) != 1: + raise ValueError( + f"Expected a single complement axis, but axes '{converted_axes}' (orignally '{axes}') " + f"for tensor dims '{tensor.dims}' leave '{complement_axes}'." + ) + + return complement_axes[0] From e258d66b7fbf415ef228f46a44184c40ffefb987 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 5 Oct 2023 01:31:56 +0200 Subject: [PATCH 037/244] update onnx model adapter --- .../prediction_pipeline/_combined_processing.py | 6 +++--- .../_model_adapters/_onnx_model_adapter.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/bioimageio/core/prediction_pipeline/_combined_processing.py b/bioimageio/core/prediction_pipeline/_combined_processing.py index c4a0d532..1cfdd66a 100644 --- a/bioimageio/core/prediction_pipeline/_combined_processing.py +++ b/bioimageio/core/prediction_pipeline/_combined_processing.py @@ -1,11 +1,11 @@ import dataclasses -from typing import Any, Dict, List, Literal, Optional, Sequence, Union +from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Union from bioimageio.core.resource_io import nodes from ._processing import AssertDtype, EnsureDtype, Processing from ._utils import PER_DATASET, PER_SAMPLE, ComputedMeasures, RequiredMeasures, Sample -from .processing import get_impl, NamedMeasures, ProcSpec, M +from .processing import ProcessingImplBase, get_impl, NamedMeasures, ProcSpec, M from bioimageio.spec.model.v0_5 import TensorId @@ -30,7 +30,7 @@ class ProcessingInfo: class CombinedProcessing: - def __init__(self, combine_tensors: Dict[TensorId, ProcessingInfo]): + def __init__(self, steps: List[]: Dict[TensorId, ProcessingInfo]): self._procs = [] # ensure all tensors have correct data type before any processing diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_onnx_model_adapter.py b/bioimageio/core/prediction_pipeline/_model_adapters/_onnx_model_adapter.py index 5495c8bd..45b882e0 100644 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/prediction_pipeline/_model_adapters/_onnx_model_adapter.py @@ -1,6 +1,6 @@ import logging import warnings -from typing import List, Optional +from typing import List, Optional, Sequence import onnxruntime as rt import xarray as xr @@ -11,12 +11,12 @@ class ONNXModelAdapter(ModelAdapter): - def _load(self, *, devices: Optional[List[str]] = None): + def _load(self, *, devices: Optional[Sequence[str]] = None): self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs] - - self._session = rt.InferenceSession(str(self.bioimageio_model.weights["onnx"].source)) - onnx_inputs = self._session.get_inputs() - self._input_names = [ipt.name for ipt in onnx_inputs] + assert self.bioimageio_model.weights.onnx is not None + self._session = rt.InferenceSession(str(self.bioimageio_model.weights.onnx.source)) + onnx_inputs = self._session.get_inputs() # type: ignore + self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] # type: ignore if devices is not None: warnings.warn(f"Device management is not implemented for onnx yet, ignoring the devices {devices}") @@ -24,11 +24,11 @@ def _load(self, *, devices: Optional[List[str]] = None): def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: assert len(input_tensors) == len(self._input_names) input_arrays = [ipt.data for ipt in input_tensors] - result = self._session.run(None, dict(zip(self._input_names, input_arrays))) + result = self._session.run(None, dict(zip(self._input_names, input_arrays))) # type: ignore if not isinstance(result, (list, tuple)): result = [] - return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] + return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] # type: ignore def _unload(self) -> None: warnings.warn("Device management is not implemented for onnx yet, cannot unload model") From 4fe9894594f0d50f5faf2b6847cae4342acc28ed Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 11 Oct 2023 09:59:54 +0200 Subject: [PATCH 038/244] WIP update for spec changes --- .../core/_internal/validation_visitors.py | 20 +++++++++++-------- .../_model_adapters/_keras_model_adapter.py | 13 +++++++----- .../test_internal/test_validation_visitors.py | 17 ++++++++-------- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/bioimageio/core/_internal/validation_visitors.py b/bioimageio/core/_internal/validation_visitors.py index bb7de16a..bdb2d25e 100644 --- a/bioimageio/core/_internal/validation_visitors.py +++ b/bioimageio/core/_internal/validation_visitors.py @@ -8,15 +8,14 @@ from types import ModuleType from typing import Any, Hashable, List, Optional, Tuple, TypedDict, Union -from annotated_types import SLOTS +from pydantic import AnyUrl, DirectoryPath +from pydantic.fields import FieldInfo +from typing_extensions import NotRequired + from bioimageio.spec._internal.base_nodes import Node from bioimageio.spec._internal.constants import ALERT_TYPE, IN_PACKAGE_MESSAGE, KW_ONLY, SLOTS -from bioimageio.spec._internal.types import Loc from bioimageio.spec.description import ResourceDescription -from bioimageio.spec.summary import ErrorEntry, WarningEntry -from pydantic import AnyUrl, DirectoryPath -from pydantic.fields import FieldInfo -from typing_extensions import NotRequired, Unpack +from bioimageio.spec.summary import ErrorEntry, Loc, WarningEntry class VisitorKwargs(TypedDict): @@ -67,10 +66,15 @@ def __init__(self, root: Union[DirectoryPath, AnyUrl]) -> None: def _visit_path(self, path: PurePath, note: Note): if not Path(path).exists(): - if note.info and note.info.description and note.info.description.startswith(IN_PACKAGE_MESSAGE): + msg = f"{path} not found" + if ( + note.info + and isinstance(note.info.description, str) + and note.info.description.startswith(IN_PACKAGE_MESSAGE) + ): self.errors.append(ErrorEntry(loc=note.loc, msg=msg, type="file-not-found")) else: - self.warnings.append(WarningEntry(loc=note.loc, msg=msg, type=ALERT_TYPE)) + self.warnings.append(WarningEntry(loc=note.loc, msg=msg, type="file-not-found")) # # info.description.startswith(IN_PACKAGE_MESSAGE) diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py b/bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py index a9ee132b..a9bf74dd 100644 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py @@ -1,11 +1,12 @@ import warnings from typing import List, Optional, Sequence -from marshmallow import missing + +from packaging.version import Version # by default, we use the keras integrated with tensorflow try: - from tensorflow import keras import tensorflow as tf + from tensorflow import keras TF_VERSION = tf.__version__ except Exception: @@ -19,11 +20,13 @@ class KerasModelAdapter(ModelAdapter): def _load(self, *, devices: Optional[Sequence[str]] = None) -> None: - model_tf_version = self.bioimageio_model.weights["keras_hdf5"].tensorflow_version - if model_tf_version is missing: + assert self.bioimageio_model.weights.keras_hdf5 is not None + tf_version = self.bioimageio_model.weights.keras_hdf5.tensorflow_version + if tf_version is None: model_tf_version = None else: - model_tf_version = (int(model_tf_version.major), int(model_tf_version.minor)) + v = Version(tf_version) + model_tf_version = (int(v.major), int(v.minor)) if TF_VERSION is None or model_tf_version is None: warnings.warn("Could not check tensorflow versions. The prediction results may be wrong.") diff --git a/tests/test_internal/test_validation_visitors.py b/tests/test_internal/test_validation_visitors.py index cc702755..9aff615a 100644 --- a/tests/test_internal/test_validation_visitors.py +++ b/tests/test_internal/test_validation_visitors.py @@ -1,9 +1,8 @@ from functools import singledispatchmethod -from bioimageio.spec._internal.base_nodes import Node -from bioimageio.spec.summary import ErrorOutcome - from bioimageio.core._internal.validation_visitors import Note, ValidationVisitor +from bioimageio.spec._internal.base_nodes import Node +from bioimageio.spec.summary import ErrorEntry def test_traversing_nodes(): @@ -15,7 +14,7 @@ def visit(self, obj: type, note: Note = Note()): @visit.register def _visit_int(self, nr: int, note: Note = Note()): super().visit(nr, note) - self.errors.append(ErrorOutcome(loc=note.loc, msg=f"nr: {nr}", type="got-int")) + self.errors.append(ErrorEntry(loc=note.loc, msg=f"nr: {nr}", type="got-int")) class NestedNode(Node, frozen=True): leaf: int @@ -32,9 +31,9 @@ class MyNode(Node, frozen=True): visitor = MyVisitor() visitor.visit(tree) assert len(visitor.errors) == [ - ErrorOutcome(loc=("a", "nested", "leaf"), msg="nr: 1", type="got-int"), - ErrorOutcome(loc=("b", 0, "leaf"), msg="nr: 2", type="got-int"), - ErrorOutcome(loc=("b", 1, "leaf"), msg="nr: 3", type="got-int"), - ErrorOutcome(loc=("c", 0, "leaf"), msg="nr: 4", type="got-int"), - ErrorOutcome(loc=("d", "deep", "nested", "leaf"), msg="nr: 5", type="got-int"), + ErrorEntry(loc=("a", "nested", "leaf"), msg="nr: 1", type="got-int"), + ErrorEntry(loc=("b", 0, "leaf"), msg="nr: 2", type="got-int"), + ErrorEntry(loc=("b", 1, "leaf"), msg="nr: 3", type="got-int"), + ErrorEntry(loc=("c", 0, "leaf"), msg="nr: 4", type="got-int"), + ErrorEntry(loc=("d", "deep", "nested", "leaf"), msg="nr: 5", type="got-int"), ] From 77e6de0dce34cb3984e3cc0e592db654215a9e52 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 11 Oct 2023 11:32:06 +0200 Subject: [PATCH 039/244] update examples --- example/dataset_creation.ipynb | 36 ++++++++------- example/demo.ipynb | 81 ++++++++++++++++++++++++++++++++-- 2 files changed, 97 insertions(+), 20 deletions(-) diff --git a/example/dataset_creation.ipynb b/example/dataset_creation.ipynb index dfefb681..e7f5d27d 100644 --- a/example/dataset_creation.ipynb +++ b/example/dataset_creation.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -13,13 +13,11 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from datetime import datetime\n", - "\n", - "from bioimageio.spec.dataset.v0_3 import Author, CiteEntry, Dataset\n", + "from bioimageio.spec.dataset.v0_3 import Author, CiteEntry, Dataset, HttpUrl, RelativeFilePath\n", "\n", "nuclei_broad_data = Dataset(\n", " name=\"Kaggle 2018 Data Science Bowl\",\n", @@ -27,18 +25,16 @@ " \"2018 Data Science Bowl sponsored by Booz Allen Hamilton with cash prizes. The image set was a testing ground \"\n", " \"for the application of novel and cutting edge approaches in computer vision and machine learning to the \"\n", " \"segmentation of the nuclei belonging to cells from a breadth of biological contexts.\",\n", - " documentation=\"README.md\",\n", + " documentation=RelativeFilePath(\"README.md\"),\n", " covers=(\n", - " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage1.png\",\n", - " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage2.png\",\n", - " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage3.png\",\n", - " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage4.png\",\n", - " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage5.png\",\n", + " HttpUrl(\"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage1.png\"),\n", + " HttpUrl(\"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage2.png\"),\n", + " HttpUrl(\"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage3.png\"),\n", + " HttpUrl(\"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage4.png\"),\n", + " HttpUrl(\"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage5.png\"),\n", " ),\n", - " authors=(\n", - " Author(name=\"Fynn Beuttenmueller\", affiliation=\"EMBL\", github_user=\"fynnbe\", orcid=\"0000-0002-8567-6389\"),\n", - " ),\n", - " source=\"https://bbbc.lbroadinstitute.org/BBBC038/\",\n", + " authors=(Author(name=\"Fynn Beuttenmueller\", affiliation=\"EMBL\", github_user=\"fynnbe\", orcid=\"0000-0002-8567-6389\"),),\n", + " source=HttpUrl(\"https://bbbc.lbroadinstitute.org/BBBC038/\"),\n", " cite=(\n", " CiteEntry(\n", " text=\"Caicedo, J.C., Goodman, A., Karhohs, K.W. et al. Nucleus segmentation across imaging experiments: \"\n", @@ -51,11 +47,19 @@ " url=\"https://kaggle.com/competitions/data-science-bowl-2018\",\n", " ),\n", " ),\n", - " timestamp=datetime.today(),\n", " license=\"CC0-1.0\",\n", ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nuclei_broad_data.source" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/example/demo.ipynb b/example/demo.ipynb index c12aff25..2bbc64ba 100644 --- a/example/demo.ipynb +++ b/example/demo.ipynb @@ -1,5 +1,63 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from bioimageio.spec.pretty_validation_errors import enable_pretty_validation_errors_in_ipynb\n", + "\n", + "enable_pretty_validation_errors_in_ipynb()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from bioimageio.spec.model.v0_4 import Model\n", + "from bioimageio.core import read_description\n", + "\n", + "from pydantic import HttpUrl\n", + "\n", + "model = read_description(HttpUrl(\"https://bioimage-io.github.io/collection-bioimage-io/rdfs/10.5281/zenodo.6334383/7805067/rdf.yaml\"))\n", + "assert isinstance(model, Model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(model.validation_summaries[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -13,9 +71,16 @@ "metadata": {}, "outputs": [], "source": [ - "from bioimageio.spec.pretty_validation_errors import enable_pretty_validation_errors_in_ipynb\n", + "import xarray as xr\n", + "import numpy as np\n", "\n", - "enable_pretty_validation_errors_in_ipynb()" + "gain = [1, 2, 3]\n", + "tensor = xr.DataArray(np.random.randn(2, 3, 2, 2), dims=(\"b\", \"c\", \"y\", \"x\"))\n", + "axes = (\"b\", \"x\", \"y\")\n", + "scale_axes = tuple(a for a in tensor.dims if a not in axes)\n", + "b = xr.DataArray([1, 2, 3], dims=scale_axes)\n", + "\n", + "a * b" ] }, { @@ -24,7 +89,7 @@ "metadata": {}, "outputs": [], "source": [ - "from bioimageio.core import load_description_and_validate" + "a.mean(dim=(\"x\", \"y\"))\n" ] }, { @@ -33,7 +98,15 @@ "metadata": {}, "outputs": [], "source": [ - "descr, summary = load_description_and_validate(\"10.5281/zenodo.6559929/6559930/rdf.yaml\")" + "gain = (2, 1, 3)\n", + "offset = (3, 0, 1)\n", + "\n", + "print(a * gain + offset)\n", + "axes = (\"x\", \"y\")\n", + "tmp = a.stack(temp=axes) * gain + offset\n", + "print(tmp)\n", + "print()\n", + "tmp.unstack(\"temp\")" ] }, { From da36bf433a50c873f26017ee76abee824ab7a307 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 11 Oct 2023 23:14:28 +0200 Subject: [PATCH 040/244] refactor and improve utils --- bioimageio/core/_internal/utils.py | 60 ----- .../core/_internal/validation_visitors.py | 219 ------------------ bioimageio/core/{_io.py => io.py} | 47 +++- bioimageio/core/utils/__init__.py | 106 +++++++++ .../pytest_utils.py => utils/testing.py} | 0 bioimageio/core/utils/validation_visitors.py | 141 +++++++++++ 6 files changed, 288 insertions(+), 285 deletions(-) delete mode 100644 bioimageio/core/_internal/utils.py delete mode 100644 bioimageio/core/_internal/validation_visitors.py rename bioimageio/core/{_io.py => io.py} (86%) create mode 100644 bioimageio/core/utils/__init__.py rename bioimageio/core/{_internal/pytest_utils.py => utils/testing.py} (100%) create mode 100644 bioimageio/core/utils/validation_visitors.py diff --git a/bioimageio/core/_internal/utils.py b/bioimageio/core/_internal/utils.py deleted file mode 100644 index f8424d6c..00000000 --- a/bioimageio/core/_internal/utils.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations - -import io -import os -import sys -from pathlib import Path -from typing import Any, Dict, Mapping, Union -from urllib.parse import urlsplit, urlunsplit -from zipfile import ZipFile - -from bioimageio.spec._internal.types import FileName -from pydantic import AnyUrl, FilePath, HttpUrl -from ruamel.yaml import YAML - -yaml = YAML(typ="safe") -if sys.version_info < (3, 9): - - def files(package_name: str): - assert package_name == "bioimageio.core" - return Path(__file__).parent.parent - -else: - from importlib.resources import files as files - - -def get_parent_url(url: HttpUrl) -> HttpUrl: - parsed = urlsplit(str(url)) - return AnyUrl( - urlunsplit((parsed.scheme, parsed.netloc, "/".join(parsed.path.split("/")[:-1]), parsed.query, parsed.fragment)) - ) - - -def write_zip( - path: os.PathLike[str], - content: Mapping[FileName, Union[str, FilePath, Dict[Any, Any]]], - *, - compression: int, - compression_level: int, -) -> None: - """Write a zip archive. - - Args: - path: output path to write to. - content: dict mapping archive names to local file paths, strings (for text files), or dict (for yaml files). - compression: The numeric constant of compression method. - compression_level: Compression level to use when writing files to the archive. - See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile - - """ - with ZipFile(path, "w", compression=compression, compresslevel=compression_level) as myzip: - for arc_name, file in content.items(): - if isinstance(file, dict): - buf = io.StringIO() - YAML.dump(file, buf) - file = buf.getvalue() - - if isinstance(file, str): - myzip.writestr(arc_name, file.encode("utf-8")) - else: - myzip.write(file, arcname=arc_name) diff --git a/bioimageio/core/_internal/validation_visitors.py b/bioimageio/core/_internal/validation_visitors.py deleted file mode 100644 index bdb2d25e..00000000 --- a/bioimageio/core/_internal/validation_visitors.py +++ /dev/null @@ -1,219 +0,0 @@ -import hashlib -import importlib.util -import os -import sys -from dataclasses import dataclass, replace -from functools import singledispatchmethod -from pathlib import Path, PosixPath, PurePath -from types import ModuleType -from typing import Any, Hashable, List, Optional, Tuple, TypedDict, Union - -from pydantic import AnyUrl, DirectoryPath -from pydantic.fields import FieldInfo -from typing_extensions import NotRequired - -from bioimageio.spec._internal.base_nodes import Node -from bioimageio.spec._internal.constants import ALERT_TYPE, IN_PACKAGE_MESSAGE, KW_ONLY, SLOTS -from bioimageio.spec.description import ResourceDescription -from bioimageio.spec.summary import ErrorEntry, Loc, WarningEntry - - -class VisitorKwargs(TypedDict): - info: NotRequired[FieldInfo] - - -@dataclass(frozen=True, **SLOTS, **KW_ONLY) -class Note: - loc: Loc = () - info: Optional[FieldInfo] = None - - -class ValidationVisitor: - def __init__(self) -> None: - super().__init__() - self.errors: List[ErrorEntry] = [] - self.warnings: List[WarningEntry] = [] - - @singledispatchmethod - def visit(self, obj: type, /, note: Note = Note()): - pass - - @visit.register - def _visit_node(self, node: Node, note: Note = Note()): - for k, v in node: - self.visit(v, replace(note, loc=note.loc + (k,), info=node.model_fields[k])) - - @visit.register - def _visit_list(self, lst: list, note: Note = Note()): # type: ignore - for i, e in enumerate(lst): # type: ignore - self.visit(e, replace(note, loc=note.loc + (i,))) - - @visit.register - def _visit_tuple(self, tup: tuple, note: Note = Note()): # type: ignore - for i, e in enumerate(tup): # type: ignore - self.visit(e, replace(note, loc=note.loc + (i,))) - - @visit.register - def _visit_dict(self, dict_: dict, note: Note = Note()): # type: ignore - for k, v in dict_.items(): # type: ignore - self.visit(v, replace(note, loc=note.loc + (k,))) - - -class SourceValidator(ValidationVisitor): - def __init__(self, root: Union[DirectoryPath, AnyUrl]) -> None: - super().__init__() - self.root = root - - def _visit_path(self, path: PurePath, note: Note): - if not Path(path).exists(): - msg = f"{path} not found" - if ( - note.info - and isinstance(note.info.description, str) - and note.info.description.startswith(IN_PACKAGE_MESSAGE) - ): - self.errors.append(ErrorEntry(loc=note.loc, msg=msg, type="file-not-found")) - else: - self.warnings.append(WarningEntry(loc=note.loc, msg=msg, type="file-not-found")) - - -# # info.description.startswith(IN_PACKAGE_MESSAGE) -# if not source_available(leaf, self.root_path): -# raise FileNotFoundError(leaf) - -# def visit_URI(self, node: raw_nodes.URI): -# self._visit_source(node) - -# def visit_PosixPath(self, leaf: PosixPath): -# self._visit_source(leaf) - -# def visit_WindowsPath(self, leaf: pathlib.WindowsPath): -# self._visit_source(leaf) - -# def generic_visit(self, node): -# """Called if no explicit visitor function exists for a node.""" - -# if isinstance(node, raw_nodes.RawNode): -# for field, value in iter_fields(node): -# if field != "root_path": # do not visit root_path, as it might be an incomplete (non-available) URL -# self.visit(value) -# else: -# super().generic_visit(node) - - -# def get_sha256(path: os.PathLike) -> str: -# """from https://stackoverflow.com/a/44873382""" -# h = hashlib.sha256() -# b = bytearray(128 * 1024) -# mv = memoryview(b) -# with open(path, "rb", buffering=0) as f: -# for n in iter(lambda: f.readinto(mv), 0): -# h.update(mv[:n]) - -# return h.hexdigest() - - -# class Sha256NodeChecker(NodeVisitor): -# """Check integrity of the source-like field for every sha256-like field encountered""" - -# def __init__(self, *, root_path: os.PathLike): -# self.root_path = root_path if isinstance(root_path, raw_nodes.URI) else pathlib.Path(root_path).resolve() - -# def generic_visit(self, node): -# if isinstance(node, raw_nodes.RawNode): -# for sha_field, expected in ((k, v) for (k, v) in iter_fields(node) if "sha256" in k and v is not missing): -# if sha_field == "sha256": -# source_name = "source" -# if not hasattr(node, "source") and hasattr(node, "uri"): -# source_name = "uri" - -# elif sha_field.endswith("_sha256"): -# source_name = sha_field[: -len("_sha256")] -# else: -# raise NotImplementedError(f"Don't know how to check integrity with {sha_field}") - -# if not hasattr(node, source_name): -# raise ValueError( -# f"Node {node} expected to have '{source_name}' field associated with '{sha_field}'" -# ) - -# source_node = getattr(node, source_name) -# if isinstance(source_node, ImportedSource): -# continue # test is run after loading. Warning issued in resource_tests._test_resource_integrity - -# source = get_resolved_source_path(source_node, root_path=self.root_path) -# actual = get_sha256(source) - -# if not isinstance(expected, str): -# raise TypeError(f"Expected '{sha_field}' to hold string, not {type(expected)}") - -# if actual != expected: -# if actual[:6] != expected[:6]: -# actual = actual[:6] + "..." -# expected = expected[:6] + "..." - -# raise ValueError( -# f"Determined {actual} for {source_name}={source}, but expected {sha_field}={expected}" -# ) - -# super().generic_visit(node) - - -# class SourceNodeTransformer(NodeTransformer): -# """ -# Imports all source callables -# note: Requires previous transformation by UriNodeTransformer -# """ - -# class TemporaryInsertionIntoPythonPath: -# def __init__(self, path: str): -# self.path = path - -# def __enter__(self): -# sys.path.insert(0, self.path) - -# def __exit__(self, exc_type, exc_value, traceback): -# sys.path.remove(self.path) - -# def transform_LocalImportableModule(self, node: raw_nodes.LocalImportableModule) -> nodes.ImportedSource: -# with self.TemporaryInsertionIntoPythonPath(str(node.root_path)): -# module = importlib.import_module(node.module_name) - -# return nodes.ImportedSource(factory=getattr(module, node.callable_name)) - -# @staticmethod -# def transform_ResolvedImportableSourceFile(node: raw_nodes.ResolvedImportableSourceFile) -> nodes.ImportedSource: -# module_path = resolve_source(node.source_file) -# module_name = f"module_from_source.{module_path.stem}" -# importlib_spec = importlib.util.spec_from_file_location(module_name, module_path) -# assert importlib_spec is not None -# dep = importlib.util.module_from_spec(importlib_spec) -# importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? -# return nodes.ImportedSource(factory=getattr(dep, node.callable_name)) - - -# class RawNodeTypeTransformer(NodeTransformer): -# def __init__(self, nodes_module: ModuleType): -# super().__init__() -# self.nodes = nodes_module - -# def generic_transformer(self, node: GenericRawNode) -> GenericResolvedNode: -# if isinstance(node, raw_nodes.RawNode): -# resolved_data = { -# field.name: self.transform(getattr(node, field.name)) for field in dataclasses.fields(node) -# } -# resolved_node_type: typing.Type[GenericResolvedNode] = getattr(self.nodes, node.__class__.__name__) -# return resolved_node_type(**resolved_data) # type: ignore -# else: -# return super().generic_transformer(node) - - -# def all_sources_available( -# node: typing.Union[GenericNode, list, tuple, dict], root_path: os.PathLike = pathlib.Path() -# ) -> bool: -# try: -# SourceNodeChecker(root_path=root_path).visit(node) -# except FileNotFoundError: -# return False -# else: -# return True diff --git a/bioimageio/core/_io.py b/bioimageio/core/io.py similarity index 86% rename from bioimageio/core/_io.py rename to bioimageio/core/io.py index 84b2d87b..934e0920 100644 --- a/bioimageio/core/_io.py +++ b/bioimageio/core/io.py @@ -1,17 +1,19 @@ from __future__ import annotations import collections.abc +import io import os from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Any, Dict, List, Literal, NamedTuple, Optional, Sequence, TextIO, Union, cast +from typing import Annotated, Any, Dict, List, Literal, Mapping, NamedTuple, Optional, Sequence, TextIO, Union, cast from zipfile import ZIP_DEFLATED, ZipFile, is_zipfile import pooch +from annotated_types import Len, Predicate from pydantic import AnyUrl, DirectoryPath, FilePath, HttpUrl, TypeAdapter from ruamel.yaml import YAML -from bioimageio.core._internal.utils import get_parent_url, write_zip +from bioimageio.core.utils import get_parent_url from bioimageio.spec import ResourceDescription from bioimageio.spec import load_description as load_description from bioimageio.spec._internal.base_nodes import ResourceDescriptionBase @@ -31,6 +33,9 @@ LEGACY_RDF_NAME = "rdf.yaml" +KnownHash = Annotated[str, Len(64 + len("sha256:")), Predicate(lambda x: str.startswith(x, "sha256:"))] + + def read_description( rdf_source: FileSource, /, @@ -139,6 +144,36 @@ def prepare_resource_package( return local_package_content +def write_zip( + path: os.PathLike[str], + content: Mapping[FileName, Union[str, FilePath, Dict[Any, Any]]], + *, + compression: int, + compression_level: int, +) -> None: + """Write a zip archive. + + Args: + path: output path to write to. + content: dict mapping archive names to local file paths, strings (for text files), or dict (for yaml files). + compression: The numeric constant of compression method. + compression_level: Compression level to use when writing files to the archive. + See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile + + """ + with ZipFile(path, "w", compression=compression, compresslevel=compression_level) as myzip: + for arc_name, file in content.items(): + if isinstance(file, dict): + buf = io.StringIO() + YAML.dump(file, buf) + file = buf.getvalue() + + if isinstance(file, str): + myzip.writestr(arc_name, file.encode("utf-8")) + else: + myzip.write(file, arcname=arc_name) + + def write_package( rdf_source: RdfSource, /, @@ -203,7 +238,7 @@ def download( source: FileSource, /, *, - known_hash: Optional[str] = None, + known_hash: Optional[KnownHash] = None, ) -> _LocalFile: source = _interprete_file_source(source) if isinstance(source, AnyUrl): @@ -218,7 +253,7 @@ def download( progressbar = True if (user_agent := os.environ.get("BIOIMAGEIO_USER_AGENT")) is not None: - headers["User-Agent"] = user_agent + headers["User-Agent"] = user_agent downloader = pooch.HTTPDownloader(headers=headers, progressbar=progressbar) _ls: Any = pooch.retrieve(url=str(source), known_hash=known_hash, downloader=downloader) @@ -235,7 +270,7 @@ def download( ) -def download_rdf(source: FileSource, /, *, known_hash: Optional[str] = None, rdf_encoding: str = "utf-8"): +def download_rdf(source: FileSource, /, *, known_hash: Optional[KnownHash] = None, rdf_encoding: str = "utf-8"): local_source, root, file_name = download(source, known_hash=known_hash) if is_zipfile(local_source): out_path = local_source.with_suffix(local_source.suffix + ".unzip") @@ -268,7 +303,7 @@ def resolve_source( source: Union[FileSource, RelativeFilePath], /, *, - known_hash: Optional[str] = None, + known_hash: Optional[KnownHash] = None, root: Union[DirectoryPath, AnyUrl, None] = None, ) -> FilePath: if isinstance(source, RelativeFilePath): diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py new file mode 100644 index 00000000..1452042d --- /dev/null +++ b/bioimageio/core/utils/__init__.py @@ -0,0 +1,106 @@ +# todo: cleanup __init__: move stuff to util submodules or elsewhere +from __future__ import annotations + +import hashlib +import importlib.util +import os +import sys +from contextlib import AbstractContextManager +from functools import singledispatch +from pathlib import Path +from types import TracebackType +from typing import Any, Callable, Optional +from urllib.parse import urlsplit, urlunsplit + +from pydantic import AnyUrl, HttpUrl + +from bioimageio.core.io import FileSource, download +from bioimageio.spec.model.v0_4 import CallableFromDepencency +from bioimageio.spec.model.v0_4 import CallableFromFile as CallableFromFile04 +from bioimageio.spec.model.v0_5 import CallableFromFile as CallableFromFile05 +from bioimageio.spec.model.v0_5 import Sha256 + +if sys.version_info < (3, 9): + + def files(package_name: str): + assert package_name == "bioimageio.core" + return Path(__file__).parent.parent + +else: + from importlib.resources import files as files + + +def get_parent_url(url: HttpUrl) -> HttpUrl: + parsed = urlsplit(str(url)) + return AnyUrl( + urlunsplit((parsed.scheme, parsed.netloc, "/".join(parsed.path.split("/")[:-1]), parsed.query, parsed.fragment)) + ) + + +def get_sha256(path: os.PathLike[str]) -> str: + """from https://stackoverflow.com/a/44873382""" + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(path, "rb", buffering=0) as f: + for n in iter(lambda: f.readinto(mv), 0): + h.update(mv[:n]) + + return h.hexdigest() + + +class TemporaryInsertionIntoPythonPath(AbstractContextManager[None]): + def __init__(self, path: Path): + super().__init__() + self.path = str(path) + + def __enter__(self): + super().__enter__() + sys.path.insert(0, self.path) + + def __exit__( + self, + __exc_type: "type[BaseException] | None", + __exc_value: "BaseException | None", + __traceback: "TracebackType | None", + ) -> "bool | None": + assert sys.path[0] == self.path + _ = sys.path.pop(0) + return super().__exit__(__exc_type, __exc_value, __traceback) + + +@singledispatch +def import_callable(node: type, /) -> Callable[..., Any]: + raise TypeError(type(node)) + + +@import_callable.register +def import_from_dependency(node: CallableFromDepencency) -> Callable[..., Any]: + module = importlib.import_module(node.module_name) + c = getattr(module, node.callable_name) + if not callable(c): + raise ValueError(f"{node} (imported: {c}) is not callable") + + return c + + +@import_callable.register +def import_from_file04(node: CallableFromFile04, sha256: Optional[Sha256] = None): + return _import_from_file_impl(node.file, node.callable_name, sha256) + + +@import_callable.register +def import_from_file05(node: CallableFromFile05, sha256: Optional[Sha256] = None): + return _import_from_file_impl(node.source_file, node.callable_name, sha256) + + +def _import_from_file_impl(source: FileSource, callable_name: str, sha256: Optional[Sha256]): + local_file = download(source, known_hash=None if sha256 is None else f"sha256:{sha256}") + module_name = local_file.path.stem + importlib_spec = importlib.util.spec_from_file_location(module_name, local_file.path) + if importlib_spec is None: + raise ImportError(f"Failed to import {module_name} from {source}.") + + dep = importlib.util.module_from_spec(importlib_spec) + importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? + return getattr(dep, callable_name) diff --git a/bioimageio/core/_internal/pytest_utils.py b/bioimageio/core/utils/testing.py similarity index 100% rename from bioimageio/core/_internal/pytest_utils.py rename to bioimageio/core/utils/testing.py diff --git a/bioimageio/core/utils/validation_visitors.py b/bioimageio/core/utils/validation_visitors.py new file mode 100644 index 00000000..48cb03c5 --- /dev/null +++ b/bioimageio/core/utils/validation_visitors.py @@ -0,0 +1,141 @@ +import hashlib +import importlib.util +import os +import sys +from dataclasses import dataclass, replace +from functools import singledispatchmethod +from pathlib import Path, PosixPath, PurePath +from typing import Any, Hashable, List, Optional, Tuple, Type, TypedDict, Union + +import requests +from pydantic import AnyUrl, DirectoryPath +from pydantic.fields import FieldInfo +from typing_extensions import NotRequired + +from bioimageio.core.utils import get_sha256 +from bioimageio.spec._internal.base_nodes import Node +from bioimageio.spec._internal.constants import IN_PACKAGE_MESSAGE, KW_ONLY, SLOTS +from bioimageio.spec._internal.types import Sha256 +from bioimageio.spec.summary import ErrorEntry, Loc, WarningEntry + + +class VisitorKwargs(TypedDict): + info: NotRequired[FieldInfo] + + +@dataclass(frozen=True, **SLOTS, **KW_ONLY) +class Memo: + loc: Loc = () + info: Optional[FieldInfo] = None + parent_nodes: Tuple[Node, ...] = () + + +class ValidationVisitor: + def __init__(self) -> None: + super().__init__() + self.errors: List[ErrorEntry] = [] + self.warnings: List[WarningEntry] = [] + + def visit(self, obj: Any, /, memo: Memo = Memo()): + self._traverse(obj, memo=memo) + + @singledispatchmethod + def _traverse(self, obj: type, /, memo: Memo): + pass + + @_traverse.register + def _traverse_node(self, node: Node, memo: Memo): + for k, v in node: + self.visit( + v, + replace(memo, loc=memo.loc + (k,), info=node.model_fields[k], parent_nodes=memo.parent_nodes + (node,)), + ) + + @_traverse.register + def _traverse_list(self, lst: list, memo: Memo): # type: ignore + e: Any + for i, e in enumerate(lst): # type: ignore + self.visit(e, replace(memo, loc=memo.loc + (i,))) + + @_traverse.register + def _traverse_tuple(self, tup: tuple, memo: Memo): # type: ignore + e: Any + for i, e in enumerate(tup): # type: ignore + self.visit(e, replace(memo, loc=memo.loc + (i,))) + + @_traverse.register + def _traverse_dict(self, dict_: dict, memo: Memo): # type: ignore + v: Any + for k, v in dict_.items(): # type: ignore + self.visit(v, replace(memo, loc=memo.loc + (k,))) + + +class _NoSha: + pass + + +class SourceValidator(ValidationVisitor): + def __init__(self, root: Union[DirectoryPath, AnyUrl]) -> None: + super().__init__() + self.root = root + + def visit(self, obj: Any, /, memo: Memo = Memo()): + self._visit_impl(obj, memo=memo) + return super().visit(obj, memo) + + @singledispatchmethod + def _visit_impl(self, obj: type, /, memo: Memo): + pass + + @_visit_impl.register + def _visit_path(self, path: PurePath, memo: Memo): + if Path(path).exists(): + sha256: Union[None, Sha256, Type[_NoSha]] = _NoSha + + for parent in memo.parent_nodes: + if "sha256" in parent.model_fields: + sha256: Optional[Sha256] = parent.sha256 # type: ignore + break + + if sha256 is _NoSha: + return + + actual_sha256 = get_sha256(path) + if sha256 is None: + self.warnings.append( + WarningEntry( + loc=memo.loc, + msg=( + f"Cannot validate file integrity (`sha256` not specified). " + f"File {path} has SHA-256: {actual_sha256}" + ), + type="unknown_hash", + ) + ) + elif actual_sha256 != sha256: + self.errors.append( + ErrorEntry( + loc=memo.loc, + msg=f"SHA-256 mismatch: actual ({actual_sha256}) != specified ({sha256})", + type="hash_mismatch", + ) + ) + else: + msg = f"{path} not found" + if ( + memo.info + and isinstance(memo.info.description, str) + and memo.info.description.startswith(IN_PACKAGE_MESSAGE) + ): + self.errors.append(ErrorEntry(loc=memo.loc, msg=msg, type="file_not_found")) + else: + self.warnings.append(WarningEntry(loc=memo.loc, msg=msg, type="file_not_found")) + + @_visit_impl.register + def _visit_url(self, url: AnyUrl, memo: Memo): + if url.scheme not in ("http", "https"): + self.errors.append(ErrorEntry(loc=memo.loc, msg=f"invalid http(s) URL: {url}", type="url_scheme")) + else: + response = requests.head(str(url)) + if response.status_code != 200: + self.errors.append(ErrorEntry(loc=memo.loc, msg=response.reason, type="url_unavailable")) From af369ee9ba92603aea8b733e210ec93093c90421 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 11 Oct 2023 23:16:15 +0200 Subject: [PATCH 041/244] expose LocalFile and LocalRdf --- bioimageio/core/io.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index 934e0920..dd8294c5 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -222,13 +222,13 @@ def write_package( return output_path -class _LocalFile(NamedTuple): +class LocalFile(NamedTuple): path: FilePath original_root: Union[AnyUrl, DirectoryPath] original_file_name: str -class _LocalRdf(NamedTuple): +class LocalRdf(NamedTuple): content: RdfContent root: Union[AnyUrl, DirectoryPath] file_name: str @@ -239,7 +239,7 @@ def download( /, *, known_hash: Optional[KnownHash] = None, -) -> _LocalFile: +) -> LocalFile: source = _interprete_file_source(source) if isinstance(source, AnyUrl): if source.scheme not in ("http", "https"): @@ -263,7 +263,7 @@ def download( local_source = source root = source.parent - return _LocalFile( + return LocalFile( local_source, root, extract_file_name(source), @@ -296,7 +296,7 @@ def download_rdf(source: FileSource, /, *, known_hash: Optional[KnownHash] = Non if not isinstance(content, collections.abc.Mapping): raise TypeError(f"Expected RDF content to be a mapping, but got '{type(content)}'.") - return _LocalRdf(cast(RdfContent, content), root, file_name) + return LocalRdf(cast(RdfContent, content), root, file_name) def resolve_source( From c590e8cb6eab2feca3e17a4d580a31e5349f3c4b Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 11 Oct 2023 23:17:02 +0200 Subject: [PATCH 042/244] simplify "no sha" case --- .../core/{utils => }/validation_visitors.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) rename bioimageio/core/{utils => }/validation_visitors.py (92%) diff --git a/bioimageio/core/utils/validation_visitors.py b/bioimageio/core/validation_visitors.py similarity index 92% rename from bioimageio/core/utils/validation_visitors.py rename to bioimageio/core/validation_visitors.py index 48cb03c5..15420017 100644 --- a/bioimageio/core/utils/validation_visitors.py +++ b/bioimageio/core/validation_visitors.py @@ -1,11 +1,7 @@ -import hashlib -import importlib.util -import os -import sys from dataclasses import dataclass, replace from functools import singledispatchmethod -from pathlib import Path, PosixPath, PurePath -from typing import Any, Hashable, List, Optional, Tuple, Type, TypedDict, Union +from pathlib import Path, PurePath +from typing import Any, List, Optional, Tuple, TypedDict, Union import requests from pydantic import AnyUrl, DirectoryPath @@ -70,10 +66,6 @@ def _traverse_dict(self, dict_: dict, memo: Memo): # type: ignore self.visit(v, replace(memo, loc=memo.loc + (k,))) -class _NoSha: - pass - - class SourceValidator(ValidationVisitor): def __init__(self, root: Union[DirectoryPath, AnyUrl]) -> None: super().__init__() @@ -90,14 +82,11 @@ def _visit_impl(self, obj: type, /, memo: Memo): @_visit_impl.register def _visit_path(self, path: PurePath, memo: Memo): if Path(path).exists(): - sha256: Union[None, Sha256, Type[_NoSha]] = _NoSha - for parent in memo.parent_nodes: if "sha256" in parent.model_fields: - sha256: Optional[Sha256] = parent.sha256 # type: ignore + sha256: Union[None, Sha256] = parent.sha256 # type: ignore break - - if sha256 is _NoSha: + else: return actual_sha256 = get_sha256(path) From ca41439181984911b29e366fc8e8bfc691f4a2b9 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 12 Oct 2023 00:31:23 +0200 Subject: [PATCH 043/244] update model adapters --- bioimageio/core/model_adapters/__init__.py | 3 ++ .../_keras_model_adapter.py | 0 .../_model_adapter.py | 0 .../_onnx_model_adapter.py | 0 .../_pytorch_model_adapter.py | 42 ++++++++++++------- .../_tensorflow_model_adapter.py | 0 .../_torchscript_model_adapter.py | 0 .../_model_adapters/__init__.py | 3 -- 8 files changed, 29 insertions(+), 19 deletions(-) create mode 100644 bioimageio/core/model_adapters/__init__.py rename bioimageio/core/{prediction_pipeline/_model_adapters => model_adapters}/_keras_model_adapter.py (100%) rename bioimageio/core/{prediction_pipeline/_model_adapters => model_adapters}/_model_adapter.py (100%) rename bioimageio/core/{prediction_pipeline/_model_adapters => model_adapters}/_onnx_model_adapter.py (100%) rename bioimageio/core/{prediction_pipeline/_model_adapters => model_adapters}/_pytorch_model_adapter.py (56%) rename bioimageio/core/{prediction_pipeline/_model_adapters => model_adapters}/_tensorflow_model_adapter.py (100%) rename bioimageio/core/{prediction_pipeline/_model_adapters => model_adapters}/_torchscript_model_adapter.py (100%) delete mode 100644 bioimageio/core/prediction_pipeline/_model_adapters/__init__.py diff --git a/bioimageio/core/model_adapters/__init__.py b/bioimageio/core/model_adapters/__init__.py new file mode 100644 index 00000000..85387b6a --- /dev/null +++ b/bioimageio/core/model_adapters/__init__.py @@ -0,0 +1,3 @@ +from ._model_adapter import ModelAdapter as ModelAdapter +from ._model_adapter import create_model_adapter as create_model_adapter +from ._model_adapter import get_weight_formats as get_weight_formats diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py similarity index 100% rename from bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py rename to bioimageio/core/model_adapters/_keras_model_adapter.py diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py similarity index 100% rename from bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py rename to bioimageio/core/model_adapters/_model_adapter.py diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py similarity index 100% rename from bioimageio/core/prediction_pipeline/_model_adapters/_onnx_model_adapter.py rename to bioimageio/core/model_adapters/_onnx_model_adapter.py diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py similarity index 56% rename from bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py rename to bioimageio/core/model_adapters/_pytorch_model_adapter.py index f47aa1d7..0cf0184d 100644 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -1,17 +1,19 @@ import gc import warnings -from typing import List, Optional +from typing import Any, List, Optional, Sequence, Tuple, Union import torch import xarray as xr -from marshmallow import missing -from bioimageio.core.resource_io import nodes +from bioimageio.core.utils import import_callable +from bioimageio.spec.model import AnyModel +from bioimageio.spec.model.v0_4 import PytorchStateDictWeights as PytorchStateDictWeights04 + from ._model_adapter import ModelAdapter class PytorchModelAdapter(ModelAdapter): - def _load(self, *, devices: Optional[List[str]] = None): + def _load(self, *, devices: Optional[Sequence[str]] = None): self._model = self.get_nn_instance(self.bioimageio_model) if devices is None: @@ -27,17 +29,18 @@ def _load(self, *, devices: Optional[List[str]] = None): assert isinstance(self._model, torch.nn.Module) weights = self.bioimageio_model.weights.get("pytorch_state_dict") if weights is not None and weights.source: - state = torch.load(weights.source, map_location=self._devices[0]) - self._model.load_state_dict(state) + state: Any = torch.load(weights.source, map_location=self._devices[0]) + _ = self._model.load_state_dict(state) - self._model.eval() + _ = self._model.eval() self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs] def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: + assert self._devices is not None with torch.no_grad(): tensors = [torch.from_numpy(ipt.data) for ipt in input_tensors] tensors = [t.to(self._devices[0]) for t in tensors] - result = self._model(*tensors) + result: Union[Tuple[Any, ...], List[Any], Any] = self._model(*tensors) if not isinstance(result, (tuple, list)): result = [result] @@ -48,15 +51,22 @@ def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: def _unload(self) -> None: self._devices = None del self._model - gc.collect() # deallocate memory + _ = gc.collect() # deallocate memory torch.cuda.empty_cache() # release reserved memory @staticmethod - def get_nn_instance(model_node: nodes.Model, **kwargs): - weight_spec = model_node.weights.get("pytorch_state_dict") + def get_nn_instance(model: AnyModel): + weight_spec = model.weights.pytorch_state_dict assert weight_spec is not None - assert isinstance(weight_spec.architecture, nodes.ImportedSource) - model_kwargs = weight_spec.kwargs - joined_kwargs = {} if model_kwargs is missing else dict(model_kwargs) - joined_kwargs.update(kwargs) - return weight_spec.architecture(**joined_kwargs) + arch = import_callable( + weight_spec.architecture, + sha256=weight_spec.architecture_sha256 + if isinstance(weight_spec, PytorchStateDictWeights04) + else weight_spec.sha256, + ) + model_kwargs = ( + weight_spec.kwargs + if isinstance(weight_spec, PytorchStateDictWeights04) + else weight_spec.architecture.kwargs + ) + return arch(**model_kwargs) diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py similarity index 100% rename from bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py rename to bioimageio/core/model_adapters/_tensorflow_model_adapter.py diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py similarity index 100% rename from bioimageio/core/prediction_pipeline/_model_adapters/_torchscript_model_adapter.py rename to bioimageio/core/model_adapters/_torchscript_model_adapter.py diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/__init__.py b/bioimageio/core/prediction_pipeline/_model_adapters/__init__.py deleted file mode 100644 index 5d2745d6..00000000 --- a/bioimageio/core/prediction_pipeline/_model_adapters/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from ._model_adapter import ModelAdapter, create_model_adapter, get_weight_formats - -__all__ = ["ModelAdapter", "create_model_adapter", "get_weight_formats"] From 8ef56d177a4123587fdc5fdc79fb03f49bd720df Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 12 Oct 2023 00:31:50 +0200 Subject: [PATCH 044/244] WIP update measure groups --- ...measure_groups.py => statistical_measure_groups.py} | 0 bioimageio/core/statistical_measures.py | 10 +++++----- 2 files changed, 5 insertions(+), 5 deletions(-) rename bioimageio/core/{prediction_pipeline/_measure_groups.py => statistical_measure_groups.py} (100%) diff --git a/bioimageio/core/prediction_pipeline/_measure_groups.py b/bioimageio/core/statistical_measure_groups.py similarity index 100% rename from bioimageio/core/prediction_pipeline/_measure_groups.py rename to bioimageio/core/statistical_measure_groups.py diff --git a/bioimageio/core/statistical_measures.py b/bioimageio/core/statistical_measures.py index c554ec33..e19689b8 100644 --- a/bioimageio/core/statistical_measures.py +++ b/bioimageio/core/statistical_measures.py @@ -6,7 +6,7 @@ import xarray as xr -from bioimageio.spec.model.v0_5 import AxisName +from bioimageio.spec.model.v0_5 import NonBatchAxisName MeasureValue = xr.DataArray @@ -21,7 +21,7 @@ def compute(self, tensor: xr.DataArray) -> MeasureValue: @dataclass(frozen=True) class Mean(Measure): - axes: Optional[Tuple[AxisName, ...]] = None + axes: Optional[Tuple[NonBatchAxisName, ...]] = None def compute(self, tensor: xr.DataArray) -> xr.DataArray: return tensor.mean(dim=self.axes) @@ -29,7 +29,7 @@ def compute(self, tensor: xr.DataArray) -> xr.DataArray: @dataclass(frozen=True) class Std(Measure): - axes: Optional[Tuple[AxisName, ...]] = None + axes: Optional[Tuple[NonBatchAxisName, ...]] = None def compute(self, tensor: xr.DataArray) -> xr.DataArray: return tensor.std(dim=self.axes) @@ -37,7 +37,7 @@ def compute(self, tensor: xr.DataArray) -> xr.DataArray: @dataclass(frozen=True) class Var(Measure): - axes: Optional[Tuple[AxisName, ...]] = None + axes: Optional[Tuple[NonBatchAxisName, ...]] = None def compute(self, tensor: xr.DataArray) -> xr.DataArray: return tensor.var(dim=self.axes) @@ -46,7 +46,7 @@ def compute(self, tensor: xr.DataArray) -> xr.DataArray: @dataclass(frozen=True) class Percentile(Measure): n: float - axes: Optional[Tuple[AxisName, ...]] = None + axes: Optional[Tuple[NonBatchAxisName, ...]] = None def __post_init__(self): assert self.n >= 0 From 2c2c5fd1e8be97fb40484cddb6306fc9af874150 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 12 Oct 2023 00:32:13 +0200 Subject: [PATCH 045/244] WIP update prediction pipeline --- .../core/prediction_pipeline/__init__.py | 5 +++-- .../_prediction_pipeline.py | 20 +++++++------------ 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/bioimageio/core/prediction_pipeline/__init__.py b/bioimageio/core/prediction_pipeline/__init__.py index d982ea1b..da136844 100644 --- a/bioimageio/core/prediction_pipeline/__init__.py +++ b/bioimageio/core/prediction_pipeline/__init__.py @@ -1,2 +1,3 @@ -from ._model_adapters import get_weight_formats -from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline +from ._model_adapters import get_weight_formats as get_weight_formats +from ._prediction_pipeline import PredictionPipeline as PredictionPipeline +from ._prediction_pipeline import create_prediction_pipeline as create_prediction_pipeline diff --git a/bioimageio/core/prediction_pipeline/_prediction_pipeline.py b/bioimageio/core/prediction_pipeline/_prediction_pipeline.py index 8d92ee62..59e677db 100644 --- a/bioimageio/core/prediction_pipeline/_prediction_pipeline.py +++ b/bioimageio/core/prediction_pipeline/_prediction_pipeline.py @@ -4,14 +4,13 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union import xarray as xr -from bioimageio.spec.model import raw_nodes from marshmallow import missing -from bioimageio.core._internal.validation_visitors import resolve_raw_node -from bioimageio.core.resource_io import nodes +from bioimageio.core.model_adapters import ModelAdapter, create_model_adapter +from bioimageio.core.validation_visitors import resolve_raw_node +from bioimageio.spec.model import AnyModel, raw_nodes from ._combined_processing import CombinedProcessing -from ._model_adapters import ModelAdapter, create_model_adapter from ._stat_state import StatsState from ._utils import ComputedMeasures, Sample, TensorName @@ -91,7 +90,7 @@ def __init__( self, *, name: str, - bioimageio_model: Union[nodes.Model, raw_nodes.Model], + bioimageio_model: AnyModel, preprocessing: CombinedProcessing, postprocessing: CombinedProcessing, ipt_stats: StatsState, @@ -102,13 +101,8 @@ def __init__( warnings.warn(f"Not yet implemented inference for run mode '{bioimageio_model.run_mode.name}'") self._name = name - if isinstance(bioimageio_model, nodes.Model): - self._input_specs = bioimageio_model.inputs - self._output_specs = bioimageio_model.outputs - else: - assert isinstance(bioimageio_model, raw_nodes.Model) - self._input_specs = [resolve_raw_node(s, nodes) for s in bioimageio_model.inputs] - self._output_specs = [resolve_raw_node(s, nodes) for s in bioimageio_model.outputs] + self._input_specs = bioimageio_model.inputs + self._output_specs = bioimageio_model.outputs self._preprocessing = preprocessing self._postprocessing = postprocessing @@ -185,7 +179,7 @@ def unload(self): def create_prediction_pipeline( - bioimageio_model: Union[nodes.Model, raw_nodes.Model], + bioimageio_model: AnyModel, *, devices: Optional[Sequence[str]] = None, weight_format: Optional[str] = None, From f37e93d1b5d608693652c7a08ce95e3ad8c79d37 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 19 Oct 2023 15:06:27 +0200 Subject: [PATCH 046/244] WIP update model adapters --- .../core/model_adapters/_model_adapter.py | 19 ++--- .../model_adapters/_pytorch_model_adapter.py | 73 ++++++++++++------- 2 files changed, 55 insertions(+), 37 deletions(-) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index adac96a7..5f0b71fd 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -18,7 +18,7 @@ ) -BioimageioModel = Union[v0_4.Model, v0_5.Model] +ModelDescription = Union[v0_4.Model, v0_5.Model] class ModelAdapter(abc.ABC): @@ -26,16 +26,17 @@ class ModelAdapter(abc.ABC): Represents model *without* any preprocessing or postprocessing """ - def __init__(self, *, bioimageio_model: BioimageioModel, devices: Optional[Sequence[str]] = None): + def __init__(self, *, model_description: ModelDescription, devices: Optional[Sequence[str]] = None): super().__init__() - self.bioimageio_model = self._prepare_model(bioimageio_model) + self.model_description = self._prepare_model(model_description) + self.model_description = self.model_description self.default_devices = devices self.loaded = False @staticmethod - def _prepare_model(bioimageio_model: BioimageioModel) -> BioimageioModel: - """The model node is prepared for the model adapter to be ready for operation.""" - return bioimageio_model + def _prepare_model(model_description: ModelDescription) -> ModelDescription: + """The model description may be altered by the model adapter to be ready for operation.""" + return model_description def __enter__(self): """load on entering context""" @@ -111,7 +112,7 @@ def get_weight_formats() -> List[str]: def create_model_adapter( *, - bioimageio_model: Union[v0_4.Model, v0_5.Model], + model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None, weight_format: Optional[WeightsFormat] = None, ) -> ModelAdapter: @@ -124,10 +125,10 @@ def create_model_adapter( raise ValueError(f"Weight format {weight_format} is not in supported formats {_WEIGHT_FORMATS}") priority_order = _WEIGHT_FORMATS if weight_format is None else (weight_format,) - weight = bioimageio_model.weights.get(*priority_order) + weight = model_description.weights.get(*priority_order) adapter_cls = _get_model_adapter(weight.type) - return adapter_cls(bioimageio_model=bioimageio_model, devices=devices) + return adapter_cls(model_description=model_description, devices=devices) def _get_model_adapter(weight_format: WeightsFormat) -> Type[ModelAdapter]: diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 0cf0184d..3db91e37 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -6,67 +6,84 @@ import xarray as xr from bioimageio.core.utils import import_callable -from bioimageio.spec.model import AnyModel -from bioimageio.spec.model.v0_4 import PytorchStateDictWeights as PytorchStateDictWeights04 +from bioimageio.spec.model import v0_4, v0_5 from ._model_adapter import ModelAdapter class PytorchModelAdapter(ModelAdapter): - def _load(self, *, devices: Optional[Sequence[str]] = None): - self._model = self.get_nn_instance(self.bioimageio_model) - - if devices is None: - self._devices = ["cuda" if torch.cuda.is_available() else "cpu"] - else: - self._devices = [torch.device(d) for d in devices] + _devices: Optional[List[torch.device]] = None - if len(self._devices) > 1: - warnings.warn("Multiple devices for single pytorch model not yet implemented") + def _load(self, *, devices: Optional[Sequence[str]] = None): + if self.model_description.weights.pytorch_state_dict is None: + raise ValueError("missing pytorch_state_dict weights") - self._model.to(self._devices[0]) + self._network = self.get_network(self.model_description.weights.pytorch_state_dict) + self._devices = self.get_devices(devices) + self._network = self._network.to(self._devices[0]) - assert isinstance(self._model, torch.nn.Module) - weights = self.bioimageio_model.weights.get("pytorch_state_dict") - if weights is not None and weights.source: - state: Any = torch.load(weights.source, map_location=self._devices[0]) - _ = self._model.load_state_dict(state) + weights = self.model_description.weights.pytorch_state_dict + state: Any = torch.load(weights.source, map_location=self._devices[0]) + _ = self._network.load_state_dict(state) - _ = self._model.eval() - self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs] + self._network = self._network.eval() def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: assert self._devices is not None with torch.no_grad(): tensors = [torch.from_numpy(ipt.data) for ipt in input_tensors] tensors = [t.to(self._devices[0]) for t in tensors] - result: Union[Tuple[Any, ...], List[Any], Any] = self._model(*tensors) + result: Union[Tuple[Any, ...], List[Any], Any] = self._network(*tensors) if not isinstance(result, (tuple, list)): result = [result] result = [r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r for r in result] + if len(result) > len(self.model_description.outputs): + raise ValueError( + f"Expected at most {len(self.model_description.outputs)} outpus, but got {len(result)}" + ) - return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] + return [ + xr.DataArray(r, dims=tuple(a if isinstance(a, str) else a.id for a in out.axes)) + for r, out in zip(result, self.model_description.outputs) + ] def _unload(self) -> None: self._devices = None - del self._model + del self._network _ = gc.collect() # deallocate memory torch.cuda.empty_cache() # release reserved memory @staticmethod - def get_nn_instance(model: AnyModel): - weight_spec = model.weights.pytorch_state_dict - assert weight_spec is not None + def get_network(weight_spec: Union[v0_4.PytorchStateDictWeights, v0_5.PytorchStateDictWeights]): arch = import_callable( weight_spec.architecture, sha256=weight_spec.architecture_sha256 - if isinstance(weight_spec, PytorchStateDictWeights04) + if isinstance(weight_spec, v0_4.PytorchStateDictWeights) else weight_spec.sha256, ) model_kwargs = ( weight_spec.kwargs - if isinstance(weight_spec, PytorchStateDictWeights04) + if isinstance(weight_spec, v0_4.PytorchStateDictWeights) else weight_spec.architecture.kwargs ) - return arch(**model_kwargs) + network = arch(**model_kwargs) + if not isinstance(network, torch.nn.Module): + raise ValueError(f"calling {weight_spec.architecture.callable} did not return a torch.nn.Module") + + return network + + @staticmethod + def get_devices(devices: Optional[Sequence[str]] = None): + if not devices: + torch_devices = [torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")] + else: + torch_devices = [torch.device(d) for d in devices] + + if len(torch_devices) > 1: + warnings.warn( + f"Multiple devices for single pytorch model not yet implemented; ignoring {torch_devices[1:]}" + ) + torch_devices = torch_devices[:1] + + return torch_devices From e21b9ed52a121b05e946728641d8f5990985c622 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 24 Oct 2023 13:17:15 +0200 Subject: [PATCH 047/244] update model adapters --- .../model_adapters/_keras_model_adapter.py | 63 +++-- .../core/model_adapters/_model_adapter.py | 228 ++++++++---------- .../model_adapters/_onnx_model_adapter.py | 29 ++- .../model_adapters/_pytorch_model_adapter.py | 14 +- .../_tensorflow_model_adapter.py | 140 +++++++---- .../_torchscript_model_adapter.py | 46 +++- 6 files changed, 289 insertions(+), 231 deletions(-) diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py index a9bf74dd..9c6f842a 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -1,6 +1,7 @@ import warnings -from typing import List, Optional, Sequence +from typing import Any, List, Optional, Sequence, Union +from numpy.typing import NDArray from packaging.version import Version # by default, we use the keras integrated with tensorflow @@ -8,49 +9,61 @@ import tensorflow as tf from tensorflow import keras - TF_VERSION = tf.__version__ + tf_version = Version(tf.__version__) except Exception: import keras - TF_VERSION = None + tf_version = None import xarray as xr +from bioimageio.core.io import download +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import RelativeFilePath + from ._model_adapter import ModelAdapter class KerasModelAdapter(ModelAdapter): - def _load(self, *, devices: Optional[Sequence[str]] = None) -> None: - assert self.bioimageio_model.weights.keras_hdf5 is not None - tf_version = self.bioimageio_model.weights.keras_hdf5.tensorflow_version - if tf_version is None: - model_tf_version = None - else: - v = Version(tf_version) - model_tf_version = (int(v.major), int(v.minor)) + def __init__( + self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None + ) -> None: + super().__init__() + if model_description.weights.keras_hdf5 is None: + raise ValueError("model has not keras_hdf5 weights specified") + model_tf_version = model_description.weights.keras_hdf5.tensorflow_version - if TF_VERSION is None or model_tf_version is None: - warnings.warn("Could not check tensorflow versions. The prediction results may be wrong.") - elif tuple(model_tf_version[:2]) != tuple(map(int, TF_VERSION.split(".")))[:2]: + if tf_version is None or model_tf_version is None: + warnings.warn("Could not check tensorflow versions.") + elif model_tf_version > tf_version: warnings.warn( - f"Model tensorflow version {model_tf_version} does not match {TF_VERSION}." - "The prediction results may be wrong" + f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." ) + elif (model_tf_version.major, model_tf_version.minor) != (tf_version.major, tf_version.minor): + warnings.warn(f"Model tensorflow version {model_tf_version} does not match {tf_version}.") # TODO keras device management if devices is not None: warnings.warn(f"Device management is not implemented for keras yet, ignoring the devices {devices}") - weight_file = self.bioimageio_model.weights["keras_hdf5"].source - self._model = keras.models.load_model(weight_file) - self._output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs] + src = model_description.weights.keras_hdf5.source + weight_path = download( + src.get_absolute(model_description.root) if isinstance(src, RelativeFilePath) else src + ).path - def _unload(self) -> None: - warnings.warn("Device management is not implemented for keras yet, cannot unload model") + self._network = keras.models.load_model(weight_path) + self._output_axes = [tuple(out.axes) for out in model_description.outputs] - def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - result = self._model.predict(*input_tensors) - if not isinstance(result, (tuple, list)): - result = [result] + def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: + _result: Union[ # pyright: ignore[reportUnknownVariableType] + Sequence[NDArray[Any]], NDArray[Any] + ] = self._network.predict(*input_tensors) + if isinstance(_result, (tuple, list)): + result: Sequence[NDArray[Any]] = _result + else: + result = [_result] # type: ignore assert len(result) == len(self._output_axes) return [xr.DataArray(r, dims=axes) for r, axes, in zip(result, self._output_axes)] + + def unload(self) -> None: + warnings.warn("Device management is not implemented for keras yet, cannot unload model") diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index 5f0b71fd..d1e8ecb6 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -1,15 +1,17 @@ -import abc -from typing import List, Optional, Sequence, Tuple, Type, Union +import warnings +from abc import ABC, abstractmethod +from typing import List, Optional, Sequence, Tuple, Union, final import xarray as xr +from bioimageio.spec._internal.types import NotEmpty from bioimageio.spec.model import v0_4, v0_5 WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat] -#: Known weight formats in order of priority -#: First match wins -_WEIGHT_FORMATS: Tuple[WeightsFormat, ...] = ( +# Known weight formats in order of priority +# First match wins +DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[WeightsFormat, ...] = ( "pytorch_state_dict", "tensorflow_saved_model_bundle", "torchscript", @@ -18,155 +20,115 @@ ) -ModelDescription = Union[v0_4.Model, v0_5.Model] - - -class ModelAdapter(abc.ABC): - """ - Represents model *without* any preprocessing or postprocessing +class ModelAdapter(ABC): """ + Represents model *without* any preprocessing or postprocessing. + + >>> print("option 1:") + option 1: + >>> adapter = create_model_adapter() + >>> adapter.load() + >>> adapter.forward() + >>> adapter.unload() + >>> print("option 2:") + option 2: + >>> with create_model_adapter_context() as adapter: + adapter.forward() - def __init__(self, *, model_description: ModelDescription, devices: Optional[Sequence[str]] = None): - super().__init__() - self.model_description = self._prepare_model(model_description) - self.model_description = self.model_description - self.default_devices = devices - self.loaded = False - - @staticmethod - def _prepare_model(model_description: ModelDescription) -> ModelDescription: - """The model description may be altered by the model adapter to be ready for operation.""" - return model_description - - def __enter__(self): - """load on entering context""" - assert not self.loaded - self.load() # using default_devices - return self - - def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore - """unload on exiting context""" - assert self.loaded - self.unload() - return False - - def load(self, *, devices: Optional[Sequence[str]] = None) -> None: - """ - Note: Use ModelAdapter as context to not worry about calling unload()! - Load model onto devices. If devices is None, self.default_devices are chosen - (which may be None as well, in which case a framework dependent default is chosen) - """ - self._load(devices=devices or self.default_devices) - self.loaded = True + """ - @abc.abstractmethod - def _load(self, *, devices: Optional[Sequence[str]] = None) -> None: + @final + @classmethod + def create( + cls, + *, + model_description: Union[v0_4.Model, v0_5.Model], + devices: Optional[Sequence[str]] = None, + weight_format_priority_order: NotEmpty[Sequence[WeightsFormat]] = DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER, + ): """ - Load model onto devices. If devices is None a framework dependent default is chosen + Creates model adapter based on the passed spec + Note: All specific adapters should happen inside this function to prevent different framework + initializations interfering with each other """ - ... + weights = model_description.weights + errors: List[Exception] = [] + for wf in weight_format_priority_order: + if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None: + try: + from ._pytorch_model_adapter import PytorchModelAdapter + + return PytorchModelAdapter(model_description=model_description, devices=devices) + except Exception as e: + errors.append(e) + elif wf == "tensorflow_saved_model_bundle" and weights.tensorflow_saved_model_bundle is not None: + try: + from ._tensorflow_model_adapter import TensorflowModelAdapter + + return TensorflowModelAdapter(model_description=model_description, devices=devices) + except Exception as e: + errors.append(e) + elif wf == "onnx" and weights.onnx is not None: + try: + from ._onnx_model_adapter import ONNXModelAdapter + + return ONNXModelAdapter(model_description=model_description, devices=devices) + except Exception as e: + errors.append(e) + elif wf == "torchscript" and weights.torchscript is not None: + try: + from ._torchscript_model_adapter import TorchscriptModelAdapter + + return TorchscriptModelAdapter(model_description=model_description, devices=devices) + except Exception as e: + errors.append(e) + elif wf == "keras_hdf5" and weights.keras_hdf5 is not None: + # keras can either be installed as a separate package or used as part of tensorflow + # we try to first import the keras model adapter using the separate package and, + # if it is not available, try to load the one using tf + try: + try: + from ._keras_model_adapter import KerasModelAdapter + except ImportError: + from ._tensorflow_model_adapter import KerasModelAdapter + + return KerasModelAdapter(model_description=model_description, devices=devices) + except Exception as e: + errors.append(e) + + if errors: + error_msg = f" Errors are: {errors}." + else: + error_msg = "" + + raise ValueError( + f"None of the weight formats {weight_format_priority_order} is supported for {model_description.name} " + f"in this environment.{error_msg}" + ) + + @final + def load(self, *, devices: Optional[Sequence[str]] = None) -> None: + warnings.warn("Deprecated. ModelAdapter is always loaded") + @abstractmethod def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - """ - Load model if unloaded/outside context; then run forward pass of model to get model predictions - """ - if not self.loaded: - self.load() - - assert self.loaded - return self._forward(*input_tensors) - - @abc.abstractmethod - def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: """ Run forward pass of model to get model predictions - Note: model is responsible converting it's data representation to - xarray.DataArray """ - ... + @abstractmethod def unload(self): """ Unload model from any devices, freeing their memory. Note: Use ModelAdapter as context to not worry about calling unload()! """ - # implementation of non-state-machine logic in _unload() - assert self.loaded - self._unload() - self.loaded = False - - @abc.abstractmethod - def _unload(self) -> None: - """ - Unload model from any devices, freeing their memory. - """ - ... def get_weight_formats() -> List[str]: """ Return list of supported weight types """ - return list(_WEIGHT_FORMATS) - - -def create_model_adapter( - *, - model_description: Union[v0_4.Model, v0_5.Model], - devices: Optional[Sequence[str]] = None, - weight_format: Optional[WeightsFormat] = None, -) -> ModelAdapter: - """ - Creates model adapter based on the passed spec - Note: All specific adapters should happen inside this function to prevent different framework - initializations interfering with each other - """ - if weight_format is not None and weight_format not in _WEIGHT_FORMATS: - raise ValueError(f"Weight format {weight_format} is not in supported formats {_WEIGHT_FORMATS}") - - priority_order = _WEIGHT_FORMATS if weight_format is None else (weight_format,) - weight = model_description.weights.get(*priority_order) - - adapter_cls = _get_model_adapter(weight.type) - return adapter_cls(model_description=model_description, devices=devices) - - -def _get_model_adapter(weight_format: WeightsFormat) -> Type[ModelAdapter]: - """ - Return adapter class based on the weight format - Note: All specific adapters should happen inside this function to prevent different framework - initializations interfering with each other - """ - if weight_format == "pytorch_state_dict": - from ._pytorch_model_adapter import PytorchModelAdapter - - return PytorchModelAdapter - - elif weight_format == "tensorflow_saved_model_bundle": - from ._tensorflow_model_adapter import TensorflowModelAdapter - - return TensorflowModelAdapter - - elif weight_format == "onnx": - from ._onnx_model_adapter import ONNXModelAdapter - - return ONNXModelAdapter - - elif weight_format == "torchscript": - from ._torchscript_model_adapter import TorchscriptModelAdapter - - return TorchscriptModelAdapter - - elif weight_format == "keras_hdf5": - # keras can either be installed as a separate package or used as part of tensorflow - # we try to first import the keras model adapter using the separate package and, - # if it is not available, try to load the one using tf - try: - from ._keras_model_adapter import KerasModelAdapter - except ImportError: - from ._tensorflow_model_adapter import KerasModelAdapter + return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER) - return KerasModelAdapter - else: - raise ValueError(f"Weight format {weight_format} is not supported.") +create_model_adapter = ModelAdapter.create diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py index 45b882e0..14ed36d1 100644 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -1,9 +1,12 @@ import logging import warnings -from typing import List, Optional, Sequence +from typing import Any, List, Optional, Sequence, Union import onnxruntime as rt import xarray as xr +from numpy.typing import NDArray + +from bioimageio.spec.model import v0_4, v0_5 from ._model_adapter import ModelAdapter @@ -11,24 +14,32 @@ class ONNXModelAdapter(ModelAdapter): - def _load(self, *, devices: Optional[Sequence[str]] = None): - self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs] - assert self.bioimageio_model.weights.onnx is not None - self._session = rt.InferenceSession(str(self.bioimageio_model.weights.onnx.source)) + def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None): + super().__init__() + self._internal_output_axes = [ + tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes) + for out in model_description.outputs + ] + if model_description.weights.onnx is None: + raise ValueError("No ONNX weights specified for {model_description.name}") + + self._session = rt.InferenceSession(str(model_description.weights.onnx.source)) onnx_inputs = self._session.get_inputs() # type: ignore self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] # type: ignore if devices is not None: warnings.warn(f"Device management is not implemented for onnx yet, ignoring the devices {devices}") - def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: + def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: assert len(input_tensors) == len(self._input_names) input_arrays = [ipt.data for ipt in input_tensors] - result = self._session.run(None, dict(zip(self._input_names, input_arrays))) # type: ignore + result: Union[ # pyright: ignore[reportUnknownVariableType] + Sequence[NDArray[Any]], NDArray[Any] + ] = self._session.run(None, dict(zip(self._input_names, input_arrays))) if not isinstance(result, (list, tuple)): result = [] - return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] # type: ignore + return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] - def _unload(self) -> None: + def unload(self) -> None: warnings.warn("Device management is not implemented for onnx yet, cannot unload model") diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 3db91e37..6b9e67b5 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -14,21 +14,23 @@ class PytorchModelAdapter(ModelAdapter): _devices: Optional[List[torch.device]] = None - def _load(self, *, devices: Optional[Sequence[str]] = None): - if self.model_description.weights.pytorch_state_dict is None: + def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None): + super().__init__() + if model_description.weights.pytorch_state_dict is None: raise ValueError("missing pytorch_state_dict weights") - self._network = self.get_network(self.model_description.weights.pytorch_state_dict) + self.model_description = model_description + self._network = self.get_network(model_description.weights.pytorch_state_dict) self._devices = self.get_devices(devices) self._network = self._network.to(self._devices[0]) - weights = self.model_description.weights.pytorch_state_dict + weights = model_description.weights.pytorch_state_dict state: Any = torch.load(weights.source, map_location=self._devices[0]) _ = self._network.load_state_dict(state) self._network = self._network.eval() - def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: + def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: assert self._devices is not None with torch.no_grad(): tensors = [torch.from_numpy(ipt.data) for ipt in input_tensors] @@ -48,7 +50,7 @@ def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: for r, out in zip(result, self.model_description.outputs) ] - def _unload(self) -> None: + def unload(self) -> None: self._devices = None del self._network _ = gc.collect() # deallocate memory diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index 7b470608..929cd57c 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -1,11 +1,14 @@ import warnings import zipfile -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Sequence, Union import numpy as np import tensorflow as tf import xarray as xr -from marshmallow import missing + +from bioimageio.core.io import FileSource, download +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import RelativeFilePath from ._model_adapter import ModelAdapter @@ -13,59 +16,82 @@ class TensorflowModelAdapterBase(ModelAdapter): weight_format: Literal["keras_hdf5", "tensorflow_saved_model_bundle"] - def require_unzipped(self, weight_file): - if zipfile.is_zipfile(weight_file): - out_path = weight_file.with_suffix("") - with zipfile.ZipFile(weight_file, "r") as f: - f.extractall(out_path) - return out_path - return weight_file - - def _load_model(self, weight_file): - weight_file = self.require_unzipped(weight_file) - if self.use_keras_api: - return tf.keras.models.load_model(weight_file, compile=False) - else: - # NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model - return str(weight_file) - - def _load(self, *, devices: Optional[List[str]] = None): - model_tf_version = self.bioimageio_model.weights[self.weight_format].tensorflow_version - if model_tf_version is missing: - model_tf_version = None - else: - model_tf_version = (int(model_tf_version.major), int(model_tf_version.minor)) - - tf_version = tf.__version__ - tf_major_and_minor = tuple(map(int, tf_version.split(".")))[:2] + def __init__( + self, + *, + devices: Optional[Sequence[str]] = None, + weights: Union[ + v0_4.KerasHdf5Weights, + v0_4.TensorflowSavedModelBundleWeights, + v0_5.KerasHdf5Weights, + v0_5.TensorflowSavedModelBundleWeights, + ], + model_description: Union[v0_4.Model, v0_5.Model], + ): + super().__init__() + self.model_description = model_description + tf_version = v0_5.Version(tf.__version__) + model_tf_version = weights.tensorflow_version if model_tf_version is None: warnings.warn( - "The model did not contain metadata about the tensorflow version used for training." - f"Cannot check if it is compatible with tf {tf_version}. The prediction result may be wrong." + "The model does not specify the tensorflow version." + f"Cannot check if it is compatible with intalled tensorflow {tf_version}." + ) + elif model_tf_version > tf_version: + warnings.warn( + f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." ) - elif tuple(model_tf_version[:2]) != tf_major_and_minor: + elif (model_tf_version.major, model_tf_version.minor) != (tf_version.major, tf_version.minor): warnings.warn( - f"Model tensorflow version {model_tf_version} does not match {tf_version}." - "The prediction results may be wrong" + "The tensorflow version specified by the model does not match the installed: " + f"{model_tf_version} != {tf_version}." ) - tf_major_ver = tf_major_and_minor[0] - assert tf_major_ver in (1, 2) - self.use_keras_api = tf_major_ver > 1 or self.weight_format == KerasModelAdapter.weight_format + self.use_keras_api = tf_version.major > 1 or self.weight_format == KerasModelAdapter.weight_format # TODO tf device management if devices is not None: warnings.warn(f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}") - weight_file = self.require_unzipped(self.bioimageio_model.weights[self.weight_format].source) - self._model = self._load_model(weight_file) - self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs] + weight_file = self.require_unzipped( + weights.source.get_absolute(model_description.root) + if isinstance(weights.source, RelativeFilePath) + else weights.source + ) + self._network = self._get_network(weight_file) + self._internal_output_axes = [ + tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes) + for out in model_description.outputs + ] + + def require_unzipped(self, weight_file: FileSource): + loacl_weights_file = download(weight_file).path + if zipfile.is_zipfile(loacl_weights_file): + out_path = loacl_weights_file.with_suffix(".unzipped") + with zipfile.ZipFile(loacl_weights_file, "r") as f: + f.extractall(out_path) + + return out_path + else: + return loacl_weights_file + + def _get_network(self, weight_file: FileSource): + weight_file = self.require_unzipped(weight_file) + if self.use_keras_api: + return tf.keras.models.load_model(weight_file, compile=False) + else: + # NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model + return str(weight_file) # TODO currently we relaod the model every time. it would be better to keep the graph and session # alive in between of forward passes (but then the sessions need to be properly opened / closed) def _forward_tf(self, *input_tensors): - input_keys = [ipt.name for ipt in self.bioimageio_model.inputs] - output_keys = [out.name for out in self.bioimageio_model.outputs] + input_keys = [ + ipt.name if isinstance(ipt, v0_4.InputTensor) else ipt.id for ipt in self.model_description.inputs + ] + output_keys = [ + out.name if isinstance(out, v0_4.OutputTensor) else out.id for out in self.model_description.outputs + ] # TODO read from spec tag = tf.saved_model.tag_constants.SERVING @@ -75,7 +101,7 @@ def _forward_tf(self, *input_tensors): with graph.as_default(): with tf.Session(graph=graph) as sess: # load the model and the signature - graph_def = tf.saved_model.loader.load(sess, [tag], self._model) + graph_def = tf.saved_model.loader.load(sess, [tag], self._network) signature = graph_def.signature_def # get the tensors into the graph @@ -91,20 +117,22 @@ def _forward_tf(self, *input_tensors): return res - def _forward_keras(self, *input_tensors): + def _forward_keras(self, *input_tensors: xr.DataArray): + assert self.use_keras_api + assert not isinstance(self._network, str) tf_tensor = [tf.convert_to_tensor(ipt) for ipt in input_tensors] try: - result = self._model.forward(*tf_tensor) + result = self._network.forward(*tf_tensor) except AttributeError: - result = self._model.predict(*tf_tensor) + result = self._network.predict(*tf_tensor) if not isinstance(result, (tuple, list)): result = [result] return [r if isinstance(r, np.ndarray) else tf.make_ndarray(r) for r in result] - def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: + def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: data = [ipt.data for ipt in input_tensors] if self.use_keras_api: result = self._forward_keras(*data) @@ -113,13 +141,33 @@ def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] - def _unload(self) -> None: + def unload(self) -> None: warnings.warn("Device management is not implemented for keras yet, cannot unload model") class TensorflowModelAdapter(TensorflowModelAdapterBase): weight_format = "tensorflow_saved_model_bundle" + def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None): + if model_description.weights.tensorflow_saved_model_bundle is None: + raise ValueError("missing tensorflow_saved_model_bundle weights") + + super().__init__( + devices=devices, + weights=model_description.weights.tensorflow_saved_model_bundle, + model_description=model_description, + ) + class KerasModelAdapter(TensorflowModelAdapterBase): weight_format = "keras_hdf5" + + def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None): + if model_description.weights.keras_hdf5 is None: + raise ValueError("missing keras_hdf5 weights") + + super().__init__( + model_description=model_description, + devices=devices, + weights=model_description.weights.keras_hdf5, + ) diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index 3b339722..9fe183bb 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -1,17 +1,29 @@ import gc import warnings -from typing import List, Optional +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import torch import xarray as xr +from numpy.typing import NDArray + +from bioimageio.core.io import download +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import RelativeFilePath from ._model_adapter import ModelAdapter class TorchscriptModelAdapter(ModelAdapter): - def _load(self, *, devices: Optional[List[str]] = None): - weight_path = str(self.bioimageio_model.weights["torchscript"].source.resolve()) + def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None): + super().__init__() + if model_description.weights.torchscript is None: + raise ValueError(f"No torchscript weights found for model {model_description.name}") + + src = model_description.weights.torchscript.source + weight_path = download( + src.get_absolute(model_description.root) if isinstance(src, RelativeFilePath) else src + ).path if devices is None: self.devices = ["cuda" if torch.cuda.is_available() else "cpu"] else: @@ -20,24 +32,34 @@ def _load(self, *, devices: Optional[List[str]] = None): if len(self.devices) > 1: warnings.warn("Multiple devices for single torchscript model not yet implemented") - self._model = torch.jit.load(weight_path) + self._model = torch.jit.load(weight_path) # pyright: ignore[reportPrivateImportUsage] self._model.to(self.devices[0]) - self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs] + self._internal_output_axes = [ + tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes) + for out in model_description.outputs + ] - def _forward(self, *batch: xr.DataArray) -> List[xr.DataArray]: + def forward(self, *batch: xr.DataArray) -> List[xr.DataArray]: with torch.no_grad(): torch_tensor = [torch.from_numpy(b.data).to(self.devices[0]) for b in batch] - result = self._model.forward(*torch_tensor) - if not isinstance(result, (tuple, list)): - result = [result] + _result: Union[ # pyright: ignore[reportUnknownVariableType] + Tuple[NDArray[Any], ...], List[NDArray[Any]], NDArray[Any] + ] = self._model.forward(*torch_tensor) + if isinstance(_result, (tuple, list)): + result: Sequence[NDArray[Any]] = _result + else: + result = [_result] - result = [r.cpu().numpy() if not isinstance(r, np.ndarray) else r for r in result] + result = [ + r.cpu().numpy() if not isinstance(r, np.ndarray) else r # pyright: ignore[reportUnnecessaryIsInstance] + for r in result + ] assert len(result) == len(self._internal_output_axes) return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] - def _unload(self) -> None: + def unload(self) -> None: self._devices = None del self._model - gc.collect() # deallocate memory + _ = gc.collect() # deallocate memory torch.cuda.empty_cache() # release reserved memory From 226407f096fd96228d98d92a9cd1ada359be747b Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 25 Oct 2023 23:44:54 +0200 Subject: [PATCH 048/244] update pytorch model adapter --- .github/workflows/build.yml | 2 +- .../core/model_adapters/_model_adapter.py | 4 ++- .../model_adapters/_pytorch_model_adapter.py | 33 ++++++++----------- pyproject.toml | 2 +- 4 files changed, 18 insertions(+), 23 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9dbb37b1..065dfa53 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -8,7 +8,7 @@ on: defaults: run: - shell: bash -l {0} + shell: bash -el {0} jobs: black: diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index d1e8ecb6..bcba49b6 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -58,7 +58,9 @@ def create( try: from ._pytorch_model_adapter import PytorchModelAdapter - return PytorchModelAdapter(model_description=model_description, devices=devices) + return PytorchModelAdapter( + outputs=model_description.outputs, weights=weights.pytorch_state_dict, devices=devices + ) except Exception as e: errors.append(e) elif wf == "tensorflow_saved_model_bundle" and weights.tensorflow_saved_model_bundle is not None: diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 6b9e67b5..7e5fd706 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -12,26 +12,25 @@ class PytorchModelAdapter(ModelAdapter): - _devices: Optional[List[torch.device]] = None - - def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None): + def __init__( + self, + *, + outputs: Union[Sequence[v0_4.OutputTensor], Sequence[v0_5.OutputTensor]], + weights: Union[v0_4.PytorchStateDictWeights, v0_5.PytorchStateDictWeights], + devices: Optional[Sequence[str]] = None, + ): super().__init__() - if model_description.weights.pytorch_state_dict is None: - raise ValueError("missing pytorch_state_dict weights") - - self.model_description = model_description - self._network = self.get_network(model_description.weights.pytorch_state_dict) + self.output_dims = [tuple(a if isinstance(a, str) else a.id for a in out.axes) for out in outputs] + self._network = self.get_network(weights) self._devices = self.get_devices(devices) self._network = self._network.to(self._devices[0]) - weights = model_description.weights.pytorch_state_dict state: Any = torch.load(weights.source, map_location=self._devices[0]) _ = self._network.load_state_dict(state) self._network = self._network.eval() def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - assert self._devices is not None with torch.no_grad(): tensors = [torch.from_numpy(ipt.data) for ipt in input_tensors] tensors = [t.to(self._devices[0]) for t in tensors] @@ -40,18 +39,12 @@ def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: result = [result] result = [r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r for r in result] - if len(result) > len(self.model_description.outputs): - raise ValueError( - f"Expected at most {len(self.model_description.outputs)} outpus, but got {len(result)}" - ) + if len(result) > len(self.output_dims): + raise ValueError(f"Expected at most {len(self.output_dims)} outputs, but got {len(result)}") - return [ - xr.DataArray(r, dims=tuple(a if isinstance(a, str) else a.id for a in out.axes)) - for r, out in zip(result, self.model_description.outputs) - ] + return [xr.DataArray(r, dims=out) for r, out in zip(result, self.output_dims)] def unload(self) -> None: - self._devices = None del self._network _ = gc.collect() # deallocate memory torch.cuda.empty_cache() # release reserved memory @@ -76,7 +69,7 @@ def get_network(weight_spec: Union[v0_4.PytorchStateDictWeights, v0_5.PytorchSta return network @staticmethod - def get_devices(devices: Optional[Sequence[str]] = None): + def get_devices(devices: Optional[Sequence[str]] = None) -> List[torch.device]: if not devices: torch_devices = [torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")] else: diff --git a/pyproject.toml b/pyproject.toml index 4857d8c7..e0296a1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ reportMissingTypeStubs = "warning" useLibraryCodeForTypes = true reportUnusedCallResult = "error" reportUnusedVariable = "error" -pythonVersion = "3.9" +pythonVersion = "3.8" pythonPlatform = "All" [tool.pytest.ini_options] From 62e8443a7fd29a9ff98b785b7a058c03cff22581 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 25 Oct 2023 23:46:48 +0200 Subject: [PATCH 049/244] add build_model_spec --- bioimageio/core/build_spec/__init__.py | 1 + bioimageio/core/build_spec/_build_spec.py | 240 ++++++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 bioimageio/core/build_spec/_build_spec.py diff --git a/bioimageio/core/build_spec/__init__.py b/bioimageio/core/build_spec/__init__.py index c11615df..e4dcb7b7 100644 --- a/bioimageio/core/build_spec/__init__.py +++ b/bioimageio/core/build_spec/__init__.py @@ -1,2 +1,3 @@ +from ._build_spec import build_model_spec as build_model_spec from .add_weights import add_weights from .build_model import build_model diff --git a/bioimageio/core/build_spec/_build_spec.py b/bioimageio/core/build_spec/_build_spec.py new file mode 100644 index 00000000..311d67c1 --- /dev/null +++ b/bioimageio/core/build_spec/_build_spec.py @@ -0,0 +1,240 @@ +import collections.abc +import shutil +from datetime import datetime +from pathlib import Path +from typing import Any, Optional, Sequence, Type, TypedDict, Union + +import numpy as np +from numpy.typing import NDArray + +# from bioimageio.core import export_resource_package, load_raw_resource_description +from typing_extensions import NotRequired, Self, Unpack + +from bioimageio.core.io import FileSource, download, read_description_and_validate, write_description +from bioimageio.spec.model.v0_5 import ( + Architecture, + Author, + CiteEntry, + Dependencies, + InputAxis, + InputTensor, + IntervalOrRatioData, + IntervalOrRatioDType, + LicenseId, + Maintainer, + Model, + NotEmpty, + OutputAxis, + OutputTensor, + Postprocessing, + Preprocessing, + PytorchStateDictWeights, + RelativeFilePath, + Sha256, + TensorData, + TensorId, + Version, + Weights, +) + + +class CoreGenericBaseKwargs(TypedDict): + name: str + description: str + authors: NotEmpty[Sequence[Author]] + maintainers: NotRequired[Sequence[Maintainer]] + tags: Sequence[str] + documentation: FileSource + cite: NotEmpty[Sequence[CiteEntry]] + license: LicenseId + output_path: Path + + +class CoreTensorKwargs(TypedDict): + test_tensor: FileSource + sample_tensor: NotRequired[FileSource] + id: NotRequired[Optional[TensorId]] + data: NotRequired[Optional[Union[TensorData, NotEmpty[Sequence[TensorData]]]]] + output_path: Path + + +class CoreInputTensorKwargs(CoreTensorKwargs): + axes: NotEmpty[Sequence[InputAxis]] + preprocessing: NotRequired[Sequence[Preprocessing]] + + +class CoreOutputTensorKwargs(CoreTensorKwargs): + axes: NotEmpty[Sequence[OutputAxis]] + postprocessing: NotRequired[Sequence[Postprocessing]] + + +def ensure_file_in_folder(source: FileSource, folder: Path) -> RelativeFilePath: + """download/copy `source` to `folder` if `source` is not already in (a subfolder of) `folder`. + Returns a relative file path (relative to `folder`)""" + path = download(source).path + try: + rel_path = path.relative_to(folder) # todo: improve for py >= 3.9 with path.is_relative_to + except ValueError: + path = Path(shutil.copy(path, folder)) + rel_path = path.relative_to(folder) + + return RelativeFilePath(rel_path) + + +class _CoreTensorMixin: + @staticmethod + def get_data_description(kwargs: Union[CoreInputTensorKwargs, CoreOutputTensorKwargs]): + tensor_data = kwargs.get("data") + if isinstance(tensor_data, TensorData): + return tensor_data + elif tensor_data is None: + test_tensor: NDArray[Any] = np.load(download(kwargs["test_tensor"]).path) + assert isinstance(test_tensor, np.ndarray) + dtype_str = str(test_tensor.dtype) + if dtype_str.startswith("float") and test_tensor.min() >= 0.0 and test_tensor.max() <= 1.0: + range_ = (0.0, 1.0) + else: + range_ = (None, None) + + dtype: IntervalOrRatioDType = dtype_str # type: ignore # validated by IntervalOrRatioData + return IntervalOrRatioData(type=dtype, range=range_, unit="arbitrary unit", scale=1.0, offset=None) + elif isinstance(tensor_data, collections.abc.Sequence): # pyright: ignore[reportUnnecessaryIsInstance] + return tuple(tensor_data) + else: + raise TypeError(tensor_data) + + +class _CoreInputTensor(InputTensor, _CoreTensorMixin, frozen=True): + @classmethod + def build(cls, **kwargs: Unpack[CoreInputTensorKwargs]): + return cls( + test_tensor=ensure_file_in_folder(kwargs["test_tensor"], kwargs["output_path"]), + id=kwargs.get("id") or TensorId("input"), + axes=tuple(kwargs["axes"]), + preprocessing=tuple(kwargs.get("preprocessing", ())), + data=cls.get_data_description(kwargs), + sample_tensor=ensure_file_in_folder(kwargs["sample_tensor"], kwargs["output_path"]) + if "sample_tensor" in kwargs + else None, + ) + + +class _CoreOutputTensor(OutputTensor, _CoreTensorMixin, frozen=True): + @classmethod + def build(cls, **kwargs: Unpack[CoreOutputTensorKwargs]): + return cls( + test_tensor=ensure_file_in_folder(kwargs["test_tensor"], kwargs["output_path"]), + id=kwargs.get("id") or TensorId("output"), + axes=tuple(kwargs["axes"]), + postprocessing=tuple(kwargs.get("postprocessing", ())), + data=cls.get_data_description(kwargs), + ) + + +class CoreModelBaseKwargs(CoreGenericBaseKwargs): + inputs: NotEmpty[Sequence[CoreInputTensorKwargs]] + outputs: NotEmpty[Sequence[CoreOutputTensorKwargs]] + + +class CoreModelKwargs(CoreModelBaseKwargs): + weights: Weights + + +class _CoreModel(Model, frozen=True): + @classmethod + def build(cls, **kwargs: Unpack[CoreModelKwargs]) -> Self: + documentation = ensure_file_in_folder(kwargs["documentation"], kwargs["output_path"]) + + inputs = tuple( + _CoreInputTensor.build( + id=t_kwargs["id"] if "id" in t_kwargs else TensorId(f"input{i}"), + test_tensor=t_kwargs["test_tensor"], + axes=t_kwargs["axes"], + data=t_kwargs.get("data"), + output_path=kwargs["output_path"], + ) + for i, t_kwargs in enumerate(kwargs["inputs"]) + ) + + outputs = tuple( + _CoreOutputTensor.build( + id=t_kwargs["id"] if "id" in t_kwargs else TensorId(f"output{i}"), + test_tensor=t_kwargs["test_tensor"], + axes=t_kwargs["axes"], + data=t_kwargs.get("data"), + output_path=kwargs["output_path"], + ) + for i, t_kwargs in enumerate(kwargs["outputs"]) + ) + + return cls( + name=kwargs["name"], + description=kwargs["description"], + authors=tuple(kwargs["authors"]), + maintainers=tuple(kwargs.get("maintainers", ())), + cite=tuple(kwargs["cite"]), + license=kwargs["license"], + timestamp=datetime.now(), + inputs=inputs, + outputs=outputs, + weights=kwargs["weights"], + documentation=documentation, + ) + + @classmethod + def build_from_pytorch_state_dict( + cls, + weights: FileSource, + architecture: Architecture, + sha256: Optional[Sha256] = None, + pytorch_version: Optional[Version] = None, + dependencies: Optional[Dependencies] = None, + **kwargs: Unpack[CoreModelBaseKwargs], + ): + if pytorch_version is None: + import torch + + pytorch_version = Version(torch.__version__) + + return cls.build( + weights=Weights( + pytorch_state_dict=PytorchStateDictWeights( + source=ensure_file_in_folder(weights, kwargs["output_path"]), + sha256=sha256, + architecture=architecture, + pytorch_version=pytorch_version, + dependencies=dependencies, + ) + ), + **kwargs, + ) + + +def _build_spec_common(core_descr: _CoreModel, descr_path: Path, expected_type: Type[Any]): + write_description(core_descr, descr_path) + loaded = read_description_and_validate(descr_path) + if type(loaded) is not expected_type: + raise RuntimeError(f"Created {descr_path} was loaded as {type(loaded)}, but expected {expected_type}") + + return descr_path, loaded + + +def build_model_spec( + *, + weights: FileSource, + architecture: Architecture, + sha256: Optional[Sha256] = None, + pytorch_version: Optional[Version] = None, + dependencies: Optional[Dependencies] = None, + **kwargs: Unpack[CoreModelBaseKwargs], +): + model = _CoreModel.build_from_pytorch_state_dict( + weights=weights, + architecture=architecture, + sha256=sha256, + pytorch_version=pytorch_version, + dependencies=dependencies, + **kwargs, + ) + + return _build_spec_common(model, kwargs["output_path"] / "description.bioimageio.yaml", Model) From 3397d832c9baec1a0704f3ec3f49768854e9a65e Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 25 Oct 2023 23:47:08 +0200 Subject: [PATCH 050/244] WIP update add_weights --- bioimageio/core/build_spec/add_weights.py | 43 +++++++++++------------ 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/bioimageio/core/build_spec/add_weights.py b/bioimageio/core/build_spec/add_weights.py index 0e4e7949..7dbafe82 100644 --- a/bioimageio/core/build_spec/add_weights.py +++ b/bioimageio/core/build_spec/add_weights.py @@ -1,19 +1,23 @@ import os from pathlib import Path from shutil import copyfile -from typing import Dict, Optional, Union, List +from typing import Dict, List, Optional, Union + +from pydantic import DirectoryPath, FilePath + +from bioimageio.core import export_resource_package +from bioimageio.core.io import FileSource, download, read_description, write_package_as_folder +from bioimageio.spec.model import AnyModel, v0_5 -from bioimageio.core import export_resource_package, load_raw_resource_description -from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription from .build_model import _get_weights def add_weights( - model: Union[RawResourceDescription, os.PathLike, str], - weight_uri: Union[str, Path], - output_path: Union[str, Path], + model: Union[AnyModel, FileSource], + weight_file: FileSource, + output_path: DirectoryPath, *, - weight_type: Optional[str] = None, + weight_type: Optional[v0_5.WeightsFormat] = None, architecture: Optional[str] = None, model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None, tensorflow_version: Optional[str] = None, @@ -25,7 +29,7 @@ def add_weights( Args: model: the resource description of the model to which the weight format is added - weight_uri: the weight file to be added + weight_file: the weight file to be added output_path: where to serialize the new model with additional weight format weight_type: the format of the weights to be added architecture: the file with the source code for the model architecture and the corresponding class. @@ -37,23 +41,18 @@ def add_weights( pytorch_version: the pytorch version for this model. Only for pytoch_state_dict or torchscript weights. attachments: extra weight specific attachments. """ - model = load_raw_resource_description(model) - if not isinstance(model.root_path, Path): - # ensure model is available locally - model = load_raw_resource_description(export_resource_package(model)) - - assert isinstance(model.root_path, Path), model.root_path + downloaded_weight_file = download(weight_file) + output_path = write_package_as_folder(model, output_path=output_path) # copy the weight path to the input model's root, otherwise it will # not be found when packaging the new model - weight_out = os.path.join(model.root_path, Path(weight_uri).name) - if Path(weight_out).absolute() != Path(weight_uri).absolute(): - copyfile(weight_uri, weight_out) + weight_out: FilePath = output_path / downloaded_weight_file.original_file_name # noqa: F821 + _ = copyfile(downloaded_weight_file.path, weight_out) new_weights, tmp_arch = _get_weights( weight_out, weight_type, - root=Path("."), + root=output_path, architecture=architecture, model_kwargs=model_kwargs, tensorflow_version=tensorflow_version, @@ -65,18 +64,18 @@ def add_weights( try: model_package = export_resource_package(model, output_path=output_path) - model = load_raw_resource_description(model_package) + model = read_description(model_package) except Exception as e: raise e finally: # clean up tmp files - if Path(weight_out).absolute() != Path(weight_uri).absolute(): + if Path(weight_out).absolute() != Path(weight_file).absolute(): os.remove(weight_out) if tmp_arch is not None: os.remove(tmp_arch) # for some reason the weights are also copied to the cwd. # not sure why this happens, but it needs to be cleaned up, unless these are the input weigths - weights_cwd = Path(os.path.split(weight_uri)[1]) - if weights_cwd.exists() and weights_cwd.absolute() != Path(weight_uri).absolute(): + weights_cwd = Path(os.path.split(weight_file)[1]) + if weights_cwd.exists() and weights_cwd.absolute() != Path(weight_file).absolute(): os.remove(weights_cwd) return model From 6d5e66cdf01bacb8eae0ebd853061f4e2fe82c86 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 25 Oct 2023 23:49:33 +0200 Subject: [PATCH 051/244] WIP update imports --- bioimageio/core/__init__.py | 43 +++++-------------- bioimageio/core/statistical_measure_groups.py | 5 +-- setup.py | 2 +- 3 files changed, 14 insertions(+), 36 deletions(-) diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 408636e8..e3a665c2 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -1,41 +1,20 @@ import json -from bioimageio.core._internal.utils import files -from bioimageio.core._io import load_description as load_description -from bioimageio.core._io import load_description_and_validate as load_description_and_validate -from bioimageio.core._io import read_description as read_description -from bioimageio.core._io import read_description_and_validate as read_description_and_validate -from bioimageio.core._io import resolve_source as resolve_source -from bioimageio.core._io import write_description as write_description -from bioimageio.core._io import write_package as write_package +from bioimageio.core.io import load_description as load_description +from bioimageio.core.io import load_description_and_validate as load_description_and_validate +from bioimageio.core.io import read_description as read_description +from bioimageio.core.io import read_description_and_validate as read_description_and_validate +from bioimageio.core.io import resolve_source as resolve_source +from bioimageio.core.io import write_description as write_description +from bioimageio.core.io import write_package as write_package +from bioimageio.core.io import write_package_as_folder as write_package_as_folder +from bioimageio.core.utils import files with files("bioimageio.core").joinpath("VERSION").open("r", encoding="utf-8") as f: __version__: str = json.load(f)["version"] assert isinstance(__version__, str) -# __version__ = json.loads((pathlib.Path(__file__).parent / "VERSION").read_text())["version"] # from .prediction import predict_image, predict_images, predict_with_padding, predict_with_tiling -# from .prediction_pipeline import create_prediction_pipeline -# from .resource_io import ( -# export_resource_package, -# load_raw_resource_description, -# load_resource_description, -# save_raw_resource_description, -# serialize_raw_resource_description, -# ) -# from .resource_tests import check_input_shape, check_output_shape, test_resource +from .prediction_pipeline import create_prediction_pipeline -# __all__ = [ -# "check_input_shape", -# "check_output_shape", -# "create_prediction_pipeline", -# "export_resource_package", -# "load_raw_resource_description", -# "load_resource_description", -# "predict_image", -# "predict_images", -# "predict_with_padding", -# "predict_with_tiling", -# "save_raw_resource_description", -# "serialize_raw_resource_description", -# "test_resource", +# from .resource_tests import check_input_shape, check_output_shape, test_resource diff --git a/bioimageio/core/statistical_measure_groups.py b/bioimageio/core/statistical_measure_groups.py index cc2c0646..88f65b3d 100644 --- a/bioimageio/core/statistical_measure_groups.py +++ b/bioimageio/core/statistical_measure_groups.py @@ -10,11 +10,10 @@ import numpy import xarray as xr from attr import dataclass -from bioimageio.spec.model.v0_5 import AxisName +from bioimageio.core.sta import PER_DATASET, PER_SAMPLE, ComputedMeasures, RequiredMeasures, Sample from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std, Var - -from ._utils import PER_DATASET, PER_SAMPLE, ComputedMeasures, RequiredMeasures, Sample, TensorName +from bioimageio.spec.model.v0_5 import AxisName try: import crick # type: ignore diff --git a/setup.py b/setup.py index e140ffac..82b05008 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ ], include_package_data=True, extras_require={ - "test": ["pytest", "black[jupyter]"], + "test": ["pytest", "black[jupyter]", "onnxruntime", "torch>=1.6", "torchvision"], "dev": ["pre-commit"], "pytorch": ["torch>=1.6", "torchvision"], "tensorflow": ["tensorflow"], From 28f9a6331e933fda857a33f90538b03ad4199e9a Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 25 Oct 2023 23:50:00 +0200 Subject: [PATCH 052/244] add write_package_as_folder --- bioimageio/core/io.py | 106 +++++++++++++++++++++++++++++++++--------- 1 file changed, 83 insertions(+), 23 deletions(-) diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index dd8294c5..306a4ffd 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -3,22 +3,24 @@ import collections.abc import io import os +import shutil +from dataclasses import dataclass from pathlib import Path -from tempfile import NamedTemporaryFile -from typing import Annotated, Any, Dict, List, Literal, Mapping, NamedTuple, Optional, Sequence, TextIO, Union, cast +from tempfile import NamedTemporaryFile, TemporaryDirectory, mkdtemp +from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, TextIO, TypedDict, Union, cast from zipfile import ZIP_DEFLATED, ZipFile, is_zipfile import pooch -from annotated_types import Len, Predicate from pydantic import AnyUrl, DirectoryPath, FilePath, HttpUrl, TypeAdapter from ruamel.yaml import YAML +from typing_extensions import NotRequired, Unpack from bioimageio.core.utils import get_parent_url from bioimageio.spec import ResourceDescription from bioimageio.spec import load_description as load_description from bioimageio.spec._internal.base_nodes import ResourceDescriptionBase from bioimageio.spec._internal.constants import DISCOVER -from bioimageio.spec._internal.types import FileName, RdfContent, RelativeFilePath, ValidationContext, YamlValue +from bioimageio.spec._internal.types import FileName, RdfContent, RelativeFilePath, Sha256, ValidationContext, YamlValue from bioimageio.spec.description import InvalidDescription, dump_description from bioimageio.spec.model.v0_4 import WeightsFormat from bioimageio.spec.package import extract_file_name, get_resource_package_content @@ -33,7 +35,15 @@ LEGACY_RDF_NAME = "rdf.yaml" -KnownHash = Annotated[str, Len(64 + len("sha256:")), Predicate(lambda x: str.startswith(x, "sha256:"))] +class HashKwargs(TypedDict): + sha256: NotRequired[Optional[Sha256]] + + +def get_known_hash(hash_kwargs: HashKwargs): + if "sha256" in hash_kwargs: + return f"sha256:{hash_kwargs['sha256']}" + else: + return None def read_description( @@ -45,7 +55,7 @@ def read_description( rdf = download_rdf(rdf_source) return load_description( rdf.content, - context=ValidationContext(root=rdf.root, file_name=rdf.file_name), + context=ValidationContext(root=rdf.original_root, file_name=rdf.original_file_name), format_version=format_version, ) @@ -58,7 +68,9 @@ def read_description_and_validate( ) -> Union[ResourceDescription, InvalidDescription]: rdf = download_rdf(rdf_source) return load_description_and_validate( - rdf.content, context=ValidationContext(root=rdf.root, file_name=rdf.file_name), format_version=format_version + rdf.content, + context=ValidationContext(root=rdf.original_root, file_name=rdf.original_file_name), + format_version=format_version, ) @@ -123,7 +135,7 @@ def prepare_resource_package( context = ValidationContext(root=_ctxt["root"], file_name=_ctxt["file_name"]) else: rdf = download_rdf(rdf_source) - context = ValidationContext(root=rdf.root, file_name=rdf.file_name) + context = ValidationContext(root=rdf.original_root, file_name=rdf.original_file_name) rd = load_description( rdf.content, context=context, @@ -198,7 +210,6 @@ def write_package( Args: rd: bioimage.io resource description - context: compression: The numeric constant of compression method. compression_level: Compression level to use when writing files to the archive. See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile @@ -222,24 +233,72 @@ def write_package( return output_path -class LocalFile(NamedTuple): +def write_package_as_folder( + rdf_source: RdfSource, + /, + *, + output_path: Optional[DirectoryPath] = None, + weights_priority_order: Optional[ # model only + Sequence[ + Literal[ + "keras_hdf5", + "onnx", + "pytorch_state_dict", + "tensorflow_js", + "tensorflow_saved_model_bundle", + "torchscript", + ] + ] + ] = None, +) -> DirectoryPath: + """Write the content of a bioimage.io resource package to a folder. + + Args: + rd: bioimage.io resource description + output_path: file path to write package to + weights_priority_order: If given only the first weights format present in the model is included. + If none of the prioritized weights formats is found all are included. + + Returns: + path to zipped bioimage.io package in BIOIMAGEIO_CACHE_PATH or 'output_path' + """ + package_content = prepare_resource_package( + rdf_source, + weights_priority_order=weights_priority_order, + ) + if output_path is None: + output_path = Path(mkdtemp()) + else: + output_path = Path(output_path) + + for name, source in package_content.items(): + if isinstance(source, collections.abc.Mapping): + yaml.dump(source, output_path / name) + else: + shutil.copy(source, output_path / name) + + return output_path + + +@dataclass +class DownloadedFile: path: FilePath original_root: Union[AnyUrl, DirectoryPath] original_file_name: str -class LocalRdf(NamedTuple): +@dataclass +class DownloadedRdf: content: RdfContent - root: Union[AnyUrl, DirectoryPath] - file_name: str + original_root: Union[AnyUrl, DirectoryPath] + original_file_name: str def download( source: FileSource, /, - *, - known_hash: Optional[KnownHash] = None, -) -> LocalFile: + **kwargs: Unpack[HashKwargs], +) -> DownloadedFile: source = _interprete_file_source(source) if isinstance(source, AnyUrl): if source.scheme not in ("http", "https"): @@ -256,22 +315,23 @@ def download( headers["User-Agent"] = user_agent downloader = pooch.HTTPDownloader(headers=headers, progressbar=progressbar) - _ls: Any = pooch.retrieve(url=str(source), known_hash=known_hash, downloader=downloader) + _ls: Any = pooch.retrieve(url=str(source), known_hash=get_known_hash(kwargs), downloader=downloader) local_source = Path(_ls) root: Union[HttpUrl, DirectoryPath] = get_parent_url(source) else: local_source = source root = source.parent - return LocalFile( + return DownloadedFile( local_source, root, extract_file_name(source), ) -def download_rdf(source: FileSource, /, *, known_hash: Optional[KnownHash] = None, rdf_encoding: str = "utf-8"): - local_source, root, file_name = download(source, known_hash=known_hash) +def download_rdf(source: FileSource, /, *, rdf_encoding: str = "utf-8", **kwargs: Unpack[HashKwargs]): + downloaded = download(source, **kwargs) + local_source = downloaded.path if is_zipfile(local_source): out_path = local_source.with_suffix(local_source.suffix + ".unzip") with ZipFile(local_source, "r") as f: @@ -296,15 +356,15 @@ def download_rdf(source: FileSource, /, *, known_hash: Optional[KnownHash] = Non if not isinstance(content, collections.abc.Mapping): raise TypeError(f"Expected RDF content to be a mapping, but got '{type(content)}'.") - return LocalRdf(cast(RdfContent, content), root, file_name) + return DownloadedRdf(cast(RdfContent, content), downloaded.original_root, downloaded.original_file_name) def resolve_source( source: Union[FileSource, RelativeFilePath], /, *, - known_hash: Optional[KnownHash] = None, root: Union[DirectoryPath, AnyUrl, None] = None, + **kwargs: Unpack[HashKwargs], ) -> FilePath: if isinstance(source, RelativeFilePath): if root is None: @@ -312,7 +372,7 @@ def resolve_source( source = source.get_absolute(root) - return download(source, known_hash=known_hash).path + return download(source, **kwargs).path def _interprete_file_source(file_source: FileSource) -> StrictFileSource: From fbfbf4b74dd43c36f471746826f4a054c908a884 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 25 Oct 2023 23:51:35 +0200 Subject: [PATCH 053/244] WIP update prediction pipeline --- bioimageio/core/prediction_pipeline/__init__.py | 3 ++- .../prediction_pipeline/_prediction_pipeline.py | 1 - bioimageio/core/prediction_pipeline/_stat_state.py | 2 +- bioimageio/core/prediction_pipeline/_utils.py | 14 -------------- 4 files changed, 3 insertions(+), 17 deletions(-) diff --git a/bioimageio/core/prediction_pipeline/__init__.py b/bioimageio/core/prediction_pipeline/__init__.py index da136844..78ce5590 100644 --- a/bioimageio/core/prediction_pipeline/__init__.py +++ b/bioimageio/core/prediction_pipeline/__init__.py @@ -1,3 +1,4 @@ -from ._model_adapters import get_weight_formats as get_weight_formats +from bioimageio.core.model_adapters import get_weight_formats as get_weight_formats + from ._prediction_pipeline import PredictionPipeline as PredictionPipeline from ._prediction_pipeline import create_prediction_pipeline as create_prediction_pipeline diff --git a/bioimageio/core/prediction_pipeline/_prediction_pipeline.py b/bioimageio/core/prediction_pipeline/_prediction_pipeline.py index 59e677db..8fffc887 100644 --- a/bioimageio/core/prediction_pipeline/_prediction_pipeline.py +++ b/bioimageio/core/prediction_pipeline/_prediction_pipeline.py @@ -4,7 +4,6 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union import xarray as xr -from marshmallow import missing from bioimageio.core.model_adapters import ModelAdapter, create_model_adapter from bioimageio.core.validation_visitors import resolve_raw_node diff --git a/bioimageio/core/prediction_pipeline/_stat_state.py b/bioimageio/core/prediction_pipeline/_stat_state.py index c0e72eb0..c84c7e4e 100644 --- a/bioimageio/core/prediction_pipeline/_stat_state.py +++ b/bioimageio/core/prediction_pipeline/_stat_state.py @@ -2,9 +2,9 @@ from tqdm import tqdm +from bioimageio.core.statistical_measure_groups import MeasureGroups, MeasureValue, get_measure_groups from bioimageio.core.statistical_measures import Measure -from ._measure_groups import MeasureGroups, MeasureValue, get_measure_groups from ._utils import PER_DATASET, PER_SAMPLE, MeasureValue, RequiredMeasure, Sample, TensorName diff --git a/bioimageio/core/prediction_pipeline/_utils.py b/bioimageio/core/prediction_pipeline/_utils.py index 83f181ef..b1f5c2c7 100644 --- a/bioimageio/core/prediction_pipeline/_utils.py +++ b/bioimageio/core/prediction_pipeline/_utils.py @@ -1,17 +1,3 @@ -from __future__ import annotations - -import collections.abc -from dataclasses import dataclass, field -from typing import Any, Dict, Iterator, List, Literal, NamedTuple, Set, Union - -import xarray as xr - -from bioimageio.core.statistical_measures import Measure -from bioimageio.spec.model.v0_5 import TensorId - -Sample = Dict[TensorId, xr.DataArray] - - # def __repr__(self) -> str: # return f"{self.measure} of {self.tensor_id} ({self.mode})" From 16f7fefeab46b384526ddf896e0d716ceb6df8a7 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 25 Oct 2023 23:52:04 +0200 Subject: [PATCH 054/244] WIP update processing --- .../core/prediction_pipeline/processing.py | 66 ++++++++++++++----- 1 file changed, 51 insertions(+), 15 deletions(-) diff --git a/bioimageio/core/prediction_pipeline/processing.py b/bioimageio/core/prediction_pipeline/processing.py index dde182e3..acef48b5 100644 --- a/bioimageio/core/prediction_pipeline/processing.py +++ b/bioimageio/core/prediction_pipeline/processing.py @@ -4,6 +4,7 @@ from typing import ( Any, ClassVar, + Dict, FrozenSet, Generic, Hashable, @@ -25,14 +26,12 @@ import numpy as np import xarray as xr from numpy.typing import DTypeLike -from typing_extensions import LiteralString +from typing_extensions import LiteralString, assert_never from bioimageio.core.statistical_measures import Mean, Measure, MeasureValue, Percentile, Std from bioimageio.spec._internal.base_nodes import NodeWithExplicitlySetFields from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import TensorId - -from ._utils import Sample +from bioimageio.spec.model.v0_5 import AxisName, NonBatchAxisName, TensorId AssertProcessingId = Literal["assert_dtype"] @@ -73,6 +72,7 @@ def get_set(self) -> Set[M]: C = TypeVar("C", bound=NamedMeasures[MeasureValue]) +Sample = Dict[TensorId, xr.DataArray] PKwargs = TypeVar("PKwargs", bound=Union[v0_4.ProcessingKwargs, v0_5.ProcessingKwargs]) ProcInput = TypeVar("ProcInput", xr.DataArray, Sample) ProcessingBase = Union[v0_4.ProcessingBase, v0_5.ProcessingBase] @@ -141,6 +141,7 @@ def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMea @dataclass(frozen=True) class AssertDtypeImpl(ProcessingImplBaseWoMeasures[AssertDtypeKwargs]): + kwargs_class = AssertDtypeKwargs _assert_with: Tuple[Type[DTypeLike], ...] = field(init=False) def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None: @@ -189,6 +190,26 @@ def get_spec(self): return v0_5.EnsureDtype(kwargs=self.kwargs) +class ScaleLinearImplBase +class ScaleLinearImpl04(ProcessingImplBaseWoMeasures[Union[v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs]]): + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + axis = ( + self.kwargs.axis + if isinstance(self.kwargs, v0_5.ScaleLinearKwargs) + else _get_complement_axis(tensor, self.kwargs.axes) + ) + if axis: + gain = xr.DataArray(np.atleast_1d(self.kwargs.gain), dims=axis) + offset = xr.DataArray(np.atleast_1d(self.kwargs.offset), dims=axis) + else: + assert isinstance(self.kwargs.gain, (float, int)) or len(self.kwargs.gain) == 1 + gain = self.kwargs.gain if isinstance(self.kwargs.gain, (float, int)) else self.kwargs.gain[0] + assert isinstance(self.kwargs.offset, (float, int)) or len(self.kwargs.offset) == 1 + offset = self.kwargs.offset if isinstance(self.kwargs.offset, (float, int)) else self.kwargs.offset[0] + + return tensor * gain + offset + + @dataclass(frozen=True) class ScaleLinearImpl(ProcessingImplBaseWoMeasures[Union[v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs]]): def apply(self, tensor: xr.DataArray) -> xr.DataArray: @@ -235,7 +256,7 @@ class ScaleMeanVarianceImpl( def get_required_measures( cls, tensor_id: TensorId, kwargs: Union[v0_4.ScaleMeanVarianceKwargs, v0_5.ScaleMeanVarianceKwargs] ): - axes = tuple(kwargs.axes) if isinstance(kwargs.axes, str) else kwargs.axes + axes = tuple(NonBatchAxisName(a) for a in kwargs.axes) if isinstance(kwargs.axes, str) else kwargs.axes return NamedMeasuresScaleMeanVariance( mean=RequiredMeasure(Mean(axes), tensor_id, mode=kwargs.mode), std=RequiredMeasure(Std(axes), tensor_id, mode=kwargs.mode), @@ -272,7 +293,7 @@ class ScaleRangeImpl( @classmethod def get_required_measures(cls, tensor_id: TensorId, kwargs: Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs]): ref_name = kwargs.reference_tensor or tensor_id - axes = None if kwargs.axes is None else tuple(kwargs.axes) + axes = None if kwargs.axes is None else tuple(NonBatchAxisName(a) for a in kwargs.axes) return NamedMeasuresScaleRange( lower=RequiredMeasure(Percentile(kwargs.min_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode), upper=RequiredMeasure(Percentile(kwargs.max_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode), @@ -320,7 +341,7 @@ class ZeroMeanUnitVarianceImpl( def get_required_measures( cls, tensor_id: TensorId, kwargs: Union[v0_4.ZeroMeanUnitVarianceKwargs, v0_5.ZeroMeanUnitVarianceKwargs] ): - axes = None if kwargs.axes is None else tuple(kwargs.axes) + axes = None if kwargs.axes is None else tuple(NonBatchAxisName(a) for a in kwargs.axes) assert kwargs.mode != "fixed" # should use FixedZeroMeanUnitVarianceImpl return NamedMeasuresZeroMeanUnitVariance( mean=RequiredMeasure(Mean(axes=axes), tensor_id, kwargs.mode), @@ -367,11 +388,20 @@ def get_spec(self): ProcSpec = Union[AssertDtype, v0_4.Preprocessing, v0_4.Postprocessing, v0_5.Preprocessing, v0_5.Postprocessing] +# todo: + +class ProcSelector: + def __init__(proc_spec: ProcSpec) -> None: + self.proc_spec = proc_spec + + def get_impl(proc_spec: ProcSpec): if isinstance(proc_spec, AssertDtype): - return AssertDtypeImpl - elif isinstance(proc_spec, (v0_4.Binarize, v0_5.Binarize)): - return BinarizeImpl + return AssertDtypeImpl, AssertDtypeKwargs + elif isinstance(proc_spec, v0_4.Binarize): + return BinarizeImpl, v0_4.BinarizeKwargs + elif isinstance(proc_spec, v0_5.Binarize): + return BinarizeImpl, v0_5.BinarizeKwargs elif isinstance(proc_spec, (v0_4.Clip, v0_5.Clip)): return ClipImpl elif isinstance(proc_spec, v0_5.EnsureDtype): @@ -388,11 +418,12 @@ def get_impl(proc_spec: ProcSpec): return SigmoidImpl elif isinstance(proc_spec, v0_4.ZeroMeanUnitVariance) and proc_spec.kwargs.mode == "fixed": return FixedZeroMeanUnitVarianceImpl - elif isinstance(proc_spec, (v0_4.ZeroMeanUnitVariance, v0_5.ZeroMeanUnitVariance)): - return ZeroMeanUnitVarianceImpl + elif isinstance(proc_spec, # pyright: ignore[reportUnnecessaryIsInstance] + (v0_4.ZeroMeanUnitVariance, v0_5.ZeroMeanUnitVariance) + ): + return ZeroMeanUnitVarianceImpl else: - raise NotImplementedError(proc_spec) - + assert_never(proc_spec) Model = Union[v0_4.Model, v0_5.Model] @@ -403,8 +434,13 @@ def get_procs(model: Model): if not ipt.preprocessing: continue + assert isinstance(ipt, v0_5.InputTensor) for proc_spec in ipt.preprocessing: - impl = get_impl(proc_spec) + impl = get_impl(proc_spec, ipt.id, computed_measures) + assert isinstance(proc_spec.kwargs, ) + procs.append(impl(tensor_id=ipt.id, kwargs=proc_spec.kwargs)) + + return procs def _get_complement_axis(tensor: xr.DataArray, axes: Optional[Sequence[Hashable]]) -> Optional[Hashable]: From 34fde4c07daaa5f25aa444c874bdf596fe16aeba Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 25 Oct 2023 23:52:35 +0200 Subject: [PATCH 055/244] use common hashkwargs --- bioimageio/core/utils/__init__.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py index 1452042d..32b178be 100644 --- a/bioimageio/core/utils/__init__.py +++ b/bioimageio/core/utils/__init__.py @@ -9,16 +9,16 @@ from functools import singledispatch from pathlib import Path from types import TracebackType -from typing import Any, Callable, Optional +from typing import Any, Callable from urllib.parse import urlsplit, urlunsplit from pydantic import AnyUrl, HttpUrl +from typing_extensions import Unpack -from bioimageio.core.io import FileSource, download +from bioimageio.core.io import FileSource, HashKwargs, download from bioimageio.spec.model.v0_4 import CallableFromDepencency from bioimageio.spec.model.v0_4 import CallableFromFile as CallableFromFile04 from bioimageio.spec.model.v0_5 import CallableFromFile as CallableFromFile05 -from bioimageio.spec.model.v0_5 import Sha256 if sys.version_info < (3, 9): @@ -85,17 +85,17 @@ def import_from_dependency(node: CallableFromDepencency) -> Callable[..., Any]: @import_callable.register -def import_from_file04(node: CallableFromFile04, sha256: Optional[Sha256] = None): - return _import_from_file_impl(node.file, node.callable_name, sha256) +def import_from_file04(node: CallableFromFile04, **kwargs: Unpack[HashKwargs]): + return _import_from_file_impl(node.file, node.callable_name, **kwargs) @import_callable.register -def import_from_file05(node: CallableFromFile05, sha256: Optional[Sha256] = None): - return _import_from_file_impl(node.source_file, node.callable_name, sha256) +def import_from_file05(node: CallableFromFile05, **kwargs: Unpack[HashKwargs]): + return _import_from_file_impl(node.source_file, node.callable_name, **kwargs) -def _import_from_file_impl(source: FileSource, callable_name: str, sha256: Optional[Sha256]): - local_file = download(source, known_hash=None if sha256 is None else f"sha256:{sha256}") +def _import_from_file_impl(source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]): + local_file = download(source, **kwargs) module_name = local_file.path.stem importlib_spec = importlib.util.spec_from_file_location(module_name, local_file.path) if importlib_spec is None: From 4949558a04494082a6456c2a30ba5d4e832e2f25 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 25 Oct 2023 23:53:39 +0200 Subject: [PATCH 056/244] WIP update existing build_model --- bioimageio/core/build_spec/build_model.py | 38 +++++++++++++++++------ 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/bioimageio/core/build_spec/build_model.py b/bioimageio/core/build_spec/build_model.py index c1f52229..c7800124 100644 --- a/bioimageio/core/build_spec/build_model.py +++ b/bioimageio/core/build_spec/build_model.py @@ -2,20 +2,40 @@ import hashlib import os from pathlib import Path -from typing import Any, Dict, List, Optional, Union, get_args +from typing import Any, Dict, List, Optional, Sequence, TypedDict, Union, get_args from warnings import warn import imageio import numpy as np import requests import tifffile +from numpy.typing import NDArray + +# from bioimageio.core import export_resource_package, load_raw_resource_description +from pydantic import AnyUrl, HttpUrl +from typing_extensions import NotRequired, Unpack import bioimageio.spec as spec import bioimageio.spec.model as model_spec -from bioimageio.core import export_resource_package, load_raw_resource_description -from bioimageio.core.resource_io.nodes import URI -from bioimageio.spec.shared import resolve_local_source, resolve_source -from bioimageio.spec.shared.raw_nodes import ImportableModule, ImportableSourceFile +from bioimageio.core.io import FileSource, download +from bioimageio.core.utils import import_callable +from bioimageio.spec.model.v0_5 import ( + Author, + CiteEntry, + InputAxis, + InputTensor, + IntervalOrRatioData, + LicenseId, + Maintainer, + Model, + NominalOrOrdinalData, + NotEmpty, + OutputAxis, + Postprocessing, + Preprocessing, + TensorData, + TensorId, +) # # utility functions to build the spec from python @@ -581,12 +601,12 @@ def _ensure_local_or_url(source: Union[Path, URI, str, list], root: Path) -> Uni def build_model( # model or tensor specific and required - weight_uri: str, - test_inputs: List[Union[str, Path]], - test_outputs: List[Union[str, Path]], + weight_uri: FileSource, + test_inputs: List[FileSource], + test_outputs: List[FileSource], input_axes: List[str], output_axes: List[str], - # general required + # general metadata name: str, description: str, authors: List[Dict[str, str]], From a1b699e7861e87ece624d24581cdb91780ca3f31 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 27 Oct 2023 09:58:58 +0200 Subject: [PATCH 057/244] draft SpecBuilder and SpecGuesser --- bioimageio/core/build_spec/_build_spec.py | 1 + bioimageio/core/build_spec/build_model.py | 4 +- bioimageio/core/build_spec/get_description.py | 314 ++++++++++++++++++ .../core/model_adapters/_model_adapter.py | 11 +- 4 files changed, 323 insertions(+), 7 deletions(-) create mode 100644 bioimageio/core/build_spec/get_description.py diff --git a/bioimageio/core/build_spec/_build_spec.py b/bioimageio/core/build_spec/_build_spec.py index 311d67c1..29149f65 100644 --- a/bioimageio/core/build_spec/_build_spec.py +++ b/bioimageio/core/build_spec/_build_spec.py @@ -99,6 +99,7 @@ def get_data_description(kwargs: Union[CoreInputTensorKwargs, CoreOutputTensorKw dtype: IntervalOrRatioDType = dtype_str # type: ignore # validated by IntervalOrRatioData return IntervalOrRatioData(type=dtype, range=range_, unit="arbitrary unit", scale=1.0, offset=None) elif isinstance(tensor_data, collections.abc.Sequence): # pyright: ignore[reportUnnecessaryIsInstance] + assert all(isinstance(td, TensorData) for td in tensor_data) return tuple(tensor_data) else: raise TypeError(tensor_data) diff --git a/bioimageio/core/build_spec/build_model.py b/bioimageio/core/build_spec/build_model.py index c7800124..10c9dfa1 100644 --- a/bioimageio/core/build_spec/build_model.py +++ b/bioimageio/core/build_spec/build_model.py @@ -600,7 +600,7 @@ def _ensure_local_or_url(source: Union[Path, URI, str, list], root: Path) -> Uni def build_model( - # model or tensor specific and required + # model specific and required weight_uri: FileSource, test_inputs: List[FileSource], test_outputs: List[FileSource], @@ -636,7 +636,7 @@ def build_model( pixel_sizes: Optional[List[Dict[str, float]]] = None, # general optional maintainers: Optional[List[Dict[str, str]]] = None, - license: Optional[str] = None, + license: LicenseId = "CC-BY-4.0", covers: Optional[List[str]] = None, git_repo: Optional[str] = None, attachments: Optional[Dict[str, Union[str, List[str]]]] = None, diff --git a/bioimageio/core/build_spec/get_description.py b/bioimageio/core/build_spec/get_description.py new file mode 100644 index 00000000..950574bf --- /dev/null +++ b/bioimageio/core/build_spec/get_description.py @@ -0,0 +1,314 @@ +import hashlib +import shutil +import warnings +from datetime import datetime +from pathlib import Path +from typing import Any, List, Optional, Sequence, Type, TypedDict, Union + +import numpy as np +from numpy.typing import NDArray +from pydantic import FilePath + +# from bioimageio.core import export_resource_package, load_raw_resource_description +from typing_extensions import NotRequired, Self, Unpack + +from bioimageio.spec.description import ValidationContext +from bioimageio.core.io import FileSource, download, read_description_and_validate, write_description +from bioimageio.core.utils import get_sha256 +from bioimageio.spec.model.v0_5 import ( + Architecture, + Author, + CiteEntry, + Dependencies, + InputAxis, + InputTensor, + IntervalOrRatioData, + IntervalOrRatioDType, + LicenseId, + Maintainer, + Model, + NotEmpty, + OutputAxis, + OutputTensor, + Postprocessing, + Preprocessing, + PytorchStateDictWeights, + RelativeFilePath, + Sha256, + TensorData, + TensorId, + Version, + Weights, +) + + +class _CoreGenericBaseKwargs(TypedDict): + name: str + description: str + authors: NotEmpty[Sequence[Author]] + maintainers: NotRequired[Sequence[Maintainer]] + tags: Sequence[str] + documentation: FileSource + cite: NotEmpty[Sequence[CiteEntry]] + license: LicenseId + output_path: Path + + +class _TensorKwargs(TypedDict): + test_tensor: FileSource + sample_tensor: NotRequired[FileSource] + id: NotRequired[Optional[TensorId]] + data: NotRequired[Optional[Union[TensorData, NotEmpty[Sequence[TensorData]]]]] + output_path: Path + + +class _OutputTensorKwargs(_TensorKwargs): + axes: NotEmpty[Sequence[OutputAxis]] + postprocessing: NotRequired[Sequence[Postprocessing]] + + +class SpecBuilder: + def __init__(self, output_path: Path, output_path_exist_ok: bool = False) -> None: + super().__init__() + output_path.mkdir(parents=True, exist_ok=output_path_exist_ok) + self.output_path = output_path + + def include_file(self, source: FileSource) -> RelativeFilePath: + local_source = download(source) + try: + rel_path = local_source.path.relative_to( + self.output_path + ) # todo: improve for py >= 3.9 with path.is_relative_to + except ValueError: + # local source is not in output path + dest_path = self.output_path / local_source.original_file_name + if dest_path.exists(): + file_hash = get_sha256(local_source.path) + for i in range(10): + dest_hash = get_sha256(dest_path) + if dest_hash == file_hash: + break + + dest_path = dest_path.with_name(f"{dest_path.stem}-{i}{dest_path.suffix}") + if not dest_path.exists(): + break + else: + raise RuntimeError("Encountered too many unidentical files with the same file name.") + + if not dest_path.exists(): + _ = Path(shutil.copy(local_source.path, dest_path)) + + rel_path = dest_path.relative_to(self.output_path) + + return RelativeFilePath(rel_path) + +class ModelBuilder(SpecBuilder): + def add_cite(self): + self._cite.append(CiteEntry()) + + def add_input_tensor( + self, + *, + test_tensor: Union[NDArray[Any], FileSource], + axes: Sequence[InputAxis], + preprocessing: Sequence[Preprocessing], + id_: TensorId, + data: TensorData, + sample_tensor: Optional[FileSource], + ) -> InputTensor: + return InputTensor.model_validate(InputTensor( + test_tensor=self.include_file(test_tensor), + id=id_, + axes=tuple(axes), + preprocessing=tuple(preprocessing), + data=data, + sample_tensor=None + if sample_tensor is None + else self.include_file(sample_tensor) + ), context=self.context) + + # def add_input_tensor() + def add_cover_image(cover) + def build(self, output_path: Path, *, inputs: Sequence[InputTensor]): + + assert False + +mb = ModelBuilder(Path("output_path")) +mb.build(inputs=[mb.build_input_tensor(test_tensor=tt) for tt in test_tensors], outputs=based_on.outputs) + + + +class SpecGuesser: + @staticmethod + def guess_data_range(array: NDArray[Any]): + if np.issubdtype(array.dtype, np.floating) and array.min() >= 0.0 and array.max() <= 1.0: + return (0.0, 1.0) + else: + return (None, None) + + @classmethod + def guess_data_description(cls, test_tensor: FileSource): + try: + array: Union[Any, NDArray[Any]] = np.load(download(test_tensor).path) + if not isinstance(array, np.ndarray): + raise TypeError(f"Expected numpy array, but got {type(array)}") + except Exception as e: + warnings.warn(f"Could not guess data type of {test_tensor}: {e}") + return None + + dtype_str = str(array.dtype) + dtype: IntervalOrRatioDType = dtype_str # type: ignore # validated by IntervalOrRatioData + return IntervalOrRatioData( + type=dtype, range=cls.guess_data_range(array), unit="arbitrary unit", scale=1.0, offset=None + ) + + + +class SpecBuilderWithGuesses(SpecBuilder, SpecGuesser): + # def __init__(self, output_path: Path) -> None: + # super().__init__(output_path) + + def build_input_tensor( + self, + *, + test_tensor: FileSource, + axes: Sequence[InputAxis], + preprocessing: Sequence[Preprocessing], + id_: TensorId, + data: Optional[TensorData] = None, + sample_tensor: FileSource | None, + ) -> InputTensor: + return super().build_input_tensor( + test_tensor=test_tensor, + axes=axes, + preprocessing=preprocessing, + id_=id_, + data=data or self.guess_data_description(test_tensor), + sample_tensor=sample_tensor, + ) + + +def build_spec_interactively(output_path: Path): + guesser = SpecGuesser(output_path) + builder = SpecBuilder(output_path) + + +class _CoreOutputTensor(OutputTensor, _CoreTensorMixin, frozen=True): + @classmethod + def build(cls, **kwargs: Unpack[_OutputTensorKwargs]): + return cls( + test_tensor=ensure_file_in_folder(kwargs["test_tensor"], kwargs["output_path"]), + id=kwargs.get("id") or TensorId("output"), + axes=tuple(kwargs["axes"]), + postprocessing=tuple(kwargs.get("postprocessing", ())), + data=cls.get_data_description(kwargs), + ) + + +class CoreModelBaseKwargs(_CoreGenericBaseKwargs): + inputs: NotEmpty[Sequence[_InputTensorKwargs]] + outputs: NotEmpty[Sequence[_OutputTensorKwargs]] + + +class CoreModelKwargs(CoreModelBaseKwargs): + weights: Weights + + +class _CoreModel(Model, frozen=True): + @classmethod + def build(cls, **kwargs: Unpack[CoreModelKwargs]) -> Self: + documentation = ensure_file_in_folder(kwargs["documentation"], kwargs["output_path"]) + + inputs = tuple( + _CoreInputTensor.build( + id=t_kwargs["id"] if "id" in t_kwargs else TensorId(f"input{i}"), + test_tensor=t_kwargs["test_tensor"], + axes=t_kwargs["axes"], + data=t_kwargs.get("data"), + output_path=kwargs["output_path"], + ) + for i, t_kwargs in enumerate(kwargs["inputs"]) + ) + + outputs = tuple( + _CoreOutputTensor.build( + id=t_kwargs["id"] if "id" in t_kwargs else TensorId(f"output{i}"), + test_tensor=t_kwargs["test_tensor"], + axes=t_kwargs["axes"], + data=t_kwargs.get("data"), + output_path=kwargs["output_path"], + ) + for i, t_kwargs in enumerate(kwargs["outputs"]) + ) + + return cls( + name=kwargs["name"], + description=kwargs["description"], + authors=tuple(kwargs["authors"]), + maintainers=tuple(kwargs.get("maintainers", ())), + cite=tuple(kwargs["cite"]), + license=kwargs["license"], + timestamp=datetime.now(), + inputs=inputs, + outputs=outputs, + weights=kwargs["weights"], + documentation=documentation, + ) + + @classmethod + def build_from_pytorch_state_dict( + cls, + weights: FileSource, + architecture: Architecture, + sha256: Optional[Sha256] = None, + pytorch_version: Optional[Version] = None, + dependencies: Optional[Dependencies] = None, + **kwargs: Unpack[CoreModelBaseKwargs], + ): + if pytorch_version is None: + import torch + + pytorch_version = Version(torch.__version__) + + return cls.build( + weights=Weights( + pytorch_state_dict=PytorchStateDictWeights( + source=ensure_file_in_folder(weights, kwargs["output_path"]), + sha256=sha256, + architecture=architecture, + pytorch_version=pytorch_version, + dependencies=dependencies, + ) + ), + **kwargs, + ) + + +def _build_spec_common(core_descr: _CoreModel, descr_path: Path, expected_type: Type[Any]): + write_description(core_descr, descr_path) + loaded = read_description_and_validate(descr_path) + if type(loaded) is not expected_type: + raise RuntimeError(f"Created {descr_path} was loaded as {type(loaded)}, but expected {expected_type}") + + return descr_path, loaded + + +def build_model_spec( + *, + weights: FileSource, + architecture: Architecture, + sha256: Optional[Sha256] = None, + pytorch_version: Optional[Version] = None, + dependencies: Optional[Dependencies] = None, + **kwargs: Unpack[CoreModelBaseKwargs], +): + model = _CoreModel.build_from_pytorch_state_dict( + weights=weights, + architecture=architecture, + sha256=sha256, + pytorch_version=pytorch_version, + dependencies=dependencies, + **kwargs, + ) + + return _build_spec_common(model, kwargs["output_path"] / "description.bioimageio.yaml", Model) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index bcba49b6..a2809293 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -24,16 +24,17 @@ class ModelAdapter(ABC): """ Represents model *without* any preprocessing or postprocessing. + >>> from bioimageio.core import read_description + >>> model = read_description() >>> print("option 1:") option 1: - >>> adapter = create_model_adapter() - >>> adapter.load() + >>> adapter = ModelAdapter.create(model) >>> adapter.forward() >>> adapter.unload() >>> print("option 2:") option 2: - >>> with create_model_adapter_context() as adapter: - adapter.forward() + >>> with ModelAdapter.create(model) as adapter: + >>> adapter.forward() """ @@ -41,8 +42,8 @@ class ModelAdapter(ABC): @classmethod def create( cls, - *, model_description: Union[v0_4.Model, v0_5.Model], + *, devices: Optional[Sequence[str]] = None, weight_format_priority_order: NotEmpty[Sequence[WeightsFormat]] = DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER, ): From d5c682d8ac12269137c4225ec1a247befe313083 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 7 Nov 2023 09:17:35 +0100 Subject: [PATCH 058/244] WIP unfreeze aftermath --- bioimageio/core/build_spec/_build_spec.py | 6 +++--- bioimageio/core/io.py | 2 +- bioimageio/core/prediction_pipeline/processing.py | 6 +++--- tests/test_internal/test_validation_visitors.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/bioimageio/core/build_spec/_build_spec.py b/bioimageio/core/build_spec/_build_spec.py index 29149f65..158f4bca 100644 --- a/bioimageio/core/build_spec/_build_spec.py +++ b/bioimageio/core/build_spec/_build_spec.py @@ -105,7 +105,7 @@ def get_data_description(kwargs: Union[CoreInputTensorKwargs, CoreOutputTensorKw raise TypeError(tensor_data) -class _CoreInputTensor(InputTensor, _CoreTensorMixin, frozen=True): +class _CoreInputTensor(InputTensor, _CoreTensorMixin): @classmethod def build(cls, **kwargs: Unpack[CoreInputTensorKwargs]): return cls( @@ -120,7 +120,7 @@ def build(cls, **kwargs: Unpack[CoreInputTensorKwargs]): ) -class _CoreOutputTensor(OutputTensor, _CoreTensorMixin, frozen=True): +class _CoreOutputTensor(OutputTensor, _CoreTensorMixin): @classmethod def build(cls, **kwargs: Unpack[CoreOutputTensorKwargs]): return cls( @@ -141,7 +141,7 @@ class CoreModelKwargs(CoreModelBaseKwargs): weights: Weights -class _CoreModel(Model, frozen=True): +class _CoreModel(Model): @classmethod def build(cls, **kwargs: Unpack[CoreModelKwargs]) -> Self: documentation = ensure_file_in_folder(kwargs["documentation"], kwargs["output_path"]) diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index 306a4ffd..74e9964b 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -6,7 +6,7 @@ import shutil from dataclasses import dataclass from pathlib import Path -from tempfile import NamedTemporaryFile, TemporaryDirectory, mkdtemp +from tempfile import NamedTemporaryFile, mkdtemp from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, TextIO, TypedDict, Union, cast from zipfile import ZIP_DEFLATED, ZipFile, is_zipfile diff --git a/bioimageio/core/prediction_pipeline/processing.py b/bioimageio/core/prediction_pipeline/processing.py index acef48b5..ab467ca7 100644 --- a/bioimageio/core/prediction_pipeline/processing.py +++ b/bioimageio/core/prediction_pipeline/processing.py @@ -36,16 +36,16 @@ AssertProcessingId = Literal["assert_dtype"] -class AssertProcessingBase(NodeWithExplicitlySetFields, frozen=True): +class AssertProcessingBase(NodeWithExplicitlySetFields ): id: AssertProcessingId fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"}) -class AssertDtypeKwargs(v0_5.ProcessingKwargs, frozen=True): +class AssertDtypeKwargs(v0_5.ProcessingKwargs ): dtype: Union[str, Sequence[str]] -class AssertDtype(AssertProcessingBase, frozen=True): +class AssertDtype(AssertProcessingBase ): id: Literal["assert_dtype"] = "assert_dtype" kwargs: AssertDtypeKwargs diff --git a/tests/test_internal/test_validation_visitors.py b/tests/test_internal/test_validation_visitors.py index 9aff615a..7988f658 100644 --- a/tests/test_internal/test_validation_visitors.py +++ b/tests/test_internal/test_validation_visitors.py @@ -16,10 +16,10 @@ def _visit_int(self, nr: int, note: Note = Note()): super().visit(nr, note) self.errors.append(ErrorEntry(loc=note.loc, msg=f"nr: {nr}", type="got-int")) - class NestedNode(Node, frozen=True): + class NestedNode(Node): leaf: int - class MyNode(Node, frozen=True): + class MyNode(Node): nested: NestedNode tree = { From 2d5118e0821ea392c22ec097b686c8e5432d507a Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 14 Nov 2023 20:34:48 +0100 Subject: [PATCH 059/244] start moving io to spec --- bioimageio/core/__init__.py | 2 +- bioimageio/core/build_spec/_build_spec.py | 4 +- bioimageio/core/build_spec/get_description.py | 38 +- bioimageio/core/io.py | 332 +----------------- .../_prediction_pipeline.py | 2 +- bioimageio/core/resource_tests.py | 10 +- bioimageio/core/utils/__init__.py | 19 - .../node_visitor.py} | 41 ++- 8 files changed, 59 insertions(+), 389 deletions(-) rename bioimageio/core/{validation_visitors.py => utils/node_visitor.py} (91%) diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index e3a665c2..8fe9b81a 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -1,9 +1,9 @@ import json +from bioimageio.core.io import build_description_and_validate as build_description_and_validate from bioimageio.core.io import load_description as load_description from bioimageio.core.io import load_description_and_validate as load_description_and_validate from bioimageio.core.io import read_description as read_description -from bioimageio.core.io import read_description_and_validate as read_description_and_validate from bioimageio.core.io import resolve_source as resolve_source from bioimageio.core.io import write_description as write_description from bioimageio.core.io import write_package as write_package diff --git a/bioimageio/core/build_spec/_build_spec.py b/bioimageio/core/build_spec/_build_spec.py index 158f4bca..d0bfb0de 100644 --- a/bioimageio/core/build_spec/_build_spec.py +++ b/bioimageio/core/build_spec/_build_spec.py @@ -10,7 +10,7 @@ # from bioimageio.core import export_resource_package, load_raw_resource_description from typing_extensions import NotRequired, Self, Unpack -from bioimageio.core.io import FileSource, download, read_description_and_validate, write_description +from bioimageio.core.io import FileSource, download, load_description_and_validate, write_description from bioimageio.spec.model.v0_5 import ( Architecture, Author, @@ -213,7 +213,7 @@ def build_from_pytorch_state_dict( def _build_spec_common(core_descr: _CoreModel, descr_path: Path, expected_type: Type[Any]): write_description(core_descr, descr_path) - loaded = read_description_and_validate(descr_path) + loaded = load_description_and_validate(descr_path) if type(loaded) is not expected_type: raise RuntimeError(f"Created {descr_path} was loaded as {type(loaded)}, but expected {expected_type}") diff --git a/bioimageio/core/build_spec/get_description.py b/bioimageio/core/build_spec/get_description.py index 950574bf..c4cd0168 100644 --- a/bioimageio/core/build_spec/get_description.py +++ b/bioimageio/core/build_spec/get_description.py @@ -12,9 +12,9 @@ # from bioimageio.core import export_resource_package, load_raw_resource_description from typing_extensions import NotRequired, Self, Unpack -from bioimageio.spec.description import ValidationContext -from bioimageio.core.io import FileSource, download, read_description_and_validate, write_description +from bioimageio.core.io import FileSource, download, load_description_and_validate, write_description from bioimageio.core.utils import get_sha256 +from bioimageio.spec.description import ValidationContext from bioimageio.spec.model.v0_5 import ( Architecture, Author, @@ -102,6 +102,7 @@ def include_file(self, source: FileSource) -> RelativeFilePath: return RelativeFilePath(rel_path) + class ModelBuilder(SpecBuilder): def add_cite(self): self._cite.append(CiteEntry()) @@ -116,28 +117,28 @@ def add_input_tensor( data: TensorData, sample_tensor: Optional[FileSource], ) -> InputTensor: - return InputTensor.model_validate(InputTensor( - test_tensor=self.include_file(test_tensor), - id=id_, - axes=tuple(axes), - preprocessing=tuple(preprocessing), - data=data, - sample_tensor=None - if sample_tensor is None - else self.include_file(sample_tensor) - ), context=self.context) + return InputTensor.model_validate( + InputTensor( + test_tensor=self.include_file(test_tensor), + id=id_, + axes=tuple(axes), + preprocessing=tuple(preprocessing), + data=data, + sample_tensor=None if sample_tensor is None else self.include_file(sample_tensor), + ), + context=self.context, + ) # def add_input_tensor() - def add_cover_image(cover) + # def add_cover_image(cover) def build(self, output_path: Path, *, inputs: Sequence[InputTensor]): - assert False + mb = ModelBuilder(Path("output_path")) mb.build(inputs=[mb.build_input_tensor(test_tensor=tt) for tt in test_tensors], outputs=based_on.outputs) - class SpecGuesser: @staticmethod def guess_data_range(array: NDArray[Any]): @@ -163,7 +164,6 @@ def guess_data_description(cls, test_tensor: FileSource): ) - class SpecBuilderWithGuesses(SpecBuilder, SpecGuesser): # def __init__(self, output_path: Path) -> None: # super().__init__(output_path) @@ -193,7 +193,7 @@ def build_spec_interactively(output_path: Path): builder = SpecBuilder(output_path) -class _CoreOutputTensor(OutputTensor, _CoreTensorMixin, frozen=True): +class _CoreOutputTensor(OutputTensor, _CoreTensorMixin): @classmethod def build(cls, **kwargs: Unpack[_OutputTensorKwargs]): return cls( @@ -214,7 +214,7 @@ class CoreModelKwargs(CoreModelBaseKwargs): weights: Weights -class _CoreModel(Model, frozen=True): +class _CoreModel(Model): @classmethod def build(cls, **kwargs: Unpack[CoreModelKwargs]) -> Self: documentation = ensure_file_in_folder(kwargs["documentation"], kwargs["output_path"]) @@ -286,7 +286,7 @@ def build_from_pytorch_state_dict( def _build_spec_common(core_descr: _CoreModel, descr_path: Path, expected_type: Type[Any]): write_description(core_descr, descr_path) - loaded = read_description_and_validate(descr_path) + loaded = load_description_and_validate(descr_path) if type(loaded) is not expected_type: raise RuntimeError(f"Created {descr_path} was loaded as {type(loaded)}, but expected {expected_type}") diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index 74e9964b..0f0784e1 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -15,7 +15,6 @@ from ruamel.yaml import YAML from typing_extensions import NotRequired, Unpack -from bioimageio.core.utils import get_parent_url from bioimageio.spec import ResourceDescription from bioimageio.spec import load_description as load_description from bioimageio.spec._internal.base_nodes import ResourceDescriptionBase @@ -26,55 +25,22 @@ from bioimageio.spec.package import extract_file_name, get_resource_package_content from bioimageio.spec.summary import ValidationSummary -yaml = YAML(typ="safe") -StrictFileSource = Union[HttpUrl, FilePath] -FileSource = Union[StrictFileSource, str] -RdfSource = Union[FileSource, ResourceDescription] - -LEGACY_RDF_NAME = "rdf.yaml" - - -class HashKwargs(TypedDict): - sha256: NotRequired[Optional[Sha256]] - - -def get_known_hash(hash_kwargs: HashKwargs): - if "sha256" in hash_kwargs: - return f"sha256:{hash_kwargs['sha256']}" - else: - return None - - -def read_description( - rdf_source: FileSource, - /, - *, - format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, -) -> Union[ResourceDescription, InvalidDescription]: - rdf = download_rdf(rdf_source) - return load_description( - rdf.content, - context=ValidationContext(root=rdf.original_root, file_name=rdf.original_file_name), - format_version=format_version, - ) - - -def read_description_and_validate( - rdf_source: FileSource, +def load_description_and_validate( + source: FileSource, /, *, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, ) -> Union[ResourceDescription, InvalidDescription]: - rdf = download_rdf(rdf_source) - return load_description_and_validate( + rdf = download_rdf(source) + return build_description_and_validate( rdf.content, context=ValidationContext(root=rdf.original_root, file_name=rdf.original_file_name), format_version=format_version, ) -def load_description_and_validate( +def build_description_and_validate( rdf_content: RdfContent, /, *, @@ -88,295 +54,15 @@ def load_description_and_validate( def validate( - rdf_source: Union[FileSource, RdfContent], + source: RdfSource, /, *, context: Optional[ValidationContext] = None, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, ) -> List[ValidationSummary]: - if isinstance(rdf_source, dict): - rd = load_description_and_validate(rdf_source, context=context, format_version=format_version) + if isinstance(source, dict): + rd = build_description_and_validate(source, context=context, format_version=format_version) else: - rd = read_description_and_validate(rdf_source, format_version=format_version) + rd = load_description_and_validate(source, format_version=format_version) return rd.validation_summaries - - -def write_description(rd: Union[ResourceDescription, RdfContent], /, file: Union[FilePath, TextIO]): - if isinstance(rd, ResourceDescriptionBase): - content = dump_description(rd) - else: - content = rd - - if isinstance(file, Path): - with file.open("w", encoding="utf-8") as f: - yaml.dump(content, f) - else: - yaml.dump(content, file) - - -def prepare_resource_package( - rdf_source: RdfSource, - /, - *, - weights_priority_order: Optional[Sequence[WeightsFormat]] = None, -) -> Dict[FileName, Union[FilePath, RdfContent]]: - """Prepare to package a resource description; downloads all required files. - - Args: - rdf_source: A bioimage.io resource description (as file, raw YAML content or description class) - context: validation context - weights_priority_order: If given only the first weights format present in the model is included. - If none of the prioritized weights formats is found all are included. - """ - if isinstance(rdf_source, ResourceDescriptionBase): - rd = rdf_source - _ctxt = rd._internal_validation_context # pyright: ignore[reportPrivateUsage] - context = ValidationContext(root=_ctxt["root"], file_name=_ctxt["file_name"]) - else: - rdf = download_rdf(rdf_source) - context = ValidationContext(root=rdf.original_root, file_name=rdf.original_file_name) - rd = load_description( - rdf.content, - context=context, - ) - - if isinstance(rd, InvalidDescription): - raise ValueError(f"{rdf_source} is invalid: {rd.validation_summaries[0]}") - - package_content = get_resource_package_content(rd, weights_priority_order=weights_priority_order) - - local_package_content: Dict[FileName, Union[FilePath, RdfContent]] = {} - for k, v in package_content.items(): - if not isinstance(v, collections.abc.Mapping): - v = resolve_source(v, root=context.root) - - local_package_content[k] = v - - return local_package_content - - -def write_zip( - path: os.PathLike[str], - content: Mapping[FileName, Union[str, FilePath, Dict[Any, Any]]], - *, - compression: int, - compression_level: int, -) -> None: - """Write a zip archive. - - Args: - path: output path to write to. - content: dict mapping archive names to local file paths, strings (for text files), or dict (for yaml files). - compression: The numeric constant of compression method. - compression_level: Compression level to use when writing files to the archive. - See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile - - """ - with ZipFile(path, "w", compression=compression, compresslevel=compression_level) as myzip: - for arc_name, file in content.items(): - if isinstance(file, dict): - buf = io.StringIO() - YAML.dump(file, buf) - file = buf.getvalue() - - if isinstance(file, str): - myzip.writestr(arc_name, file.encode("utf-8")) - else: - myzip.write(file, arcname=arc_name) - - -def write_package( - rdf_source: RdfSource, - /, - *, - compression: int = ZIP_DEFLATED, - compression_level: int = 1, - output_path: Optional[os.PathLike[str]] = None, - weights_priority_order: Optional[ # model only - Sequence[ - Literal[ - "keras_hdf5", - "onnx", - "pytorch_state_dict", - "tensorflow_js", - "tensorflow_saved_model_bundle", - "torchscript", - ] - ] - ] = None, -) -> FilePath: - """Package a bioimage.io resource as a zip file. - - Args: - rd: bioimage.io resource description - compression: The numeric constant of compression method. - compression_level: Compression level to use when writing files to the archive. - See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile - output_path: file path to write package to - weights_priority_order: If given only the first weights format present in the model is included. - If none of the prioritized weights formats is found all are included. - - Returns: - path to zipped bioimage.io package in BIOIMAGEIO_CACHE_PATH or 'output_path' - """ - package_content = prepare_resource_package( - rdf_source, - weights_priority_order=weights_priority_order, - ) - if output_path is None: - output_path = Path(NamedTemporaryFile(suffix=".bioimageio.zip", delete=False).name) - else: - output_path = Path(output_path) - - write_zip(output_path, package_content, compression=compression, compression_level=compression_level) - return output_path - - -def write_package_as_folder( - rdf_source: RdfSource, - /, - *, - output_path: Optional[DirectoryPath] = None, - weights_priority_order: Optional[ # model only - Sequence[ - Literal[ - "keras_hdf5", - "onnx", - "pytorch_state_dict", - "tensorflow_js", - "tensorflow_saved_model_bundle", - "torchscript", - ] - ] - ] = None, -) -> DirectoryPath: - """Write the content of a bioimage.io resource package to a folder. - - Args: - rd: bioimage.io resource description - output_path: file path to write package to - weights_priority_order: If given only the first weights format present in the model is included. - If none of the prioritized weights formats is found all are included. - - Returns: - path to zipped bioimage.io package in BIOIMAGEIO_CACHE_PATH or 'output_path' - """ - package_content = prepare_resource_package( - rdf_source, - weights_priority_order=weights_priority_order, - ) - if output_path is None: - output_path = Path(mkdtemp()) - else: - output_path = Path(output_path) - - for name, source in package_content.items(): - if isinstance(source, collections.abc.Mapping): - yaml.dump(source, output_path / name) - else: - shutil.copy(source, output_path / name) - - return output_path - - -@dataclass -class DownloadedFile: - path: FilePath - original_root: Union[AnyUrl, DirectoryPath] - original_file_name: str - - -@dataclass -class DownloadedRdf: - content: RdfContent - original_root: Union[AnyUrl, DirectoryPath] - original_file_name: str - - -def download( - source: FileSource, - /, - **kwargs: Unpack[HashKwargs], -) -> DownloadedFile: - source = _interprete_file_source(source) - if isinstance(source, AnyUrl): - if source.scheme not in ("http", "https"): - raise NotImplementedError(source.scheme) - - if os.environ.get("CI", "false").lower() in ("1", "t", "true", "yes", "y"): - headers = {"User-Agent": "ci"} - progressbar = False - else: - headers = {} - progressbar = True - - if (user_agent := os.environ.get("BIOIMAGEIO_USER_AGENT")) is not None: - headers["User-Agent"] = user_agent - - downloader = pooch.HTTPDownloader(headers=headers, progressbar=progressbar) - _ls: Any = pooch.retrieve(url=str(source), known_hash=get_known_hash(kwargs), downloader=downloader) - local_source = Path(_ls) - root: Union[HttpUrl, DirectoryPath] = get_parent_url(source) - else: - local_source = source - root = source.parent - - return DownloadedFile( - local_source, - root, - extract_file_name(source), - ) - - -def download_rdf(source: FileSource, /, *, rdf_encoding: str = "utf-8", **kwargs: Unpack[HashKwargs]): - downloaded = download(source, **kwargs) - local_source = downloaded.path - if is_zipfile(local_source): - out_path = local_source.with_suffix(local_source.suffix + ".unzip") - with ZipFile(local_source, "r") as f: - rdfs = [fname for fname in f.namelist() if fname.endswith(".bioimageio.yaml")] - if len(rdfs) > 1: - raise ValueError(f"Multiple RDFs in one package not yet supported (found {rdfs}).") - elif len(rdfs) == 1: - rdf_file_name = rdfs[0] - elif LEGACY_RDF_NAME in f.namelist(): - rdf_file_name = LEGACY_RDF_NAME - else: - raise ValueError( - f"No RDF found in {local_source}. (Looking for any '*.bioimageio.yaml' file or an 'rdf.yaml' file)." - ) - - f.extractall(out_path) - local_source = out_path / rdf_file_name - - with local_source.open(encoding=rdf_encoding) as f: - content: YamlValue = yaml.load(f) - - if not isinstance(content, collections.abc.Mapping): - raise TypeError(f"Expected RDF content to be a mapping, but got '{type(content)}'.") - - return DownloadedRdf(cast(RdfContent, content), downloaded.original_root, downloaded.original_file_name) - - -def resolve_source( - source: Union[FileSource, RelativeFilePath], - /, - *, - root: Union[DirectoryPath, AnyUrl, None] = None, - **kwargs: Unpack[HashKwargs], -) -> FilePath: - if isinstance(source, RelativeFilePath): - if root is None: - raise ValueError(f"Cannot resolve relative file path '{source}' without root.") - - source = source.get_absolute(root) - - return download(source, **kwargs).path - - -def _interprete_file_source(file_source: FileSource) -> StrictFileSource: - return TypeAdapter(StrictFileSource).validate_python(file_source) - # todo: prettier file source validation error - # try: - # except ValidationError as e: diff --git a/bioimageio/core/prediction_pipeline/_prediction_pipeline.py b/bioimageio/core/prediction_pipeline/_prediction_pipeline.py index 8fffc887..483a7ff9 100644 --- a/bioimageio/core/prediction_pipeline/_prediction_pipeline.py +++ b/bioimageio/core/prediction_pipeline/_prediction_pipeline.py @@ -6,7 +6,7 @@ import xarray as xr from bioimageio.core.model_adapters import ModelAdapter, create_model_adapter -from bioimageio.core.validation_visitors import resolve_raw_node +from bioimageio.core.utils.node_visitor import resolve_raw_node from bioimageio.spec.model import AnyModel, raw_nodes from ._combined_processing import CombinedProcessing diff --git a/bioimageio/core/resource_tests.py b/bioimageio/core/resource_tests.py index 32556671..dabc72ff 100644 --- a/bioimageio/core/resource_tests.py +++ b/bioimageio/core/resource_tests.py @@ -9,11 +9,6 @@ import numpy import numpy as np import xarray as xr -from bioimageio.spec import __version__ as bioimageio_spec_version -from bioimageio.spec.model.raw_nodes import WeightsFormat -from bioimageio.spec.shared import resolve_source -from bioimageio.spec.shared.common import ValidationWarning -from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription from marshmallow import ValidationError from bioimageio.core import __version__ as bioimageio_core_version @@ -29,6 +24,11 @@ ParametrizedInputShape, ResourceDescription, ) +from bioimageio.spec import __version__ as bioimageio_spec_version +from bioimageio.spec.model.raw_nodes import WeightsFormat +from bioimageio.spec.shared import resolve_source +from bioimageio.spec.shared.common import ValidationWarning +from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription def test_model( diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py index 32b178be..eb1dbbfc 100644 --- a/bioimageio/core/utils/__init__.py +++ b/bioimageio/core/utils/__init__.py @@ -30,25 +30,6 @@ def files(package_name: str): from importlib.resources import files as files -def get_parent_url(url: HttpUrl) -> HttpUrl: - parsed = urlsplit(str(url)) - return AnyUrl( - urlunsplit((parsed.scheme, parsed.netloc, "/".join(parsed.path.split("/")[:-1]), parsed.query, parsed.fragment)) - ) - - -def get_sha256(path: os.PathLike[str]) -> str: - """from https://stackoverflow.com/a/44873382""" - h = hashlib.sha256() - b = bytearray(128 * 1024) - mv = memoryview(b) - with open(path, "rb", buffering=0) as f: - for n in iter(lambda: f.readinto(mv), 0): - h.update(mv[:n]) - - return h.hexdigest() - - class TemporaryInsertionIntoPythonPath(AbstractContextManager[None]): def __init__(self, path: Path): super().__init__() diff --git a/bioimageio/core/validation_visitors.py b/bioimageio/core/utils/node_visitor.py similarity index 91% rename from bioimageio/core/validation_visitors.py rename to bioimageio/core/utils/node_visitor.py index 15420017..e991f87c 100644 --- a/bioimageio/core/validation_visitors.py +++ b/bioimageio/core/utils/node_visitor.py @@ -1,12 +1,12 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass, replace from functools import singledispatchmethod from pathlib import Path, PurePath -from typing import Any, List, Optional, Tuple, TypedDict, Union +from typing import Any, List, Optional, Tuple, Union import requests from pydantic import AnyUrl, DirectoryPath from pydantic.fields import FieldInfo -from typing_extensions import NotRequired from bioimageio.core.utils import get_sha256 from bioimageio.spec._internal.base_nodes import Node @@ -15,10 +15,6 @@ from bioimageio.spec.summary import ErrorEntry, Loc, WarningEntry -class VisitorKwargs(TypedDict): - info: NotRequired[FieldInfo] - - @dataclass(frozen=True, **SLOTS, **KW_ONLY) class Memo: loc: Loc = () @@ -26,12 +22,7 @@ class Memo: parent_nodes: Tuple[Node, ...] = () -class ValidationVisitor: - def __init__(self) -> None: - super().__init__() - self.errors: List[ErrorEntry] = [] - self.warnings: List[WarningEntry] = [] - +class NodeVisitor: def visit(self, obj: Any, /, memo: Memo = Memo()): self._traverse(obj, memo=memo) @@ -66,20 +57,32 @@ def _traverse_dict(self, dict_: dict, memo: Memo): # type: ignore self.visit(v, replace(memo, loc=memo.loc + (k,))) -class SourceValidator(ValidationVisitor): - def __init__(self, root: Union[DirectoryPath, AnyUrl]) -> None: +class ValidationVisitor(NodeVisitor, ABC): + def __init__(self) -> None: super().__init__() - self.root = root + self.errors: List[ErrorEntry] = [] + self.warnings: List[WarningEntry] = [] def visit(self, obj: Any, /, memo: Memo = Memo()): - self._visit_impl(obj, memo=memo) + self.validate(obj, memo=memo) return super().visit(obj, memo) @singledispatchmethod - def _visit_impl(self, obj: type, /, memo: Memo): + @abstractmethod + def validate(self, obj: type, /, memo: Memo): + ... + + +class SourceValidator(ValidationVisitor): + def __init__(self, root: Union[DirectoryPath, AnyUrl]) -> None: + super().__init__() + self.root = root + + @singledispatchmethod + def validate(self, obj: type, /, memo: Memo): pass - @_visit_impl.register + @validate.register def _visit_path(self, path: PurePath, memo: Memo): if Path(path).exists(): for parent in memo.parent_nodes: @@ -120,7 +123,7 @@ def _visit_path(self, path: PurePath, memo: Memo): else: self.warnings.append(WarningEntry(loc=memo.loc, msg=msg, type="file_not_found")) - @_visit_impl.register + @validate.register def _visit_url(self, url: AnyUrl, memo: Memo): if url.scheme not in ("http", "https"): self.errors.append(ErrorEntry(loc=memo.loc, msg=f"invalid http(s) URL: {url}", type="url_scheme")) From a240d498ee8d592529ef3e1fef453552eac79b7a Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 14 Nov 2023 23:00:56 +0100 Subject: [PATCH 060/244] remove print --- tests/test_bioimageio_spec_version.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_bioimageio_spec_version.py b/tests/test_bioimageio_spec_version.py index 59cddd83..cbe10744 100644 --- a/tests/test_bioimageio_spec_version.py +++ b/tests/test_bioimageio_spec_version.py @@ -26,7 +26,6 @@ def test_bioimageio_spec_version(): # get currently pinned bioimageio.spec version meta = metadata("bioimageio.core") req = meta["Requires-Dist"] - print(req) assert req.startswith("bioimageio.spec ==") spec_ver = req[len("bioimageio.spec ==") :] assert spec_ver.count(".") == 2 From 30bbaaf184f7ecfede1e0d6a2579d674e3caca1c Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 15 Nov 2023 09:18:31 +0100 Subject: [PATCH 061/244] remove SourceValidator --- bioimageio/core/utils/node_visitor.py | 60 --------------------------- 1 file changed, 60 deletions(-) diff --git a/bioimageio/core/utils/node_visitor.py b/bioimageio/core/utils/node_visitor.py index e991f87c..dd523a7f 100644 --- a/bioimageio/core/utils/node_visitor.py +++ b/bioimageio/core/utils/node_visitor.py @@ -71,63 +71,3 @@ def visit(self, obj: Any, /, memo: Memo = Memo()): @abstractmethod def validate(self, obj: type, /, memo: Memo): ... - - -class SourceValidator(ValidationVisitor): - def __init__(self, root: Union[DirectoryPath, AnyUrl]) -> None: - super().__init__() - self.root = root - - @singledispatchmethod - def validate(self, obj: type, /, memo: Memo): - pass - - @validate.register - def _visit_path(self, path: PurePath, memo: Memo): - if Path(path).exists(): - for parent in memo.parent_nodes: - if "sha256" in parent.model_fields: - sha256: Union[None, Sha256] = parent.sha256 # type: ignore - break - else: - return - - actual_sha256 = get_sha256(path) - if sha256 is None: - self.warnings.append( - WarningEntry( - loc=memo.loc, - msg=( - f"Cannot validate file integrity (`sha256` not specified). " - f"File {path} has SHA-256: {actual_sha256}" - ), - type="unknown_hash", - ) - ) - elif actual_sha256 != sha256: - self.errors.append( - ErrorEntry( - loc=memo.loc, - msg=f"SHA-256 mismatch: actual ({actual_sha256}) != specified ({sha256})", - type="hash_mismatch", - ) - ) - else: - msg = f"{path} not found" - if ( - memo.info - and isinstance(memo.info.description, str) - and memo.info.description.startswith(IN_PACKAGE_MESSAGE) - ): - self.errors.append(ErrorEntry(loc=memo.loc, msg=msg, type="file_not_found")) - else: - self.warnings.append(WarningEntry(loc=memo.loc, msg=msg, type="file_not_found")) - - @validate.register - def _visit_url(self, url: AnyUrl, memo: Memo): - if url.scheme not in ("http", "https"): - self.errors.append(ErrorEntry(loc=memo.loc, msg=f"invalid http(s) URL: {url}", type="url_scheme")) - else: - response = requests.head(str(url)) - if response.status_code != 200: - self.errors.append(ErrorEntry(loc=memo.loc, msg=response.reason, type="url_unavailable")) From a5ccb8751d5de09574b1b173e52773f9ab23983a Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 15 Nov 2023 13:24:05 +0100 Subject: [PATCH 062/244] WIP start cleaning up keras->tf converter --- .../core/weight_converter/keras/tensorflow.py | 63 +++++++++---------- 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/bioimageio/core/weight_converter/keras/tensorflow.py b/bioimageio/core/weight_converter/keras/tensorflow.py index 0656ac9e..38af1ea2 100644 --- a/bioimageio/core/weight_converter/keras/tensorflow.py +++ b/bioimageio/core/weight_converter/keras/tensorflow.py @@ -4,38 +4,41 @@ from typing import Union from zipfile import ZipFile -import bioimageio.spec as spec -from bioimageio.core import load_resource_description - import tensorflow from tensorflow import saved_model +from bioimageio.spec import AnyModel, load_description +from bioimageio.spec._internal.io_utils import download + -def _zip_weights(output_path): - zipped_model = f"{output_path}.zip" - # zip the weights - file_paths = [] - for folder_names, subfolder, filenames in os.walk(os.path.join(output_path)): - for filename in filenames: - # create complete filepath of file in directory - file_paths.append(os.path.join(folder_names, filename)) +def _zip_model_bundle(model_bundle_folder: Path): + zipped_model_bundle = f"{model_bundle_folder}.zip" - with ZipFile(zipped_model, "w") as zip_obj: - for f in file_paths: - # Add file to zip - zip_obj.write(f, os.path.relpath(f, output_path)) + with ZipFile(zipped_model_bundle, "w") as zip_obj: + for root, _, files in os.walk(model_bundle_folder): + for filename in files: + src = os.path.join(root, filename) + zip_obj.write(src, os.path.relpath(src, model_bundle_folder)) try: - shutil.rmtree(output_path) + shutil.rmtree(model_bundle_folder) except Exception: print("TensorFlow bundled model was not removed after compression") - return zipped_model + return zipped_model_bundle # adapted from # https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236 -def _convert_tf1(keras_weight_path, output_path, input_name, output_name, zip_weights): +def _convert_tf1(keras_weight_path: Path, output_path: Path, input_name: str, output_name: str, zip_weights: bool): + try: + # try to build the tf model with the keras import from tensorflow + from tensorflow import keras + + except Exception: + # if the above fails try to export with the standalone keras + import keras + def build_tf_model(): keras_model = keras.models.load_model(keras_weight_path) @@ -51,18 +54,10 @@ def build_tf_model(): ) builder.save() - try: - # try to build the tf model with the keras import from tensorflow - from tensorflow import keras - build_tf_model() - except Exception: - # if the above fails try to export with the standalone keras - import keras - - build_tf_model() + build_tf_model() if zip_weights: - output_path = _zip_weights(output_path) + output_path = _zip_model_bundle(output_path) print("TensorFlow model exported to", output_path) return 0 @@ -80,14 +75,14 @@ def _convert_tf2(keras_weight_path, output_path, zip_weights): keras.models.save_model(model, output_path) if zip_weights: - output_path = _zip_weights(output_path) + output_path = _zip_model_bundle(output_path) print("TensorFlow model exported to", output_path) return 0 def convert_weights_to_tensorflow_saved_model_bundle( - model_spec: Union[str, Path, spec.model.raw_nodes.Model], output_path: Union[str, Path] + model_spec: Union[str, Path, AnyModel], output_path: Union[str, Path] ): """Convert model weights from format 'keras_hdf5' to 'tensorflow_saved_model_bundle'. @@ -110,10 +105,10 @@ def convert_weights_to_tensorflow_saved_model_bundle( if path_.exists(): raise ValueError(f"The ouptut directory at {path_} must not exist.") - model = load_resource_description(model_spec) - assert "keras_hdf5" in model.weights - weight_spec = model.weights["keras_hdf5"] - weight_path = str(weight_spec.source) + model = load_description(model_spec) + model.weights.keras_hdf5 is not None + weight_spec = model.weights.keras_hdf5 + weight_path = download(weight_spec.source).path if weight_spec.tensorflow_version: model_tf_major_ver = int(weight_spec.tensorflow_version.major) From 3e57ef41363bcaf61519c51bef888a25db33102b Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 20 Nov 2023 14:58:18 +0100 Subject: [PATCH 063/244] WIP various updates --- .../_tensorflow_model_adapter.py | 10 ++--- .../core/prediction_pipeline/processing.py | 22 ++++++----- .../core/weight_converter/torch/onnx.py | 38 ++++++++++--------- .../weight_converter/torch/torchscript.py | 7 ++-- 4 files changed, 41 insertions(+), 36 deletions(-) diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index 929cd57c..9a97a6e7 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -21,12 +21,12 @@ def __init__( *, devices: Optional[Sequence[str]] = None, weights: Union[ - v0_4.KerasHdf5Weights, - v0_4.TensorflowSavedModelBundleWeights, - v0_5.KerasHdf5Weights, - v0_5.TensorflowSavedModelBundleWeights, + v0_4.KerasHdf5WeightsDescr, + v0_4.TensorflowSavedModelBundleWeightsDescr, + v0_5.KerasHdf5WeightsDescr, + v0_5.TensorflowSavedModelBundleWeightsDescr, ], - model_description: Union[v0_4.Model, v0_5.Model], + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], ): super().__init__() self.model_description = model_description diff --git a/bioimageio/core/prediction_pipeline/processing.py b/bioimageio/core/prediction_pipeline/processing.py index ab467ca7..1ef1b88c 100644 --- a/bioimageio/core/prediction_pipeline/processing.py +++ b/bioimageio/core/prediction_pipeline/processing.py @@ -36,16 +36,16 @@ AssertProcessingId = Literal["assert_dtype"] -class AssertProcessingBase(NodeWithExplicitlySetFields ): +class AssertProcessingBase(NodeWithExplicitlySetFields): id: AssertProcessingId fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"}) -class AssertDtypeKwargs(v0_5.ProcessingKwargs ): +class AssertDtypeKwargs(v0_5.ProcessingKwargs): dtype: Union[str, Sequence[str]] -class AssertDtype(AssertProcessingBase ): +class AssertDtype(AssertProcessingBase): id: Literal["assert_dtype"] = "assert_dtype" kwargs: AssertDtypeKwargs @@ -190,7 +190,6 @@ def get_spec(self): return v0_5.EnsureDtype(kwargs=self.kwargs) -class ScaleLinearImplBase class ScaleLinearImpl04(ProcessingImplBaseWoMeasures[Union[v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs]]): def apply(self, tensor: xr.DataArray) -> xr.DataArray: axis = ( @@ -390,6 +389,7 @@ def get_spec(self): # todo: + class ProcSelector: def __init__(proc_spec: ProcSpec) -> None: self.proc_spec = proc_spec @@ -418,13 +418,15 @@ def get_impl(proc_spec: ProcSpec): return SigmoidImpl elif isinstance(proc_spec, v0_4.ZeroMeanUnitVariance) and proc_spec.kwargs.mode == "fixed": return FixedZeroMeanUnitVarianceImpl - elif isinstance(proc_spec, # pyright: ignore[reportUnnecessaryIsInstance] - (v0_4.ZeroMeanUnitVariance, v0_5.ZeroMeanUnitVariance) - ): - return ZeroMeanUnitVarianceImpl + elif isinstance( + proc_spec, # pyright: ignore[reportUnnecessaryIsInstance] + (v0_4.ZeroMeanUnitVariance, v0_5.ZeroMeanUnitVariance), + ): + return ZeroMeanUnitVarianceImpl else: assert_never(proc_spec) + Model = Union[v0_4.Model, v0_5.Model] @@ -437,7 +439,9 @@ def get_procs(model: Model): assert isinstance(ipt, v0_5.InputTensor) for proc_spec in ipt.preprocessing: impl = get_impl(proc_spec, ipt.id, computed_measures) - assert isinstance(proc_spec.kwargs, ) + assert isinstance( + proc_spec.kwargs, + ) procs.append(impl(tensor_id=ipt.id, kwargs=proc_spec.kwargs)) return procs diff --git a/bioimageio/core/weight_converter/torch/onnx.py b/bioimageio/core/weight_converter/torch/onnx.py index 6f9ac1d2..104421a1 100644 --- a/bioimageio/core/weight_converter/torch/onnx.py +++ b/bioimageio/core/weight_converter/torch/onnx.py @@ -1,47 +1,49 @@ import warnings from pathlib import Path -from typing import Union, Optional +from typing import Optional, Union import numpy as np import torch from numpy.testing import assert_array_almost_equal -import bioimageio.spec as spec -from bioimageio.core import load_resource_description -from bioimageio.core.resource_io import nodes -from .utils import load_model +from bioimageio.spec import load_description +from bioimageio.spec._internal.types import BioimageioYamlSource +from bioimageio.spec.model import v0_4, v0_5 try: import onnxruntime as rt except ImportError: rt = None +# def add_converted_onnx_weights(model_spec: AnyModel, *, opset_version: Optional[int] = 12, use_tracing: bool = True, +# verbose: bool = True, +# test_decimal: int = 4): -def convert_weights_to_onnx( - model_spec: Union[str, Path, spec.model.raw_nodes.Model], - output_path: Union[str, Path], - opset_version: Optional[int] = 12, + +# def add_onnx_weights_from_pytorch_state_dict(model_spec: Union[BioimageioYamlSource, AnyModel], test_decimals: int = 4): + + +def add_onnx_weights( + source_model: Union[BioimageioYamlSource, AnyModel], + *, use_tracing: bool = True, - verbose: bool = True, - test_decimal: int = 4 + test_decimal: int = 4, ): """Convert model weights from format 'pytorch_state_dict' to 'onnx'. Args: - model_spec: location of the resource for the input bioimageio model - output_path: where to save the onnx weights + source_model: model without onnx weights opset_version: onnx opset version use_tracing: whether to use tracing or scripting to export the onnx format - verbose: be verbose during the onnx export test_decimal: precision for testing whether the results agree """ - if isinstance(model_spec, (str, Path)): - model_spec = load_resource_description(Path(model_spec)) + if isinstance(source_model, (str, Path)): + model = load_description(Path(source_model)) + assert isinstance(model, (v0_4.Model, v0_5.Model)) - assert isinstance(model_spec, nodes.Model) with torch.no_grad(): # load input and expected output data - input_data = [np.load(inp).astype("float32") for inp in model_spec.test_inputs] + input_data = [np.load(ipt).astype("float32") for ipt in model.test_inputs] input_tensors = [torch.from_numpy(inp) for inp in input_data] # instantiate and generate the expected output diff --git a/bioimageio/core/weight_converter/torch/torchscript.py b/bioimageio/core/weight_converter/torch/torchscript.py index 7da79bfe..3feca51c 100644 --- a/bioimageio/core/weight_converter/torch/torchscript.py +++ b/bioimageio/core/weight_converter/torch/torchscript.py @@ -1,5 +1,4 @@ import warnings - from pathlib import Path from typing import Union @@ -8,7 +7,8 @@ from numpy.testing import assert_array_almost_equal import bioimageio.spec as spec -from bioimageio.core import load_resource_description +from bioimageio.spec import load_description + from .utils import load_model @@ -57,7 +57,6 @@ def _check(input_): # check that input and output agree for decreasing input sizes while True: - slice_ = tuple(slice(None) if st == 0 else slice(step_factor * st, -step_factor * st) for st in half_step) this_input = [inp[slice_] for inp in input_data] this_shape = this_input[0].shape @@ -83,7 +82,7 @@ def convert_weights_to_torchscript( use_tracing: whether to use tracing or scripting to export the torchscript format """ if isinstance(model_spec, (str, Path)): - model_spec = load_resource_description(Path(model_spec)) + model_spec = load_description(Path(model_spec)) with torch.no_grad(): # load input and expected output data From 6c9bf0fbaca05e1a8b98453842a39245e4fc90ab Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 20 Nov 2023 14:58:31 +0100 Subject: [PATCH 064/244] WIP add model_utils --- bioimageio/core/model_utils.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 bioimageio/core/model_utils.py diff --git a/bioimageio/core/model_utils.py b/bioimageio/core/model_utils.py new file mode 100644 index 00000000..2c8dd51f --- /dev/null +++ b/bioimageio/core/model_utils.py @@ -0,0 +1,31 @@ +from functools import singledispatch +from typing import Any, List, Union + +import numpy as np +import xarray as xr +from numpy.typing import NDArray + +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + +# @singledispatch +# def is_valid_tensor(description: object, tensor: Union[NDArray[Any], xr.DataArray]) -> bool: +# raise NotImplementedError(type(description)) + +# is_valid_tensor.register +# def _(description: v0_4.InputTensor, tensor: Union[NDArray[Any], xr.DataArray]): + + +@singledispatch +def get_test_input_tensors(model: object) -> List[xr.DataArray]: + raise NotImplementedError(type(model)) + + +@get_test_input_tensors.register +def _(model: v0_4.Model): + data = [np.load(download(ipt).path) for ipt in model.test_inputs] + assert all(isinstance(d, np.ndarray) for d in data) + + +# @get_test_input_tensors.register +# def _(model: v0_5.Model): From 26c6f5895923afebf3246fcee09c9af5f3493c4c Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Mon, 20 Nov 2023 16:11:24 +0100 Subject: [PATCH 065/244] Fixes type-checking with spec v0.5 in torchscript model adapter --- .../core/model_adapters/_torchscript_model_adapter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index 9fe183bb..804c8503 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -7,7 +7,7 @@ import xarray as xr from numpy.typing import NDArray -from bioimageio.core.io import download +from bioimageio.spec.utils import download from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import RelativeFilePath @@ -15,7 +15,7 @@ class TorchscriptModelAdapter(ModelAdapter): - def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None): + def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None): super().__init__() if model_description.weights.torchscript is None: raise ValueError(f"No torchscript weights found for model {model_description.name}") @@ -32,7 +32,7 @@ def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: if len(self.devices) > 1: warnings.warn("Multiple devices for single torchscript model not yet implemented") - self._model = torch.jit.load(weight_path) # pyright: ignore[reportPrivateImportUsage] + self._model = torch.jit.load(weight_path) self._model.to(self.devices[0]) self._internal_output_axes = [ tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes) From 32e53f50cf90706a14932444f48c0a834f0f47ac Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Mon, 20 Nov 2023 16:45:55 +0100 Subject: [PATCH 066/244] Fixes some typing issues in tensorflow adapter with spec v0.5 --- .../core/model_adapters/_tensorflow_model_adapter.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index 9a97a6e7..5ecd3cb3 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -6,7 +6,8 @@ import tensorflow as tf import xarray as xr -from bioimageio.core.io import FileSource, download +from bioimageio.spec.utils import download +from bioimageio.spec.generic.v0_3 import FileSource #FIXME: getre-export from somewhere? from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import RelativeFilePath @@ -87,10 +88,10 @@ def _get_network(self, weight_file: FileSource): # alive in between of forward passes (but then the sessions need to be properly opened / closed) def _forward_tf(self, *input_tensors): input_keys = [ - ipt.name if isinstance(ipt, v0_4.InputTensor) else ipt.id for ipt in self.model_description.inputs + ipt.name if isinstance(ipt, v0_4.InputTensorDescr) else ipt.id for ipt in self.model_description.inputs ] output_keys = [ - out.name if isinstance(out, v0_4.OutputTensor) else out.id for out in self.model_description.outputs + out.name if isinstance(out, v0_4.OutputTensorDescr) else out.id for out in self.model_description.outputs ] # TODO read from spec @@ -148,7 +149,7 @@ def unload(self) -> None: class TensorflowModelAdapter(TensorflowModelAdapterBase): weight_format = "tensorflow_saved_model_bundle" - def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None): + def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None): if model_description.weights.tensorflow_saved_model_bundle is None: raise ValueError("missing tensorflow_saved_model_bundle weights") @@ -162,7 +163,7 @@ def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: class KerasModelAdapter(TensorflowModelAdapterBase): weight_format = "keras_hdf5" - def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None): + def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None): if model_description.weights.keras_hdf5 is None: raise ValueError("missing keras_hdf5 weights") From f7469cb75c9d7c7e5bb6db6bda7040616db5fd3f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 22 Nov 2023 10:35:19 +0100 Subject: [PATCH 067/244] moved demo notebook to spec --- example/demo.ipynb | 322 --------------------------------------------- 1 file changed, 322 deletions(-) delete mode 100644 example/demo.ipynb diff --git a/example/demo.ipynb b/example/demo.ipynb deleted file mode 100644 index 2bbc64ba..00000000 --- a/example/demo.ipynb +++ /dev/null @@ -1,322 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from bioimageio.spec.pretty_validation_errors import enable_pretty_validation_errors_in_ipynb\n", - "\n", - "enable_pretty_validation_errors_in_ipynb()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from bioimageio.spec.model.v0_4 import Model\n", - "from bioimageio.core import read_description\n", - "\n", - "from pydantic import HttpUrl\n", - "\n", - "model = read_description(HttpUrl(\"https://bioimage-io.github.io/collection-bioimage-io/rdfs/10.5281/zenodo.6334383/7805067/rdf.yaml\"))\n", - "assert isinstance(model, Model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(model.validation_summaries[0])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model.weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import xarray as xr\n", - "import numpy as np\n", - "\n", - "gain = [1, 2, 3]\n", - "tensor = xr.DataArray(np.random.randn(2, 3, 2, 2), dims=(\"b\", \"c\", \"y\", \"x\"))\n", - "axes = (\"b\", \"x\", \"y\")\n", - "scale_axes = tuple(a for a in tensor.dims if a not in axes)\n", - "b = xr.DataArray([1, 2, 3], dims=scale_axes)\n", - "\n", - "a * b" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a.mean(dim=(\"x\", \"y\"))\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gain = (2, 1, 3)\n", - "offset = (3, 0, 1)\n", - "\n", - "print(a * gain + offset)\n", - "axes = (\"x\", \"y\")\n", - "tmp = a.stack(temp=axes) * gain + offset\n", - "print(tmp)\n", - "print()\n", - "tmp.unstack(\"temp\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# from bioimageio.spec.model.v0_5 import Author, CiteEntry, Model\n", - "\n", - "# # id: raw\n", - "# # description: raw input\n", - "# # axes:\n", - "# # - type: batch\n", - "# # - type: channel\n", - "# # channel_names: [raw_intensity]\n", - "# # - type: space # todo: scale/unit\n", - "# # name: y\n", - "# # size: 512\n", - "# # - type: space\n", - "# # name: x\n", - "# # size: 512\n", - "# # test_tensor: test_input.npy\n", - "# # sample_tensor: test_input.npy\n", - "# # preprocessing: # list of preprocessing steps\n", - "# # - id: zero_mean_unit_variance # name of preprocessing step\n", - "# # kwargs:\n", - "# # mode: per_sample\n", - "# # axes: [x, y]\n", - "\n", - "# # outputs:\n", - "# # - id: probability\n", - "# # description: probability in [0,1]\n", - "# # data:\n", - "# # type: float32\n", - "# # range:\n", - "# # - 0.0\n", - "# # - 1.0\n", - "# # axes:\n", - "# # - type: batch\n", - "# # - type: channel\n", - "# # channel_names: [probability]\n", - "# # - type: space\n", - "# # name: y\n", - "# # size: raw.y\n", - "# # halo: 32\n", - "# # - type: space\n", - "# # size: raw.x\n", - "# # name: x\n", - "# # halo: 32\n", - "# # test_tensor: test_output.npy\n", - "# # sample_tensor: test_output.npy\n", - "\n", - "# # weights:\n", - "# # pytorch_state_dict:\n", - "# # authors:\n", - "# # - name: \"Constantin Pape;@bioimage-io\"\n", - "# # affiliation: \"EMBL Heidelberg\"\n", - "# # orcid: \"0000-0001-6562-7187\"\n", - "# # sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2\n", - "# # source: https://zenodo.org/record/3446812/files/unet2d_weights.torch\n", - "# # architecture:\n", - "# # callable: unet2d.py:UNet2d\n", - "# # sha256: cf42a6d86adeb4eb6e8e37b539a20e5413866b183bed88f4e2e26ad1639761ed\n", - "# # kwargs: { input_channels: 1, output_channels: 1 }\n", - "# # dependencies: conda:environment.yaml\n", - "# # pytorch_version: \"1.5.1\"\n", - "# # onnx:\n", - "# # sha256: f1f086d5e340f9d4d7001a1b62a2b835f9b87a2fb5452c4fe7d8cc821bdf539c\n", - "# # source: weights.onnx\n", - "# # opset_version: 12\n", - "# # parent: pytorch_state_dict\n", - "# # torchscript:\n", - "# # sha256: 62fa1c39923bee7d58a192277e0dd58f2da9ee810662addadd0f44a3784d9210\n", - "# # source: weights.pt\n", - "# # parent: pytorch_state_dict\n", - "# # pytorch_version: \"1.5.1\"\n", - "\n", - "\n", - "# my_model = Model(\n", - "# name=\"UNet 2D Nuclei Broad\",\n", - "# version=\"0.2.0\",\n", - "# description=\"A 2d U-Net trained on the nuclei broad dataset.\",\n", - "# documentation=\"README.md\",\n", - "# authors=(\n", - "# Author(\n", - "# name=\"Constantin Pape\",\n", - "# affiliation=\"EMBL Heidelberg\",\n", - "# orcid=\"0000-0001-6562-7187\",\n", - "# ),\n", - "# Author(\n", - "# name=\"Fynn Beuttenmueller\",\n", - "# affiliation=\"EMBL Heidelberg\",\n", - "# orcid=\"0000-0002-8567-6389\",\n", - "# ),\n", - "# ),\n", - "# cite=(CiteEntry(text=\"bioimage.io\", doi=\"10.1101/2022.06.07.495102\"),),\n", - "# inputs=(),\n", - "# outputs=(),\n", - "# timestamp=\"2019-12-11T12:22:32\",\n", - "# training_data={\"id\": \"ilastik/covid_if_training_data\"}, # note: not the real training data\n", - "# license=\"MIT\",\n", - "# )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "my_model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import xarray as xr\n", - "\n", - "a = xr.DataArray([[1, 2], [3, 4]], dims=(\"x\", \"y\"))\n", - "a[{\"x\": slice(None)}]" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "bio38", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.17" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} From fd8b5ca522f94abfd9c0a9878adb1dac9a60142e Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Wed, 22 Nov 2023 14:38:57 +0100 Subject: [PATCH 068/244] Fixes typing in pytorch model adapter --- .../core/model_adapters/_pytorch_model_adapter.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 7e5fd706..5810da18 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -7,6 +7,7 @@ from bioimageio.core.utils import import_callable from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download from ._model_adapter import ModelAdapter @@ -15,8 +16,8 @@ class PytorchModelAdapter(ModelAdapter): def __init__( self, *, - outputs: Union[Sequence[v0_4.OutputTensor], Sequence[v0_5.OutputTensor]], - weights: Union[v0_4.PytorchStateDictWeights, v0_5.PytorchStateDictWeights], + outputs: Union[Sequence[v0_4.OutputTensorDescr], Sequence[v0_5.OutputTensorDescr]], + weights: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr], devices: Optional[Sequence[str]] = None, ): super().__init__() @@ -25,7 +26,7 @@ def __init__( self._devices = self.get_devices(devices) self._network = self._network.to(self._devices[0]) - state: Any = torch.load(weights.source, map_location=self._devices[0]) + state: Any = torch.load(download(weights.source).path, map_location=self._devices[0]) _ = self._network.load_state_dict(state) self._network = self._network.eval() @@ -50,16 +51,16 @@ def unload(self) -> None: torch.cuda.empty_cache() # release reserved memory @staticmethod - def get_network(weight_spec: Union[v0_4.PytorchStateDictWeights, v0_5.PytorchStateDictWeights]): + def get_network(weight_spec: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr]): arch = import_callable( weight_spec.architecture, sha256=weight_spec.architecture_sha256 - if isinstance(weight_spec, v0_4.PytorchStateDictWeights) + if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) else weight_spec.sha256, ) model_kwargs = ( weight_spec.kwargs - if isinstance(weight_spec, v0_4.PytorchStateDictWeights) + if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) else weight_spec.architecture.kwargs ) network = arch(**model_kwargs) From 2defed2430969618537a5a1fabd4e44a7e505c18 Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Wed, 22 Nov 2023 14:58:55 +0100 Subject: [PATCH 069/244] Fixes load_model pytorch helper --- .../model_adapters/_pytorch_model_adapter.py | 2 +- bioimageio/core/weight_converter/torch/utils.py | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 5810da18..3a9a3109 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -51,7 +51,7 @@ def unload(self) -> None: torch.cuda.empty_cache() # release reserved memory @staticmethod - def get_network(weight_spec: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr]): + def get_network(weight_spec: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr]) -> torch.nn.Module: arch = import_callable( weight_spec.architecture, sha256=weight_spec.architecture_sha256 diff --git a/bioimageio/core/weight_converter/torch/utils.py b/bioimageio/core/weight_converter/torch/utils.py index 9c122ad5..14e08514 100644 --- a/bioimageio/core/weight_converter/torch/utils.py +++ b/bioimageio/core/weight_converter/torch/utils.py @@ -1,12 +1,15 @@ import torch -from bioimageio.core.prediction_pipeline._model_adapters._pytorch_model_adapter import PytorchModelAdapter + +from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + # additional convenience for pytorch state dict, eventually we want this in python-bioimageio too # and for each weight format -def load_model(node): - model = PytorchModelAdapter.get_nn_instance(node) - state = torch.load(node.weights["pytorch_state_dict"].source, map_location="cpu") - model.load_state_dict(state) - model.eval() - return model +def load_model(node: "v0_4.PytorchStateDictWeightsDescr | v0_5.PytorchStateDictWeightsDescr"): + model = PytorchModelAdapter.get_network(node) + state = torch.load(download(node.source).path, map_location="cpu") + _ = model.load_state_dict(state) #FIXME: check incompatible keys? + return model.eval() From e2da2c44e29fdd6b2caa8c285564bea0944c433e Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Wed, 22 Nov 2023 16:16:50 +0100 Subject: [PATCH 070/244] Mostly fixes typing in torchscript converter. Missing impl for v0_5 --- bioimageio/core/__main__.py | 4 +- .../weight_converter/torch/torchscript.py | 113 ++++++++++-------- 2 files changed, 63 insertions(+), 54 deletions(-) diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index 75da0316..aabd2b05 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -280,8 +280,8 @@ def convert_torch_weights_to_torchscript( output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."), use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."), ): - ret_code = torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing) - sys.exit(ret_code) + torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing) + sys.exit(0) convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_torchscript.__doc__ diff --git a/bioimageio/core/weight_converter/torch/torchscript.py b/bioimageio/core/weight_converter/torch/torchscript.py index 3feca51c..0ebe6201 100644 --- a/bioimageio/core/weight_converter/torch/torchscript.py +++ b/bioimageio/core/weight_converter/torch/torchscript.py @@ -1,4 +1,5 @@ -import warnings +from typing import List, Sequence +from typing_extensions import Any from pathlib import Path from typing import Union @@ -6,73 +7,65 @@ import torch from numpy.testing import assert_array_almost_equal -import bioimageio.spec as spec from bioimageio.spec import load_description +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec import load_description +from bioimageio.spec.common import InvalidDescription +from bioimageio.spec.utils import download from .utils import load_model - -def _check_predictions(model, scripted_model, model_spec, input_data): - assert isinstance(input_data, list) - - def _check(input_): +# FIXME: remove Any +def _check_predictions(model: Any, scripted_model: Any, model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", input_data: Sequence[torch.Tensor]): + def _check(input_: Sequence[torch.Tensor]) -> None: # get the expected output to validate the torchscript weights - expected_outputs = model(*input_) - if isinstance(expected_outputs, (torch.Tensor)): - expected_outputs = [expected_outputs] - expected_outputs = [out.numpy() for out in expected_outputs] + expected_tensors = model(*input_) + if isinstance(expected_tensors, torch.Tensor): + expected_tensors = [expected_tensors] + expected_outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in expected_tensors] - outputs = scripted_model(*input_) - if isinstance(outputs, (torch.Tensor)): - outputs = [outputs] - outputs = [out.numpy() for out in outputs] + output_tensors = scripted_model(*input_) + if isinstance(output_tensors, torch.Tensor): + output_tensors = [output_tensors] + outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in output_tensors] try: for exp, out in zip(expected_outputs, outputs): assert_array_almost_equal(exp, out, decimal=4) - return 0 except AssertionError as e: - msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}" - warnings.warn(msg) - return 1 + raise ValueError(f"Results before and after weights conversion do not agree:\n {str(e)}") - ret = _check(input_data) - n_inputs = len(model_spec.inputs) - # check has not passed or we have more tahn one input? then return immediately - if ret == 1 or n_inputs > 1: - return ret + _check(input_data) + + if len(model_spec.inputs) > 1: + return # FIXME: why don't we check multiple inputs? # do we have fixed input size or variable? # if variable, we need to check multiple sizes! - shape_spec = model_spec.inputs[0].shape - try: # we have a variable shape - min_shape = shape_spec.min - step = shape_spec.step - except AttributeError: # we have fixed shape - return ret + input_descr = model_spec.inputs[0] + if isinstance(input_descr, v0_4.InputTensorDescr): + if not isinstance(input_descr.shape, v0_4.ParametrizedInputShape): + return + min_shape = input_descr.shape.min + step = input_descr.shape.step + else: + raise NotImplementedError("FIXME: Can't handle v0.5 parameterized inputs yet") half_step = [st // 2 for st in step] max_steps = 4 step_factor = 1 # check that input and output agree for decreasing input sizes - while True: + for step_factor in range(1, max_steps + 1): slice_ = tuple(slice(None) if st == 0 else slice(step_factor * st, -step_factor * st) for st in half_step) this_input = [inp[slice_] for inp in input_data] this_shape = this_input[0].shape if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)): - return ret - - ret = _check(this_input) - if ret == 1: - return ret - step_factor += 1 - if step_factor > max_steps: - return ret - + raise ValueError(f"Mismatched shapes: {this_shape}. Expected at least {min_shape}") + _check(this_input) def convert_weights_to_torchscript( - model_spec: Union[str, Path, spec.model.raw_nodes.Model], output_path: Union[str, Path], use_tracing: bool = True + model_spec: Union[str, Path, v0_4.ModelDescr, v0_5.ModelDescr], output_path: Path, use_tracing: bool = True ): """Convert model weights from format 'pytorch_state_dict' to 'torchscript'. @@ -82,24 +75,40 @@ def convert_weights_to_torchscript( use_tracing: whether to use tracing or scripting to export the torchscript format """ if isinstance(model_spec, (str, Path)): - model_spec = load_description(Path(model_spec)) + loaded_spec = load_description(Path(model_spec)) + if isinstance(loaded_spec, InvalidDescription): + raise ValueError(f"Bad resource description: {loaded_spec}") + if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)): + raise TypeError(f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr") + model_spec = loaded_spec + + state_dict_weights_descr = model_spec.weights.pytorch_state_dict + if state_dict_weights_descr is None: + raise ValueError(f"The provided model does not have weights in the pytorch state dict format") with torch.no_grad(): - # load input and expected output data - input_data = [np.load(inp).astype("float32") for inp in model_spec.test_inputs] + if isinstance(model_spec, v0_4.ModelDescr): + downloaded_test_inputs = [download(inp) for inp in model_spec.test_inputs] + else: + downloaded_test_inputs = [inp.test_tensor.download() for inp in model_spec.inputs] + + input_data = [np.load(dl.path).astype("float32") for dl in downloaded_test_inputs] input_data = [torch.from_numpy(inp) for inp in input_data] - # instantiate model and get reference output - model = load_model(model_spec) + model = load_model(state_dict_weights_descr) - # make scripted model + # FIXME: remove Any if use_tracing: - scripted_model = torch.jit.trace(model, input_data) + scripted_model: Any = torch.jit.trace(model, input_data) else: - scripted_model = torch.jit.script(model) - - # check the scripted model - ret = _check_predictions(model, scripted_model, model_spec, input_data) + scripted_model: Any = torch.jit.script(model) + + ret = _check_predictions( + model=model, + scripted_model=scripted_model, + model_spec=model_spec, + input_data=input_data + ) # save the torchscript model scripted_model.save(str(output_path)) # does not support Path, so need to cast to str From 02bfedbe054fbb8f0c5375761f86fc725704824a Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Thu, 23 Nov 2023 13:35:25 +0100 Subject: [PATCH 071/244] Adds more convertion logic in torchcript converter. Fixes test typing --- .../weight_converter/torch/torchscript.py | 22 ++++++++++++++----- .../torch/test_torchscript.py | 6 ++++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/bioimageio/core/weight_converter/torch/torchscript.py b/bioimageio/core/weight_converter/torch/torchscript.py index 0ebe6201..bace789e 100644 --- a/bioimageio/core/weight_converter/torch/torchscript.py +++ b/bioimageio/core/weight_converter/torch/torchscript.py @@ -1,5 +1,5 @@ from typing import List, Sequence -from typing_extensions import Any +from typing_extensions import Any, assert_never from pathlib import Path from typing import Union @@ -18,7 +18,6 @@ # FIXME: remove Any def _check_predictions(model: Any, scripted_model: Any, model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", input_data: Sequence[torch.Tensor]): def _check(input_: Sequence[torch.Tensor]) -> None: - # get the expected output to validate the torchscript weights expected_tensors = model(*input_) if isinstance(expected_tensors, torch.Tensor): expected_tensors = [expected_tensors] @@ -40,8 +39,6 @@ def _check(input_: Sequence[torch.Tensor]) -> None: if len(model_spec.inputs) > 1: return # FIXME: why don't we check multiple inputs? - # do we have fixed input size or variable? - # if variable, we need to check multiple sizes! input_descr = model_spec.inputs[0] if isinstance(input_descr, v0_4.InputTensorDescr): if not isinstance(input_descr.shape, v0_4.ParametrizedInputShape): @@ -49,11 +46,24 @@ def _check(input_: Sequence[torch.Tensor]) -> None: min_shape = input_descr.shape.min step = input_descr.shape.step else: - raise NotImplementedError("FIXME: Can't handle v0.5 parameterized inputs yet") + min_shape: List[int] = [] + step: List[int] = [] + for axis in input_descr.axes: + if isinstance(axis.size, v0_5.ParameterizedSize): + min_shape.append(axis.size.min) + step.append(axis.size.step) + elif isinstance(axis.size, int): + min_shape.append(axis.size) + step.append(0) + elif isinstance(axis.size, (v0_5.AxisId, v0_5.TensorAxisId, type(None))): + raise NotImplementedError(f"Can't verify inputs that don't specify their shape fully: {axis}") + elif isinstance(axis.size, v0_5.SizeReference): # pyright: ignore [reportUnnecessaryIsInstance] + raise NotImplementedError(f"Can't handle axes like '{axis}' yet") + else: + assert_never(axis.size) half_step = [st // 2 for st in step] max_steps = 4 - step_factor = 1 # check that input and output agree for decreasing input sizes for step_factor in range(1, max_steps + 1): diff --git a/tests/weight_converter/torch/test_torchscript.py b/tests/weight_converter/torch/test_torchscript.py index 5c879577..2c1e47d2 100644 --- a/tests/weight_converter/torch/test_torchscript.py +++ b/tests/weight_converter/torch/test_torchscript.py @@ -1,4 +1,8 @@ -def test_torchscript_converter(any_torch_model, tmp_path): +from pathlib import Path +from bioimageio.spec.model import v0_4, v0_5 + + +def test_torchscript_converter(any_torch_model: "v0_4.ModelDescr | v0_5.ModelDescr", tmp_path: Path): from bioimageio.core.weight_converter.torch import convert_weights_to_torchscript out_path = tmp_path / "weights.pt" From ce3e0145bb1be788d0ac0a10d3d567b7f542cf05 Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Thu, 23 Nov 2023 15:15:34 +0100 Subject: [PATCH 072/244] Fixes typing in onnx conversion function --- .../core/weight_converter/torch/onnx.py | 77 +++++++++++-------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/bioimageio/core/weight_converter/torch/onnx.py b/bioimageio/core/weight_converter/torch/onnx.py index 104421a1..acdecc41 100644 --- a/bioimageio/core/weight_converter/torch/onnx.py +++ b/bioimageio/core/weight_converter/torch/onnx.py @@ -1,33 +1,25 @@ import warnings from pathlib import Path -from typing import Optional, Union +from typing import Any, Dict, List, Sequence, cast import numpy as np import torch from numpy.testing import assert_array_almost_equal from bioimageio.spec import load_description -from bioimageio.spec._internal.types import BioimageioYamlSource from bioimageio.spec.model import v0_4, v0_5 - -try: - import onnxruntime as rt -except ImportError: - rt = None - -# def add_converted_onnx_weights(model_spec: AnyModel, *, opset_version: Optional[int] = 12, use_tracing: bool = True, -# verbose: bool = True, -# test_decimal: int = 4): - - -# def add_onnx_weights_from_pytorch_state_dict(model_spec: Union[BioimageioYamlSource, AnyModel], test_decimals: int = 4): - +from bioimageio.core.weight_converter.torch.utils import load_model +from bioimageio.spec.common import InvalidDescription +from bioimageio.spec.utils import download def add_onnx_weights( - source_model: Union[BioimageioYamlSource, AnyModel], + model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr", *, + output_path: Path, use_tracing: bool = True, test_decimal: int = 4, + verbose: bool = False, + opset_version: "int | None" = None, ): """Convert model weights from format 'pytorch_state_dict' to 'onnx'. @@ -37,42 +29,59 @@ def add_onnx_weights( use_tracing: whether to use tracing or scripting to export the onnx format test_decimal: precision for testing whether the results agree """ - if isinstance(source_model, (str, Path)): - model = load_description(Path(source_model)) - assert isinstance(model, (v0_4.Model, v0_5.Model)) + if isinstance(model_spec, (str, Path)): + loaded_spec = load_description(Path(model_spec)) + if isinstance(loaded_spec, InvalidDescription): + raise ValueError(f"Bad resource description: {loaded_spec}") + if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)): + raise TypeError(f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr") + model_spec = loaded_spec + + state_dict_weights_descr = model_spec.weights.pytorch_state_dict + if state_dict_weights_descr is None: + raise ValueError(f"The provided model does not have weights in the pytorch state dict format") with torch.no_grad(): - # load input and expected output data - input_data = [np.load(ipt).astype("float32") for ipt in model.test_inputs] + if isinstance(model_spec, v0_4.ModelDescr): + downloaded_test_inputs = [download(inp) for inp in model_spec.test_inputs] + else: + downloaded_test_inputs = [inp.test_tensor.download() for inp in model_spec.inputs] + + input_data: List[np.ndarray[Any, Any]] = [np.load(dl.path).astype("float32") for dl in downloaded_test_inputs] input_tensors = [torch.from_numpy(inp) for inp in input_data] - # instantiate and generate the expected output - model = load_model(model_spec) - expected_outputs = model(*input_tensors) - if isinstance(expected_outputs, torch.Tensor): - expected_outputs = [expected_outputs] - expected_outputs = [out.numpy() for out in expected_outputs] + model = load_model(state_dict_weights_descr) + + expected_tensors = model(*input_tensors) + if isinstance(expected_tensors, torch.Tensor): + expected_tensors = [expected_tensors] + expected_outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in expected_tensors] if use_tracing: torch.onnx.export( model, - input_tensors if len(input_tensors) > 1 else input_tensors[0], - output_path, + tuple(input_tensors) if len(input_tensors) > 1 else input_tensors[0], + str(output_path), verbose=verbose, opset_version=opset_version, ) else: raise NotImplementedError - if rt is None: + try: + import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs] + except ImportError: msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked." warnings.warn(msg) - return 1 + return # check the onnx model - sess = rt.InferenceSession(str(output_path)) # does not support Path, so need to cast to str - onnx_inputs = {input_name.name: inp for input_name, inp in zip(sess.get_inputs(), input_data)} - outputs = sess.run(None, onnx_inputs) + sess = rt.InferenceSession(str(output_path)) + onnx_input_node_args = cast(List[Any], sess.get_inputs()) # fixme: remove cast, try using rt.NodeArg instead of Any + onnx_inputs: Dict[str, np.ndarray[Any, Any]] = { + input_name.name: inp for input_name, inp in zip(onnx_input_node_args, input_data) + } + outputs = cast(Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs)) #FIXME: remove cast try: for exp, out in zip(expected_outputs, outputs): From 7a94ce99e11a37116e81cf24ac2c87f1df43de3d Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 24 Nov 2023 10:42:28 +0100 Subject: [PATCH 073/244] WIP restructure API --- bioimageio/core/common.py | 45 ++ bioimageio/core/image_helper.py | 38 +- bioimageio/core/io.py | 13 +- ...ion_pipeline.py => prediction_pipeline.py} | 9 +- .../core/prediction_pipeline/__init__.py | 4 - .../_combined_processing.py | 105 ----- bioimageio/core/prediction_pipeline/_utils.py | 65 --- .../processing.py => proc_impl.py} | 188 +++++---- bioimageio/core/proc_setup.py | 75 ++++ bioimageio/core/stat_calculators.py | 396 ++++++++++++++++++ ...atistical_measures.py => stat_measures.py} | 13 +- .../_stat_state.py => stat_state.py} | 63 +-- bioimageio/core/statistical_measure_groups.py | 340 --------------- bioimageio/core/{model_utils.py => utils.py} | 3 +- tests/test_image_helper.py | 6 +- 15 files changed, 697 insertions(+), 666 deletions(-) create mode 100644 bioimageio/core/common.py rename bioimageio/core/{prediction_pipeline/_prediction_pipeline.py => prediction_pipeline.py} (97%) delete mode 100644 bioimageio/core/prediction_pipeline/__init__.py delete mode 100644 bioimageio/core/prediction_pipeline/_combined_processing.py delete mode 100644 bioimageio/core/prediction_pipeline/_utils.py rename bioimageio/core/{prediction_pipeline/processing.py => proc_impl.py} (73%) create mode 100644 bioimageio/core/proc_setup.py create mode 100644 bioimageio/core/stat_calculators.py rename bioimageio/core/{statistical_measures.py => stat_measures.py} (78%) rename bioimageio/core/{prediction_pipeline/_stat_state.py => stat_state.py} (52%) delete mode 100644 bioimageio/core/statistical_measure_groups.py rename bioimageio/core/{model_utils.py => utils.py} (90%) diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py new file mode 100644 index 00000000..96c71592 --- /dev/null +++ b/bioimageio/core/common.py @@ -0,0 +1,45 @@ +from typing import Any, Dict, Generic, List, Literal, NamedTuple, TypeVar, Union + +import numpy as np +import xarray as xr +from attr import dataclass +from typing_extensions import Final + +from bioimageio.core.stat_measures import Measure +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import TensorId + +TensorId = v0_5.TensorId +AxisId = v0_5.AxisId + +Sample = Dict[TensorId, xr.DataArray] + +ProcessingDescrBase = Union[v0_4.ProcessingDescrBase, v0_5.ProcessingDescrBase] +ProcessingKwargs = Union[v0_4.ProcessingKwargs, v0_5.ProcessingKwargs] + +PER_SAMPLE = "per_sample" +PER_DATASET = "per_dataset" + + +MeasureVar = TypeVar("MeasureVar", bound=Measure) +ModeVar = TypeVar("ModeVar", Literal["per_sample"], Literal["per_dataset"]) + + +@dataclass(frozen=True) +class RequiredMeasure(Generic[MeasureVar, ModeVar]): + measure: MeasureVar + tensor_id: TensorId + mode: ModeVar + + +@dataclass(frozen=True) +class SampleMeasure(RequiredMeasure[MeasureVar, Literal["per_sample"]]): + pass + + +@dataclass(frozen=True) +class DatasetMeasure(RequiredMeasure[MeasureVar, Literal["per_dataset"]]): + pass + + +MeasureValue = xr.DataArray diff --git a/bioimageio/core/image_helper.py b/bioimageio/core/image_helper.py index e26d5e4c..8c25b832 100644 --- a/bioimageio/core/image_helper.py +++ b/bioimageio/core/image_helper.py @@ -2,16 +2,17 @@ import os from copy import deepcopy -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, TypeVar, Union import imageio import numpy as np +from numpy.typing import NDArray +from xarray import DataArray + from bioimageio.spec.model.v0_4 import InputTensor as InputTensor04 from bioimageio.spec.model.v0_4 import OutputTensor as OutputTensor04 from bioimageio.spec.model.v0_5 import InputTensor as InputTensor05 from bioimageio.spec.model.v0_5 import OutputTensor as OutputTensor05 -from numpy.typing import NDArray -from xarray import DataArray InputTensor = Union[InputTensor04, InputTensor05] OutputTensor = Union[OutputTensor04, OutputTensor05] @@ -22,34 +23,37 @@ # -def transform_input_image(image: NDArray, tensor_axes: str, image_axes: Optional[str] = None): - """Transform input image into output tensor with desired axes. +DType = TypeVar("DType", bound=np.dtype) + + +def transpose_image(image: NDArray[DType], desired_axes: str, current_axes: Optional[str] = None) -> NDArray[DType]: + """Transform an image to match desired axes. Args: image: the input image - tensor_axes: the desired tensor axes - input_axes: the axes of the input image (optional) + desired_axes: the desired image axes + current_axes: the axes of the input image """ # if the image axes are not given deduce them from the required axes and image shape - if image_axes is None: - has_z_axis = "z" in tensor_axes + if current_axes is None: + has_z_axis = "z" in desired_axes ndim = image.ndim if ndim == 2: - image_axes = "yx" + current_axes = "yx" elif ndim == 3: - image_axes = "zyx" if has_z_axis else "cyx" + current_axes = "zyx" if has_z_axis else "cyx" elif ndim == 4: - image_axes = "czyx" + current_axes = "czyx" elif ndim == 5: - image_axes = "bczyx" + current_axes = "bczyx" else: raise ValueError(f"Invalid number of image dimensions: {ndim}") - tensor = DataArray(image, dims=tuple(image_axes)) + tensor = DataArray(image, dims=tuple(current_axes)) # expand the missing image axes - missing_axes = tuple(set(tensor_axes) - set(image_axes)) + missing_axes = tuple(set(desired_axes) - set(current_axes)) tensor = tensor.expand_dims(dim=missing_axes) # transpose to the correct axis order - tensor = tensor.transpose(*tuple(tensor_axes)) + tensor = tensor.transpose(*tuple(desired_axes)) # return numpy array return tensor.values @@ -103,7 +107,7 @@ def load_image(in_path, axes: Sequence[str]) -> DataArray: else: is_volume = "z" in axes im = imageio.volread(in_path) if is_volume else imageio.imread(in_path) - im = transform_input_image(im, axes) + im = transpose_image(im, axes) return DataArray(im, dims=axes) diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index 0f0784e1..53d54b01 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -20,7 +20,7 @@ from bioimageio.spec._internal.base_nodes import ResourceDescriptionBase from bioimageio.spec._internal.constants import DISCOVER from bioimageio.spec._internal.types import FileName, RdfContent, RelativeFilePath, Sha256, ValidationContext, YamlValue -from bioimageio.spec.description import InvalidDescription, dump_description +from bioimageio.spec.common import BioimageioYamlContent, FileSource, InvalidDescription from bioimageio.spec.model.v0_4 import WeightsFormat from bioimageio.spec.package import extract_file_name, get_resource_package_content from bioimageio.spec.summary import ValidationSummary @@ -32,23 +32,24 @@ def load_description_and_validate( *, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, ) -> Union[ResourceDescription, InvalidDescription]: - rdf = download_rdf(source) + opened = open_bioimageio_yaml(source) + return build_description_and_validate( - rdf.content, - context=ValidationContext(root=rdf.original_root, file_name=rdf.original_file_name), + opened.content, + context=ValidationContext(root=opened.original_root, file_name=opened.original_file_name), format_version=format_version, ) def build_description_and_validate( - rdf_content: RdfContent, + data: BioimageioYamlContent, /, *, context: Optional[ValidationContext] = None, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, ) -> Union[ResourceDescription, InvalidDescription]: """load and validate a BioImage.IO description from the content of a resource description file (RDF)""" - rd = load_description(rdf_content, context=context, format_version=format_version) + descr = build_description(rdf_content, context=context, format_version=format_version) # todo: add dynamic validation return rd diff --git a/bioimageio/core/prediction_pipeline/_prediction_pipeline.py b/bioimageio/core/prediction_pipeline.py similarity index 97% rename from bioimageio/core/prediction_pipeline/_prediction_pipeline.py rename to bioimageio/core/prediction_pipeline.py index 483a7ff9..fbe10e74 100644 --- a/bioimageio/core/prediction_pipeline/_prediction_pipeline.py +++ b/bioimageio/core/prediction_pipeline.py @@ -6,19 +6,20 @@ import xarray as xr from bioimageio.core.model_adapters import ModelAdapter, create_model_adapter +from bioimageio.core.model_adapters import get_weight_formats as get_weight_formats from bioimageio.core.utils.node_visitor import resolve_raw_node from bioimageio.spec.model import AnyModel, raw_nodes from ._combined_processing import CombinedProcessing -from ._stat_state import StatsState from ._utils import ComputedMeasures, Sample, TensorName +from .stat_state import StatsState @dataclass class NamedImplicitOutputShape: - reference_input: TensorName = missing - scale: List[Tuple[str, float]] = missing - offset: List[Tuple[str, int]] = missing + reference_input: TensorName + scale: List[Tuple[str, float]] + offset: List[Tuple[str, int]] def __len__(self): return len(self.scale) diff --git a/bioimageio/core/prediction_pipeline/__init__.py b/bioimageio/core/prediction_pipeline/__init__.py deleted file mode 100644 index 78ce5590..00000000 --- a/bioimageio/core/prediction_pipeline/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from bioimageio.core.model_adapters import get_weight_formats as get_weight_formats - -from ._prediction_pipeline import PredictionPipeline as PredictionPipeline -from ._prediction_pipeline import create_prediction_pipeline as create_prediction_pipeline diff --git a/bioimageio/core/prediction_pipeline/_combined_processing.py b/bioimageio/core/prediction_pipeline/_combined_processing.py deleted file mode 100644 index 1cfdd66a..00000000 --- a/bioimageio/core/prediction_pipeline/_combined_processing.py +++ /dev/null @@ -1,105 +0,0 @@ -import dataclasses -from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Union - -from bioimageio.core.resource_io import nodes - -from ._processing import AssertDtype, EnsureDtype, Processing -from ._utils import PER_DATASET, PER_SAMPLE, ComputedMeasures, RequiredMeasures, Sample -from .processing import ProcessingImplBase, get_impl, NamedMeasures, ProcSpec, M -from bioimageio.spec.model.v0_5 import TensorId - - -@dataclass -class CombinedMeasures(NamedMeasures[M]): - step_specs: Sequence[ProcSpec] - steps: ProcessingImplBase[Any, Any, Any] - def get_set(self) -> Set[M]: - ret = set() - for step in self.steps: - for f in fields(step) - return {f"{}getattr(self, f.name) for f in fields(self)} - - -@dataclasses.dataclass -class ProcessingInfo: - steps: List[Processing] - # assert_dtype_before: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match - ensure_dtype_before: Optional[str] = None # cast data type if needed - # assert_dtype_after: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match - ensure_dtype_after: Optional[str] = None # throw AssertionError if data type doesn't match - - -class CombinedProcessing: - def __init__(self, steps: List[]: Dict[TensorId, ProcessingInfo]): - self._procs = [] - - # ensure all tensors have correct data type before any processing - for tensor_name, info in combine_tensors.items(): - if info.assert_dtype_before is not None: - self._procs.append(AssertDtype(tensor_name=tensor_name, dtype=info.assert_dtype_before)) - - if info.ensure_dtype_before is not None: - self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.ensure_dtype_before)) - - for tensor_name, info in combine_tensors.items(): - for step in info.steps: - - self._procs.append((tensor_name=tensor_name, **step.kwargs)) - - if info.assert_dtype_after is not None: - self._procs.append(AssertDtype(tensor_name=tensor_name, dtype=info.assert_dtype_after)) - - # ensure tensor has correct data type right after its processing - if info.ensure_dtype_after is not None: - self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.ensure_dtype_after)) - - self.required_measures: RequiredMeasures = self._collect_required_measures(self._procs) - self.tensor_names = list(combine_tensors) - - @classmethod - def from_tensor_specs(cls, tensor_specs: List[Union[nodes.InputTensor, nodes.OutputTensor]]): - combine_tensors = {} - for ts in tensor_specs: - # There is a difference between pre-and postprocessing: - # After preprocessing we ensure float32, because the output is consumed by the model. - # After postprocessing the dtype that is specified in the model spec needs to be ensured. - assert ts.name not in combine_tensors - if isinstance(ts, nodes.InputTensor): - # todo: assert nodes.InputTensor.dtype with assert_dtype_before? - # todo: in the long run we do not want to limit model inputs to float32... - combine_tensors[ts.name] = ProcessingInfo( - [ProcessingInfoStep(p.name, kwargs=p.kwargs) for p in ts.preprocessing or []], - ensure_dtype_after="float32", - ) - elif isinstance(ts, nodes.OutputTensor): - combine_tensors[ts.name] = ProcessingInfo( - [ProcessingInfoStep(p.name, kwargs=p.kwargs) for p in ts.postprocessing or []], - ensure_dtype_after=ts.data_type, - ) - else: - raise NotImplementedError(type(ts)) - - inst = cls(combine_tensors) - for ts in tensor_specs: - if isinstance(ts, nodes.OutputTensor) and ts.name in inst.required_measures[PER_DATASET]: - raise NotImplementedError("computing statistics for output tensors per dataset is not yet implemented") - - return inst - - def apply(self, sample: Sample, computed_measures: ComputedMeasures) -> None: - for proc in self._procs: - proc.set_computed_measures(computed_measures) - sample[proc.tensor_name] = proc.apply(sample[proc.tensor_name]) - - @staticmethod - def _collect_required_measures(proc: Sequence[Processing]) -> RequiredMeasures: - ret: RequiredMeasures = {PER_SAMPLE: {}, PER_DATASET: {}} - for p in proc: - for mode, ms_per_mode in p.get_required_measures().items(): - for tn, ms_per_tn in ms_per_mode.items(): - if tn not in ret[mode]: - ret[mode][tn] = set() - - ret[mode][tn].update(ms_per_tn) - - return ret diff --git a/bioimageio/core/prediction_pipeline/_utils.py b/bioimageio/core/prediction_pipeline/_utils.py deleted file mode 100644 index b1f5c2c7..00000000 --- a/bioimageio/core/prediction_pipeline/_utils.py +++ /dev/null @@ -1,65 +0,0 @@ -# def __repr__(self) -> str: -# return f"{self.measure} of {self.tensor_id} ({self.mode})" - - -# RequiredMeasures = List[ReqMeasure] -# @dataclass -# class RequiredMeasures(collections.abc.Iterator[ReqMeasureEntry]): -# per_sample: Dict[TensorId, Set[Measure]] = field(default_factory=dict) -# per_dataset: Dict[TensorId, Set[Measure]] = field(default_factory=dict) - -# def update(self, *others: RequiredMeasures): -# for other in others: -# for t, ms in other.per_sample.items(): -# self.per_sample.setdefault(t, set()).update(ms) - -# for t, ms in other.per_dataset.items(): -# self.per_dataset.setdefault(t, set()).update(ms) - -# def __iter__(self) -> Iterator[ReqMeasureEntry]: -# for t, ms in self.per_sample.items(): -# for m in ms: -# yield ReqMeasureEntry("per_sample", t, m) - -# for t, ms in self.per_dataset.items(): -# for m in ms: -# yield ReqMeasureEntry("per_dataset", t, m) - - -# class ComputedMeasure(NamedTuple): -# measure: Measure -# tensor_id: TensorId -# mode: Mode -# value: MeasureValue -# def __repr__(self) -> str: -# return f"{self.measure} of {self.tensor_id} ({self.mode}) is {self.value}" - - -# @dataclass -# class ComputedMeasures(collections.abc.Container[CompMeasureEntry]): -# per_sample: Dict[TensorId, Dict[Measure, MeasureValue]] = field(default_factory=dict) -# per_dataset: Dict[TensorId, Dict[Measure, MeasureValue]] = field(default_factory=dict) - -# def update(self, other: ComputedMeasures) -> None: -# for t, ms in other.per_sample.items(): -# self.per_sample.setdefault(t, {}).update(ms) - -# for t, ms in other.per_dataset.items(): -# self.per_dataset.setdefault(t, {}).update(ms) - -# def __iter__(self) -> Iterator[CompMeasureEntry]: -# for t, ms in self.per_sample.items(): -# for m, v in ms.items(): -# yield CompMeasureEntry("per_sample", t, m, v) - -# for t, ms in self.per_dataset.items(): -# for m, v in ms.items(): -# yield CompMeasureEntry("per_dataset", t, m, v) - -# def __contains__(self, __x: Any) -> bool: -# if isinstance(__x, CompMeasureEntry): - -# elif isinstance(__x, ReqMeasureEntry): - -# else: -# return super().__contains__(__x) diff --git a/bioimageio/core/prediction_pipeline/processing.py b/bioimageio/core/proc_impl.py similarity index 73% rename from bioimageio/core/prediction_pipeline/processing.py rename to bioimageio/core/proc_impl.py index 1ef1b88c..d061a1f8 100644 --- a/bioimageio/core/prediction_pipeline/processing.py +++ b/bioimageio/core/proc_impl.py @@ -1,17 +1,14 @@ +import collections.abc from abc import ABC, abstractmethod from dataclasses import InitVar, dataclass, field, fields from types import MappingProxyType from typing import ( - Any, ClassVar, - Dict, FrozenSet, Generic, Hashable, - List, Literal, Mapping, - NamedTuple, Optional, Sequence, Set, @@ -28,10 +25,11 @@ from numpy.typing import DTypeLike from typing_extensions import LiteralString, assert_never -from bioimageio.core.statistical_measures import Mean, Measure, MeasureValue, Percentile, Std +from bioimageio.core.common import MeasureValue, ProcessingDescrBase, ProcessingKwargs, RequiredMeasure, Sample +from bioimageio.core.stat_measures import Mean, Percentile, Std from bioimageio.spec._internal.base_nodes import NodeWithExplicitlySetFields from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import AxisName, NonBatchAxisName, TensorId +from bioimageio.spec.model.v0_5 import NonBatchAxisId, TensorId AssertProcessingId = Literal["assert_dtype"] @@ -50,12 +48,6 @@ class AssertDtype(AssertProcessingBase): kwargs: AssertDtypeKwargs -class RequiredMeasure(NamedTuple): - measure: Measure - tensor_id: TensorId - mode: Literal["per_sample", "per_dataset"] - - M = TypeVar("M", RequiredMeasure, MeasureValue) @@ -72,10 +64,8 @@ def get_set(self) -> Set[M]: C = TypeVar("C", bound=NamedMeasures[MeasureValue]) -Sample = Dict[TensorId, xr.DataArray] -PKwargs = TypeVar("PKwargs", bound=Union[v0_4.ProcessingKwargs, v0_5.ProcessingKwargs]) +PKwargs = TypeVar("PKwargs", bound=ProcessingKwargs) ProcInput = TypeVar("ProcInput", xr.DataArray, Sample) -ProcessingBase = Union[v0_4.ProcessingBase, v0_5.ProcessingBase] @dataclass(frozen=True) @@ -126,7 +116,7 @@ def apply_to_sample(self, sample: Sample) -> Sample: return ret @abstractmethod - def get_spec(self) -> Union[ProcessingBase, AssertProcessingBase]: + def get_descr(self) -> Union[ProcessingDescrBase, AssertProcessingBase]: ... @@ -157,7 +147,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray: assert isinstance(tensor.dtype, self._assert_with) return tensor - def get_spec(self): + def get_descr(self): return AssertDtype(kwargs=self.kwargs) @@ -168,8 +158,8 @@ class BinarizeImpl(ProcessingImplBaseWoMeasures[Union[v0_4.BinarizeKwargs, v0_5. def apply(self, tensor: xr.DataArray) -> xr.DataArray: return tensor > self.kwargs.threshold - def get_spec(self): - return v0_5.Binarize(kwargs=self.kwargs) + def get_descr(self): + return v0_5.BinarizeDescr(kwargs=self.kwargs) @dataclass(frozen=True) @@ -177,8 +167,8 @@ class ClipImpl(ProcessingImplBaseWoMeasures[Union[v0_4.ClipKwargs, v0_5.ClipKwar def apply(self, tensor: xr.DataArray) -> xr.DataArray: return tensor.clip(min=self.kwargs.min, max=self.kwargs.max) - def get_spec(self): - return v0_5.Clip(kwargs=self.kwargs) + def get_descr(self): + return v0_5.ClipDescr(kwargs=self.kwargs) @dataclass(frozen=True) @@ -186,8 +176,8 @@ class EnsureDtypeImpl(ProcessingImplBaseWoMeasures[v0_5.EnsureDtypeKwargs]): def apply(self, tensor: xr.DataArray) -> xr.DataArray: return tensor.astype(self.kwargs.dtype) - def get_spec(self): - return v0_5.EnsureDtype(kwargs=self.kwargs) + def get_descr(self): + return v0_5.EnsureDtypeDescr(kwargs=self.kwargs) class ScaleLinearImpl04(ProcessingImplBaseWoMeasures[Union[v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs]]): @@ -228,11 +218,11 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray: return tensor * gain + offset - def get_spec(self): + def get_descr(self): if isinstance(self.kwargs, v0_4.ScaleLinearKwargs): raise NotImplementedError - return v0_5.ScaleLinear(kwargs=self.kwargs) + return v0_5.ScaleLinearDescr(kwargs=self.kwargs) @dataclass @@ -255,7 +245,15 @@ class ScaleMeanVarianceImpl( def get_required_measures( cls, tensor_id: TensorId, kwargs: Union[v0_4.ScaleMeanVarianceKwargs, v0_5.ScaleMeanVarianceKwargs] ): - axes = tuple(NonBatchAxisName(a) for a in kwargs.axes) if isinstance(kwargs.axes, str) else kwargs.axes + if kwargs.axes is None: + axes = None + elif isinstance(kwargs.axes, str): + axes = tuple(NonBatchAxisId(a) for a in kwargs.axes) + elif isinstance(kwargs.axes, collections.abc.Sequence): # pyright: ignore[reportUnnecessaryIsInstance] + axes = tuple(kwargs.axes) + else: + assert_never(kwargs.axes) + return NamedMeasuresScaleMeanVariance( mean=RequiredMeasure(Mean(axes), tensor_id, mode=kwargs.mode), std=RequiredMeasure(Std(axes), tensor_id, mode=kwargs.mode), @@ -268,11 +266,11 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray: eps = self.kwargs.eps return (tensor - c.mean) / (c.std + eps) * (c.ref_std + eps) + c.ref_mean - def get_spec(self): + def get_descr(self): if isinstance(self.kwargs, v0_4.ScaleMeanVarianceKwargs): raise NotImplementedError - return v0_5.ScaleMeanVariance(kwargs=self.kwargs) + return v0_5.ScaleMeanVarianceDescr(kwargs=self.kwargs) @dataclass @@ -292,7 +290,7 @@ class ScaleRangeImpl( @classmethod def get_required_measures(cls, tensor_id: TensorId, kwargs: Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs]): ref_name = kwargs.reference_tensor or tensor_id - axes = None if kwargs.axes is None else tuple(NonBatchAxisName(a) for a in kwargs.axes) + axes = None if kwargs.axes is None else tuple(NonBatchAxisId(a) for a in kwargs.axes) return NamedMeasuresScaleRange( lower=RequiredMeasure(Percentile(kwargs.min_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode), upper=RequiredMeasure(Percentile(kwargs.max_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode), @@ -302,11 +300,11 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray: c = self.computed return (tensor - c.lower) / (c.upper - c.lower + self.kwargs.eps) - def get_spec(self): + def get_descr(self): if isinstance(self.kwargs, v0_4.ScaleRangeKwargs): raise NotImplementedError - return v0_5.ScaleRange(kwargs=self.kwargs) + return v0_5.ScaleRangeDescr(kwargs=self.kwargs) @dataclass(frozen=True) @@ -316,8 +314,8 @@ class SigmoidImpl(ProcessingImplBaseWoMeasures[v0_5.ProcessingKwargs]): def apply(self, tensor: xr.DataArray) -> xr.DataArray: return 1.0 / (1.0 + np.exp(-tensor)) # type: ignore - def get_spec(self): - return v0_5.Sigmoid() + def get_descr(self): + return v0_5.SigmoidDescr() @dataclass @@ -340,7 +338,7 @@ class ZeroMeanUnitVarianceImpl( def get_required_measures( cls, tensor_id: TensorId, kwargs: Union[v0_4.ZeroMeanUnitVarianceKwargs, v0_5.ZeroMeanUnitVarianceKwargs] ): - axes = None if kwargs.axes is None else tuple(NonBatchAxisName(a) for a in kwargs.axes) + axes = None if kwargs.axes is None else tuple(NonBatchAxisId(a) for a in kwargs.axes) assert kwargs.mode != "fixed" # should use FixedZeroMeanUnitVarianceImpl return NamedMeasuresZeroMeanUnitVariance( mean=RequiredMeasure(Mean(axes=axes), tensor_id, kwargs.mode), @@ -352,11 +350,11 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray: std = self.computed.std return (tensor - mean) / (std + self.kwargs.eps) - def get_spec(self): + def get_descr(self): if isinstance(self.kwargs, v0_4.ZeroMeanUnitVarianceKwargs): raise NotImplementedError - return v0_5.ZeroMeanUnitVariance(kwargs=self.kwargs) + return v0_5.ZeroMeanUnitVarianceDescr(kwargs=self.kwargs) @dataclass(frozen=True) @@ -377,76 +375,94 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray: std = xr.DataArray(self.kwargs.std, dims=axis) return (tensor - mean) / std - def get_spec(self): + def get_descr(self): if isinstance(self.kwargs, v0_4.ZeroMeanUnitVarianceKwargs): raise NotImplementedError - return v0_5.FixedZeroMeanUnitVariance(kwargs=self.kwargs) - - -ProcSpec = Union[AssertDtype, v0_4.Preprocessing, v0_4.Postprocessing, v0_5.Preprocessing, v0_5.Postprocessing] - - -# todo: - - -class ProcSelector: - def __init__(proc_spec: ProcSpec) -> None: - self.proc_spec = proc_spec - - -def get_impl(proc_spec: ProcSpec): + return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=self.kwargs) + + +ProcDescr = Union[ + AssertDtype, v0_4.PreprocessingDescr, v0_4.PostprocessingDescr, v0_5.PreprocessingDescr, v0_5.PostprocessingDescr +] + +# get_impl_class which also returns the kwargs class +# def get_impl_class(proc_spec: ProcDescr): +# if isinstance(proc_spec, AssertDtype): +# return AssertDtypeImpl, AssertDtypeKwargs +# elif isinstance(proc_spec, v0_4.BinarizeDescr): +# return BinarizeImpl, v0_4.BinarizeKwargs +# elif isinstance(proc_spec, v0_5.BinarizeDescr): +# return BinarizeImpl, v0_5.BinarizeKwargs +# elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): +# return ClipImpl, v0_5.ClipKwargs +# elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): +# return EnsureDtypeImpl, v0_5.EnsureDtypeKwargs +# elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): +# return FixedZeroMeanUnitVarianceImpl, v0_5.FixedZeroMeanUnitVarianceKwargs +# elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): +# return ScaleLinearImpl, v0_5.ScaleLinearKwargs +# elif isinstance(proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)): +# return ScaleMeanVarianceImpl, v0_5.ScaleMeanVarianceKwargs +# elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): +# return ScaleRangeImpl, v0_5.ScaleRangeKwargs +# elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): +# return SigmoidImpl, v0_5.ProcessingKwargs +# elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) and proc_spec.kwargs.mode == "fixed": +# return FixedZeroMeanUnitVarianceImpl, v0_5.FixedZeroMeanUnitVarianceKwargs +# elif isinstance( +# proc_spec, # pyright: ignore[reportUnnecessaryIsInstance +# (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), +# ): +# return ZeroMeanUnitVarianceImpl, v0_5.ZeroMeanUnitVarianceKwargs +# else: +# assert_never(proc_spec) + +ProcessingImpl = Union[ + AssertDtypeImpl, + BinarizeImpl, + ClipImpl, + EnsureDtypeImpl, + FixedZeroMeanUnitVarianceImpl, + FixedZeroMeanUnitVarianceImpl, + ScaleLinearImpl, + ScaleMeanVarianceImpl, + ScaleRangeImpl, + SigmoidImpl, + ZeroMeanUnitVarianceImpl, +] + + +def get_impl_class(proc_spec: ProcDescr) -> Type[ProcessingImpl]: if isinstance(proc_spec, AssertDtype): - return AssertDtypeImpl, AssertDtypeKwargs - elif isinstance(proc_spec, v0_4.Binarize): - return BinarizeImpl, v0_4.BinarizeKwargs - elif isinstance(proc_spec, v0_5.Binarize): - return BinarizeImpl, v0_5.BinarizeKwargs - elif isinstance(proc_spec, (v0_4.Clip, v0_5.Clip)): + return AssertDtypeImpl + elif isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): + return BinarizeImpl + elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): return ClipImpl - elif isinstance(proc_spec, v0_5.EnsureDtype): + elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): return EnsureDtypeImpl - elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVariance): + elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): return FixedZeroMeanUnitVarianceImpl - elif isinstance(proc_spec, (v0_4.ScaleLinear, v0_5.ScaleLinear)): + elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): return ScaleLinearImpl - elif isinstance(proc_spec, (v0_4.ScaleMeanVariance, v0_5.ScaleMeanVariance)): + elif isinstance(proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)): return ScaleMeanVarianceImpl - elif isinstance(proc_spec, (v0_4.ScaleRange, v0_5.ScaleRange)): + elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): return ScaleRangeImpl - elif isinstance(proc_spec, (v0_4.Sigmoid, v0_5.Sigmoid)): + elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): return SigmoidImpl - elif isinstance(proc_spec, v0_4.ZeroMeanUnitVariance) and proc_spec.kwargs.mode == "fixed": + elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) and proc_spec.kwargs.mode == "fixed": return FixedZeroMeanUnitVarianceImpl elif isinstance( proc_spec, # pyright: ignore[reportUnnecessaryIsInstance] - (v0_4.ZeroMeanUnitVariance, v0_5.ZeroMeanUnitVariance), + (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), ): return ZeroMeanUnitVarianceImpl else: assert_never(proc_spec) -Model = Union[v0_4.Model, v0_5.Model] - - -def get_procs(model: Model): - procs: List[ProcessingImplBase[Any, Any, Any]] = [] - for ipt in model.inputs: - if not ipt.preprocessing: - continue - - assert isinstance(ipt, v0_5.InputTensor) - for proc_spec in ipt.preprocessing: - impl = get_impl(proc_spec, ipt.id, computed_measures) - assert isinstance( - proc_spec.kwargs, - ) - procs.append(impl(tensor_id=ipt.id, kwargs=proc_spec.kwargs)) - - return procs - - def _get_complement_axis(tensor: xr.DataArray, axes: Optional[Sequence[Hashable]]) -> Optional[Hashable]: if axes is None: return None diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py new file mode 100644 index 00000000..a66c46aa --- /dev/null +++ b/bioimageio/core/proc_setup.py @@ -0,0 +1,75 @@ +from typing import ( + Any, + Iterator, + List, + NamedTuple, + Sequence, + Set, + Tuple, + Type, + Union, + cast, +) + +from typing_extensions import assert_never + +from bioimageio.core.common import ProcessingKwargs, RequiredMeasure, Sample +from bioimageio.core.proc_impl import ( + ProcessingImpl, + ProcessingImplBase, + get_impl_class, +) +from bioimageio.core.stat_calculators import compute_measures +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import TensorId + +ModelDescr = Union[v0_4.ModelDescr, v0_5.ModelDescr] +TensorDescr = Union[v0_4.InputTensorDescr, v0_4.OutputTensorDescr, v0_5.InputTensorDescr, v0_5.OutputTensorDescr] + + +class _SetupProcessing(NamedTuple): + preprocessing: List[ProcessingImpl] + postprocessing: List[ProcessingImpl] + + +def setup_pre_and_postprocessing(model: ModelDescr, dataset: Iterator[Sample]) -> _SetupProcessing: + Prepared = List[Tuple[Type[ProcessingImplBase[Any, Any, Any]], ProcessingKwargs, TensorId]] + + required_measures: Set[RequiredMeasure] = set() + + def prepare_procs(tensor_descrs: Sequence[TensorDescr]): + prepared: Prepared = [] + for t_descr in tensor_descrs: + if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): + proc_specs = t_descr.preprocessing + elif isinstance( + t_descr, # pyright: ignore[reportUnnecessaryIsInstance] + (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr), + ): + proc_specs = t_descr.postprocessing + else: + assert_never(t_descr) + + for proc_spec in proc_specs: + impl_class = get_impl_class(proc_spec) + tensor_id = cast(TensorId, t_descr.name) if isinstance(t_descr, v0_4.TensorDescrBase) else t_descr.id + req = impl_class.get_required_measures(tensor_id, proc_spec.kwargs) # type: ignore + required_measures.update(req.get_set()) + prepared.append((impl_class, proc_spec.kwargs, tensor_id)) + + return prepared + + prepared_preps = prepare_procs(model.inputs) + prepared_posts = prepare_procs(model.outputs) + + computed_measures = compute_measures(required_measures, dataset=dataset) + + def init_procs(prepared: Prepared): + initialized: List[ProcessingImpl] = [] + for impl_class, kwargs, tensor_id in prepared: + impl = impl_class(tensor_id=tensor_id, kwargs=kwargs, computed_measures=computed_measures) + initialized.append(impl) + + return initialized + + return _SetupProcessing(init_procs(prepared_preps), init_procs(prepared_posts)) diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py new file mode 100644 index 00000000..f31bcfe0 --- /dev/null +++ b/bioimageio/core/stat_calculators.py @@ -0,0 +1,396 @@ +from __future__ import annotations + +import collections +import warnings +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import field +from itertools import product +from typing import ( + Any, + ClassVar, + DefaultDict, + Dict, + Generic, + Hashable, + Iterable, + Iterator, + List, + Literal, + Mapping, + Optional, + OrderedDict, + Sequence, + Set, + Tuple, + Type, + Union, +) + +import numpy as np +import xarray as xr +from numpy.typing import NDArray + +from bioimageio.core.common import ( + PER_DATASET, + PER_SAMPLE, + AxisId, + DatasetMeasure, + MeasureVar, + RequiredMeasure, + Sample, + SampleMeasure, + TensorId, +) +from bioimageio.core.stat_measures import Mean, Measure, Percentile, Std, Var + +try: + import crick # type: ignore +except ImportError: + crick = None + +MeasureValue = Union[xr.DataArray, float] + + +class SampleMeasureCalculator(ABC, Generic[MeasureVar]): + """group of measures for more efficient computation of multiple measures per sample""" + + @abstractmethod + def compute(self, sample: Sample) -> Mapping[SampleMeasure[MeasureVar], MeasureValue]: + ... + + +class DatasetMeasureCalculator(ABC, Generic[MeasureVar]): + """group of measures for more efficient computation of multiple measures per dataset""" + + @abstractmethod + def update_with_sample(self, sample: Sample) -> None: + """update intermediate representation with a data sample""" + ... + + @abstractmethod + def finalize(self) -> Mapping[DatasetMeasure[MeasureVar], MeasureValue]: + """compute statistics from intermediate representation""" + ... + + +class MeanCalculator(SampleMeasureCalculator[Mean], DatasetMeasureCalculator[Mean]): + def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): + super().__init__() + self._axes = None if axes is None else tuple(axes) + self._tensor_id = tensor_id + self._n: int = 0 + self._mean: Optional[xr.DataArray] = None + + def compute(self, sample: Sample): + return { + SampleMeasure(measure=Mean(axes=self._axes), tensor_id=self._tensor_id): sample[self._tensor_id].mean( + dim=self._axes + ) + } + + def update_with_sample(self, sample: Sample): + tensor = sample[self._tensor_id].astype(np.float64, copy=False) + mean_b = tensor.mean(dim=self._axes) + assert mean_b.dtype == np.float64 + n_b = np.prod(list(tensor.shape)) / np.prod(list(mean_b.shape)) # reduced voxel count + if self._mean is None: + assert self._n == 0 + self._n = n_b + self._mean = mean_b + else: + assert self._n != 0 + n_a = self._n + mean_a = self._mean + self._n = n = n_a + n_b + self._mean = (n_a * mean_a + n_b * mean_b) / n + assert self._mean.dtype == np.float64 + + def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: + if self._mean is None: + return {} + else: + return {DatasetMeasure(measure=Mean(axes=self._axes), tensor_id=self._tensor_id): self._mean} + + +class MeanVarStdCalculator(SampleMeasureCalculator, DatasetMeasureCalculator): + def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): + super().__init__() + self._axes = None if axes is None else tuple(axes) + self._tensor_id = tensor_id + self._n: int = 0 + self._mean: Optional[xr.DataArray] = None + self._m2: Optional[xr.DataArray] = None + + def compute(self, sample: Sample): + tensor = sample[self._tensor_id] + mean = tensor.mean(dim=self._axes) + c = tensor - mean + if self._axes is None: + n = tensor.size + else: + n = int(np.prod([tensor.sizes[d] for d in self._axes])) # type: ignore # FIXME: type annotation + + var = xr.dot(c, c, dims=self._axes) / n + std = np.sqrt(var) + return { + SampleMeasure(Mean(axes=self._axes), tensor_id=self._tensor_id): mean, + SampleMeasure(Var(axes=self._axes), tensor_id=self._tensor_id): var, + SampleMeasure(Std(axes=self._axes), tensor_id=self._tensor_id): std, + } + + def update_with_sample(self, sample: Sample): + tensor = sample[self._tensor_id].astype(np.float64, copy=False) + mean_b = tensor.mean(dim=self._axes) + assert mean_b.dtype == np.float64 + # reduced voxel count + n_b = int(np.prod(tensor.shape) / np.prod(mean_b.shape)) # type: ignore + m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) + assert m2_b.dtype == np.float64 + if self._mean is None: + assert self._m2 is None + self._n = n_b + self._mean = mean_b + self._m2 = m2_b + else: + n_a = self._n + mean_a = self._mean + m2_a = self._m2 + self._n = n = n_a + n_b + self._mean = (n_a * mean_a + n_b * mean_b) / n + assert self._mean.dtype == np.float64 + d = mean_b - mean_a + self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n + assert self._m2.dtype == np.float64 + + def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: + if self._mean is None: + return {} + else: + assert self._m2 is not None + var = self._m2 / self._n + sqrt: xr.DataArray = np.sqrt(var) # type: ignore + return { + DatasetMeasure(tensor_id=self._tensor_id, measure=Mean(axes=self._axes)): self._mean, + DatasetMeasure(tensor_id=self._tensor_id, measure=Var(axes=self._axes)): var, + DatasetMeasure(tensor_id=self._tensor_id, measure=Std(axes=self._axes)): sqrt, + } + + +class SamplePercentilesCalculator(SampleMeasureCalculator): + def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Sequence[float]): + super().__init__() + assert all(0 <= n <= 100 for n in ns) + self.ns = ns + self._qs = [n / 100 for n in ns] + self._axes = None if axes is None else tuple(axes) + self._tensor_id = tensor_id + + def compute(self, sample: Sample): + tensor = sample[self._tensor_id] + ps = tensor.quantile(self._qs, dim=self._axes) # type: ignore + return { + SampleMeasure(measure=Percentile(n=n, axes=self._axes), tensor_id=self._tensor_id): p + for n, p in zip(self.ns, ps) + } + + +class MeanPercentilesCalculator(DatasetMeasureCalculator): + def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Sequence[float]): + super().__init__() + assert all(0 <= n <= 100 for n in ns) + self._ns = ns + self._qs = np.asarray([n / 100 for n in ns]) + self._axes = None if axes is None else tuple(axes) + self._tensor_id = tensor_id + self._n: int = 0 + self._estimates: Optional[xr.DataArray] = None + + def update_with_sample(self, sample: Sample): + tensor = sample[self._tensor_id] + sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(np.float64, copy=False) + + # reduced voxel count + n = int(np.prod(tensor.shape) / np.prod(sample_estimates.shape[1:])) # type: ignore + + if self._estimates is None: + assert self._n == 0 + self._estimates = sample_estimates + else: + self._estimates = (self._n * self._estimates + n * sample_estimates) / (self._n + n) + assert self._estimates.dtype == np.float64 + + self._n += n + + def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: + if self._estimates is None: + return {} + else: + warnings.warn("Computed dataset percentiles naively by averaging percentiles of samples.") + return { + DatasetMeasure(measure=Percentile(n=n, axes=self._axes), tensor_id=self._tensor_id): e + for n, e in zip(self._ns, self._estimates) + } + + +class CrickPercentilesCalculator(DatasetMeasureCalculator): + def __init__(self, tensor_name: TensorId, axes: Optional[Sequence[AxisId]], ns: Sequence[float]): + warnings.warn("Computing dataset percentiles with experimental 'crick' library.") + super().__init__() + assert all(0 <= n <= 100 for n in ns) + assert axes is None or "_percentiles" not in axes + self._ns = ns + self._qs = [n / 100 for n in ns] + self._axes = None if axes is None else tuple(axes) + self._tensor_id = tensor_name + self._digest: Optional[List[crick.TDigest]] = None + self._dims: Optional[Tuple[Hashable, ...]] = None + self._indices: Optional[Iterator[Tuple[int, ...]]] = None + self._shape: Optional[Tuple[int, ...]] = None + + def _initialize(self, tensor_sizes: Mapping[Hashable, int]): + assert crick is not None + out_sizes: OrderedDict[Hashable, int] = collections.OrderedDict(_percentiles=len(self._ns)) + if self._axes is not None: + for d, s in tensor_sizes.items(): + if d not in self._axes: + out_sizes[d] = s + + self._dims, self._shape = zip(*out_sizes.items()) + d = int(np.prod(self._shape[1:])) # type: ignore + self._digest = [crick.TDigest() for _ in range(d)] + self._indices = product(*map(range, self._shape[1:])) + + def update_with_sample(self, sample: Sample): + tensor = sample[self._tensor_id] + assert "_percentiles" not in tensor.dims + if self._digest is None: + self._initialize(tensor.sizes) + + assert self._digest is not None + assert self._indices is not None + assert self._dims is not None + for i, idx in enumerate(self._indices): + self._digest[i].update(tensor.isel(dict(zip(self._dims[1:], idx)))) + + def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: + if self._digest is None: + return {} + else: + assert self._dims is not None + vs: NDArray[Any] = np.asarray([[d.quantile(q) for d in self._digest] for q in self._qs]).reshape(self._shape) # type: ignore + return { + DatasetMeasure(measure=Percentile(n=n, axes=self._axes), tensor_id=self._tensor_id): xr.DataArray( + v, dims=self._dims[1:] + ) + for n, v in zip(self._ns, vs) + } + + +if crick is None: + DatasetPercentileCalculator: Type[ + Union[MeanPercentilesCalculator, CrickPercentilesCalculator] + ] = MeanPercentilesCalculator +else: + DatasetPercentileCalculator = CrickPercentilesCalculator + + +class NaivSampleMeasureCalculator(SampleMeasureCalculator): + """wrapper for measures to match interface of SampleMeasureGroup""" + + def __init__(self, tensor_id: TensorId, measure: Measure): + super().__init__() + self.tensor_name = tensor_id + self.measure = measure + + def compute(self, sample: Sample) -> Mapping[SampleMeasure, MeasureValue]: + return { + SampleMeasure(measure=self.measure, tensor_id=self.tensor_name): self.measure.compute( + sample[self.tensor_name] + ) + } + + +def get_measure_calculators( + required_measures: Iterable[RequiredMeasure], +) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: + """determines which calculators are needed to compute the required measures efficiently""" + + sample_calculators: List[SampleMeasureCalculator] = [] + dataset_calculators: List[DatasetMeasureCalculator] = [] + + # split required measures into groups + required_means: Set[RequiredMeasure] = set() + required_mean_var_std: Set[RequiredMeasure] = set() + required_percentiles: Set[RequiredMeasure] = set() + + for rm in required_measures: + if isinstance(rm.measure, Mean): + required_means.add(rm) + elif isinstance(rm.measure, (Var, Std)): + required_mean_var_std.update( + { + RequiredMeasure(measure=msv(rm.measure.axes), tensor_id=rm.tensor_id, mode=rm.mode) + for msv in (Mean, Std, Var) + } + ) + assert rm in required_mean_var_std + elif isinstance(rm.measure, Percentile): + required_percentiles.add(rm) + elif rm.mode == PER_SAMPLE: + sample_calculators.append(NaivSampleMeasureCalculator(tensor_id=rm.tensor_id, measure=rm.measure)) + else: + raise NotImplementedError(f"Computing statistics for {rm.measure} {rm.mode} not yet implemented") + + for rm in required_means: + if rm in required_mean_var_std: + # computed togehter with var and std + continue + + if rm.mode == PER_SAMPLE: + sample_calculators.append(MeanCalculator(tensor_id=rm.tensor_id, axes=rm.measure.axes)) + # add all mean measures that are not included in a mean/var/std group + for tn, m in means: + if (tn, m.axes) not in required_mean_var_std: + # compute only mean + if mode == PER_SAMPLE: + calculators[mode].append(NaivSampleMeasureCalculator(tensor_id=tn, measure=m)) + elif mode == PER_DATASET: + calculators[mode].append(DatasetMeanCalculator(tensor_id=tn, axes=m.axes)) + else: + raise NotImplementedError(mode) + + for tn, axes in mean_var_std_groups: + calculators[mode].append(MeanVarStdCalculator(tensor_id=tn, axes=axes)) + + for (tn, axes), ns in required_percentiles.items(): + if mode == PER_SAMPLE: + calculators[mode].append(SamplePercentilesCalculator(tensor_id=tn, axes=axes, ns=ns)) + elif mode == PER_DATASET: + calculators[mode].append(DatasetPercentileCalculator(tensor_name=tn, axes=axes, ns=ns)) + else: + raise NotImplementedError(mode) + + return calculators + + +def compute_measures( + measures: RequiredMeasures, *, sample: Optional[Sample] = None, dataset: Iterator[Sample] = () +) -> ComputedMeasures: + ms_groups = get_measure_calculators(measures) + ret = {PER_SAMPLE: {}, PER_DATASET: {}} + if sample is not None: + for mg in ms_groups[PER_SAMPLE]: + assert isinstance(mg, SampleMeasureCalculator) + ret[PER_SAMPLE].update(mg.compute(sample)) + + for sample in dataset: + for mg in ms_groups[PER_DATASET]: + assert isinstance(mg, DatasetMeasureCalculator) + mg.update_with_sample(sample) + + for mg in ms_groups[PER_DATASET]: + assert isinstance(mg, DatasetMeasureCalculator) + ret[PER_DATASET].update(mg.finalize()) + + return ret diff --git a/bioimageio/core/statistical_measures.py b/bioimageio/core/stat_measures.py similarity index 78% rename from bioimageio/core/statistical_measures.py rename to bioimageio/core/stat_measures.py index e19689b8..29d6857a 100644 --- a/bioimageio/core/statistical_measures.py +++ b/bioimageio/core/stat_measures.py @@ -6,9 +6,8 @@ import xarray as xr -from bioimageio.spec.model.v0_5 import NonBatchAxisName - -MeasureValue = xr.DataArray +from bioimageio.core.common import MeasureValue +from bioimageio.spec.model.v0_5 import AxisId @dataclass(frozen=True) @@ -21,7 +20,7 @@ def compute(self, tensor: xr.DataArray) -> MeasureValue: @dataclass(frozen=True) class Mean(Measure): - axes: Optional[Tuple[NonBatchAxisName, ...]] = None + axes: Optional[Tuple[AxisId, ...]] = None def compute(self, tensor: xr.DataArray) -> xr.DataArray: return tensor.mean(dim=self.axes) @@ -29,7 +28,7 @@ def compute(self, tensor: xr.DataArray) -> xr.DataArray: @dataclass(frozen=True) class Std(Measure): - axes: Optional[Tuple[NonBatchAxisName, ...]] = None + axes: Optional[Tuple[AxisId, ...]] = None def compute(self, tensor: xr.DataArray) -> xr.DataArray: return tensor.std(dim=self.axes) @@ -37,7 +36,7 @@ def compute(self, tensor: xr.DataArray) -> xr.DataArray: @dataclass(frozen=True) class Var(Measure): - axes: Optional[Tuple[NonBatchAxisName, ...]] = None + axes: Optional[Tuple[AxisId, ...]] = None def compute(self, tensor: xr.DataArray) -> xr.DataArray: return tensor.var(dim=self.axes) @@ -46,7 +45,7 @@ def compute(self, tensor: xr.DataArray) -> xr.DataArray: @dataclass(frozen=True) class Percentile(Measure): n: float - axes: Optional[Tuple[NonBatchAxisName, ...]] = None + axes: Optional[Tuple[AxisId, ...]] = None def __post_init__(self): assert self.n >= 0 diff --git a/bioimageio/core/prediction_pipeline/_stat_state.py b/bioimageio/core/stat_state.py similarity index 52% rename from bioimageio/core/prediction_pipeline/_stat_state.py rename to bioimageio/core/stat_state.py index c84c7e4e..107383be 100644 --- a/bioimageio/core/prediction_pipeline/_stat_state.py +++ b/bioimageio/core/stat_state.py @@ -1,43 +1,50 @@ -from typing import Dict, Iterable, Optional +from dataclasses import dataclass, field +from typing import Dict, Iterable, Literal, Optional, Union from tqdm import tqdm -from bioimageio.core.statistical_measure_groups import MeasureGroups, MeasureValue, get_measure_groups -from bioimageio.core.statistical_measures import Measure - -from ._utils import PER_DATASET, PER_SAMPLE, MeasureValue, RequiredMeasure, Sample, TensorName +from bioimageio.core.common import PER_DATASET, PER_SAMPLE, MeasureValue, RequiredMeasure, Sample, TensorId +from bioimageio.core.stat_calculators import MeasureGroups, MeasureValue, get_measure_calculators +from bioimageio.core.stat_measures import Measure +@dataclass class StatsState: """class to compute, hold and update dataset and sample statistics""" - sample_count: int - last_sample: Optional[Sample] - measure_groups: MeasureGroups - _n_start: int - _n_stop: int - _final_dataset_stats: Optional[Dict[TensorName, Dict[Measure, MeasureValue]]] + required_measures: Iterable[RequiredMeasure] + + +def compute_statistics() + dataset: Iterable[Sample] + update_dataset_stats_after_n_samples: Optional[int] = None + update_dataset_stats_for_n_samples: Union[int, float] = float("inf") + +def + """iterates over dataset to compute dataset statistics (if required). The resulting dataset statistics are further updated with each new sample. A sample in this context may be a mini-batch. + + Args: + required_measures: measures to be computed + dataset: (partial) dataset to initialize dataset statistics with + update_dataset_stats_after_n_samples: Update dataset statistics for new samples S_i if i > n. + (default: len(dataset)) + This parameter allows to avoid weighting the first n processed + samples to count twice if they make up the given 'dataset'. + update_dataset_stats_for_n_samples: stop updating dataset statistics with new samples S_i if + i > for_n_samples (+ update_dataset_stats_after_n_samples) + """ + sample_count: int = field(init=False) + last_sample: Optional[Sample] = field(init=False) + measure_groups: MeasureGroups = field(init=False) + _n_start: Union[int, float] = field(init=False) + _n_stop: Union[int, float] = field(init=False) + _final_dataset_stats: Optional[Dict[RequiredMeasure, MeasureValue]] = field(init=False) def __init__( self, - required_measures: RequiredMeasures, *, - dataset: Iterable[Sample] = tuple(), - update_dataset_stats_after_n_samples: Optional[int] = None, - update_dataset_stats_for_n_samples: int = float("inf"), ): - """iterates over dataset to compute dataset statistics (if required). The resulting dataset statistics are further updated with each new sample. A sample in this context may be a mini-batch. - - Args: - required_measures: measures to be computed - dataset: (partial) dataset to initialize dataset statistics with - update_dataset_stats_after_n_samples: Update dataset statistics for new samples S_i if i > n. - (default: len(dataset)) - This parameter allows to avoid weighting the first n processed - samples to count twice if they make up the given 'dataset'. - update_dataset_stats_for_n_samples: stop updating dataset statistics with new samples S_i if - i > for_n_samples (+ update_dataset_stats_after_n_samples) - """ + super().__init__() self.required_measures = required_measures self.update_dataset_stats_after_n_samples = update_dataset_stats_after_n_samples self.update_dataset_stats_for_n_samples = update_dataset_stats_for_n_samples @@ -47,7 +54,7 @@ def reset(self, dataset: Iterable[Sample]): self.sample_count = 0 self.last_sample = None self._final_dataset_stats = None - self.measure_groups = get_measure_groups(self.required_measures) + self.measure_groups = get_measure_calculators(self.required_measures) len_dataset = 0 if self.measure_groups[PER_DATASET]: diff --git a/bioimageio/core/statistical_measure_groups.py b/bioimageio/core/statistical_measure_groups.py deleted file mode 100644 index 88f65b3d..00000000 --- a/bioimageio/core/statistical_measure_groups.py +++ /dev/null @@ -1,340 +0,0 @@ -from __future__ import annotations - -import collections -import warnings -from collections import defaultdict -from dataclasses import field -from itertools import product -from typing import DefaultDict, Dict, Hashable, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union - -import numpy -import xarray as xr -from attr import dataclass - -from bioimageio.core.sta import PER_DATASET, PER_SAMPLE, ComputedMeasures, RequiredMeasures, Sample -from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std, Var -from bioimageio.spec.model.v0_5 import AxisName - -try: - import crick # type: ignore -except ImportError: - crick = None - -MeasureValue = xr.DataArray - - -class SampleMeasureCalculator: - """group of measures for more efficient computation of multiple measures per sample""" - - def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - raise NotImplementedError - - -class DatasetMeasureCalculator: - """group of measures for more efficient computation of multiple measures per dataset""" - - def reset(self) -> None: - """reset any accumulated intermediates""" - raise NotImplementedError - - def update_with_sample(self, sample: Sample) -> None: - """update intermediate representation with a data sample""" - raise NotImplementedError - - def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - """compute statistics from intermediate representation""" - raise NotImplementedError - - -@dataclass -class MeasureGroups: - per_sample: List[SampleMeasureCalculator] = field(default_factory=list) - per_dataset: List[DatasetMeasureCalculator] = field(default_factory=list) - - -class DatasetMean(DatasetMeasureCalculator): - def __init__(self, tensor_name: TensorName, axes: Optional[Sequence[AxisName]]): - super().__init__() - self.axes = None if axes is None else tuple(axes) - self.tensor_name = tensor_name - self.n: int = 0 - self.mean: Optional[xr.DataArray] = None - - def reset(self): - self.n = 0 - self.mean = None - - def update_with_sample(self, sample: Sample): - tensor = sample[self.tensor_name].astype(numpy.float64, copy=False) - mean_b = tensor.mean(dim=self.axes) - assert mean_b.dtype == numpy.float64 - n_b = numpy.prod(tensor.shape) / numpy.prod(mean_b.shape) # reduced voxel count - if self.n == 0: - assert self.mean is None - self.n = n_b - self.mean = mean_b - else: - n_a = self.n - mean_a = self.mean - self.n = n = n_a + n_b - self.mean = (n_a * mean_a + n_b * mean_b) / n - assert self.mean.dtype == numpy.float64 - - def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - if self.n == 0: - return {} - else: - return {self.tensor_name: {Mean(axes=self.axes): self.mean}} - - -class MeanVarStd(SampleMeasureCalculator, DatasetMeasureCalculator): - def __init__(self, tensor_name: TensorName, axes: Optional[Sequence[AxisName]]): - self.axes = None if axes is None else tuple(axes) - self.tensor_name = tensor_name - self.n: int = 0 - self.mean: Optional[xr.DataArray] = None - self.m2: Optional[xr.DataArray] = None - - def reset(self): - self.n = 0 - self.mean = None - self.m2 = None - - def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - tensor = sample[self.tensor_name] - mean = tensor.mean(dim=self.axes) - c = tensor - mean - n = tensor.size if self.axes is None else numpy.prod([tensor.sizes[d] for d in self.axes]) - var = xr.dot(c, c, dims=self.axes) / n - std = numpy.sqrt(var) - return {self.tensor_name: {Mean(axes=self.axes): mean, Var(axes=self.axes): var, Std(axes=self.axes): std}} - - def update_with_sample(self, sample: Sample): - tensor = sample[self.tensor_name].astype(numpy.float64, copy=False) - mean_b = tensor.mean(dim=self.axes) - assert mean_b.dtype == numpy.float64 - n_b = numpy.prod(tensor.shape) / numpy.prod(mean_b.shape) # reduced voxel count - m2_b = ((tensor - mean_b) ** 2).sum(dim=self.axes) - assert m2_b.dtype == numpy.float64 - if self.n == 0: - assert self.mean is None - assert self.m2 is None - self.n = n_b - self.mean = mean_b - self.m2 = m2_b - else: - n_a = self.n - mean_a = self.mean - m2_a = self.m2 - self.n = n = n_a + n_b - self.mean = (n_a * mean_a + n_b * mean_b) / n - assert self.mean.dtype == numpy.float64 - d = mean_b - mean_a - self.m2 = m2_a + m2_b + d**2 * n_a * n_b / n - assert self.m2.dtype == numpy.float64 - - def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - if self.n == 0: - return {} - else: - var = self.m2 / self.n - return { - self.tensor_name: { - Mean(axes=self.axes): self.mean, - Var(axes=self.axes): var, - Std(axes=self.axes): numpy.sqrt(var), - } - } - - -class SamplePercentiles(SampleMeasureCalculator): - def __init__(self, tensor_name: TensorName, axes: Optional[Tuple[str]], ns: Sequence[float]): - assert all(0 <= n <= 100 for n in ns) - self.ns = ns - self.qs = [n / 100 for n in ns] - self.axes = axes - self.tensor_name = tensor_name - - def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - tensor = sample[self.tensor_name] - ps = tensor.quantile(self.qs, dim=self.axes) - return {self.tensor_name: {Percentile(n=n, axes=self.axes): p for n, p in zip(self.ns, ps)}} - - -class MeanPercentiles(DatasetMeasureCalculator): - n: int - estimates: Optional[xr.DataArray] - - def __init__(self, tensor_name: TensorName, axes: Optional[Tuple[str]], ns: Sequence[float]): - assert all(0 <= n <= 100 for n in ns) - self.ns = ns - self.qs = [n / 100 for n in ns] - self.axes = axes - self.tensor_name = tensor_name - self.reset() - - def reset(self): - self.n = 0 - self.estimates = None - - def update_with_sample(self, sample: Sample): - tensor = sample[self.tensor_name] - sample_estimates = tensor.quantile(self.qs, dim=self.axes).astype(numpy.float64, copy=False) - - n = numpy.prod(tensor.shape) / numpy.prod(sample_estimates.shape[1:]) # reduced voxel count - - if self.n == 0: - self.estimates = sample_estimates - else: - self.estimates = (self.n * self.estimates + n * sample_estimates) / (self.n + n) - assert self.estimates.dtype == numpy.float64 - - self.n += n - - def finalize(self) -> Dict[TensorName, Dict[Percentile, MeasureValue]]: - if self.n == 0: - return {} - else: - warnings.warn(f"Computed dataset percentiles naively by averaging percentiles of samples.") - return {self.tensor_name: {Percentile(n=n, axes=self.axes): e for n, e in zip(self.ns, self.estimates)}} - - -class CrickPercentiles(DatasetMeasureCalculator): - digest: Optional[List["crick.TDigest"]] - dims: Optional[Tuple[Hashable, ...]] - indices: Optional[Iterator[Tuple[int, ...]]] - shape: Optional[Tuple[int, ...]] - - def __init__(self, tensor_name: TensorName, axes: Optional[Tuple[str]], ns: Sequence[float]): - assert all(0 <= n <= 100 for n in ns) - assert axes is None or "_percentiles" not in axes - warnings.warn(f"Computing dataset percentiles with experimental 'crick' library.") - self.ns = ns - self.qs = [n / 100 for n in ns] - self.axes = axes - self.tensor_name = tensor_name - self.reset() - - def reset(self): - self.digest = None - self.dims = None - self.indices = None - self.shape = None - - def _initialize(self, tensor_sizes: Mapping[Hashable, int]): - out_sizes = collections.OrderedDict(_percentiles=len(self.ns)) - if self.axes is not None: - for d, s in tensor_sizes.items(): - if d not in self.axes: - out_sizes[d] = s - - self.dims, self.shape = zip(*out_sizes.items()) - self.digest = [crick.TDigest() for _ in range(int(numpy.prod(self.shape[1:])))] - self.indices = product(*map(range, self.shape[1:])) - - def update_with_sample(self, sample: Sample): - tensor = sample[self.tensor_name] - assert "_percentiles" not in tensor.dims - if self.digest is None: - self._initialize(tensor.sizes) - assert self.digest is not None - - for i, idx in enumerate(self.indices): - self.digest[i].update(tensor.isel(dict(zip(self.dims[1:], idx)))) - - def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - if self.digest is None: - return {} - else: - vs = numpy.asarray([[d.quantile(q) for d in self.digest] for q in self.qs]).reshape(self.shape) - return { - self.tensor_name: { - Percentile(n=n, axes=self.axes): xr.DataArray(v, dims=self.dims[1:]) for n, v in zip(self.ns, vs) - } - } - - -if crick is None: - DatasetPercentileGroup: Type[Union[MeanPercentiles, CrickPercentiles]] = MeanPercentiles -else: - DatasetPercentileGroup = CrickPercentiles - - -class SingleMeasureAsGroup(SampleMeasureCalculator): - """wrapper for measures to match interface of SampleMeasureGroup""" - - def __init__(self, tensor_name: TensorName, measure: Measure): - super().__init__() - self.tensor_name = tensor_name - self.measure = measure - - def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - return {self.tensor_name: {self.measure: self.measure.compute(sample[self.tensor_name])}} - - -def get_measure_groups(measures: RequiredMeasures) -> MeasureGroups: - """find a list of MeasureGroups to compute measures efficiently""" - - measure_groups = MeasureGroups() - means: Set[Tuple[TensorName, Mean]] = set() - mean_var_std_groups: Set[Tuple[TensorName, Optional[Tuple[str, ...]]]] = set() - percentile_groups: DefaultDict[Tuple[TensorName, Optional[Tuple[str, ...]]], List[float]] = defaultdict(list) - for mode, ms_per_mode in measures.items(): - for tn, ms_per_tn in ms_per_mode.items(): - for m in ms_per_tn: - if isinstance(m, Mean): - means.add((tn, m)) - elif isinstance(m, (Var, Std)): - mean_var_std_groups.add((tn, m.axes)) - elif isinstance(m, Percentile): - percentile_groups[(tn, m.axes)].append(m.n) - elif mode == PER_SAMPLE: - measure_groups.per_sample.append(SingleMeasureAsGroup(tensor_name=tn, measure=m)) - else: - raise NotImplementedError(f"Computing statistics for {m} {mode} not yet implemented") - - # add all mean measures that are not included in a mean/var/std group - for tn, m in means: - if (tn, m.axes) not in mean_var_std_groups: - # compute only mean - if mode == PER_SAMPLE: - measure_groups[mode].append(SingleMeasureAsGroup(tensor_name=tn, measure=m)) - elif mode == PER_DATASET: - measure_groups[mode].append(DatasetMean(tensor_name=tn, axes=m.axes)) - else: - raise NotImplementedError(mode) - - for tn, axes in mean_var_std_groups: - measure_groups[mode].append(MeanVarStd(tensor_name=tn, axes=axes)) - - for (tn, axes), ns in percentile_groups.items(): - if mode == PER_SAMPLE: - measure_groups[mode].append(SamplePercentiles(tensor_name=tn, axes=axes, ns=ns)) - elif mode == PER_DATASET: - measure_groups[mode].append(DatasetPercentileGroup(tensor_name=tn, axes=axes, ns=ns)) - else: - raise NotImplementedError(mode) - - return measure_groups - - -def compute_measures( - measures: RequiredMeasures, *, sample: Optional[Sample] = None, dataset: Iterator[Sample] = tuple() -) -> ComputedMeasures: - ms_groups = get_measure_groups(measures) - ret = {PER_SAMPLE: {}, PER_DATASET: {}} - if sample is not None: - for mg in ms_groups[PER_SAMPLE]: - assert isinstance(mg, SampleMeasureCalculator) - ret[PER_SAMPLE].update(mg.compute(sample)) - - for sample in dataset: - for mg in ms_groups[PER_DATASET]: - assert isinstance(mg, DatasetMeasureCalculator) - mg.update_with_sample(sample) - - for mg in ms_groups[PER_DATASET]: - assert isinstance(mg, DatasetMeasureCalculator) - ret[PER_DATASET].update(mg.finalize()) - - return ret diff --git a/bioimageio/core/model_utils.py b/bioimageio/core/utils.py similarity index 90% rename from bioimageio/core/model_utils.py rename to bioimageio/core/utils.py index 2c8dd51f..dd3532d7 100644 --- a/bioimageio/core/model_utils.py +++ b/bioimageio/core/utils.py @@ -1,11 +1,12 @@ from functools import singledispatch -from typing import Any, List, Union +from typing import Any, Dict, List, Union import numpy as np import xarray as xr from numpy.typing import NDArray from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import TensorId from bioimageio.spec.utils import download # @singledispatch diff --git a/tests/test_image_helper.py b/tests/test_image_helper.py index 9c495de1..d9721fc2 100644 --- a/tests/test_image_helper.py +++ b/tests/test_image_helper.py @@ -2,18 +2,18 @@ def test_transform_input_image(): - from bioimageio.core.image_helper import transform_input_image + from bioimageio.core.image_helper import transpose_image ax_list = ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] im = np.random.rand(256, 256) for axes in ax_list: - inp = transform_input_image(im, axes) + inp = transpose_image(im, axes) assert inp.ndim == len(axes) ax_list = ["zyx", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] vol = np.random.rand(64, 64, 64) for axes in ax_list: - inp = transform_input_image(vol, axes) + inp = transpose_image(vol, axes) assert inp.ndim == len(axes) From 46c840d3e71965fdb6ba40ec55f4f3ed22a1a3dc Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Fri, 24 Nov 2023 13:46:46 +0100 Subject: [PATCH 074/244] Fixes typing issues in core.io --- bioimageio/core/io.py | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index 53d54b01..092e6d6a 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -1,28 +1,14 @@ from __future__ import annotations -import collections.abc -import io -import os -import shutil -from dataclasses import dataclass -from pathlib import Path -from tempfile import NamedTemporaryFile, mkdtemp -from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, TextIO, TypedDict, Union, cast -from zipfile import ZIP_DEFLATED, ZipFile, is_zipfile +from typing import List, Literal, Optional, Union -import pooch -from pydantic import AnyUrl, DirectoryPath, FilePath, HttpUrl, TypeAdapter -from ruamel.yaml import YAML -from typing_extensions import NotRequired, Unpack - -from bioimageio.spec import ResourceDescription +from bioimageio.spec import build_description from bioimageio.spec import load_description as load_description -from bioimageio.spec._internal.base_nodes import ResourceDescriptionBase +from bioimageio.spec._description import ResourceDescr from bioimageio.spec._internal.constants import DISCOVER -from bioimageio.spec._internal.types import FileName, RdfContent, RelativeFilePath, Sha256, ValidationContext, YamlValue +from bioimageio.spec._internal.validation_context import ValidationContext +from bioimageio.spec._internal.io_utils import open_bioimageio_yaml from bioimageio.spec.common import BioimageioYamlContent, FileSource, InvalidDescription -from bioimageio.spec.model.v0_4 import WeightsFormat -from bioimageio.spec.package import extract_file_name, get_resource_package_content from bioimageio.spec.summary import ValidationSummary @@ -31,7 +17,7 @@ def load_description_and_validate( /, *, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, -) -> Union[ResourceDescription, InvalidDescription]: +) -> Union[ResourceDescr, InvalidDescription]: opened = open_bioimageio_yaml(source) return build_description_and_validate( @@ -47,15 +33,15 @@ def build_description_and_validate( *, context: Optional[ValidationContext] = None, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, -) -> Union[ResourceDescription, InvalidDescription]: +) -> Union[ResourceDescr, InvalidDescription]: """load and validate a BioImage.IO description from the content of a resource description file (RDF)""" - descr = build_description(rdf_content, context=context, format_version=format_version) + rd = build_description(data, context=context, format_version=format_version) # todo: add dynamic validation return rd def validate( - source: RdfSource, + source: "FileSource | BioimageioYamlContent", /, *, context: Optional[ValidationContext] = None, From 876602372a353ac12f2aed0b55bf96747266d2f2 Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Fri, 24 Nov 2023 15:16:30 +0100 Subject: [PATCH 075/244] Fixes conftest --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a4efa28a..6a8366ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,8 @@ from pytest import FixtureRequest, fixture os.environ["BIOIMAGEIO_COUNT_RDF_DOWNLOADS"] = "false" # disable tracking before bioimageio imports -from bioimageio.core import write_package from bioimageio.spec import __version__ as bioimageio_spec_version +from bioimageio.spec._package import save_bioimageio_package logger = logging.getLogger(__name__) warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}") @@ -133,7 +133,7 @@ @fixture(scope="session") def model_packages(): - return MappingProxyType({name: write_package(MODEL_SOURCES[name]) for name in load_model_packages}) + return MappingProxyType({name: save_bioimageio_package(MODEL_SOURCES[name]) for name in load_model_packages}) @fixture(scope="session") From e3b253b8191f1bba5486e345748dd00164dd3a13 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 27 Nov 2023 09:56:54 +0100 Subject: [PATCH 076/244] remove build_model funcs --- bioimageio/core/build_spec/__init__.py | 3 - bioimageio/core/build_spec/_build_spec.py | 241 ----- bioimageio/core/build_spec/build_model.py | 945 ------------------ bioimageio/core/build_spec/get_description.py | 314 ------ 4 files changed, 1503 deletions(-) delete mode 100644 bioimageio/core/build_spec/__init__.py delete mode 100644 bioimageio/core/build_spec/_build_spec.py delete mode 100644 bioimageio/core/build_spec/build_model.py delete mode 100644 bioimageio/core/build_spec/get_description.py diff --git a/bioimageio/core/build_spec/__init__.py b/bioimageio/core/build_spec/__init__.py deleted file mode 100644 index e4dcb7b7..00000000 --- a/bioimageio/core/build_spec/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from ._build_spec import build_model_spec as build_model_spec -from .add_weights import add_weights -from .build_model import build_model diff --git a/bioimageio/core/build_spec/_build_spec.py b/bioimageio/core/build_spec/_build_spec.py deleted file mode 100644 index d0bfb0de..00000000 --- a/bioimageio/core/build_spec/_build_spec.py +++ /dev/null @@ -1,241 +0,0 @@ -import collections.abc -import shutil -from datetime import datetime -from pathlib import Path -from typing import Any, Optional, Sequence, Type, TypedDict, Union - -import numpy as np -from numpy.typing import NDArray - -# from bioimageio.core import export_resource_package, load_raw_resource_description -from typing_extensions import NotRequired, Self, Unpack - -from bioimageio.core.io import FileSource, download, load_description_and_validate, write_description -from bioimageio.spec.model.v0_5 import ( - Architecture, - Author, - CiteEntry, - Dependencies, - InputAxis, - InputTensor, - IntervalOrRatioData, - IntervalOrRatioDType, - LicenseId, - Maintainer, - Model, - NotEmpty, - OutputAxis, - OutputTensor, - Postprocessing, - Preprocessing, - PytorchStateDictWeights, - RelativeFilePath, - Sha256, - TensorData, - TensorId, - Version, - Weights, -) - - -class CoreGenericBaseKwargs(TypedDict): - name: str - description: str - authors: NotEmpty[Sequence[Author]] - maintainers: NotRequired[Sequence[Maintainer]] - tags: Sequence[str] - documentation: FileSource - cite: NotEmpty[Sequence[CiteEntry]] - license: LicenseId - output_path: Path - - -class CoreTensorKwargs(TypedDict): - test_tensor: FileSource - sample_tensor: NotRequired[FileSource] - id: NotRequired[Optional[TensorId]] - data: NotRequired[Optional[Union[TensorData, NotEmpty[Sequence[TensorData]]]]] - output_path: Path - - -class CoreInputTensorKwargs(CoreTensorKwargs): - axes: NotEmpty[Sequence[InputAxis]] - preprocessing: NotRequired[Sequence[Preprocessing]] - - -class CoreOutputTensorKwargs(CoreTensorKwargs): - axes: NotEmpty[Sequence[OutputAxis]] - postprocessing: NotRequired[Sequence[Postprocessing]] - - -def ensure_file_in_folder(source: FileSource, folder: Path) -> RelativeFilePath: - """download/copy `source` to `folder` if `source` is not already in (a subfolder of) `folder`. - Returns a relative file path (relative to `folder`)""" - path = download(source).path - try: - rel_path = path.relative_to(folder) # todo: improve for py >= 3.9 with path.is_relative_to - except ValueError: - path = Path(shutil.copy(path, folder)) - rel_path = path.relative_to(folder) - - return RelativeFilePath(rel_path) - - -class _CoreTensorMixin: - @staticmethod - def get_data_description(kwargs: Union[CoreInputTensorKwargs, CoreOutputTensorKwargs]): - tensor_data = kwargs.get("data") - if isinstance(tensor_data, TensorData): - return tensor_data - elif tensor_data is None: - test_tensor: NDArray[Any] = np.load(download(kwargs["test_tensor"]).path) - assert isinstance(test_tensor, np.ndarray) - dtype_str = str(test_tensor.dtype) - if dtype_str.startswith("float") and test_tensor.min() >= 0.0 and test_tensor.max() <= 1.0: - range_ = (0.0, 1.0) - else: - range_ = (None, None) - - dtype: IntervalOrRatioDType = dtype_str # type: ignore # validated by IntervalOrRatioData - return IntervalOrRatioData(type=dtype, range=range_, unit="arbitrary unit", scale=1.0, offset=None) - elif isinstance(tensor_data, collections.abc.Sequence): # pyright: ignore[reportUnnecessaryIsInstance] - assert all(isinstance(td, TensorData) for td in tensor_data) - return tuple(tensor_data) - else: - raise TypeError(tensor_data) - - -class _CoreInputTensor(InputTensor, _CoreTensorMixin): - @classmethod - def build(cls, **kwargs: Unpack[CoreInputTensorKwargs]): - return cls( - test_tensor=ensure_file_in_folder(kwargs["test_tensor"], kwargs["output_path"]), - id=kwargs.get("id") or TensorId("input"), - axes=tuple(kwargs["axes"]), - preprocessing=tuple(kwargs.get("preprocessing", ())), - data=cls.get_data_description(kwargs), - sample_tensor=ensure_file_in_folder(kwargs["sample_tensor"], kwargs["output_path"]) - if "sample_tensor" in kwargs - else None, - ) - - -class _CoreOutputTensor(OutputTensor, _CoreTensorMixin): - @classmethod - def build(cls, **kwargs: Unpack[CoreOutputTensorKwargs]): - return cls( - test_tensor=ensure_file_in_folder(kwargs["test_tensor"], kwargs["output_path"]), - id=kwargs.get("id") or TensorId("output"), - axes=tuple(kwargs["axes"]), - postprocessing=tuple(kwargs.get("postprocessing", ())), - data=cls.get_data_description(kwargs), - ) - - -class CoreModelBaseKwargs(CoreGenericBaseKwargs): - inputs: NotEmpty[Sequence[CoreInputTensorKwargs]] - outputs: NotEmpty[Sequence[CoreOutputTensorKwargs]] - - -class CoreModelKwargs(CoreModelBaseKwargs): - weights: Weights - - -class _CoreModel(Model): - @classmethod - def build(cls, **kwargs: Unpack[CoreModelKwargs]) -> Self: - documentation = ensure_file_in_folder(kwargs["documentation"], kwargs["output_path"]) - - inputs = tuple( - _CoreInputTensor.build( - id=t_kwargs["id"] if "id" in t_kwargs else TensorId(f"input{i}"), - test_tensor=t_kwargs["test_tensor"], - axes=t_kwargs["axes"], - data=t_kwargs.get("data"), - output_path=kwargs["output_path"], - ) - for i, t_kwargs in enumerate(kwargs["inputs"]) - ) - - outputs = tuple( - _CoreOutputTensor.build( - id=t_kwargs["id"] if "id" in t_kwargs else TensorId(f"output{i}"), - test_tensor=t_kwargs["test_tensor"], - axes=t_kwargs["axes"], - data=t_kwargs.get("data"), - output_path=kwargs["output_path"], - ) - for i, t_kwargs in enumerate(kwargs["outputs"]) - ) - - return cls( - name=kwargs["name"], - description=kwargs["description"], - authors=tuple(kwargs["authors"]), - maintainers=tuple(kwargs.get("maintainers", ())), - cite=tuple(kwargs["cite"]), - license=kwargs["license"], - timestamp=datetime.now(), - inputs=inputs, - outputs=outputs, - weights=kwargs["weights"], - documentation=documentation, - ) - - @classmethod - def build_from_pytorch_state_dict( - cls, - weights: FileSource, - architecture: Architecture, - sha256: Optional[Sha256] = None, - pytorch_version: Optional[Version] = None, - dependencies: Optional[Dependencies] = None, - **kwargs: Unpack[CoreModelBaseKwargs], - ): - if pytorch_version is None: - import torch - - pytorch_version = Version(torch.__version__) - - return cls.build( - weights=Weights( - pytorch_state_dict=PytorchStateDictWeights( - source=ensure_file_in_folder(weights, kwargs["output_path"]), - sha256=sha256, - architecture=architecture, - pytorch_version=pytorch_version, - dependencies=dependencies, - ) - ), - **kwargs, - ) - - -def _build_spec_common(core_descr: _CoreModel, descr_path: Path, expected_type: Type[Any]): - write_description(core_descr, descr_path) - loaded = load_description_and_validate(descr_path) - if type(loaded) is not expected_type: - raise RuntimeError(f"Created {descr_path} was loaded as {type(loaded)}, but expected {expected_type}") - - return descr_path, loaded - - -def build_model_spec( - *, - weights: FileSource, - architecture: Architecture, - sha256: Optional[Sha256] = None, - pytorch_version: Optional[Version] = None, - dependencies: Optional[Dependencies] = None, - **kwargs: Unpack[CoreModelBaseKwargs], -): - model = _CoreModel.build_from_pytorch_state_dict( - weights=weights, - architecture=architecture, - sha256=sha256, - pytorch_version=pytorch_version, - dependencies=dependencies, - **kwargs, - ) - - return _build_spec_common(model, kwargs["output_path"] / "description.bioimageio.yaml", Model) diff --git a/bioimageio/core/build_spec/build_model.py b/bioimageio/core/build_spec/build_model.py deleted file mode 100644 index 10c9dfa1..00000000 --- a/bioimageio/core/build_spec/build_model.py +++ /dev/null @@ -1,945 +0,0 @@ -import datetime -import hashlib -import os -from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, TypedDict, Union, get_args -from warnings import warn - -import imageio -import numpy as np -import requests -import tifffile -from numpy.typing import NDArray - -# from bioimageio.core import export_resource_package, load_raw_resource_description -from pydantic import AnyUrl, HttpUrl -from typing_extensions import NotRequired, Unpack - -import bioimageio.spec as spec -import bioimageio.spec.model as model_spec -from bioimageio.core.io import FileSource, download -from bioimageio.core.utils import import_callable -from bioimageio.spec.model.v0_5 import ( - Author, - CiteEntry, - InputAxis, - InputTensor, - IntervalOrRatioData, - LicenseId, - Maintainer, - Model, - NominalOrOrdinalData, - NotEmpty, - OutputAxis, - Postprocessing, - Preprocessing, - TensorData, - TensorId, -) - -# -# utility functions to build the spec from python -# - - -def _get_hash(path): - with open(path, "rb") as f: - data = f.read() - return hashlib.sha256(data).hexdigest() - - -def _infer_weight_type(path): - ext = os.path.splitext(path)[-1] - if ext in (".pt", ".pth", ".torch"): - return "pytorch_state_dict" - elif ext == ".onnx": - return "onnx" - elif ext in (".hdf", ".hdf5", ".h5"): - return "keras_hdf5" - elif ext == ".zip": - return "tensorflow_saved_model_bundle" - elif ext == ".json": - return "tensorflow_js" - else: - raise ValueError(f"Could not infer weight type from extension {ext} for weight file {path}") - - -def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root): - assert architecture is not None - tmp_archtecture = None - weight_kwargs = {"kwargs": model_kwargs} if model_kwargs else {} - if ":" in architecture: - # note: path itself might include : for absolute paths in windows - *arch_file_parts, callable_name = architecture.replace("::", ":").split(":") - arch_file = _ensure_local(":".join(arch_file_parts), root) - arch = ImportableSourceFile(callable_name, arch_file) - arch_hash = _get_hash(root / arch.source_file) - weight_kwargs["architecture_sha256"] = arch_hash - else: - arch = spec.shared.fields.ImportableSource().deserialize(architecture) - assert isinstance(arch, ImportableModule) - - weight_kwargs["architecture"] = arch - return weight_kwargs, tmp_archtecture - - -def _get_attachments(attachments, root): - assert isinstance(attachments, dict) - if "files" in attachments: - afiles = attachments["files"] - if isinstance(afiles, str): - afiles = [afiles] - - if isinstance(afiles, list): - afiles = _ensure_local_or_url(afiles, root) - else: - raise TypeError(attachments) - - attachments["files"] = afiles - return attachments - - -def _get_weights( - original_weight_source, - weight_type, - root, - architecture=None, - model_kwargs=None, - tensorflow_version=None, - opset_version=None, - pytorch_version=None, - dependencies=None, - attachments=None, -): - weight_path = resolve_source(original_weight_source, root) - if weight_type is None: - weight_type = _infer_weight_type(weight_path) - weight_hash = _get_hash(weight_path) - - weight_types = model_spec.raw_nodes.WeightsFormat - weight_source = _ensure_local_or_url(original_weight_source, root) - - weight_kwargs = {"source": weight_source, "sha256": weight_hash} - if attachments is not None: - weight_kwargs["attachments"] = _get_attachments(attachments, root) - if dependencies is not None: - weight_kwargs["dependencies"] = _get_dependencies(dependencies, root) - - tmp_archtecture = None - if weight_type == "pytorch_state_dict": - # pytorch-state-dict -> we need an architecture definition - pytorch_weight_kwargs, tmp_file = _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root) - weight_kwargs.update(**pytorch_weight_kwargs) - if pytorch_version is not None: - weight_kwargs["pytorch_version"] = pytorch_version - elif dependencies is None: - warn( - "You are building a pytorch model but have neither passed dependencies nor the pytorch_version." - "It may not be possible to create an environmnet where your model can be used." - ) - weights = model_spec.raw_nodes.PytorchStateDictWeightsEntry(**weight_kwargs) - - elif weight_type == "onnx": - if opset_version is not None: - weight_kwargs["opset_version"] = opset_version - elif dependencies is None: - warn( - "You are building an onnx model but have neither passed dependencies nor the opset_version." - "It may not be possible to create an environmnet where your model can be used." - ) - weights = model_spec.raw_nodes.OnnxWeightsEntry(**weight_kwargs) - - elif weight_type == "torchscript": - if pytorch_version is not None: - weight_kwargs["pytorch_version"] = pytorch_version - elif dependencies is None: - warn( - "You are building a pytorch model but have neither passed dependencies nor the pytorch_version." - "It may not be possible to create an environmnet where your model can be used." - ) - weights = model_spec.raw_nodes.TorchscriptWeightsEntry(**weight_kwargs) - - elif weight_type == "keras_hdf5": - if tensorflow_version is not None: - weight_kwargs["tensorflow_version"] = tensorflow_version - elif dependencies is None: - warn( - "You are building a keras model but have neither passed dependencies nor the tensorflow_version." - "It may not be possible to create an environmnet where your model can be used." - ) - weights = model_spec.raw_nodes.KerasHdf5WeightsEntry(**weight_kwargs) - - elif weight_type == "tensorflow_saved_model_bundle": - if tensorflow_version is not None: - weight_kwargs["tensorflow_version"] = tensorflow_version - elif dependencies is None: - warn( - "You are building a tensorflow model but have neither passed dependencies nor the tensorflow_version." - "It may not be possible to create an environmnet where your model can be used." - ) - weights = model_spec.raw_nodes.TensorflowSavedModelBundleWeightsEntry(**weight_kwargs) - - elif weight_type == "tensorflow_js": - if tensorflow_version is not None: - weight_kwargs["tensorflow_version"] = tensorflow_version - elif dependencies is None: - warn( - "You are building a tensorflow model but have neither passed dependencies nor the tensorflow_version." - "It may not be possible to create an environmnet where your model can be used." - ) - weights = model_spec.raw_nodes.TensorflowJsWeightsEntry(**weight_kwargs) - - elif weight_type in weight_types: - raise ValueError(f"Weight type {weight_type} is not supported yet in 'build_spec'") - else: - raise ValueError(f"Invalid weight type {weight_type}, expect one of {weight_types}") - - return {weight_type: weights}, tmp_archtecture - - -def _get_data_range(data_range, dtype): - if data_range is None: - if np.issubdtype(np.dtype(dtype), np.integer): - min_, max_ = np.iinfo(dtype).min, np.iinfo(dtype).max - # for floating point numbers we assume valid range from -inf to inf - elif np.issubdtype(np.dtype(dtype), np.floating): - min_, max_ = -np.inf, np.inf - elif np.issubdtype(np.dtype(dtype), np.bool): - min_, max_ = 0, 1 - else: - raise RuntimeError(f"Cannot derived data range for dtype {dtype}") - data_range = (min_, max_) - assert isinstance(data_range, (tuple, list)), type(data_range) - assert len(data_range) == 2 - return data_range - - -def _get_input_tensor(path, name, step, min_shape, data_range, axes, preprocessing): - test_in = np.load(path) - shape = test_in.shape - if step is None: - assert min_shape is None - shape_description = shape - else: - shape_description = {"min": shape if min_shape is None else min_shape, "step": step} - - data_range = _get_data_range(data_range, test_in.dtype) - kwargs = {} - if preprocessing is not None: - kwargs["preprocessing"] = preprocessing - - inputs = model_spec.raw_nodes.InputTensor( - name="input" if name is None else name, - data_type=str(test_in.dtype), - axes=axes, - shape=shape_description, - data_range=data_range, - **kwargs, - ) - return inputs - - -def _get_output_tensor(path, name, reference_tensor, scale, offset, axes, data_range, postprocessing, halo): - test_out = np.load(path) - shape = test_out.shape - if reference_tensor is None: - assert scale is None - assert offset is None - shape_description = shape - else: - assert scale is not None - assert offset is not None - shape_description = {"reference_tensor": reference_tensor, "scale": scale, "offset": offset} - - data_range = _get_data_range(data_range, test_out.dtype) - kwargs = {} - if postprocessing is not None: - kwargs["postprocessing"] = postprocessing - if halo is not None: - kwargs["halo"] = halo - - outputs = model_spec.raw_nodes.OutputTensor( - name="output" if name is None else name, - data_type=str(test_out.dtype), - axes=axes, - data_range=data_range, - shape=shape_description, - **kwargs, - ) - return outputs - - -def _build_cite(cite: List[Dict[str, str]]): - citation_list = [] - for entry in cite: - if "doi" in entry: - spec_entry = spec.rdf.raw_nodes.CiteEntry(text=entry["text"], doi=entry["doi"]) - elif "url" in entry: - spec_entry = spec.rdf.raw_nodes.CiteEntry(text=entry["text"], url=entry["url"]) - else: - raise ValueError(f"Expect one of doi or url in citation enrty {entry}") - citation_list.append(spec_entry) - return citation_list - - -def _get_dependencies(dependencies, root): - if isinstance(dependencies, Path) or ":" not in dependencies: - manager = "conda" - path = dependencies - else: - manager, path = dependencies.split(":") - - return model_spec.raw_nodes.Dependencies(manager=manager, file=_ensure_local(path, root)) - - -def _get_deepimagej_macro(name, kwargs, export_folder): - # macros available in deepimagej - macro_names = ("binarize", "scale_linear", "scale_range", "zero_mean_unit_variance") - if name == "scale_linear": - macro = "scale_linear.ijm" - replace = {"gain": kwargs["gain"], "offset": kwargs["offset"]} - - elif name == "scale_range": - macro = "per_sample_scale_range.ijm" - replace = {"min_precentile": kwargs["min_percentile"], "max_percentile": kwargs["max_percentile"]} - - elif name == "zero_mean_unit_variance": - mode = kwargs["mode"] - if mode == "fixed": - macro = "fixed_zero_mean_unit_variance.ijm" - replace = {"paramMean": kwargs["mean"], "paramStd": kwargs["std"]} - else: - macro = "zero_mean_unit_variance.ijm" - replace = {} - - elif name == "binarize": - macro = "binarize.ijm" - replace = {"optimalThreshold": kwargs["threshold"]} - - else: - raise ValueError(f"Macro {name} is not available, must be one of {macro_names}.") - - url = f"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/bioimage.io/{macro}" - - path = os.path.join(export_folder, macro) - # use https://github.com/bioimage-io/core-bioimage-io-python/blob/main/bioimageio/core/resource_io/utils.py#L267 - # instead if the implementation is update s.t. an output path is accepted - with requests.get(url, stream=True) as r: - text = r.text - if text.startswith("4"): - raise RuntimeError(f"An error occured when downloading {url}: {r.text}") - with open(path, "w") as f: - f.write(r.text) - - # replace the kwargs in the macro file - if replace: - lines = [] - with open(path) as f: - for line in f: - kwarg = [kwarg for kwarg in replace if line.startswith(kwarg)] - if kwarg: - assert len(kwarg) == 1 - kwarg = kwarg[0] - # each kwarg should only be replaced ones - val = replace.pop(kwarg) - lines.append(f"{kwarg} = {val};\n") - else: - lines.append(line) - - with open(path, "w") as f: - for line in lines: - f.write(line) - - return {"spec": "ij.IJ::runMacroFile", "kwargs": macro} - - -def _get_deepimagej_config( - export_folder, test_inputs, test_outputs, input_axes, output_axes, pixel_sizes, preprocessing, postprocessing -): - assert len(test_inputs) == len(test_outputs) == 1, "deepimagej config only valid for single input/output" - - if any(preproc is not None for preproc in preprocessing): - assert len(preprocessing) == 1 - preprocess_ij = [ - _get_deepimagej_macro(preproc["name"], preproc["kwargs"], export_folder) for preproc in preprocessing[0] - ] - attachments = [preproc["kwargs"] for preproc in preprocess_ij] - else: - preprocess_ij = [{"spec": None}] - attachments = [] - - if any(postproc is not None for postproc in postprocessing): - assert len(postprocessing) == 1 - postprocess_ij = [ - _get_deepimagej_macro(postproc["name"], postproc["kwargs"], export_folder) for postproc in postprocessing[0] - ] - attachments.extend([postproc["kwargs"] for postproc in postprocess_ij]) - else: - postprocess_ij = [{"spec": None}] - - def get_size(fname, axes): - shape = np.load(export_folder / fname).shape - assert len(shape) == len(axes) - shape = [sh for sh, ax in zip(shape, axes) if ax != "b"] - axes = [ax for ax in axes if ax != "b"] - # the shape for deepij is always given as xyzc - if len(shape) == 3: - axes_ij = "xyc" - else: - axes_ij = "xyzc" - assert set(axes) == set(axes_ij) - axis_permutation = [axes_ij.index(ax) for ax in axes] - shape = [shape[permut] for permut in axis_permutation] - if len(shape) == 3: - shape = shape[:2] + [1] + shape[-1:] - assert len(shape) == 4 - return " x ".join(map(str, shape)) - - # deepimagej always expexts a pixel size for the z axis - pixel_sizes_ = [pix_size if "z" in pix_size else dict(z=1.0, **pix_size) for pix_size in pixel_sizes] - - test_info = { - "inputs": [ - {"name": in_path, "size": get_size(in_path, axes), "pixel_size": pix_size} - for in_path, axes, pix_size in zip(test_inputs, input_axes, pixel_sizes_) - ], - "outputs": [ - {"name": out_path, "type": "image", "size": get_size(out_path, axes)} - for out_path, axes in zip(test_outputs, output_axes) - ], - "memory_peak": None, - "runtime": None, - } - - config = { - "prediction": {"preprocess": preprocess_ij, "postprocess": postprocess_ij}, - "test_information": test_info, - # other stuff deepimagej needs - "pyramidal_model": False, - "allow_tiling": True, - "model_keys": None, - } - return {"deepimagej": config}, [Path(a) for a in attachments] - - -def _write_sample_data(input_paths, output_paths, input_axes, output_axes, pixel_sizes, export_folder: Path): - def write_im(path, im, axes, pixel_size=None): - assert len(axes) == im.ndim, f"{len(axes), {im.ndim}}" - assert im.ndim in (4, 5), f"{im.ndim}" - - # convert the image to expects (Z)CYX axis order - if im.ndim == 4: - assert set(axes) == {"b", "x", "y", "c"}, f"{axes}" - resolution_axes_ij = "cyxb" - else: - assert set(axes) == {"b", "x", "y", "z", "c"}, f"{axes}" - resolution_axes_ij = "bzcyx" - - def addMissingAxes(im_axes): - needed_axes = ["b", "c", "x", "y", "z", "s"] - for ax in needed_axes: - if ax not in im_axes: - im_axes += ax - return im_axes - - axes_ij = "bzcyxs" - # Expand the image to ImageJ dimensions - im = np.expand_dims(im, axis=tuple(range(len(axes), len(axes_ij)))) - - axis_permutation = tuple(addMissingAxes(axes).index(ax) for ax in axes_ij) - im = im.transpose(axis_permutation) - - if pixel_size is None: - resolution = None - else: - spatial_axes = list(set(resolution_axes_ij) - set("bc")) - resolution = tuple(1.0 / pixel_size[ax] for ax in resolution_axes_ij if ax in spatial_axes) - # does not work for double - if np.dtype(im.dtype) == np.dtype("float64"): - im = im.astype("float32") - tifffile.imwrite(path, im, imagej=True, resolution=resolution) - - sample_in_paths = [] - for i, (in_path, axes) in enumerate(zip(input_paths, input_axes)): - inp = np.load(export_folder / in_path) - sample_in_path = export_folder / f"sample_input_{i}.tif" - pixel_size = None if pixel_sizes is None else pixel_sizes[i] - write_im(sample_in_path, inp, axes, pixel_size) - sample_in_paths.append(sample_in_path) - - sample_out_paths = [] - for i, (out_path, axes) in enumerate(zip(output_paths, output_axes)): - outp = np.load(export_folder / out_path) - sample_out_path = export_folder / f"sample_output_{i}.tif" - write_im(sample_out_path, outp, axes) - sample_out_paths.append(sample_out_path) - - return [Path(p.name) for p in sample_in_paths], [Path(p.name) for p in sample_out_paths] - - -# create better cover images for 3d data and non-image outputs -def _generate_covers(in_path, out_path, input_axes, output_axes, root): - def normalize(data, axis, eps=1e-7): - data = data.astype("float32") - data -= data.min(axis=axis, keepdims=True) - data /= data.max(axis=axis, keepdims=True) + eps - return data - - def to_image(data, data_axes): - assert data.ndim in (4, 5) - - # transpose the data to "bczyx" / "bcyx" order - axes = "bczyx" if data.ndim == 5 else "bcyx" - assert set(data_axes) == set(axes) - if axes != data_axes: - ax_permutation = tuple(data_axes.index(ax) for ax in axes) - data = data.transpose(ax_permutation) - - # select single image with channels from the data - if data.ndim == 5: - z0 = data.shape[2] // 2 - data = data[0, :, z0] - else: - data = data[0, :] - - # normalize the data and map to 8 bit - data = normalize(data, axis=(1, 2)) - data = (data * 255).astype("uint8") - return data - - cover_path = os.path.join(root, "cover.png") - input_, output = np.load(in_path), np.load(out_path) - - input_ = to_image(input_, input_axes) - # this is not image data so we only save the input image - if output.ndim < 4: - imageio.imwrite(cover_path, input_.transpose((1, 2, 0))) - return [_ensure_local(cover_path, root)] - output = to_image(output, output_axes) - - chan_in = input_.shape[0] - # make sure the input is rgb - if chan_in == 1: # single channel -> repeat it 3 times - input_ = np.repeat(input_, 3, axis=0) - elif chan_in != 3: # != 3 channels -> take first channe and repeat it 3 times - input_ = np.repeat(input_[0:1], 3, axis=0) - - im_shape = input_.shape[1:] - # we just save the input image if the shapes don't agree - if im_shape != output.shape[1:]: - imageio.imwrite(cover_path, input_.transpose((1, 2, 0))) - return [_ensure_local(cover_path, root)] - - def diagonal_split(im0, im1): - assert im0.shape[0] == im1.shape[0] == 3 - n, m = im_shape - out = np.ones((3, n, m), dtype="uint8") - for c in range(3): - outc = np.tril(im0[c]) - mask = outc == 0 - outc[mask] = np.triu(im1[c])[mask] - out[c] = outc - return out - - def grid_im(im0, im1): - ims_per_row = 3 - n_chan = im1.shape[0] - n_images = n_chan + 1 - n_rows = int(np.ceil(float(n_images) / ims_per_row)) - - n, m = im_shape - x, y = ims_per_row * n, n_rows * m - out = np.zeros((3, y, x), dtype=im0.dtype) - images = [im0] + [np.repeat(im1[i : i + 1], 3, axis=0) for i in range(n_chan)] - - i, j = 0, 0 - for im in images: - x0, x1 = i * n, (i + 1) * n - y0, y1 = j * m, (j + 1) * m - out[:, y0:y1, x0:x1] = im - - i += 1 - if i == ims_per_row: - i = 0 - j += 1 - - return out - - chan_out = output.shape[0] - if chan_out == 1: # single prediction channel: create diagonal split - im = diagonal_split(input_, np.repeat(output, 3, axis=0)) - elif chan_out == 3: # three prediction channel: create diagonal split with rgb - im = diagonal_split(input_, output) - else: # otherwise create grid image - im = grid_im(input_, output) - - # to channel last - imageio.imwrite(cover_path, im.transpose((1, 2, 0))) - return [_ensure_local(cover_path, root)] - - -def _ensure_local(source: Union[Path, URI, str, list], root: Path) -> Union[Path, URI, list]: - """ensure source is local relative path in root""" - if isinstance(source, list): - return [_ensure_local(s, root) for s in source] - - local_source = resolve_source(source, root) - local_source = resolve_source(local_source, root, root / local_source.name) - return local_source.relative_to(root) - - -def _ensure_local_or_url(source: Union[Path, URI, str, list], root: Path) -> Union[Path, URI, list]: - """ensure source is remote URI or local relative path in root""" - if isinstance(source, list): - return [_ensure_local_or_url(s, root) for s in source] - - local_source = resolve_local_source(source, root) - if not isinstance(local_source, URI): - local_source = resolve_local_source(local_source, root, root / local_source.name) - return local_source.relative_to(root) - - -def build_model( - # model specific and required - weight_uri: FileSource, - test_inputs: List[FileSource], - test_outputs: List[FileSource], - input_axes: List[str], - output_axes: List[str], - # general metadata - name: str, - description: str, - authors: List[Dict[str, str]], - tags: List[Union[str, Path]], - documentation: Union[str, Path], - cite: List[Dict[str, str]], - output_path: Union[str, Path], - # model specific optional - architecture: Optional[str] = None, - model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None, - weight_type: Optional[str] = None, - sample_inputs: Optional[List[str]] = None, - sample_outputs: Optional[List[str]] = None, - # tensor specific - input_names: Optional[List[str]] = None, - input_step: Optional[List[List[int]]] = None, - input_min_shape: Optional[List[List[int]]] = None, - input_data_range: Optional[List[List[Union[int, str]]]] = None, - output_names: Optional[List[str]] = None, - output_reference: Optional[List[str]] = None, - output_scale: Optional[List[List[int]]] = None, - output_offset: Optional[List[List[int]]] = None, - output_data_range: Optional[List[List[Union[int, str]]]] = None, - halo: Optional[List[List[int]]] = None, - preprocessing: Optional[List[List[Dict[str, Dict[str, Union[int, float, str]]]]]] = None, - postprocessing: Optional[List[List[Dict[str, Dict[str, Union[int, float, str]]]]]] = None, - pixel_sizes: Optional[List[Dict[str, float]]] = None, - # general optional - maintainers: Optional[List[Dict[str, str]]] = None, - license: LicenseId = "CC-BY-4.0", - covers: Optional[List[str]] = None, - git_repo: Optional[str] = None, - attachments: Optional[Dict[str, Union[str, List[str]]]] = None, - packaged_by: Optional[List[str]] = None, - run_mode: Optional[str] = None, - parent: Optional[Dict[str, str]] = None, - config: Optional[Dict[str, Any]] = None, - dependencies: Optional[Union[Path, str]] = None, - links: Optional[List[str]] = None, - training_data: Optional[Dict[str, str]] = None, - root: Optional[Union[Path, str]] = None, - add_deepimagej_config: bool = False, - tensorflow_version: Optional[str] = None, - opset_version: Optional[int] = None, - pytorch_version: Optional[str] = None, - weight_attachments: Optional[Dict[str, Union[str, List[str]]]] = None, -): - """Create a zipped bioimage.io model. - - Example usage: - ``` - from pathlib import Path - import bioimageio.spec as spec - import bioimageio.core.build_spec as build_spec - model_spec = build_spec.build_model( - weight_uri="test_weights.pt", - test_inputs=["./test_inputs"], - test_outputs=["./test_outputs"], - input_axes=["bcyx"], - output_axes=["bcyx"], - name="my-model", - description="My very fancy model.", - authors=[{"name": "John Doe", "affiliation": "My Institute"}], - tags=["segmentation", "light sheet data"], - license="CC-BY-4.0", - documentation="./documentation.md", - cite=[{"text": "Ronneberger et al. U-Net", "doi": "10.1007/978-3-319-24574-4_28"}], - output_path="my-model.zip" - ) - ``` - - Args: - weight_uri: the url or relative local file path to the weight file for this model. - test_inputs: list of test input files stored in numpy format. - test_outputs: list of test outputs corresponding to test_inputs, stored in numpy format. - input_axes: axis names of the input tensors. - output_axes: axiss names of the output tensors. - name: name of this model. - description: short description of this model. - authors: the authors of this model. - tags: list of tags for this model. - documentation: relative file path to markdown documentation for this model. - cite: references for this model. - output_path: where to save the zipped model package. - architecture: the file with the source code for the model architecture and the corresponding class. - Only required for models with pytorch_state_dict weight format. - model_kwargs: the keyword arguments for the model class. - Only required for models with pytorch_state_dict weight format. - weight_type: the type of the weights. - sample_inputs: list of sample inputs to demonstrate the model performance. - sample_outputs: list of sample outputs corresponding to sample_inputs. - input_names: names of the input tensors. - input_step: minimal valid increase of the input tensor shape. - input_min_shape: minimal input tensor shape. - input_data_range: valid data range for the input tensor. - output_names: names of the output tensors. - output_reference: name of the input reference tensor used to cimpute the output tensor shape. - output_scale: multiplicative factor to compute the output tensor shape. - output_offset: additive term to compute the output tensor shape. - output_data_range: valid data range for the output tensor. - halo: halo to be cropped from the output tensor. - preprocessing: list of preprocessing operations for the input. - postprocessing: list of postprocessing operations for the output. - pixel_sizes: the pixel sizes for the input tensors, only for spatial axes. - This information is currently only used by deepimagej, but will be added to the spec soon. - license: the license for this model. By default CC-BY-4.0 will be set as license. - covers: list of file paths for cover images. - By default a cover will be generated from the input and output data. - git_repo: reference git repository for this model. - attachments: list of additional files to package with the model. - packaged_by: list of authors that have packaged this model. - run_mode: custom run mode for this model. - parent: id of the parent model from which this model is derived and sha256 of the corresponding rdf file. - config: custom configuration for this model. - dependencies: relative path to file with dependencies for this model. - training_data: the training data for this model, either id for a bioimageio dataset or a dataset spec. - root: optional root path for relative paths. This can be helpful when building a spec from another model spec. - add_deepimagej_config: add the deepimagej config to the model. - tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights. - opset_version: the opset version for this model. Only for onnx weights. - pytorch_version: the pytorch version for this model. Only for pytoch_state_dict or torchscript weights. - weight_attachments: extra weight specific attachments. - """ - assert architecture is None or isinstance(architecture, str) - if root is None: - root = "." - root = Path(root) - - if attachments is not None: - attachments = _get_attachments(attachments, root) - - # - # generate the model specific fields - # - - assert len(test_inputs) - assert len(test_outputs) - test_inputs = _ensure_local_or_url(test_inputs, root) - test_outputs = _ensure_local_or_url(test_outputs, root) - - n_inputs = len(test_inputs) - if input_names is None: - input_names = [f"input{i}" for i in range(n_inputs)] - else: - assert len(input_names) == len(test_inputs) - - input_step = n_inputs * [None] if input_step is None else input_step - input_min_shape = n_inputs * [None] if input_min_shape is None else input_min_shape - input_data_range = n_inputs * [None] if input_data_range is None else input_data_range - preprocessing = n_inputs * [None] if preprocessing is None else preprocessing - - inputs = [ - _get_input_tensor(root / test_in, name, step, min_shape, data_range, axes, preproc) - for test_in, name, step, min_shape, axes, data_range, preproc in zip( - test_inputs, input_names, input_step, input_min_shape, input_axes, input_data_range, preprocessing - ) - ] - - n_outputs = len(test_outputs) - if output_names is None: - output_names = [f"output{i}" for i in range(n_outputs)] - else: - assert len(output_names) == len(test_outputs) - - output_reference = n_outputs * [None] if output_reference is None else output_reference - output_scale = n_outputs * [None] if output_scale is None else output_scale - output_offset = n_outputs * [None] if output_offset is None else output_offset - output_data_range = n_outputs * [None] if output_data_range is None else output_data_range - postprocessing = n_outputs * [None] if postprocessing is None else postprocessing - halo = n_outputs * [None] if halo is None else halo - - outputs = [ - _get_output_tensor(root / test_out, name, reference, scale, offset, axes, data_range, postproc, hal) - for test_out, name, reference, scale, offset, axes, data_range, postproc, hal in zip( - test_outputs, - output_names, - output_reference, - output_scale, - output_offset, - output_axes, - output_data_range, - postprocessing, - halo, - ) - ] - - # validate the pixel sizes (currently only used by deepimagej) - spatial_axes = [[ax for ax in inp.axes if ax in "xyz"] for inp in inputs] - if pixel_sizes is None: - pixel_sizes = [{ax: 1.0 for ax in axes} for axes in spatial_axes] - else: - assert len(pixel_sizes) == n_inputs - for pix_size, axes in zip(pixel_sizes, spatial_axes): - assert isinstance(pix_size, dict) - assert set(pix_size.keys()) == set(axes) - - # - # generate general fields - # - format_version = get_args(model_spec.raw_nodes.FormatVersion)[-1] - timestamp = datetime.datetime.now() - - authors = [model_spec.raw_nodes.Author(**a) for a in authors] - cite = _build_cite(cite) - documentation = _ensure_local(documentation, root) - if covers is None: - covers = _generate_covers(root / test_inputs[0], root / test_outputs[0], input_axes[0], output_axes[0], root) - else: - covers = _ensure_local(covers, root) - if license is None: - license = "CC-BY-4.0" - - # parse the weights - weights, tmp_archtecture = _get_weights( - weight_uri, - weight_type, - root, - architecture, - model_kwargs, - tensorflow_version=tensorflow_version, - opset_version=opset_version, - pytorch_version=pytorch_version, - dependencies=dependencies, - attachments=weight_attachments, - ) - - # validate the sample inputs and outputs (if given) - if sample_inputs is not None: - assert sample_outputs is not None - assert len(sample_inputs) == n_inputs - assert len(sample_outputs) == n_outputs - - # add the deepimagej config if specified - if add_deepimagej_config: - if sample_inputs is None: - sample_inputs, sample_outputs = _write_sample_data( - test_inputs, test_outputs, input_axes, output_axes, pixel_sizes, root - ) - # deepimagej expect tifs as sample data - assert all(os.path.splitext(path)[1] in (".tif", ".tiff") for path in sample_inputs) - assert all(os.path.splitext(path)[1] in (".tif", ".tiff") for path in sample_outputs) - - ij_config, ij_attachments = _get_deepimagej_config( - root, test_inputs, test_outputs, input_axes, output_axes, pixel_sizes, preprocessing, postprocessing - ) - - if config is None: - config = ij_config - else: - config.update(ij_config) - - if ij_attachments is not None: - if attachments is None: - attachments = {"files": ij_attachments} - elif "files" not in attachments: - attachments["files"] = ij_attachments - else: - attachments["files"] = list(set(attachments["files"]) | set(ij_attachments)) - - if links is None: - links = ["deepimagej/deepimagej"] - else: - links.append("deepimagej/deepimagej") - - # make sure links are unique - if links is not None: - links = list(set(links)) - - # make sure sample inputs / outputs are relative paths - if sample_inputs is not None: - sample_inputs = _ensure_local_or_url(sample_inputs, root) - - if sample_outputs is not None: - sample_outputs = _ensure_local_or_url(sample_outputs, root) - - # optional kwargs, don't pass them if none - optional_kwargs = { - "config": config, - "git_repo": git_repo, - "packaged_by": packaged_by, - "run_mode": run_mode, - "sample_inputs": sample_inputs, - "sample_outputs": sample_outputs, - "links": links, - } - kwargs = {k: v for k, v in optional_kwargs.items() if v is not None} - - if attachments is not None: - kwargs["attachments"] = spec.rdf.raw_nodes.Attachments(**attachments) - - if maintainers is not None: - kwargs["maintainers"] = [model_spec.raw_nodes.Maintainer(**m) for m in maintainers] - - if parent is not None: - kwargs["parent"] = parent - - if training_data is not None: - if "id" in training_data: - msg = f"If training data is specified via 'id' no other keys are allowed, got {training_data}" - assert len(training_data) == 1, msg - kwargs["training_data"] = training_data - else: - if "type" not in training_data: - training_data["type"] = "dataset" - if "format_version" not in training_data: - training_data["format_version"] = spec.dataset.format_version - - try: - model = model_spec.raw_nodes.Model( - authors=authors, - cite=cite, - covers=covers, - description=description, - documentation=documentation, - format_version=format_version, - inputs=inputs, - license=license, - name=name, - outputs=outputs, - root_path=root, - tags=tags, - test_inputs=test_inputs, - test_outputs=test_outputs, - timestamp=timestamp, - weights=weights, - **kwargs, - ) - model_package = export_resource_package(model, output_path=output_path) - except Exception as e: - raise e - finally: - if tmp_archtecture is not None: - os.remove(tmp_archtecture) - - model = load_raw_resource_description(model_package) - return model diff --git a/bioimageio/core/build_spec/get_description.py b/bioimageio/core/build_spec/get_description.py deleted file mode 100644 index c4cd0168..00000000 --- a/bioimageio/core/build_spec/get_description.py +++ /dev/null @@ -1,314 +0,0 @@ -import hashlib -import shutil -import warnings -from datetime import datetime -from pathlib import Path -from typing import Any, List, Optional, Sequence, Type, TypedDict, Union - -import numpy as np -from numpy.typing import NDArray -from pydantic import FilePath - -# from bioimageio.core import export_resource_package, load_raw_resource_description -from typing_extensions import NotRequired, Self, Unpack - -from bioimageio.core.io import FileSource, download, load_description_and_validate, write_description -from bioimageio.core.utils import get_sha256 -from bioimageio.spec.description import ValidationContext -from bioimageio.spec.model.v0_5 import ( - Architecture, - Author, - CiteEntry, - Dependencies, - InputAxis, - InputTensor, - IntervalOrRatioData, - IntervalOrRatioDType, - LicenseId, - Maintainer, - Model, - NotEmpty, - OutputAxis, - OutputTensor, - Postprocessing, - Preprocessing, - PytorchStateDictWeights, - RelativeFilePath, - Sha256, - TensorData, - TensorId, - Version, - Weights, -) - - -class _CoreGenericBaseKwargs(TypedDict): - name: str - description: str - authors: NotEmpty[Sequence[Author]] - maintainers: NotRequired[Sequence[Maintainer]] - tags: Sequence[str] - documentation: FileSource - cite: NotEmpty[Sequence[CiteEntry]] - license: LicenseId - output_path: Path - - -class _TensorKwargs(TypedDict): - test_tensor: FileSource - sample_tensor: NotRequired[FileSource] - id: NotRequired[Optional[TensorId]] - data: NotRequired[Optional[Union[TensorData, NotEmpty[Sequence[TensorData]]]]] - output_path: Path - - -class _OutputTensorKwargs(_TensorKwargs): - axes: NotEmpty[Sequence[OutputAxis]] - postprocessing: NotRequired[Sequence[Postprocessing]] - - -class SpecBuilder: - def __init__(self, output_path: Path, output_path_exist_ok: bool = False) -> None: - super().__init__() - output_path.mkdir(parents=True, exist_ok=output_path_exist_ok) - self.output_path = output_path - - def include_file(self, source: FileSource) -> RelativeFilePath: - local_source = download(source) - try: - rel_path = local_source.path.relative_to( - self.output_path - ) # todo: improve for py >= 3.9 with path.is_relative_to - except ValueError: - # local source is not in output path - dest_path = self.output_path / local_source.original_file_name - if dest_path.exists(): - file_hash = get_sha256(local_source.path) - for i in range(10): - dest_hash = get_sha256(dest_path) - if dest_hash == file_hash: - break - - dest_path = dest_path.with_name(f"{dest_path.stem}-{i}{dest_path.suffix}") - if not dest_path.exists(): - break - else: - raise RuntimeError("Encountered too many unidentical files with the same file name.") - - if not dest_path.exists(): - _ = Path(shutil.copy(local_source.path, dest_path)) - - rel_path = dest_path.relative_to(self.output_path) - - return RelativeFilePath(rel_path) - - -class ModelBuilder(SpecBuilder): - def add_cite(self): - self._cite.append(CiteEntry()) - - def add_input_tensor( - self, - *, - test_tensor: Union[NDArray[Any], FileSource], - axes: Sequence[InputAxis], - preprocessing: Sequence[Preprocessing], - id_: TensorId, - data: TensorData, - sample_tensor: Optional[FileSource], - ) -> InputTensor: - return InputTensor.model_validate( - InputTensor( - test_tensor=self.include_file(test_tensor), - id=id_, - axes=tuple(axes), - preprocessing=tuple(preprocessing), - data=data, - sample_tensor=None if sample_tensor is None else self.include_file(sample_tensor), - ), - context=self.context, - ) - - # def add_input_tensor() - # def add_cover_image(cover) - def build(self, output_path: Path, *, inputs: Sequence[InputTensor]): - assert False - - -mb = ModelBuilder(Path("output_path")) -mb.build(inputs=[mb.build_input_tensor(test_tensor=tt) for tt in test_tensors], outputs=based_on.outputs) - - -class SpecGuesser: - @staticmethod - def guess_data_range(array: NDArray[Any]): - if np.issubdtype(array.dtype, np.floating) and array.min() >= 0.0 and array.max() <= 1.0: - return (0.0, 1.0) - else: - return (None, None) - - @classmethod - def guess_data_description(cls, test_tensor: FileSource): - try: - array: Union[Any, NDArray[Any]] = np.load(download(test_tensor).path) - if not isinstance(array, np.ndarray): - raise TypeError(f"Expected numpy array, but got {type(array)}") - except Exception as e: - warnings.warn(f"Could not guess data type of {test_tensor}: {e}") - return None - - dtype_str = str(array.dtype) - dtype: IntervalOrRatioDType = dtype_str # type: ignore # validated by IntervalOrRatioData - return IntervalOrRatioData( - type=dtype, range=cls.guess_data_range(array), unit="arbitrary unit", scale=1.0, offset=None - ) - - -class SpecBuilderWithGuesses(SpecBuilder, SpecGuesser): - # def __init__(self, output_path: Path) -> None: - # super().__init__(output_path) - - def build_input_tensor( - self, - *, - test_tensor: FileSource, - axes: Sequence[InputAxis], - preprocessing: Sequence[Preprocessing], - id_: TensorId, - data: Optional[TensorData] = None, - sample_tensor: FileSource | None, - ) -> InputTensor: - return super().build_input_tensor( - test_tensor=test_tensor, - axes=axes, - preprocessing=preprocessing, - id_=id_, - data=data or self.guess_data_description(test_tensor), - sample_tensor=sample_tensor, - ) - - -def build_spec_interactively(output_path: Path): - guesser = SpecGuesser(output_path) - builder = SpecBuilder(output_path) - - -class _CoreOutputTensor(OutputTensor, _CoreTensorMixin): - @classmethod - def build(cls, **kwargs: Unpack[_OutputTensorKwargs]): - return cls( - test_tensor=ensure_file_in_folder(kwargs["test_tensor"], kwargs["output_path"]), - id=kwargs.get("id") or TensorId("output"), - axes=tuple(kwargs["axes"]), - postprocessing=tuple(kwargs.get("postprocessing", ())), - data=cls.get_data_description(kwargs), - ) - - -class CoreModelBaseKwargs(_CoreGenericBaseKwargs): - inputs: NotEmpty[Sequence[_InputTensorKwargs]] - outputs: NotEmpty[Sequence[_OutputTensorKwargs]] - - -class CoreModelKwargs(CoreModelBaseKwargs): - weights: Weights - - -class _CoreModel(Model): - @classmethod - def build(cls, **kwargs: Unpack[CoreModelKwargs]) -> Self: - documentation = ensure_file_in_folder(kwargs["documentation"], kwargs["output_path"]) - - inputs = tuple( - _CoreInputTensor.build( - id=t_kwargs["id"] if "id" in t_kwargs else TensorId(f"input{i}"), - test_tensor=t_kwargs["test_tensor"], - axes=t_kwargs["axes"], - data=t_kwargs.get("data"), - output_path=kwargs["output_path"], - ) - for i, t_kwargs in enumerate(kwargs["inputs"]) - ) - - outputs = tuple( - _CoreOutputTensor.build( - id=t_kwargs["id"] if "id" in t_kwargs else TensorId(f"output{i}"), - test_tensor=t_kwargs["test_tensor"], - axes=t_kwargs["axes"], - data=t_kwargs.get("data"), - output_path=kwargs["output_path"], - ) - for i, t_kwargs in enumerate(kwargs["outputs"]) - ) - - return cls( - name=kwargs["name"], - description=kwargs["description"], - authors=tuple(kwargs["authors"]), - maintainers=tuple(kwargs.get("maintainers", ())), - cite=tuple(kwargs["cite"]), - license=kwargs["license"], - timestamp=datetime.now(), - inputs=inputs, - outputs=outputs, - weights=kwargs["weights"], - documentation=documentation, - ) - - @classmethod - def build_from_pytorch_state_dict( - cls, - weights: FileSource, - architecture: Architecture, - sha256: Optional[Sha256] = None, - pytorch_version: Optional[Version] = None, - dependencies: Optional[Dependencies] = None, - **kwargs: Unpack[CoreModelBaseKwargs], - ): - if pytorch_version is None: - import torch - - pytorch_version = Version(torch.__version__) - - return cls.build( - weights=Weights( - pytorch_state_dict=PytorchStateDictWeights( - source=ensure_file_in_folder(weights, kwargs["output_path"]), - sha256=sha256, - architecture=architecture, - pytorch_version=pytorch_version, - dependencies=dependencies, - ) - ), - **kwargs, - ) - - -def _build_spec_common(core_descr: _CoreModel, descr_path: Path, expected_type: Type[Any]): - write_description(core_descr, descr_path) - loaded = load_description_and_validate(descr_path) - if type(loaded) is not expected_type: - raise RuntimeError(f"Created {descr_path} was loaded as {type(loaded)}, but expected {expected_type}") - - return descr_path, loaded - - -def build_model_spec( - *, - weights: FileSource, - architecture: Architecture, - sha256: Optional[Sha256] = None, - pytorch_version: Optional[Version] = None, - dependencies: Optional[Dependencies] = None, - **kwargs: Unpack[CoreModelBaseKwargs], -): - model = _CoreModel.build_from_pytorch_state_dict( - weights=weights, - architecture=architecture, - sha256=sha256, - pytorch_version=pytorch_version, - dependencies=dependencies, - **kwargs, - ) - - return _build_spec_common(model, kwargs["output_path"] / "description.bioimageio.yaml", Model) From a22738880be48960b961e2949f31b651fe1bb955 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 27 Nov 2023 09:57:14 +0100 Subject: [PATCH 077/244] add numpy_load helper --- bioimageio/core/image_helper.py | 11 ++++++----- bioimageio/core/resource_tests.py | 10 +++++----- bioimageio/core/utils.py | 4 ++-- tests/prediction_pipeline/test_measures.py | 8 ++++---- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/bioimageio/core/image_helper.py b/bioimageio/core/image_helper.py index 8c25b832..6fc3fd24 100644 --- a/bioimageio/core/image_helper.py +++ b/bioimageio/core/image_helper.py @@ -9,10 +9,11 @@ from numpy.typing import NDArray from xarray import DataArray -from bioimageio.spec.model.v0_4 import InputTensor as InputTensor04 -from bioimageio.spec.model.v0_4 import OutputTensor as OutputTensor04 -from bioimageio.spec.model.v0_5 import InputTensor as InputTensor05 -from bioimageio.spec.model.v0_5 import OutputTensor as OutputTensor05 +from bioimageio.spec._internal.io_utils import load_array +from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensor04 +from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensor04 +from bioimageio.spec.model.v0_5 import InputTensorDescr as InputTensor05 +from bioimageio.spec.model.v0_5 import OutputTensorDescr as OutputTensor05 InputTensor = Union[InputTensor04, InputTensor05] OutputTensor = Union[OutputTensor04, OutputTensor05] @@ -103,7 +104,7 @@ def to_channel_last(image): def load_image(in_path, axes: Sequence[str]) -> DataArray: ext = os.path.splitext(in_path)[1] if ext == ".npy": - im = np.load(in_path) + im = load_array(in_path) else: is_volume = "z" in axes im = imageio.volread(in_path) if is_volume else imageio.imread(in_path) diff --git a/bioimageio/core/resource_tests.py b/bioimageio/core/resource_tests.py index dabc72ff..51d3dd53 100644 --- a/bioimageio/core/resource_tests.py +++ b/bioimageio/core/resource_tests.py @@ -9,7 +9,6 @@ import numpy import numpy as np import xarray as xr -from marshmallow import ValidationError from bioimageio.core import __version__ as bioimageio_core_version from bioimageio.core import load_raw_resource_description, load_resource_description @@ -25,6 +24,7 @@ ResourceDescription, ) from bioimageio.spec import __version__ as bioimageio_spec_version +from bioimageio.spec._internal.io_utils import load_array from bioimageio.spec.model.raw_nodes import WeightsFormat from bioimageio.spec.shared import resolve_source from bioimageio.spec.shared.common import ValidationWarning @@ -161,8 +161,8 @@ def _test_model_inference(model: Model, weight_format: str, devices: Optional[Li tb: Optional = None with warnings.catch_warnings(record=True) as all_warnings: try: - inputs = [np.load(str(in_path)) for in_path in model.test_inputs] - expected = [np.load(str(out_path)) for out_path in model.test_outputs] + inputs = [load_array(str(in_path)) for in_path in model.test_inputs] + expected = [load_array(str(out_path)) for out_path in model.test_outputs] assert len(inputs) == len(model.inputs) # should be checked by validation input_shapes = {} @@ -362,7 +362,7 @@ def debug_model( bioimageio_model=model, devices=devices, weight_format=weight_format ) inputs = [ - xr.DataArray(np.load(str(in_path)), dims=input_spec.axes) + xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) for in_path, input_spec in zip(model.test_inputs, model.inputs) ] input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} @@ -383,7 +383,7 @@ def debug_model( outputs = [outputs] expected = [ - xr.DataArray(np.load(str(out_path)), dims=output_spec.axes) + xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) for out_path, output_spec in zip(model.test_outputs, model.outputs) ] if len(outputs) != len(expected): diff --git a/bioimageio/core/utils.py b/bioimageio/core/utils.py index dd3532d7..46fff6db 100644 --- a/bioimageio/core/utils.py +++ b/bioimageio/core/utils.py @@ -7,7 +7,7 @@ from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import TensorId -from bioimageio.spec.utils import download +from bioimageio.spec.utils import download, load_array # @singledispatch # def is_valid_tensor(description: object, tensor: Union[NDArray[Any], xr.DataArray]) -> bool: @@ -24,7 +24,7 @@ def get_test_input_tensors(model: object) -> List[xr.DataArray]: @get_test_input_tensors.register def _(model: v0_4.Model): - data = [np.load(download(ipt).path) for ipt in model.test_inputs] + data = [load_array(download(ipt).path) for ipt in model.test_inputs] assert all(isinstance(d, np.ndarray) for d in data) diff --git a/tests/prediction_pipeline/test_measures.py b/tests/prediction_pipeline/test_measures.py index 6eeaa7fd..37916424 100644 --- a/tests/prediction_pipeline/test_measures.py +++ b/tests/prediction_pipeline/test_measures.py @@ -6,16 +6,16 @@ import pytest import xarray as xr -from bioimageio.core import statistical_measures +from bioimageio.core import stat_measures from bioimageio.core.prediction_pipeline._measure_groups import get_measure_groups from bioimageio.core.prediction_pipeline._utils import PER_DATASET, PER_SAMPLE -from bioimageio.core.statistical_measures import Mean, Percentile, Std, Var +from bioimageio.core.stat_measures import Mean, Percentile, Std, Var @pytest.mark.parametrize("name_axes", product(["mean", "var", "std"], [None, ("x", "y")])) def test_individual_normal_measure(name_axes): name, axes = name_axes - measure = getattr(statistical_measures, name.title())(axes=axes) + measure = getattr(stat_measures, name.title())(axes=axes) data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) expected = getattr(data, name)(dim=axes) @@ -26,7 +26,7 @@ def test_individual_normal_measure(name_axes): @pytest.mark.parametrize("axes_n", product([None, ("x", "y")], [0, 10, 50, 100])) def test_individual_percentile_measure(axes_n): axes, n = axes_n - measure = statistical_measures.Percentile(axes=axes, n=n) + measure = stat_measures.Percentile(axes=axes, n=n) data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) expected = data.quantile(q=n / 100, dim=axes) From 6ea6c853ed96bc6a311bfa44d8da570e276046c5 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 28 Nov 2023 15:34:24 +0100 Subject: [PATCH 078/244] WIP measures --- bioimageio/core/common.py | 26 +- bioimageio/core/proc_impl.py | 246 ++++++++++++------ bioimageio/core/proc_setup.py | 2 +- bioimageio/core/stat_calculators.py | 208 ++++++++------- bioimageio/core/stat_measures.py | 90 +++++-- bioimageio/core/stat_state.py | 6 +- bioimageio/core/utils.py | 32 --- .../weight_converter/torch/torchscript.py | 60 ++--- 8 files changed, 381 insertions(+), 289 deletions(-) diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 96c71592..73f9c4a9 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -5,7 +5,7 @@ from attr import dataclass from typing_extensions import Final -from bioimageio.core.stat_measures import Measure +from bioimageio.core.stat_measures import MeasureBase from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import TensorId @@ -19,27 +19,3 @@ PER_SAMPLE = "per_sample" PER_DATASET = "per_dataset" - - -MeasureVar = TypeVar("MeasureVar", bound=Measure) -ModeVar = TypeVar("ModeVar", Literal["per_sample"], Literal["per_dataset"]) - - -@dataclass(frozen=True) -class RequiredMeasure(Generic[MeasureVar, ModeVar]): - measure: MeasureVar - tensor_id: TensorId - mode: ModeVar - - -@dataclass(frozen=True) -class SampleMeasure(RequiredMeasure[MeasureVar, Literal["per_sample"]]): - pass - - -@dataclass(frozen=True) -class DatasetMeasure(RequiredMeasure[MeasureVar, Literal["per_dataset"]]): - pass - - -MeasureValue = xr.DataArray diff --git a/bioimageio/core/proc_impl.py b/bioimageio/core/proc_impl.py index d061a1f8..26de8cdf 100644 --- a/bioimageio/core/proc_impl.py +++ b/bioimageio/core/proc_impl.py @@ -3,6 +3,7 @@ from dataclasses import InitVar, dataclass, field, fields from types import MappingProxyType from typing import ( + Any, ClassVar, FrozenSet, Generic, @@ -23,13 +24,21 @@ import numpy as np import xarray as xr from numpy.typing import DTypeLike -from typing_extensions import LiteralString, assert_never - -from bioimageio.core.common import MeasureValue, ProcessingDescrBase, ProcessingKwargs, RequiredMeasure, Sample -from bioimageio.core.stat_measures import Mean, Percentile, Std +from typing_extensions import LiteralString, assert_never, Unpack + +from bioimageio.core.common import ( + AnyRequiredMeasure, + AxisId, + MeasureVar, + ProcessingDescrBase, + ProcessingKwargs, + RequiredMeasure, + Sample, +) +from bioimageio.core.stat_measures import Mean, MeasureValue, Percentile, Std from bioimageio.spec._internal.base_nodes import NodeWithExplicitlySetFields from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import NonBatchAxisId, TensorId +from bioimageio.spec.model.v0_5 import NonBatchAxisId, TensorId, BinarizeKwargs AssertProcessingId = Literal["assert_dtype"] @@ -48,19 +57,19 @@ class AssertDtype(AssertProcessingBase): kwargs: AssertDtypeKwargs -M = TypeVar("M", RequiredMeasure, MeasureValue) +M = TypeVar("M", AnyRequiredMeasure, MeasureValue) @dataclass -class NamedMeasures(Generic[M]): +class NamedMeasures: """Named Measures that specifies all required/computed measures of a Processing instance""" - def get_set(self) -> Set[M]: + def get_set(self) -> Set[RequiredMeasure[Any, Any]]: return {getattr(self, f.name) for f in fields(self)} # The two generics are conceptually a higher kinded generic -R = TypeVar("R", bound=NamedMeasures[RequiredMeasure]) +R = TypeVar("R", bound=NamedMeasures[RequiredMeasure[Any, Any]]) C = TypeVar("C", bound=NamedMeasures[MeasureValue]) @@ -68,92 +77,171 @@ def get_set(self) -> Set[M]: ProcInput = TypeVar("ProcInput", xr.DataArray, Sample) -@dataclass(frozen=True) -class ProcessingImplBase(Generic[PKwargs, R, C], ABC): - """Base class for all Pre- and Postprocessing implementations.""" +Tensor = xr.DataArray - tensor_id: TensorId - """id of tensor to operate on""" +@dataclass +class Operator(Generic[PKwargs], ABC): kwargs: PKwargs - computed_measures: InitVar[Mapping[RequiredMeasure, MeasureValue]] = field( - default=MappingProxyType[RequiredMeasure, MeasureValue]({}) - ) - assert type(R) is type(C), "R and C are conceptually a higher kindes generic, their class has to be identical" - required: R = field(init=False) - computed: C = field(init=False) - - def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None: - object.__setattr__(self, "required", self.get_required_measures(self.tensor_id, self.kwargs)) - selected = {} - for f in fields(self.required): - req = getattr(self.required, f.name) - if req in computed_measures: - selected[f.name] = computed_measures[req] - else: - raise ValueError(f"Missing computed measure: {req} (as '{f.name}').") - - object.__setattr__(self, "computed", self.required.__class__(**selected)) - + computed: @abstractmethod - @classmethod - def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMeasures[RequiredMeasure]: + def __call__(self) -> Tensor: ... - def __call__(self, __input: ProcInput, /) -> ProcInput: - if isinstance(__input, xr.DataArray): - return self.apply(__input) - else: - return self.apply_to_sample(__input) - @abstractmethod - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - """apply processing""" - ... +@dataclass +class Binarize(Operator): + threshold: float + + +# class Source(Operator): +# def __call__(self) -> Tensor: +# return Tensor() - def apply_to_sample(self, sample: Sample) -> Sample: - ret = dict(sample) - ret[self.tensor_id] = self.apply(sample[self.tensor_id]) - return ret + +@dataclass +class Smooth(Operator): + sigma: float + tensor_source: Source @abstractmethod - def get_descr(self) -> Union[ProcessingDescrBase, AssertProcessingBase]: + @classmethod + def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> R: ... + def __call__(self) -> Tensor: + return tensor * self.sigma # fixme -@dataclass(frozen=True) -class ProcessingImplBaseWoMeasures( - ProcessingImplBase[PKwargs, NamedMeasures[RequiredMeasure], NamedMeasures[MeasureValue]] -): - @classmethod - def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMeasures[RequiredMeasure]: - return NamedMeasures() +class Diff(Operator): + def __call__(self, a: Tensor, b: Tensor) -> Tensor: + return a - b -@dataclass(frozen=True) -class AssertDtypeImpl(ProcessingImplBaseWoMeasures[AssertDtypeKwargs]): - kwargs_class = AssertDtypeKwargs - _assert_with: Tuple[Type[DTypeLike], ...] = field(init=False) - - def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None: - super().__post_init__(computed_measures) - if isinstance(self.kwargs.dtype, str): - dtype = [self.kwargs.dtype] - else: - dtype = self.kwargs.dtype - object.__setattr__(self, "assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype)) - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - assert isinstance(tensor.dtype, self._assert_with) - return tensor - def get_descr(self): - return AssertDtype(kwargs=self.kwargs) -@dataclass(frozen=True) -class BinarizeImpl(ProcessingImplBaseWoMeasures[Union[v0_4.BinarizeKwargs, v0_5.BinarizeKwargs]]): +@dataclass +class SimpleOperator(Operator, ABC): + input_id: TensorId + output_id: TensorId + + def __call__(self, sample: Sample, /) -> Sample: + ret = dict(sample) + ret[self.output_id] = self.apply(sample[self.input_id]) + return ret + + + @abstractmethod + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + ... + +# @dataclass(frozen=True) +# class ProcessingImplBase(Generic[PKwargs, R, C], ABC): +# """Base class for all Pre- and Postprocessing implementations.""" + +# tensor_id: TensorId +# """id of tensor to operate on""" +# kwargs: PKwargs +# computed_measures: InitVar[Mapping[AnyRequiredMeasure, MeasureValue]] = field( +# default=MappingProxyType[AnyRequiredMeasure, MeasureValue]({}) +# ) +# assert type(R) is type(C), "R and C are conceptually a higher kindes generic, their class has to be identical" +# required: R = field(init=False) +# computed: C = field(init=False) + +# def __post_init__(self, computed_measures: Mapping[AnyRequiredMeasure, MeasureValue]) -> None: +# object.__setattr__(self, "required", self.get_required_measures(self.tensor_id, self.kwargs)) +# selected = {} +# for f in fields(self.required): +# req = getattr(self.required, f.name) +# if req in computed_measures: +# selected[f.name] = computed_measures[req] +# else: +# raise ValueError(f"Missing computed measure: {req} (as '{f.name}').") + +# object.__setattr__(self, "computed", self.required.__class__(**selected)) + +# @abstractmethod +# @classmethod +# def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> R: +# ... + +# def __call__(self, __input: ProcInput, /) -> ProcInput: +# if isinstance(__input, xr.DataArray): +# return self.apply(__input) +# else: +# return self.apply_to_sample(__input) + +# @abstractmethod +# def apply(self, tensor: xr.DataArray) -> xr.DataArray: +# """apply processing""" +# ... + +# def apply_to_sample(self, sample: Sample) -> Sample: +# ret = dict(sample) +# ret[self.tensor_id] = self.apply(sample[self.tensor_id]) +# return ret + +# @abstractmethod +# def get_descr(self) -> Union[ProcessingDescrBase, AssertProcessingBase]: +# ... + + +# @dataclass(frozen=True) +# class ProcessingImplBaseWoMeasures( +# ProcessingImplBase[PKwargs, NamedMeasures[AnyRequiredMeasure], NamedMeasures[MeasureValue]] +# ): +# @classmethod +# def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMeasures[AnyRequiredMeasure]: +# return NamedMeasures() + + +# @dataclass(frozen=True) +# class AssertDtypeImpl(ProcessingImplBaseWoMeasures[AssertDtypeKwargs]): +# kwargs_class = AssertDtypeKwargs +# _assert_with: Tuple[Type[DTypeLike], ...] = field(init=False) + +# def __post_init__(self, computed_measures: Mapping[AnyRequiredMeasure, MeasureValue]) -> None: +# super().__post_init__(computed_measures) +# if isinstance(self.kwargs.dtype, str): +# dtype = [self.kwargs.dtype] +# else: +# dtype = self.kwargs.dtype + +# object.__setattr__(self, "assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype)) + +# def apply(self, tensor: xr.DataArray) -> xr.DataArray: +# assert isinstance(tensor.dtype, self._assert_with) +# return tensor + +# def get_descr(self): +# return AssertDtype(kwargs=self.kwargs) + +# class AssertDtype(Operator): +# dtype: Sequence[DTypeLike] +# _assert_with: Tuple[Type[DTypeLike], ...] = field(init=False) + +# def __post_init__(self, computed_measures: Mapping[AnyRequiredMeasure, MeasureValue]) -> None: +# super().__post_init__(computed_measures) +# if isinstance(self.kwargs.dtype, str): +# dtype = [self.kwargs.dtype] +# else: +# dtype = self.kwargs.dtype + +# object.__setattr__(self, "assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype)) + +# def apply(self, tensor: xr.DataArray) -> xr.DataArray: +# assert isinstance(tensor.dtype, self._assert_with) +# return tensor + +# def get_descr(self): +# return AssertDtype(kwargs=self.kwargs) + +@dataclass +class BinarizeImpl(Operator): """'output = tensor > threshold'.""" + threshold: float def apply(self, tensor: xr.DataArray) -> xr.DataArray: return tensor > self.kwargs.threshold @@ -237,7 +325,7 @@ class NamedMeasuresScaleMeanVariance(NamedMeasures[M]): class ScaleMeanVarianceImpl( ProcessingImplBase[ Union[v0_4.ScaleMeanVarianceKwargs, v0_5.ScaleMeanVarianceKwargs], - NamedMeasuresScaleMeanVariance[RequiredMeasure], + NamedMeasuresScaleMeanVariance[AnyRequiredMeasure], NamedMeasuresScaleMeanVariance[MeasureValue], ] ): @@ -248,7 +336,7 @@ def get_required_measures( if kwargs.axes is None: axes = None elif isinstance(kwargs.axes, str): - axes = tuple(NonBatchAxisId(a) for a in kwargs.axes) + axes = tuple(AxisId(a) for a in kwargs.axes) elif isinstance(kwargs.axes, collections.abc.Sequence): # pyright: ignore[reportUnnecessaryIsInstance] axes = tuple(kwargs.axes) else: @@ -283,14 +371,14 @@ class NamedMeasuresScaleRange(NamedMeasures[M]): class ScaleRangeImpl( ProcessingImplBase[ Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs], - NamedMeasuresScaleRange[RequiredMeasure], + NamedMeasuresScaleRange[RequiredMeasure[Percentile, Any]], NamedMeasuresScaleRange[MeasureValue], ] ): @classmethod def get_required_measures(cls, tensor_id: TensorId, kwargs: Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs]): ref_name = kwargs.reference_tensor or tensor_id - axes = None if kwargs.axes is None else tuple(NonBatchAxisId(a) for a in kwargs.axes) + axes = None if kwargs.axes is None else tuple(AxisId(a) for a in kwargs.axes) return NamedMeasuresScaleRange( lower=RequiredMeasure(Percentile(kwargs.min_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode), upper=RequiredMeasure(Percentile(kwargs.max_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode), diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index a66c46aa..e0079b89 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -35,7 +35,7 @@ class _SetupProcessing(NamedTuple): def setup_pre_and_postprocessing(model: ModelDescr, dataset: Iterator[Sample]) -> _SetupProcessing: Prepared = List[Tuple[Type[ProcessingImplBase[Any, Any, Any]], ProcessingKwargs, TensorId]] - required_measures: Set[RequiredMeasure] = set() + required_measures: Set[RequiredMeasure[Any, Any]] = set() def prepare_procs(tensor_descrs: Sequence[TensorDescr]): prepared: Prepared = [] diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index f31bcfe0..d465cd6c 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -25,6 +25,7 @@ Tuple, Type, Union, + assert_never, ) import numpy as np @@ -35,14 +36,27 @@ PER_DATASET, PER_SAMPLE, AxisId, - DatasetMeasure, - MeasureVar, - RequiredMeasure, Sample, - SampleMeasure, TensorId, ) -from bioimageio.core.stat_measures import Mean, Measure, Percentile, Std, Var +from bioimageio.core.stat_measures import ( + DatasetMean, + DatasetMeasureBase, + DatasetMeasureVar, + DatasetPercentile, + DatasetStd, + DatasetVar, + Measure, + MeasureVar, + Percentile, + SampleMean, + SampleMeasureBase, + SamplePercentile, + SampleStd, + SampleVar, + Std, + Var, +) try: import crick # type: ignore @@ -52,29 +66,29 @@ MeasureValue = Union[xr.DataArray, float] -class SampleMeasureCalculator(ABC, Generic[MeasureVar]): - """group of measures for more efficient computation of multiple measures per sample""" +# class SampleMeasureCalculator(ABC): +# """group of measures for more efficient computation of multiple measures per sample""" - @abstractmethod - def compute(self, sample: Sample) -> Mapping[SampleMeasure[MeasureVar], MeasureValue]: - ... +# @abstractmethod +# def compute(self, sample: Sample) -> Mapping[SampleMeasure, MeasureValue]: +# ... -class DatasetMeasureCalculator(ABC, Generic[MeasureVar]): - """group of measures for more efficient computation of multiple measures per dataset""" +# class DatasetMeasureCalculator(ABC): +# """group of measures for more efficient computation of multiple measures per dataset""" - @abstractmethod - def update_with_sample(self, sample: Sample) -> None: - """update intermediate representation with a data sample""" - ... +# @abstractmethod +# def update_with_sample(self, sample: Sample) -> None: +# """update intermediate representation with a data sample""" +# ... - @abstractmethod - def finalize(self) -> Mapping[DatasetMeasure[MeasureVar], MeasureValue]: - """compute statistics from intermediate representation""" - ... +# @abstractmethod +# def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: +# """compute statistics from intermediate representation""" +# ... -class MeanCalculator(SampleMeasureCalculator[Mean], DatasetMeasureCalculator[Mean]): +class MeanCalculator: def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): super().__init__() self._axes = None if axes is None else tuple(axes) @@ -83,11 +97,8 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): self._mean: Optional[xr.DataArray] = None def compute(self, sample: Sample): - return { - SampleMeasure(measure=Mean(axes=self._axes), tensor_id=self._tensor_id): sample[self._tensor_id].mean( - dim=self._axes - ) - } + mean = SampleMean(axes=self._axes, tensor_id=self._tensor_id) + return {mean: mean.compute(sample)} def update_with_sample(self, sample: Sample): tensor = sample[self._tensor_id].astype(np.float64, copy=False) @@ -106,14 +117,14 @@ def update_with_sample(self, sample: Sample): self._mean = (n_a * mean_a + n_b * mean_b) / n assert self._mean.dtype == np.float64 - def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: + def finalize(self) -> Mapping[DatasetMeasureBase, MeasureValue]: if self._mean is None: return {} else: - return {DatasetMeasure(measure=Mean(axes=self._axes), tensor_id=self._tensor_id): self._mean} + return {DatasetMean(axes=self._axes, tensor_id=self._tensor_id): self._mean} -class MeanVarStdCalculator(SampleMeasureCalculator, DatasetMeasureCalculator): +class MeanVarStdCalculator: def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): super().__init__() self._axes = None if axes is None else tuple(axes) @@ -134,9 +145,9 @@ def compute(self, sample: Sample): var = xr.dot(c, c, dims=self._axes) / n std = np.sqrt(var) return { - SampleMeasure(Mean(axes=self._axes), tensor_id=self._tensor_id): mean, - SampleMeasure(Var(axes=self._axes), tensor_id=self._tensor_id): var, - SampleMeasure(Std(axes=self._axes), tensor_id=self._tensor_id): std, + SampleMean(axes=self._axes, tensor_id=self._tensor_id): mean, + SampleVar(axes=self._axes, tensor_id=self._tensor_id): var, + SampleStd(axes=self._axes, tensor_id=self._tensor_id): std, } def update_with_sample(self, sample: Sample): @@ -163,7 +174,7 @@ def update_with_sample(self, sample: Sample): self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n assert self._m2.dtype == np.float64 - def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: + def finalize(self) -> Mapping[DatasetMeasureBase, MeasureValue]: if self._mean is None: return {} else: @@ -171,13 +182,13 @@ def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: var = self._m2 / self._n sqrt: xr.DataArray = np.sqrt(var) # type: ignore return { - DatasetMeasure(tensor_id=self._tensor_id, measure=Mean(axes=self._axes)): self._mean, - DatasetMeasure(tensor_id=self._tensor_id, measure=Var(axes=self._axes)): var, - DatasetMeasure(tensor_id=self._tensor_id, measure=Std(axes=self._axes)): sqrt, + DatasetMean(tensor_id=self._tensor_id, axes=self._axes): self._mean, + DatasetVar(tensor_id=self._tensor_id, axes=self._axes): var, + DatasetStd(tensor_id=self._tensor_id, axes=self._axes): sqrt, } -class SamplePercentilesCalculator(SampleMeasureCalculator): +class SamplePercentilesCalculator: def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Sequence[float]): super().__init__() assert all(0 <= n <= 100 for n in ns) @@ -189,13 +200,10 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Se def compute(self, sample: Sample): tensor = sample[self._tensor_id] ps = tensor.quantile(self._qs, dim=self._axes) # type: ignore - return { - SampleMeasure(measure=Percentile(n=n, axes=self._axes), tensor_id=self._tensor_id): p - for n, p in zip(self.ns, ps) - } + return {SamplePercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): p for n, p in zip(self.ns, ps)} -class MeanPercentilesCalculator(DatasetMeasureCalculator): +class MeanPercentilesCalculator: def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Sequence[float]): super().__init__() assert all(0 <= n <= 100 for n in ns) @@ -222,18 +230,18 @@ def update_with_sample(self, sample: Sample): self._n += n - def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: + def finalize(self) -> Mapping[DatasetPercentile, MeasureValue]: if self._estimates is None: return {} else: warnings.warn("Computed dataset percentiles naively by averaging percentiles of samples.") return { - DatasetMeasure(measure=Percentile(n=n, axes=self._axes), tensor_id=self._tensor_id): e + DatasetPercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): e for n, e in zip(self._ns, self._estimates) } -class CrickPercentilesCalculator(DatasetMeasureCalculator): +class CrickPercentilesCalculator: def __init__(self, tensor_name: TensorId, axes: Optional[Sequence[AxisId]], ns: Sequence[float]): warnings.warn("Computing dataset percentiles with experimental 'crick' library.") super().__init__() @@ -273,16 +281,14 @@ def update_with_sample(self, sample: Sample): for i, idx in enumerate(self._indices): self._digest[i].update(tensor.isel(dict(zip(self._dims[1:], idx)))) - def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: + def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: if self._digest is None: return {} else: assert self._dims is not None vs: NDArray[Any] = np.asarray([[d.quantile(q) for d in self._digest] for q in self._qs]).reshape(self._shape) # type: ignore return { - DatasetMeasure(measure=Percentile(n=n, axes=self._axes), tensor_id=self._tensor_id): xr.DataArray( - v, dims=self._dims[1:] - ) + DatasetPercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): xr.DataArray(v, dims=self._dims[1:]) for n, v in zip(self._ns, vs) } @@ -295,24 +301,26 @@ def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: DatasetPercentileCalculator = CrickPercentilesCalculator -class NaivSampleMeasureCalculator(SampleMeasureCalculator): - """wrapper for measures to match interface of SampleMeasureGroup""" +class NaivSampleMeasureCalculator: + """wrapper for measures to match interface of other sample measure calculators""" - def __init__(self, tensor_id: TensorId, measure: Measure): + def __init__(self, tensor_id: TensorId, measure: SampleMeasureBase): super().__init__() self.tensor_name = tensor_id self.measure = measure - def compute(self, sample: Sample) -> Mapping[SampleMeasure, MeasureValue]: - return { - SampleMeasure(measure=self.measure, tensor_id=self.tensor_name): self.measure.compute( - sample[self.tensor_name] - ) - } + def compute(self, sample: Sample) -> Mapping[SampleMeasureBase, MeasureValue]: + return {self.measure: self.measure.compute(sample)} + + +SampleMeasureCalculator = Union[ + MeanCalculator, MeanVarStdCalculator, SamplePercentilesCalculator, NaivSampleMeasureCalculator +] +DatasetMeasureCalculator = Union[MeanCalculator, MeanVarStdCalculator, DatasetPercentileCalculator] def get_measure_calculators( - required_measures: Iterable[RequiredMeasure], + required_measures: Iterable[Measure], ) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: """determines which calculators are needed to compute the required measures efficiently""" @@ -320,50 +328,58 @@ def get_measure_calculators( dataset_calculators: List[DatasetMeasureCalculator] = [] # split required measures into groups - required_means: Set[RequiredMeasure] = set() - required_mean_var_std: Set[RequiredMeasure] = set() - required_percentiles: Set[RequiredMeasure] = set() + required_sample_means: Set[SampleMean] = set() + required_dataset_means: Set[DatasetMean] = set() + required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() + required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = set() + required_sample_percentiles: Set[SamplePercentile] = set() + required_dataset_percentiles: Set[DatasetPercentile] = set() for rm in required_measures: - if isinstance(rm.measure, Mean): - required_means.add(rm) - elif isinstance(rm.measure, (Var, Std)): - required_mean_var_std.update( - { - RequiredMeasure(measure=msv(rm.measure.axes), tensor_id=rm.tensor_id, mode=rm.mode) - for msv in (Mean, Std, Var) - } + if isinstance(rm, SampleMean): + required_sample_means.add(rm) + elif isinstance(rm, DatasetMean): + required_dataset_means.add(rm) + elif isinstance(rm, (SampleVar, SampleStd)): + required_sample_mean_var_std.update( + {msv(axes=rm.axes, tensor_id=rm.tensor_id) for msv in (SampleMean, SampleStd, SampleVar)} + ) + assert rm in required_sample_mean_var_std + elif isinstance(rm, (DatasetVar, DatasetStd)): + required_dataset_mean_var_std.update( + {msv(axes=rm.axes, tensor_id=rm.tensor_id) for msv in (DatasetMean, DatasetStd, DatasetVar)} ) - assert rm in required_mean_var_std - elif isinstance(rm.measure, Percentile): - required_percentiles.add(rm) - elif rm.mode == PER_SAMPLE: - sample_calculators.append(NaivSampleMeasureCalculator(tensor_id=rm.tensor_id, measure=rm.measure)) + assert rm in required_dataset_mean_var_std + elif isinstance(rm, SamplePercentile): + required_sample_percentiles.add(rm) + elif isinstance(rm, DatasetPercentile): # pyright: ignore[reportUnnecessaryIsInstance] + required_dataset_percentiles.add(rm) else: - raise NotImplementedError(f"Computing statistics for {rm.measure} {rm.mode} not yet implemented") + assert_never(rm) + + for rm in required_sample_means: + if rm in required_sample_mean_var_std: + # computed togehter with var and std + continue + + sample_calculators.append(MeanCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) + + for rm in required_sample_mean_var_std: + sample_calculators.append(MeanVarStdCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) - for rm in required_means: - if rm in required_mean_var_std: + for rm in required_dataset_means: + if rm in required_dataset_mean_var_std: # computed togehter with var and std continue - if rm.mode == PER_SAMPLE: - sample_calculators.append(MeanCalculator(tensor_id=rm.tensor_id, axes=rm.measure.axes)) - # add all mean measures that are not included in a mean/var/std group - for tn, m in means: - if (tn, m.axes) not in required_mean_var_std: - # compute only mean - if mode == PER_SAMPLE: - calculators[mode].append(NaivSampleMeasureCalculator(tensor_id=tn, measure=m)) - elif mode == PER_DATASET: - calculators[mode].append(DatasetMeanCalculator(tensor_id=tn, axes=m.axes)) - else: - raise NotImplementedError(mode) - - for tn, axes in mean_var_std_groups: - calculators[mode].append(MeanVarStdCalculator(tensor_id=tn, axes=axes)) - - for (tn, axes), ns in required_percentiles.items(): + dataset_calculators.append(MeanCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) + + for rm in required_dataset_mean_var_std: + dataset_calculators.append(MeanVarStdCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) + + for rm in required_sample_percentiles: + sample_calculators.append(SamplePercentilesCalculator(tensor_id=rm.tensor_id, axes=axes)) + for (tn, axes), ns in required_sample_percentiles.items(): if mode == PER_SAMPLE: calculators[mode].append(SamplePercentilesCalculator(tensor_id=tn, axes=axes, ns=ns)) elif mode == PER_DATASET: @@ -375,7 +391,7 @@ def get_measure_calculators( def compute_measures( - measures: RequiredMeasures, *, sample: Optional[Sample] = None, dataset: Iterator[Sample] = () + measures: Set[Measure], *, sample: Optional[Sample] = None, dataset: Iterator[Sample] = () ) -> ComputedMeasures: ms_groups = get_measure_calculators(measures) ret = {PER_SAMPLE: {}, PER_DATASET: {}} diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index 29d6857a..6f4f3aa9 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -2,48 +2,84 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, TypeVar, Union import xarray as xr -from bioimageio.core.common import MeasureValue -from bioimageio.spec.model.v0_5 import AxisId +from bioimageio.core.common import Sample +from bioimageio.spec.model.v0_5 import AxisId, TensorId + +MeasureValue = Union[float, xr.DataArray] + + +@dataclass(frozen=True) +class MeasureBase(ABC): + tensor_id: TensorId @dataclass(frozen=True) -class Measure(ABC): +class SampleMeasureBase(MeasureBase, ABC): @abstractmethod - def compute(self, tensor: xr.DataArray) -> MeasureValue: - """compute the measure (and also associated other Measures)""" + def compute(self, sample: Sample) -> MeasureValue: + """compute the measure""" ... @dataclass(frozen=True) -class Mean(Measure): +class DatasetMeasureBase(MeasureBase, ABC): + pass + + +@dataclass(frozen=True) +class _Mean(MeasureBase): axes: Optional[Tuple[AxisId, ...]] = None - def compute(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.mean(dim=self.axes) + +@dataclass(frozen=True) +class SampleMean(_Mean, SampleMeasureBase): + def compute(self, sample: Sample) -> MeasureValue: + return sample[self.tensor_id].mean(dim=self.axes) + + +@dataclass(frozen=True) +class DatasetMean(_Mean, DatasetMeasureBase): + pass @dataclass(frozen=True) -class Std(Measure): +class _Std(MeasureBase): axes: Optional[Tuple[AxisId, ...]] = None - def compute(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.std(dim=self.axes) + +@dataclass(frozen=True) +class SampleStd(_Std, SampleMeasureBase): + def compute(self, sample: Sample) -> MeasureValue: + return sample[self.tensor_id].std(dim=self.axes) + + +@dataclass(frozen=True) +class DatasetStd(_Std, DatasetMeasureBase): + pass @dataclass(frozen=True) -class Var(Measure): +class _Var(MeasureBase): axes: Optional[Tuple[AxisId, ...]] = None - def compute(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.var(dim=self.axes) + +@dataclass(frozen=True) +class SampleVar(_Var, SampleMeasureBase): + def compute(self, sample: Sample) -> MeasureValue: + return sample[self.tensor_id].var(dim=self.axes) + + +@dataclass(frozen=True) +class DatasetVar(_Var, DatasetMeasureBase): + pass @dataclass(frozen=True) -class Percentile(Measure): +class _Percentile(MeasureBase): n: float axes: Optional[Tuple[AxisId, ...]] = None @@ -51,5 +87,23 @@ def __post_init__(self): assert self.n >= 0 assert self.n <= 100 - def compute(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.quantile(self.n / 100.0, dim=self.axes) + +@dataclass(frozen=True) +class SamplePercentile(_Percentile, SampleMeasureBase): + def compute(self, sample: Sample) -> MeasureValue: + return sample[self.tensor_id].tensor.quantile(self.n / 100.0, dim=self.axes) + + +@dataclass(frozen=True) +class DatasetPercentile(_Percentile, DatasetMeasureBase): + pass + + +SampleMeasure = Union[SampleMean, SampleStd, SampleVar, SamplePercentile] +DatasetMeasure = Union[DatasetMean, DatasetStd, DatasetVar, DatasetPercentile] +Measure = Union[SampleMeasure, DatasetMeasure] + +# MeasureVar = TypeVar("MeasureVar", bound=MeasureBase) +# SampleMeasureVar = TypeVar("SampleMeasureVar", bound=SampleMeasureBase) +# DatasetMeasureVar = TypeVar("DatasetMeasureVar", bound=DatasetMeasureBase) +# ModeVar = TypeVar("ModeVar", bound=Literal["per_sample", "per_dataset"]) diff --git a/bioimageio/core/stat_state.py b/bioimageio/core/stat_state.py index 107383be..24f062c9 100644 --- a/bioimageio/core/stat_state.py +++ b/bioimageio/core/stat_state.py @@ -3,9 +3,9 @@ from tqdm import tqdm -from bioimageio.core.common import PER_DATASET, PER_SAMPLE, MeasureValue, RequiredMeasure, Sample, TensorId +from bioimageio.core.common import PER_DATASET, PER_SAMPLE, RequiredMeasure, Sample, TensorId from bioimageio.core.stat_calculators import MeasureGroups, MeasureValue, get_measure_calculators -from bioimageio.core.stat_measures import Measure +from bioimageio.core.stat_measures import MeasureBase, MeasureValue @dataclass @@ -15,7 +15,7 @@ class StatsState: required_measures: Iterable[RequiredMeasure] -def compute_statistics() +def compute_statistics(): dataset: Iterable[Sample] update_dataset_stats_after_n_samples: Optional[int] = None update_dataset_stats_for_n_samples: Union[int, float] = float("inf") diff --git a/bioimageio/core/utils.py b/bioimageio/core/utils.py index 46fff6db..e69de29b 100644 --- a/bioimageio/core/utils.py +++ b/bioimageio/core/utils.py @@ -1,32 +0,0 @@ -from functools import singledispatch -from typing import Any, Dict, List, Union - -import numpy as np -import xarray as xr -from numpy.typing import NDArray - -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import TensorId -from bioimageio.spec.utils import download, load_array - -# @singledispatch -# def is_valid_tensor(description: object, tensor: Union[NDArray[Any], xr.DataArray]) -> bool: -# raise NotImplementedError(type(description)) - -# is_valid_tensor.register -# def _(description: v0_4.InputTensor, tensor: Union[NDArray[Any], xr.DataArray]): - - -@singledispatch -def get_test_input_tensors(model: object) -> List[xr.DataArray]: - raise NotImplementedError(type(model)) - - -@get_test_input_tensors.register -def _(model: v0_4.Model): - data = [load_array(download(ipt).path) for ipt in model.test_inputs] - assert all(isinstance(d, np.ndarray) for d in data) - - -# @get_test_input_tensors.register -# def _(model: v0_5.Model): diff --git a/bioimageio/core/weight_converter/torch/torchscript.py b/bioimageio/core/weight_converter/torch/torchscript.py index bace789e..e01ac34f 100644 --- a/bioimageio/core/weight_converter/torch/torchscript.py +++ b/bioimageio/core/weight_converter/torch/torchscript.py @@ -1,22 +1,24 @@ -from typing import List, Sequence -from typing_extensions import Any, assert_never from pathlib import Path -from typing import Union +from typing import List, Sequence, Union import numpy as np import torch from numpy.testing import assert_array_almost_equal +from typing_extensions import Any, assert_never -from bioimageio.spec import load_description -from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec import load_description from bioimageio.spec.common import InvalidDescription +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import Version from bioimageio.spec.utils import download from .utils import load_model + # FIXME: remove Any -def _check_predictions(model: Any, scripted_model: Any, model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", input_data: Sequence[torch.Tensor]): +def _check_predictions( + model: Any, scripted_model: Any, model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", input_data: Sequence[torch.Tensor] +): def _check(input_: Sequence[torch.Tensor]) -> None: expected_tensors = model(*input_) if isinstance(expected_tensors, torch.Tensor): @@ -37,7 +39,7 @@ def _check(input_: Sequence[torch.Tensor]) -> None: _check(input_data) if len(model_spec.inputs) > 1: - return # FIXME: why don't we check multiple inputs? + return # FIXME: why don't we check multiple inputs? input_descr = model_spec.inputs[0] if isinstance(input_descr, v0_4.InputTensorDescr): @@ -57,7 +59,7 @@ def _check(input_: Sequence[torch.Tensor]) -> None: step.append(0) elif isinstance(axis.size, (v0_5.AxisId, v0_5.TensorAxisId, type(None))): raise NotImplementedError(f"Can't verify inputs that don't specify their shape fully: {axis}") - elif isinstance(axis.size, v0_5.SizeReference): # pyright: ignore [reportUnnecessaryIsInstance] + elif isinstance(axis.size, v0_5.SizeReference): # pyright: ignore [reportUnnecessaryIsInstance] raise NotImplementedError(f"Can't handle axes like '{axis}' yet") else: assert_never(axis.size) @@ -74,36 +76,26 @@ def _check(input_: Sequence[torch.Tensor]) -> None: raise ValueError(f"Mismatched shapes: {this_shape}. Expected at least {min_shape}") _check(this_input) + def convert_weights_to_torchscript( - model_spec: Union[str, Path, v0_4.ModelDescr, v0_5.ModelDescr], output_path: Path, use_tracing: bool = True -): + model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr], output_path: Path, use_tracing: bool = True +) -> v0_5.TorchscriptWeightsDescr: """Convert model weights from format 'pytorch_state_dict' to 'torchscript'. Args: - model_spec: location of the resource for the input bioimageio model + model_descr: location of the resource for the input bioimageio model output_path: where to save the torchscript weights use_tracing: whether to use tracing or scripting to export the torchscript format """ - if isinstance(model_spec, (str, Path)): - loaded_spec = load_description(Path(model_spec)) - if isinstance(loaded_spec, InvalidDescription): - raise ValueError(f"Bad resource description: {loaded_spec}") - if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)): - raise TypeError(f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr") - model_spec = loaded_spec - - state_dict_weights_descr = model_spec.weights.pytorch_state_dict + + state_dict_weights_descr = model_descr.weights.pytorch_state_dict if state_dict_weights_descr is None: - raise ValueError(f"The provided model does not have weights in the pytorch state dict format") + raise ValueError("The provided model does not have weights in the pytorch state dict format") - with torch.no_grad(): - if isinstance(model_spec, v0_4.ModelDescr): - downloaded_test_inputs = [download(inp) for inp in model_spec.test_inputs] - else: - downloaded_test_inputs = [inp.test_tensor.download() for inp in model_spec.inputs] + input_data = model_descr.get_input_test_arrays() - input_data = [np.load(dl.path).astype("float32") for dl in downloaded_test_inputs] - input_data = [torch.from_numpy(inp) for inp in input_data] + with torch.no_grad(): + input_data = [torch.from_numpy(inp.astype("float32")) for inp in input_data] model = load_model(state_dict_weights_descr) @@ -113,13 +105,11 @@ def convert_weights_to_torchscript( else: scripted_model: Any = torch.jit.script(model) - ret = _check_predictions( - model=model, - scripted_model=scripted_model, - model_spec=model_spec, - input_data=input_data - ) + _check_predictions(model=model, scripted_model=scripted_model, model_spec=model_descr, input_data=input_data) # save the torchscript model scripted_model.save(str(output_path)) # does not support Path, so need to cast to str - return ret + + return v0_5.TorchscriptWeightsDescr( + source=output_path, pytorch_version=Version(torch.__version__), parent="pytorch_state_dict" + ) From 392a73aa47ac1774ab4d47c586b46fba13c6f1e2 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 14 Dec 2023 02:44:09 +0100 Subject: [PATCH 079/244] WIP overhaul Measure and Procs (as Ops) --- bioimageio/core/common.py | 23 +- bioimageio/core/op_base.py | 18 + bioimageio/core/proc_impl.py | 575 ---------------------------- bioimageio/core/proc_ops.py | 549 ++++++++++++++++++++++++++ bioimageio/core/proc_setup.py | 30 +- bioimageio/core/stat_calculators.py | 269 +++++++------ bioimageio/core/stat_measures.py | 70 ++-- bioimageio/core/stat_state.py | 100 ----- setup.py | 2 +- 9 files changed, 794 insertions(+), 842 deletions(-) create mode 100644 bioimageio/core/op_base.py delete mode 100644 bioimageio/core/proc_impl.py create mode 100644 bioimageio/core/proc_ops.py delete mode 100644 bioimageio/core/stat_state.py diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 73f9c4a9..db9b73fd 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,21 +1,26 @@ -from typing import Any, Dict, Generic, List, Literal, NamedTuple, TypeVar, Union +from dataclasses import field +from typing import Dict, Union -import numpy as np import xarray as xr from attr import dataclass -from typing_extensions import Final -from bioimageio.core.stat_measures import MeasureBase +from bioimageio.core.stat_measures import Measure, MeasureValue from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import TensorId TensorId = v0_5.TensorId AxisId = v0_5.AxisId -Sample = Dict[TensorId, xr.DataArray] +Tensor = xr.DataArray + +Data = Dict[TensorId, Tensor] +Stat = Dict[Measure, MeasureValue] + + +@dataclass +class Sample: + data: Data = field(default_factory=dict) + stat: Stat = field(default_factory=dict) + ProcessingDescrBase = Union[v0_4.ProcessingDescrBase, v0_5.ProcessingDescrBase] ProcessingKwargs = Union[v0_4.ProcessingKwargs, v0_5.ProcessingKwargs] - -PER_SAMPLE = "per_sample" -PER_DATASET = "per_dataset" diff --git a/bioimageio/core/op_base.py b/bioimageio/core/op_base.py new file mode 100644 index 00000000..2e872a19 --- /dev/null +++ b/bioimageio/core/op_base.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Collection + +from bioimageio.core.common import Sample +from bioimageio.core.stat_measures import Measure + + +@dataclass +class Operator(ABC): + @abstractmethod + def __call__(self, sample: Sample) -> None: + ... + + @property + @abstractmethod + def required_measures(self) -> Collection[Measure]: + ... diff --git a/bioimageio/core/proc_impl.py b/bioimageio/core/proc_impl.py deleted file mode 100644 index 26de8cdf..00000000 --- a/bioimageio/core/proc_impl.py +++ /dev/null @@ -1,575 +0,0 @@ -import collections.abc -from abc import ABC, abstractmethod -from dataclasses import InitVar, dataclass, field, fields -from types import MappingProxyType -from typing import ( - Any, - ClassVar, - FrozenSet, - Generic, - Hashable, - Literal, - Mapping, - Optional, - Sequence, - Set, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -import numpy -import numpy as np -import xarray as xr -from numpy.typing import DTypeLike -from typing_extensions import LiteralString, assert_never, Unpack - -from bioimageio.core.common import ( - AnyRequiredMeasure, - AxisId, - MeasureVar, - ProcessingDescrBase, - ProcessingKwargs, - RequiredMeasure, - Sample, -) -from bioimageio.core.stat_measures import Mean, MeasureValue, Percentile, Std -from bioimageio.spec._internal.base_nodes import NodeWithExplicitlySetFields -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import NonBatchAxisId, TensorId, BinarizeKwargs - -AssertProcessingId = Literal["assert_dtype"] - - -class AssertProcessingBase(NodeWithExplicitlySetFields): - id: AssertProcessingId - fields_to_set_explicitly: ClassVar[FrozenSet[LiteralString]] = frozenset({"id"}) - - -class AssertDtypeKwargs(v0_5.ProcessingKwargs): - dtype: Union[str, Sequence[str]] - - -class AssertDtype(AssertProcessingBase): - id: Literal["assert_dtype"] = "assert_dtype" - kwargs: AssertDtypeKwargs - - -M = TypeVar("M", AnyRequiredMeasure, MeasureValue) - - -@dataclass -class NamedMeasures: - """Named Measures that specifies all required/computed measures of a Processing instance""" - - def get_set(self) -> Set[RequiredMeasure[Any, Any]]: - return {getattr(self, f.name) for f in fields(self)} - - -# The two generics are conceptually a higher kinded generic -R = TypeVar("R", bound=NamedMeasures[RequiredMeasure[Any, Any]]) -C = TypeVar("C", bound=NamedMeasures[MeasureValue]) - - -PKwargs = TypeVar("PKwargs", bound=ProcessingKwargs) -ProcInput = TypeVar("ProcInput", xr.DataArray, Sample) - - -Tensor = xr.DataArray - -@dataclass -class Operator(Generic[PKwargs], ABC): - kwargs: PKwargs - computed: - @abstractmethod - def __call__(self) -> Tensor: - ... - - -@dataclass -class Binarize(Operator): - threshold: float - - -# class Source(Operator): -# def __call__(self) -> Tensor: -# return Tensor() - - -@dataclass -class Smooth(Operator): - sigma: float - tensor_source: Source - - @abstractmethod - @classmethod - def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> R: - ... - - def __call__(self) -> Tensor: - return tensor * self.sigma # fixme - - -class Diff(Operator): - def __call__(self, a: Tensor, b: Tensor) -> Tensor: - return a - b - - - - - - -@dataclass -class SimpleOperator(Operator, ABC): - input_id: TensorId - output_id: TensorId - - def __call__(self, sample: Sample, /) -> Sample: - ret = dict(sample) - ret[self.output_id] = self.apply(sample[self.input_id]) - return ret - - - @abstractmethod - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - ... - -# @dataclass(frozen=True) -# class ProcessingImplBase(Generic[PKwargs, R, C], ABC): -# """Base class for all Pre- and Postprocessing implementations.""" - -# tensor_id: TensorId -# """id of tensor to operate on""" -# kwargs: PKwargs -# computed_measures: InitVar[Mapping[AnyRequiredMeasure, MeasureValue]] = field( -# default=MappingProxyType[AnyRequiredMeasure, MeasureValue]({}) -# ) -# assert type(R) is type(C), "R and C are conceptually a higher kindes generic, their class has to be identical" -# required: R = field(init=False) -# computed: C = field(init=False) - -# def __post_init__(self, computed_measures: Mapping[AnyRequiredMeasure, MeasureValue]) -> None: -# object.__setattr__(self, "required", self.get_required_measures(self.tensor_id, self.kwargs)) -# selected = {} -# for f in fields(self.required): -# req = getattr(self.required, f.name) -# if req in computed_measures: -# selected[f.name] = computed_measures[req] -# else: -# raise ValueError(f"Missing computed measure: {req} (as '{f.name}').") - -# object.__setattr__(self, "computed", self.required.__class__(**selected)) - -# @abstractmethod -# @classmethod -# def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> R: -# ... - -# def __call__(self, __input: ProcInput, /) -> ProcInput: -# if isinstance(__input, xr.DataArray): -# return self.apply(__input) -# else: -# return self.apply_to_sample(__input) - -# @abstractmethod -# def apply(self, tensor: xr.DataArray) -> xr.DataArray: -# """apply processing""" -# ... - -# def apply_to_sample(self, sample: Sample) -> Sample: -# ret = dict(sample) -# ret[self.tensor_id] = self.apply(sample[self.tensor_id]) -# return ret - -# @abstractmethod -# def get_descr(self) -> Union[ProcessingDescrBase, AssertProcessingBase]: -# ... - - -# @dataclass(frozen=True) -# class ProcessingImplBaseWoMeasures( -# ProcessingImplBase[PKwargs, NamedMeasures[AnyRequiredMeasure], NamedMeasures[MeasureValue]] -# ): -# @classmethod -# def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMeasures[AnyRequiredMeasure]: -# return NamedMeasures() - - -# @dataclass(frozen=True) -# class AssertDtypeImpl(ProcessingImplBaseWoMeasures[AssertDtypeKwargs]): -# kwargs_class = AssertDtypeKwargs -# _assert_with: Tuple[Type[DTypeLike], ...] = field(init=False) - -# def __post_init__(self, computed_measures: Mapping[AnyRequiredMeasure, MeasureValue]) -> None: -# super().__post_init__(computed_measures) -# if isinstance(self.kwargs.dtype, str): -# dtype = [self.kwargs.dtype] -# else: -# dtype = self.kwargs.dtype - -# object.__setattr__(self, "assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype)) - -# def apply(self, tensor: xr.DataArray) -> xr.DataArray: -# assert isinstance(tensor.dtype, self._assert_with) -# return tensor - -# def get_descr(self): -# return AssertDtype(kwargs=self.kwargs) - -# class AssertDtype(Operator): -# dtype: Sequence[DTypeLike] -# _assert_with: Tuple[Type[DTypeLike], ...] = field(init=False) - -# def __post_init__(self, computed_measures: Mapping[AnyRequiredMeasure, MeasureValue]) -> None: -# super().__post_init__(computed_measures) -# if isinstance(self.kwargs.dtype, str): -# dtype = [self.kwargs.dtype] -# else: -# dtype = self.kwargs.dtype - -# object.__setattr__(self, "assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype)) - -# def apply(self, tensor: xr.DataArray) -> xr.DataArray: -# assert isinstance(tensor.dtype, self._assert_with) -# return tensor - -# def get_descr(self): -# return AssertDtype(kwargs=self.kwargs) - -@dataclass -class BinarizeImpl(Operator): - """'output = tensor > threshold'.""" - threshold: float - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor > self.kwargs.threshold - - def get_descr(self): - return v0_5.BinarizeDescr(kwargs=self.kwargs) - - -@dataclass(frozen=True) -class ClipImpl(ProcessingImplBaseWoMeasures[Union[v0_4.ClipKwargs, v0_5.ClipKwargs]]): - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.clip(min=self.kwargs.min, max=self.kwargs.max) - - def get_descr(self): - return v0_5.ClipDescr(kwargs=self.kwargs) - - -@dataclass(frozen=True) -class EnsureDtypeImpl(ProcessingImplBaseWoMeasures[v0_5.EnsureDtypeKwargs]): - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.astype(self.kwargs.dtype) - - def get_descr(self): - return v0_5.EnsureDtypeDescr(kwargs=self.kwargs) - - -class ScaleLinearImpl04(ProcessingImplBaseWoMeasures[Union[v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs]]): - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - axis = ( - self.kwargs.axis - if isinstance(self.kwargs, v0_5.ScaleLinearKwargs) - else _get_complement_axis(tensor, self.kwargs.axes) - ) - if axis: - gain = xr.DataArray(np.atleast_1d(self.kwargs.gain), dims=axis) - offset = xr.DataArray(np.atleast_1d(self.kwargs.offset), dims=axis) - else: - assert isinstance(self.kwargs.gain, (float, int)) or len(self.kwargs.gain) == 1 - gain = self.kwargs.gain if isinstance(self.kwargs.gain, (float, int)) else self.kwargs.gain[0] - assert isinstance(self.kwargs.offset, (float, int)) or len(self.kwargs.offset) == 1 - offset = self.kwargs.offset if isinstance(self.kwargs.offset, (float, int)) else self.kwargs.offset[0] - - return tensor * gain + offset - - -@dataclass(frozen=True) -class ScaleLinearImpl(ProcessingImplBaseWoMeasures[Union[v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs]]): - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - axis = ( - self.kwargs.axis - if isinstance(self.kwargs, v0_5.ScaleLinearKwargs) - else _get_complement_axis(tensor, self.kwargs.axes) - ) - if axis: - gain = xr.DataArray(np.atleast_1d(self.kwargs.gain), dims=axis) - offset = xr.DataArray(np.atleast_1d(self.kwargs.offset), dims=axis) - else: - assert isinstance(self.kwargs.gain, (float, int)) or len(self.kwargs.gain) == 1 - gain = self.kwargs.gain if isinstance(self.kwargs.gain, (float, int)) else self.kwargs.gain[0] - assert isinstance(self.kwargs.offset, (float, int)) or len(self.kwargs.offset) == 1 - offset = self.kwargs.offset if isinstance(self.kwargs.offset, (float, int)) else self.kwargs.offset[0] - - return tensor * gain + offset - - def get_descr(self): - if isinstance(self.kwargs, v0_4.ScaleLinearKwargs): - raise NotImplementedError - - return v0_5.ScaleLinearDescr(kwargs=self.kwargs) - - -@dataclass -class NamedMeasuresScaleMeanVariance(NamedMeasures[M]): - mean: M - std: M - ref_mean: M - ref_std: M - - -@dataclass(frozen=True) -class ScaleMeanVarianceImpl( - ProcessingImplBase[ - Union[v0_4.ScaleMeanVarianceKwargs, v0_5.ScaleMeanVarianceKwargs], - NamedMeasuresScaleMeanVariance[AnyRequiredMeasure], - NamedMeasuresScaleMeanVariance[MeasureValue], - ] -): - @classmethod - def get_required_measures( - cls, tensor_id: TensorId, kwargs: Union[v0_4.ScaleMeanVarianceKwargs, v0_5.ScaleMeanVarianceKwargs] - ): - if kwargs.axes is None: - axes = None - elif isinstance(kwargs.axes, str): - axes = tuple(AxisId(a) for a in kwargs.axes) - elif isinstance(kwargs.axes, collections.abc.Sequence): # pyright: ignore[reportUnnecessaryIsInstance] - axes = tuple(kwargs.axes) - else: - assert_never(kwargs.axes) - - return NamedMeasuresScaleMeanVariance( - mean=RequiredMeasure(Mean(axes), tensor_id, mode=kwargs.mode), - std=RequiredMeasure(Std(axes), tensor_id, mode=kwargs.mode), - ref_mean=RequiredMeasure(Mean(axes), cast(TensorId, kwargs.reference_tensor), mode=kwargs.mode), - ref_std=RequiredMeasure(Std(axes), cast(TensorId, kwargs.reference_tensor), mode=kwargs.mode), - ) - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - c = self.computed - eps = self.kwargs.eps - return (tensor - c.mean) / (c.std + eps) * (c.ref_std + eps) + c.ref_mean - - def get_descr(self): - if isinstance(self.kwargs, v0_4.ScaleMeanVarianceKwargs): - raise NotImplementedError - - return v0_5.ScaleMeanVarianceDescr(kwargs=self.kwargs) - - -@dataclass -class NamedMeasuresScaleRange(NamedMeasures[M]): - lower: M - upper: M - - -@dataclass(frozen=True) -class ScaleRangeImpl( - ProcessingImplBase[ - Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs], - NamedMeasuresScaleRange[RequiredMeasure[Percentile, Any]], - NamedMeasuresScaleRange[MeasureValue], - ] -): - @classmethod - def get_required_measures(cls, tensor_id: TensorId, kwargs: Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs]): - ref_name = kwargs.reference_tensor or tensor_id - axes = None if kwargs.axes is None else tuple(AxisId(a) for a in kwargs.axes) - return NamedMeasuresScaleRange( - lower=RequiredMeasure(Percentile(kwargs.min_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode), - upper=RequiredMeasure(Percentile(kwargs.max_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode), - ) - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - c = self.computed - return (tensor - c.lower) / (c.upper - c.lower + self.kwargs.eps) - - def get_descr(self): - if isinstance(self.kwargs, v0_4.ScaleRangeKwargs): - raise NotImplementedError - - return v0_5.ScaleRangeDescr(kwargs=self.kwargs) - - -@dataclass(frozen=True) -class SigmoidImpl(ProcessingImplBaseWoMeasures[v0_5.ProcessingKwargs]): - """1 / (1 + e^(-tensor)).""" - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return 1.0 / (1.0 + np.exp(-tensor)) # type: ignore - - def get_descr(self): - return v0_5.SigmoidDescr() - - -@dataclass -class NamedMeasuresZeroMeanUnitVariance(NamedMeasures[M]): - mean: M - std: M - - -@dataclass(frozen=True) -class ZeroMeanUnitVarianceImpl( - ProcessingImplBase[ - Union[v0_4.ZeroMeanUnitVarianceKwargs, v0_5.ZeroMeanUnitVarianceKwargs], - NamedMeasuresZeroMeanUnitVariance[RequiredMeasure], - NamedMeasuresZeroMeanUnitVariance[MeasureValue], - ] -): - """normalize to zero mean, unit variance.""" - - @classmethod - def get_required_measures( - cls, tensor_id: TensorId, kwargs: Union[v0_4.ZeroMeanUnitVarianceKwargs, v0_5.ZeroMeanUnitVarianceKwargs] - ): - axes = None if kwargs.axes is None else tuple(NonBatchAxisId(a) for a in kwargs.axes) - assert kwargs.mode != "fixed" # should use FixedZeroMeanUnitVarianceImpl - return NamedMeasuresZeroMeanUnitVariance( - mean=RequiredMeasure(Mean(axes=axes), tensor_id, kwargs.mode), - std=RequiredMeasure(Std(axes=axes), tensor_id, kwargs.mode), - ) - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - mean = self.computed.mean - std = self.computed.std - return (tensor - mean) / (std + self.kwargs.eps) - - def get_descr(self): - if isinstance(self.kwargs, v0_4.ZeroMeanUnitVarianceKwargs): - raise NotImplementedError - - return v0_5.ZeroMeanUnitVarianceDescr(kwargs=self.kwargs) - - -@dataclass(frozen=True) -class FixedZeroMeanUnitVarianceImpl( - ProcessingImplBaseWoMeasures[Union[v0_4.ZeroMeanUnitVarianceKwargs, v0_5.FixedZeroMeanUnitVarianceKwargs]] -): - """normalize to zero mean, unit variance with precomputed values.""" - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - if isinstance(self.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): - axis = self.kwargs.axis - elif isinstance(self.kwargs.mean, float) and isinstance(self.kwargs.std, float): - axis = None - else: - axis = _get_complement_axis(tensor, self.kwargs.axes) - - mean = xr.DataArray(self.kwargs.mean, dims=axis) - std = xr.DataArray(self.kwargs.std, dims=axis) - return (tensor - mean) / std - - def get_descr(self): - if isinstance(self.kwargs, v0_4.ZeroMeanUnitVarianceKwargs): - raise NotImplementedError - - return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=self.kwargs) - - -ProcDescr = Union[ - AssertDtype, v0_4.PreprocessingDescr, v0_4.PostprocessingDescr, v0_5.PreprocessingDescr, v0_5.PostprocessingDescr -] - -# get_impl_class which also returns the kwargs class -# def get_impl_class(proc_spec: ProcDescr): -# if isinstance(proc_spec, AssertDtype): -# return AssertDtypeImpl, AssertDtypeKwargs -# elif isinstance(proc_spec, v0_4.BinarizeDescr): -# return BinarizeImpl, v0_4.BinarizeKwargs -# elif isinstance(proc_spec, v0_5.BinarizeDescr): -# return BinarizeImpl, v0_5.BinarizeKwargs -# elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): -# return ClipImpl, v0_5.ClipKwargs -# elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): -# return EnsureDtypeImpl, v0_5.EnsureDtypeKwargs -# elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): -# return FixedZeroMeanUnitVarianceImpl, v0_5.FixedZeroMeanUnitVarianceKwargs -# elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): -# return ScaleLinearImpl, v0_5.ScaleLinearKwargs -# elif isinstance(proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)): -# return ScaleMeanVarianceImpl, v0_5.ScaleMeanVarianceKwargs -# elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): -# return ScaleRangeImpl, v0_5.ScaleRangeKwargs -# elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): -# return SigmoidImpl, v0_5.ProcessingKwargs -# elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) and proc_spec.kwargs.mode == "fixed": -# return FixedZeroMeanUnitVarianceImpl, v0_5.FixedZeroMeanUnitVarianceKwargs -# elif isinstance( -# proc_spec, # pyright: ignore[reportUnnecessaryIsInstance -# (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), -# ): -# return ZeroMeanUnitVarianceImpl, v0_5.ZeroMeanUnitVarianceKwargs -# else: -# assert_never(proc_spec) - -ProcessingImpl = Union[ - AssertDtypeImpl, - BinarizeImpl, - ClipImpl, - EnsureDtypeImpl, - FixedZeroMeanUnitVarianceImpl, - FixedZeroMeanUnitVarianceImpl, - ScaleLinearImpl, - ScaleMeanVarianceImpl, - ScaleRangeImpl, - SigmoidImpl, - ZeroMeanUnitVarianceImpl, -] - - -def get_impl_class(proc_spec: ProcDescr) -> Type[ProcessingImpl]: - if isinstance(proc_spec, AssertDtype): - return AssertDtypeImpl - elif isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): - return BinarizeImpl - elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): - return ClipImpl - elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): - return EnsureDtypeImpl - elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): - return FixedZeroMeanUnitVarianceImpl - elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): - return ScaleLinearImpl - elif isinstance(proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)): - return ScaleMeanVarianceImpl - elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): - return ScaleRangeImpl - elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): - return SigmoidImpl - elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) and proc_spec.kwargs.mode == "fixed": - return FixedZeroMeanUnitVarianceImpl - elif isinstance( - proc_spec, # pyright: ignore[reportUnnecessaryIsInstance] - (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), - ): - return ZeroMeanUnitVarianceImpl - else: - assert_never(proc_spec) - - -def _get_complement_axis(tensor: xr.DataArray, axes: Optional[Sequence[Hashable]]) -> Optional[Hashable]: - if axes is None: - return None - - v04_AXIS_TYPE_MAP = { - "b": "batch", - "t": "time", - "i": "index", - "c": "channel", - "x": "space", - "y": "space", - "z": "space", - } - converted_axes = [v04_AXIS_TYPE_MAP.get(a, a) for a in map(str, axes)] + ["batch"] - complement_axes = [a for a in tensor.dims if str(a) not in converted_axes] - if len(complement_axes) != 1: - raise ValueError( - f"Expected a single complement axis, but axes '{converted_axes}' (orignally '{axes}') " - f"for tensor dims '{tensor.dims}' leave '{complement_axes}'." - ) - - return complement_axes[0] diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py new file mode 100644 index 00000000..1faf1be2 --- /dev/null +++ b/bioimageio/core/proc_ops.py @@ -0,0 +1,549 @@ +import collections.abc +from abc import ABC, abstractmethod +from dataclasses import InitVar, dataclass, field +from typing import ( + Collection, + Hashable, + Literal, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, + cast, +) + +import numpy as np +import xarray as xr +from numpy.typing import DTypeLike +from typing_extensions import Self, assert_never + +from bioimageio.core.common import ( + AxisId, + Sample, + Stat, + Tensor, + TensorId, +) +from bioimageio.core.op_base import Operator +from bioimageio.core.stat_measures import ( + DatasetMean, + DatasetPercentile, + DatasetStd, + Measure, + MeasureValue, + SampleMean, + SamplePercentile, + SampleStd, +) +from bioimageio.spec.model import v0_4, v0_5 + + +def convert_axis_ids( + axes: Union[Sequence[AxisId], v0_4.AxesInCZYX], mode: Literal["per_sample", "per_dataset"] +) -> Tuple[AxisId, ...]: + if not isinstance(axes, str): + return tuple(axes) + + axis_map = dict(b=AxisId("batch"), c=AxisId("channel"), i=AxisId("index")) + if mode == "per_sample": + ret = [] + elif mode == "per_dataset": + ret = [AxisId("batch")] + else: + assert_never(mode) + + ret.extend([axis_map.get(a, AxisId(a)) for a in axes]) + return tuple(ret) + + +@dataclass +class _SimpleOperator(Operator, ABC): + input: TensorId + output: TensorId + + @property + def required_measures(self) -> Collection[Measure]: + return set() + + # @property + # def required_tensors(self) -> Set[TensorId]: + # return {self.input} + + # @property + # def produced_tensors(self) -> Set[TensorId]: + # return {self.output} + + def __call__(self, sample: Sample) -> None: + sample.data[self.output] = self._apply(sample.data[self.input], sample.stat) + + @abstractmethod + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + ... + + +@dataclass +class Dataset(Operator): + @property + def required_measures(self) -> Set[Measure]: + return set() + + +# @dataclass +# class AssertDtype(Operator): +# tensor: TensorId +# dtype: Union[Type[DTypeLike], Tuple[Type[DTypeLike], ...]] + +# @property +# def required_measures(self) -> Set[Measure]: +# return set() + +# def apply(self, tensor: Tensor) -> Tensor: +# assert isinstance(tensor.dtype, self.dtype) +# return tensor + + +@dataclass +class Binarize(_SimpleOperator): + """'output = tensor > threshold'.""" + + threshold: float + + def _apply(self, input: Tensor, stat: Stat) -> xr.DataArray: + return input > self.threshold + + # @classmethod + # def from_descr(cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr]): + # return cls(threshold=descr.kwargs.threshold) + + # def get_descr(self): + # return v0_5.BinarizeDescr(kwargs=v0_5.BinarizeKwargs(threshold=self.threshold)) + @classmethod + def from_proc_descr(cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], tensor_id: TensorId) -> Self: + return cls(input=tensor_id, output=tensor_id, threshold=descr.kwargs.threshold) + + +@dataclass +class Clip(_SimpleOperator): + min: Optional[float] = None + """minimum value for clipping""" + max: Optional[float] = None + """maximum value for clipping""" + + def __post_init__(self): + assert self.min is not None or self.max is not None, "missing min or max value" + assert ( + self.min is None or self.max is None or self.min < self.max + ), f"expected min < max, but {self.min} !< {self.max}" + + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + return input.clip(self.min, self.max) + + @classmethod + def from_proc_descr(cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], tensor_id: TensorId) -> Self: + return cls(input=tensor_id, output=tensor_id, min=descr.kwargs.min, max=descr.kwargs.max) + + +@dataclass +class EnsureDtype(_SimpleOperator): + dtype: DTypeLike + + @classmethod + def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, tensor_id: TensorId): + return cls(input=tensor_id, output=tensor_id, dtype=descr.kwargs.dtype) + + def get_descr(self): + return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=str(self.dtype))) + + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + return input.astype(self.dtype) + + +@dataclass +class ScaleLinear(_SimpleOperator): + gain: Union[float, xr.DataArray] = 1.0 + """multiplicative factor""" + + offset: Union[float, xr.DataArray] = 0.0 + """additive term""" + + def apply(self, input: Tensor, stat: Stat) -> Tensor: + return input * self.gain + self.offset + + # @classmethod + # def from_descr(cls, descr: ScaleLinearDescr) -> Self: + # ... + + @classmethod + def from_proc_descr(cls, descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], tensor_id: TensorId) -> Self: + kwargs = descr.kwargs + if isinstance(kwargs, v0_5.ScaleLinearKwargs): + axis = kwargs.axis + elif kwargs.axes is not None: + raise NotImplementedError("ScaleLinear operator from v0_4.ScaleLinearDescr with axes") + else: + axis = None + + if axis: + gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) + offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) + else: + assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 + gain = kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] + assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 + offset = kwargs.offset if isinstance(kwargs.offset, (float, int)) else kwargs.offset[0] + + return cls(input=tensor_id, output=tensor_id, gain=gain, offset=offset) + + +@dataclass +class ScaleMeanVariance(_SimpleOperator): + axes: Optional[Sequence[AxisId]] = None + reference_tensor: Optional[TensorId] = None + eps: float = 1e-6 + mean: Union[SampleMean, DatasetMean] = field(init=False) + std: Union[SampleStd, DatasetStd] = field(init=False) + ref_mean: Union[SampleMean, DatasetMean] = field(init=False) + ref_std: Union[SampleStd, DatasetStd] = field(init=False) + + @property + def required_measures(self): + return {self.mean, self.std, self.ref_mean, self.ref_std} + + def __post_init__(self): + axes = None if self.axes is None else tuple(self.axes) + ref_tensor = self.reference_tensor or self.input + if axes is None or AxisId("batch") not in axes: + Mean = SampleMean + Std = SampleStd + else: + Mean = DatasetMean + Std = DatasetStd + + self.mean = Mean(tensor_id=self.input, axes=axes) + self.std = Std(tensor_id=self.input, axes=axes) + self.ref_mean = Mean(tensor_id=ref_tensor, axes=axes) + self.ref_std = Std(tensor_id=ref_tensor, axes=axes) + + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + mean = stat[self.mean] + std = stat[self.std] + self.eps + ref_mean = stat[self.ref_mean] + ref_std = stat[self.ref_std] + self.eps + return (input - mean) / std * ref_std + ref_mean + + @classmethod + def from_proc_descr( + cls, descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], tensor_id: TensorId + ) -> Self: + kwargs = descr.kwargs + axes = _get_axes(descr.kwargs) + + return cls( + input=tensor_id, + output=tensor_id, + reference_tensor=cast(TensorId, kwargs.reference_tensor), + axes=axes, + eps=kwargs.eps, + ) + + +def _get_axes( + kwargs: Union[ + v0_4.ZeroMeanUnitVarianceKwargs, + v0_5.ZeroMeanUnitVarianceKwargs, + v0_4.ScaleRangeKwargs, + v0_5.ScaleRangeKwargs, + v0_4.ScaleMeanVarianceKwargs, + v0_5.ScaleMeanVarianceKwargs, + ] +) -> Union[Tuple[AxisId, ...], None]: + if kwargs.axes is None: + axes = None + elif isinstance(kwargs.axes, str): + axes = convert_axis_ids(kwargs.axes, kwargs["mode"]) + elif isinstance(kwargs.axes, collections.abc.Sequence): # pyright: ignore[reportUnnecessaryIsInstance] + axes = tuple(kwargs.axes) + else: + assert_never(kwargs.axes) + + return axes + + +@dataclass +class ScaleRange(_SimpleOperator): + lower_percentile: InitVar[Optional[Union[SamplePercentile, DatasetPercentile]]] = None + upper_percentile: InitVar[Optional[Union[SamplePercentile, DatasetPercentile]]] = None + lower: Union[SamplePercentile, DatasetPercentile] = field(init=False) + upper: Union[SamplePercentile, DatasetPercentile] = field(init=False) + + eps: float = 1e-6 + + def __post_init__( + self, + lower_percentile: Optional[Union[SamplePercentile, DatasetPercentile]], + upper_percentile: Optional[Union[SamplePercentile, DatasetPercentile]], + ): + if lower_percentile is None: + tid = self.input if upper_percentile is None else upper_percentile.tensor_id + self.lower = DatasetPercentile(n=0, tensor_id=tid) + else: + self.lower = lower_percentile + + if upper_percentile is None: + self.upper = DatasetPercentile(n=100, tensor_id=self.lower.tensor_id) + else: + self.upper = upper_percentile + + assert self.lower.tensor_id == self.upper.tensor_id + assert self.lower.n < self.upper.n + assert self.lower.axes == self.upper.axes + + @property + def required_measures(self): + return {self.lower, self.upper} + + @classmethod + def from_proc_descr(cls, descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], tensor_id: TensorId): + kwargs = descr.kwargs + ref_tensor = cast(TensorId, kwargs.reference_tensor) or tensor_id + axes = _get_axes(descr.kwargs) + if axes is None or AxisId("batch") in axes: + Percentile = DatasetPercentile + else: + Percentile = SamplePercentile + + return cls( + input=tensor_id, + output=tensor_id, + lower_percentile=Percentile(kwargs.min_percentile, axes=axes, tensor_id=ref_tensor), + upper_percentile=Percentile(kwargs.max_percentile, axes=axes, tensor_id=ref_tensor), + ) + + def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: + lower = stat[self.lower] + upper = stat[self.upper] + return (input - lower) / (upper - lower + self.eps) + + def get_descr(self): + assert self.lower.axes == self.upper.axes + assert self.lower.tensor_id == self.upper.tensor_id + + return v0_5.ScaleRangeDescr( + kwargs=v0_5.ScaleRangeKwargs( + axes=self.lower.axes, + min_percentile=self.lower.n, + max_percentile=self.upper.n, + eps=self.eps, + reference_tensor=self.lower.tensor_id, + ) + ) + + +@dataclass +class Sigmoid: + """1 / (1 + e^(-input)).""" + + def _apply(self, input: xr.DataArray) -> xr.DataArray: + return 1.0 / (1.0 + np.exp(-input)) # type: ignore + + @classmethod + def from_proc_descr(cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], tensor_id: TensorId) -> Self: + assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) + return cls() + + def get_descr(self): + return v0_5.SigmoidDescr() + + +@dataclass +class ZeroMeanUnitVariance(_SimpleOperator): + """normalize to zero mean, unit variance.""" + + mean: Union[SampleMean, DatasetMean] + std: Union[SampleStd, DatasetStd] + + eps: float = 1e-6 + + def __post_init__(self): + assert self.mean.axes == self.std.axes + + @property + def required_measures(self) -> Collection[Measure]: + return {self.mean, self.std} + + @classmethod + def from_proc_descr( + cls, descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], tensor_id: TensorId + ): + axes = _get_axes(descr.kwargs) + + if axes is None or AxisId("batch") in axes: + Mean = DatasetMean + Std = DatasetStd + else: + Mean = SampleMean + Std = SampleStd + + return cls( + input=tensor_id, + output=tensor_id, + mean=Mean(axes=axes, tensor_id=tensor_id), + std=Std(axes=axes, tensor_id=tensor_id), + ) + + def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: + mean = stat[self.mean] + std = stat[self.std] + return (input - mean) / (std + self.eps) + + def get_descr(self): + return v0_5.ZeroMeanUnitVarianceDescr(kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)) + + +@dataclass +class FixedZeroMeanUnitVariance(_SimpleOperator): + """normalize to zero mean, unit variance with precomputed values.""" + + mean: Union[float, xr.DataArray] + std: Union[float, xr.DataArray] + + eps: float = 1e-6 + + def __post_init__(self): + assert ( + isinstance(self.mean, (int, float)) or isinstance(self.std, (int, float)) or self.mean.dims == self.std.dims + ) + + @classmethod + def from_proc_descr( + cls, + descr: v0_5.FixedZeroMeanUnitVarianceDescr, + tensor_id: TensorId, + ) -> Self: + return cls( + input=tensor_id, + output=tensor_id, + mean=xr.DataArray(descr.kwargs.mean, dims=(descr.kwargs.axis,)), + std=xr.DataArray(descr.kwargs.std, dims=(descr.kwargs.axis,)), + ) + + def get_descr(self): + if isinstance(self.mean, (int, float)): + assert isinstance(self.std, (int, float)) + axis = None + mean = self.mean + std = self.std + else: + assert isinstance(self.std, xr.DataArray) + assert len(self.mean.dims) == 1 + axis = AxisId(str(self.mean.dims[0])) + mean = tuple(self.mean) + std = tuple(self.std) + + return v0_5.FixedZeroMeanUnitVarianceDescr( + kwargs=v0_5.FixedZeroMeanUnitVarianceKwargs(axis=axis, mean=mean, std=std) + ) + + def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: + return (input - self.mean) / (self.std + self.eps) + + +ProcDescr = Union[v0_4.PreprocessingDescr, v0_4.PostprocessingDescr, v0_5.PreprocessingDescr, v0_5.PostprocessingDescr] + +# get_impl_class which also returns the kwargs class +# def get_impl_class(proc_spec: ProcDescr): +# if isinstance(proc_spec, AssertDtype): +# return AssertDtypeImpl, AssertDtypeKwargs +# elif isinstance(proc_spec, v0_4.BinarizeDescr): +# return BinarizeImpl, v0_4.BinarizeKwargs +# elif isinstance(proc_spec, v0_5.BinarizeDescr): +# return BinarizeImpl, v0_5.BinarizeKwargs +# elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): +# return ClipImpl, v0_5.ClipKwargs +# elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): +# return EnsureDtypeImpl, v0_5.EnsureDtypeKwargs +# elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): +# return FixedZeroMeanUnitVarianceImpl, v0_5.FixedZeroMeanUnitVarianceKwargs +# elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): +# return ScaleLinearImpl, v0_5.ScaleLinearKwargs +# elif isinstance(proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)): +# return ScaleMeanVarianceImpl, v0_5.ScaleMeanVarianceKwargs +# elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): +# return ScaleRangeImpl, v0_5.ScaleRangeKwargs +# elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): +# return SigmoidImpl, v0_5.ProcessingKwargs +# elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) and proc_spec.kwargs.mode == "fixed": +# return FixedZeroMeanUnitVarianceImpl, v0_5.FixedZeroMeanUnitVarianceKwargs +# elif isinstance( +# proc_spec, # pyright: ignore[reportUnnecessaryIsInstance +# (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), +# ): +# return ZeroMeanUnitVarianceImpl, v0_5.ZeroMeanUnitVarianceKwargs +# else: +# assert_never(proc_spec) + +Processing = Union[ + Binarize, + Clip, + EnsureDtype, + FixedZeroMeanUnitVariance, + ScaleLinear, + ScaleMeanVariance, + ScaleRange, + Sigmoid, + ZeroMeanUnitVariance, +] + + +def get_proc_class(proc_spec: ProcDescr) -> Type[Processing]: + if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): + return Binarize + elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): + return Clip + elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): + return EnsureDtype + elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): + return FixedZeroMeanUnitVariance + elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): + return ScaleLinear + elif isinstance(proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)): + return ScaleMeanVariance + elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): + return ScaleRange + elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): + return Sigmoid + elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) and proc_spec.kwargs.mode == "fixed": + return FixedZeroMeanUnitVariance + elif isinstance( + proc_spec, # pyright: ignore[reportUnnecessaryIsInstance] + (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), + ): + return ZeroMeanUnitVariance + else: + assert_never(proc_spec) + + +def _get_complement_axis(tensor: xr.DataArray, axes: Optional[Sequence[Hashable]]) -> Optional[Hashable]: + if axes is None: + return None + + v04_AXIS_TYPE_MAP = { + "b": "batch", + "t": "time", + "i": "index", + "c": "channel", + "x": "space", + "y": "space", + "z": "space", + } + converted_axes = [v04_AXIS_TYPE_MAP.get(a, a) for a in map(str, axes)] + ["batch"] + complement_axes = [a for a in tensor.dims if str(a) not in converted_axes] + if len(complement_axes) != 1: + raise ValueError( + f"Expected a single complement axis, but axes '{converted_axes}' (orignally '{axes}') " + f"for tensor dims '{tensor.dims}' leave '{complement_axes}'." + ) + + return complement_axes[0] diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index e0079b89..762e3385 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -13,13 +13,13 @@ from typing_extensions import assert_never -from bioimageio.core.common import ProcessingKwargs, RequiredMeasure, Sample -from bioimageio.core.proc_impl import ( - ProcessingImpl, - ProcessingImplBase, - get_impl_class, +from bioimageio.core.common import ProcessingKwargs, Sample +from bioimageio.core.proc_ops import ( + Processing, + get_proc_class, ) from bioimageio.core.stat_calculators import compute_measures +from bioimageio.core.stat_measures import Measure from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import TensorId @@ -28,34 +28,34 @@ class _SetupProcessing(NamedTuple): - preprocessing: List[ProcessingImpl] - postprocessing: List[ProcessingImpl] + preprocessing: List[Processing] + postprocessing: List[Processing] def setup_pre_and_postprocessing(model: ModelDescr, dataset: Iterator[Sample]) -> _SetupProcessing: - Prepared = List[Tuple[Type[ProcessingImplBase[Any, Any, Any]], ProcessingKwargs, TensorId]] + Prepared = List[Tuple[Type[Processing], ProcessingKwargs, TensorId]] - required_measures: Set[RequiredMeasure[Any, Any]] = set() + required_measures: Set[Measure] = set() def prepare_procs(tensor_descrs: Sequence[TensorDescr]): prepared: Prepared = [] for t_descr in tensor_descrs: if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): - proc_specs = t_descr.preprocessing + proc_descrs = t_descr.preprocessing elif isinstance( t_descr, # pyright: ignore[reportUnnecessaryIsInstance] (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr), ): - proc_specs = t_descr.postprocessing + proc_descrs = t_descr.postprocessing else: assert_never(t_descr) - for proc_spec in proc_specs: - impl_class = get_impl_class(proc_spec) + for proc_d in proc_descrs: + proc_class = get_proc_class(proc_d) tensor_id = cast(TensorId, t_descr.name) if isinstance(t_descr, v0_4.TensorDescrBase) else t_descr.id - req = impl_class.get_required_measures(tensor_id, proc_spec.kwargs) # type: ignore + req = proc_class.from_proc_descr(proc_d, tensor_id) required_measures.update(req.get_set()) - prepared.append((impl_class, proc_spec.kwargs, tensor_id)) + prepared.append((proc_class, proc_d.kwargs, tensor_id)) return prepared diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index d465cd6c..8c12d29e 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -2,21 +2,15 @@ import collections import warnings -from abc import ABC, abstractmethod -from collections import defaultdict -from dataclasses import field from itertools import product from typing import ( Any, - ClassVar, - DefaultDict, + Collection, Dict, - Generic, Hashable, Iterable, Iterator, List, - Literal, Mapping, Optional, OrderedDict, @@ -25,103 +19,99 @@ Tuple, Type, Union, - assert_never, + cast, ) import numpy as np import xarray as xr from numpy.typing import NDArray +from typing_extensions import assert_never from bioimageio.core.common import ( - PER_DATASET, - PER_SAMPLE, AxisId, Sample, TensorId, ) from bioimageio.core.stat_measures import ( DatasetMean, - DatasetMeasureBase, - DatasetMeasureVar, + DatasetMeasure, DatasetPercentile, DatasetStd, DatasetVar, Measure, - MeasureVar, - Percentile, + MeasureValue, SampleMean, - SampleMeasureBase, + SampleMeasure, SamplePercentile, SampleStd, SampleVar, - Std, - Var, ) try: - import crick # type: ignore + import crick + except ImportError: crick = None -MeasureValue = Union[xr.DataArray, float] - - -# class SampleMeasureCalculator(ABC): -# """group of measures for more efficient computation of multiple measures per sample""" - -# @abstractmethod -# def compute(self, sample: Sample) -> Mapping[SampleMeasure, MeasureValue]: -# ... - + class TDigest: + def update(self, obj: Any): + pass -# class DatasetMeasureCalculator(ABC): -# """group of measures for more efficient computation of multiple measures per dataset""" + def quantile(self, q: Any) -> Any: + pass -# @abstractmethod -# def update_with_sample(self, sample: Sample) -> None: -# """update intermediate representation with a data sample""" -# ... - -# @abstractmethod -# def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: -# """compute statistics from intermediate representation""" -# ... +else: + TDigest = crick.TDigest # type: ignore class MeanCalculator: def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): super().__init__() - self._axes = None if axes is None else tuple(axes) - self._tensor_id = tensor_id self._n: int = 0 self._mean: Optional[xr.DataArray] = None + self._axes = None if axes is None else tuple(axes) + self._tensor_id = tensor_id + self._sample_mean = SampleMean(tensor_id=self._tensor_id, axes=self._axes) + self._dataset_mean = DatasetMean(tensor_id=self._tensor_id, axes=self._axes) - def compute(self, sample: Sample): - mean = SampleMean(axes=self._axes, tensor_id=self._tensor_id) - return {mean: mean.compute(sample)} + def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: + return {self._sample_mean: self._compute_impl(sample)} + + def _compute_impl(self, sample: Sample) -> xr.DataArray: + tensor = sample.data[self._tensor_id].astype(np.float64, copy=False) + return tensor.mean(dim=self._axes) + + def update(self, sample: Sample) -> None: + mean = self._compute_impl(sample) + self._update_impl(sample.data[self._tensor_id], mean) + + def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: + mean = self._compute_impl(sample) + self._update_impl(sample.data[self._tensor_id], mean) + return {self._sample_mean: mean} + + def _update_impl(self, tensor: xr.DataArray, tensor_mean: xr.DataArray): + assert tensor_mean.dtype == np.float64 + # reduced voxel count + n_b = np.prod(tensor.shape) / np.prod(tensor_mean.shape) # type: ignore - def update_with_sample(self, sample: Sample): - tensor = sample[self._tensor_id].astype(np.float64, copy=False) - mean_b = tensor.mean(dim=self._axes) - assert mean_b.dtype == np.float64 - n_b = np.prod(list(tensor.shape)) / np.prod(list(mean_b.shape)) # reduced voxel count if self._mean is None: assert self._n == 0 self._n = n_b - self._mean = mean_b + self._mean = tensor_mean else: assert self._n != 0 n_a = self._n - mean_a = self._mean - self._n = n = n_a + n_b - self._mean = (n_a * mean_a + n_b * mean_b) / n + mean_old = self._mean + self._n = n_a + n_b + self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n assert self._mean.dtype == np.float64 - def finalize(self) -> Mapping[DatasetMeasureBase, MeasureValue]: + def finalize(self) -> Dict[DatasetMean, MeasureValue]: if self._mean is None: return {} else: - return {DatasetMean(axes=self._axes, tensor_id=self._tensor_id): self._mean} + return {self._dataset_mean: self._mean} class MeanVarStdCalculator: @@ -133,8 +123,8 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): self._mean: Optional[xr.DataArray] = None self._m2: Optional[xr.DataArray] = None - def compute(self, sample: Sample): - tensor = sample[self._tensor_id] + def compute(self, sample: Sample) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: + tensor = sample.data[self._tensor_id] mean = tensor.mean(dim=self._axes) c = tensor - mean if self._axes is None: @@ -142,16 +132,18 @@ def compute(self, sample: Sample): else: n = int(np.prod([tensor.sizes[d] for d in self._axes])) # type: ignore # FIXME: type annotation - var = xr.dot(c, c, dims=self._axes) / n - std = np.sqrt(var) + var: xr.DataArray = xr.dot(c, c, dims=self._axes) / n + assert isinstance(var, xr.DataArray) + std: xr.DataArray = np.sqrt(var) # type: ignore + assert isinstance(std, xr.DataArray) return { SampleMean(axes=self._axes, tensor_id=self._tensor_id): mean, SampleVar(axes=self._axes, tensor_id=self._tensor_id): var, SampleStd(axes=self._axes, tensor_id=self._tensor_id): std, } - def update_with_sample(self, sample: Sample): - tensor = sample[self._tensor_id].astype(np.float64, copy=False) + def update(self, sample: Sample): + tensor = sample.data[self._tensor_id].astype(np.float64, copy=False) mean_b = tensor.mean(dim=self._axes) assert mean_b.dtype == np.float64 # reduced voxel count @@ -174,7 +166,7 @@ def update_with_sample(self, sample: Sample): self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n assert self._m2.dtype == np.float64 - def finalize(self) -> Mapping[DatasetMeasureBase, MeasureValue]: + def finalize(self) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: if self._mean is None: return {} else: @@ -189,7 +181,7 @@ def finalize(self) -> Mapping[DatasetMeasureBase, MeasureValue]: class SamplePercentilesCalculator: - def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Sequence[float]): + def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Collection[float]): super().__init__() assert all(0 <= n <= 100 for n in ns) self.ns = ns @@ -197,14 +189,14 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Se self._axes = None if axes is None else tuple(axes) self._tensor_id = tensor_id - def compute(self, sample: Sample): - tensor = sample[self._tensor_id] + def compute(self, sample: Sample) -> Dict[SamplePercentile, MeasureValue]: + tensor = sample.data[self._tensor_id] ps = tensor.quantile(self._qs, dim=self._axes) # type: ignore return {SamplePercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): p for n, p in zip(self.ns, ps)} class MeanPercentilesCalculator: - def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Sequence[float]): + def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Collection[float]): super().__init__() assert all(0 <= n <= 100 for n in ns) self._ns = ns @@ -214,8 +206,8 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Se self._n: int = 0 self._estimates: Optional[xr.DataArray] = None - def update_with_sample(self, sample: Sample): - tensor = sample[self._tensor_id] + def update(self, sample: Sample): + tensor = sample.data[self._tensor_id] sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(np.float64, copy=False) # reduced voxel count @@ -230,7 +222,7 @@ def update_with_sample(self, sample: Sample): self._n += n - def finalize(self) -> Mapping[DatasetPercentile, MeasureValue]: + def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: if self._estimates is None: return {} else: @@ -242,7 +234,7 @@ def finalize(self) -> Mapping[DatasetPercentile, MeasureValue]: class CrickPercentilesCalculator: - def __init__(self, tensor_name: TensorId, axes: Optional[Sequence[AxisId]], ns: Sequence[float]): + def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Collection[float]): warnings.warn("Computing dataset percentiles with experimental 'crick' library.") super().__init__() assert all(0 <= n <= 100 for n in ns) @@ -250,8 +242,8 @@ def __init__(self, tensor_name: TensorId, axes: Optional[Sequence[AxisId]], ns: self._ns = ns self._qs = [n / 100 for n in ns] self._axes = None if axes is None else tuple(axes) - self._tensor_id = tensor_name - self._digest: Optional[List[crick.TDigest]] = None + self._tensor_id = tensor_id + self._digest: Optional[List[TDigest]] = None self._dims: Optional[Tuple[Hashable, ...]] = None self._indices: Optional[Iterator[Tuple[int, ...]]] = None self._shape: Optional[Tuple[int, ...]] = None @@ -266,11 +258,11 @@ def _initialize(self, tensor_sizes: Mapping[Hashable, int]): self._dims, self._shape = zip(*out_sizes.items()) d = int(np.prod(self._shape[1:])) # type: ignore - self._digest = [crick.TDigest() for _ in range(d)] + self._digest = [TDigest() for _ in range(d)] self._indices = product(*map(range, self._shape[1:])) - def update_with_sample(self, sample: Sample): - tensor = sample[self._tensor_id] + def update(self, sample: Sample): + tensor = sample.data[self._tensor_id] assert "_percentiles" not in tensor.dims if self._digest is None: self._initialize(tensor.sizes) @@ -286,7 +278,11 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: return {} else: assert self._dims is not None - vs: NDArray[Any] = np.asarray([[d.quantile(q) for d in self._digest] for q in self._qs]).reshape(self._shape) # type: ignore + assert self._shape is not None + + vs: NDArray[Any] = np.asarray([[d.quantile(q) for d in self._digest] for q in self._qs]).reshape( + self._shape + ) return { DatasetPercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): xr.DataArray(v, dims=self._dims[1:]) for n, v in zip(self._ns, vs) @@ -294,29 +290,83 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: if crick is None: - DatasetPercentileCalculator: Type[ + DatasetPercentilesCalculator: Type[ Union[MeanPercentilesCalculator, CrickPercentilesCalculator] ] = MeanPercentilesCalculator else: - DatasetPercentileCalculator = CrickPercentilesCalculator + DatasetPercentilesCalculator = CrickPercentilesCalculator class NaivSampleMeasureCalculator: """wrapper for measures to match interface of other sample measure calculators""" - def __init__(self, tensor_id: TensorId, measure: SampleMeasureBase): + def __init__(self, tensor_id: TensorId, measure: SampleMeasure): super().__init__() self.tensor_name = tensor_id self.measure = measure - def compute(self, sample: Sample) -> Mapping[SampleMeasureBase, MeasureValue]: + def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: return {self.measure: self.measure.compute(sample)} SampleMeasureCalculator = Union[ MeanCalculator, MeanVarStdCalculator, SamplePercentilesCalculator, NaivSampleMeasureCalculator ] -DatasetMeasureCalculator = Union[MeanCalculator, MeanVarStdCalculator, DatasetPercentileCalculator] +DatasetMeasureCalculator = Union[MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator] + + +class StatsCalculator: + """Estimates dataset statistics and computes sample statistics efficiently""" + + def __init__( + self, + *, + measures: Iterable[Measure], + ): + super().__init__() + self.sample_count = 0 + self.sample_calculators, self.dataset_calculators = get_measure_calculators(measures) + self._current_dataset_measures: Optional[Dict[DatasetMeasure, MeasureValue]] = None + + def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: + ret: Dict[SampleMeasure, MeasureValue] = {} + for calc in self.sample_calculators: + values = calc.compute(sample) + ret.update(values.items()) + + return ret + + def _update(self, sample: Sample): + self.sample_count += 1 + for calc in self.dataset_calculators: + calc.update(sample) + self._current_dataset_measures = None + + def _compute_and_update(self, sample: Sample): + self._update(sample) + return self._compute(sample) + + def _finalize(self) -> Dict[DatasetMeasure, MeasureValue]: + """returns aggregated dataset statistics""" + if self._current_dataset_measures is None: + self._current_dataset_measures = {} + for calc in self.dataset_calculators: + values = calc.finalize() + self._current_dataset_measures.update(values.items()) + + return self._current_dataset_measures + + def update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: + """Returns sample as well as updated dataset statistics""" + ret = cast(Dict[Measure, MeasureValue], self._compute_and_update(sample)) + ret.update(self._finalize().items()) + return ret + + def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: + """Returns sample as well as previously computed dataset statistics""" + ret = cast(Dict[Measure, MeasureValue], self._compute(sample)) + ret.update(self._finalize().items()) + return ret def get_measure_calculators( @@ -332,8 +382,8 @@ def get_measure_calculators( required_dataset_means: Set[DatasetMean] = set() required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = set() - required_sample_percentiles: Set[SamplePercentile] = set() - required_dataset_percentiles: Set[DatasetPercentile] = set() + required_sample_percentiles: Dict[Tuple[TensorId, Optional[Tuple[AxisId, ...]]], Set[float]] = {} + required_dataset_percentiles: Dict[Tuple[TensorId, Optional[Tuple[AxisId, ...]]], Set[float]] = {} for rm in required_measures: if isinstance(rm, SampleMean): @@ -351,9 +401,9 @@ def get_measure_calculators( ) assert rm in required_dataset_mean_var_std elif isinstance(rm, SamplePercentile): - required_sample_percentiles.add(rm) + required_sample_percentiles.setdefault((rm.tensor_id, rm.axes), set()).add(rm.n) elif isinstance(rm, DatasetPercentile): # pyright: ignore[reportUnnecessaryIsInstance] - required_dataset_percentiles.add(rm) + required_dataset_percentiles.setdefault((rm.tensor_id, rm.axes), set()).add(rm.n) else: assert_never(rm) @@ -377,36 +427,29 @@ def get_measure_calculators( for rm in required_dataset_mean_var_std: dataset_calculators.append(MeanVarStdCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) - for rm in required_sample_percentiles: - sample_calculators.append(SamplePercentilesCalculator(tensor_id=rm.tensor_id, axes=axes)) - for (tn, axes), ns in required_sample_percentiles.items(): - if mode == PER_SAMPLE: - calculators[mode].append(SamplePercentilesCalculator(tensor_id=tn, axes=axes, ns=ns)) - elif mode == PER_DATASET: - calculators[mode].append(DatasetPercentileCalculator(tensor_name=tn, axes=axes, ns=ns)) - else: - raise NotImplementedError(mode) - - return calculators - - -def compute_measures( - measures: Set[Measure], *, sample: Optional[Sample] = None, dataset: Iterator[Sample] = () -) -> ComputedMeasures: - ms_groups = get_measure_calculators(measures) - ret = {PER_SAMPLE: {}, PER_DATASET: {}} - if sample is not None: - for mg in ms_groups[PER_SAMPLE]: - assert isinstance(mg, SampleMeasureCalculator) - ret[PER_SAMPLE].update(mg.compute(sample)) + for (tid, axes), ns in required_sample_percentiles.items(): + sample_calculators.append(SamplePercentilesCalculator(tensor_id=tid, axes=axes, ns=ns)) + + for (tid, axes), ns in required_dataset_percentiles.items(): + dataset_calculators.append(DatasetPercentilesCalculator(tensor_id=tid, axes=axes, ns=ns)) + + return sample_calculators, dataset_calculators + + +def compute_dataset_measures( + *, measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] +) -> Dict[DatasetMeasure, MeasureValue]: + """compute all dataset `measures` for the given `dataset`""" + sample_calculators, calculators = get_measure_calculators(measures) + assert not sample_calculators + + ret: Dict[DatasetMeasure, MeasureValue] = {} for sample in dataset: - for mg in ms_groups[PER_DATASET]: - assert isinstance(mg, DatasetMeasureCalculator) - mg.update_with_sample(sample) + for calc in calculators: + calc.update(sample) - for mg in ms_groups[PER_DATASET]: - assert isinstance(mg, DatasetMeasureCalculator) - ret[PER_DATASET].update(mg.finalize()) + for calc in calculators: + ret.update(calc.finalize().items()) return ret diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index 6f4f3aa9..96de3a9c 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -2,18 +2,17 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Tuple, TypeVar, Union +from typing import Optional, Tuple, Union import xarray as xr -from bioimageio.core.common import Sample -from bioimageio.spec.model.v0_5 import AxisId, TensorId +from bioimageio.core.common import AxisId, Sample, TensorId MeasureValue = Union[float, xr.DataArray] @dataclass(frozen=True) -class MeasureBase(ABC): +class MeasureBase: tensor_id: TensorId @@ -31,55 +30,67 @@ class DatasetMeasureBase(MeasureBase, ABC): @dataclass(frozen=True) -class _Mean(MeasureBase): +class _Mean: axes: Optional[Tuple[AxisId, ...]] = None @dataclass(frozen=True) -class SampleMean(_Mean, SampleMeasureBase): +class SampleMean(SampleMeasureBase, _Mean): def compute(self, sample: Sample) -> MeasureValue: - return sample[self.tensor_id].mean(dim=self.axes) + return sample.data[self.tensor_id].mean(dim=self.axes) + + def __post_init__(self): + assert self.axes is None or AxisId("batch") not in self.axes @dataclass(frozen=True) -class DatasetMean(_Mean, DatasetMeasureBase): - pass +class DatasetMean(DatasetMeasureBase, _Mean): + def __post_init__(self): + assert self.axes is None or AxisId("batch") in self.axes @dataclass(frozen=True) -class _Std(MeasureBase): +class _Std: axes: Optional[Tuple[AxisId, ...]] = None @dataclass(frozen=True) -class SampleStd(_Std, SampleMeasureBase): +class SampleStd(SampleMeasureBase, _Std): def compute(self, sample: Sample) -> MeasureValue: - return sample[self.tensor_id].std(dim=self.axes) + return sample.data[self.tensor_id].std(dim=self.axes) + + def __post_init__(self): + assert self.axes is None or AxisId("batch") not in self.axes @dataclass(frozen=True) -class DatasetStd(_Std, DatasetMeasureBase): - pass +class DatasetStd(DatasetMeasureBase, _Std): + def __post_init__(self): + assert self.axes is None or AxisId("batch") in self.axes @dataclass(frozen=True) -class _Var(MeasureBase): +class _Var: axes: Optional[Tuple[AxisId, ...]] = None @dataclass(frozen=True) -class SampleVar(_Var, SampleMeasureBase): +class SampleVar(SampleMeasureBase, _Var): def compute(self, sample: Sample) -> MeasureValue: - return sample[self.tensor_id].var(dim=self.axes) + return sample.data[self.tensor_id].var(dim=self.axes) + + def __post_init__(self): + assert self.axes is None or AxisId("batch") not in self.axes @dataclass(frozen=True) -class DatasetVar(_Var, DatasetMeasureBase): - pass +class DatasetVar(DatasetMeasureBase, _Var): + def __post_init__(self): + assert self.axes is None or AxisId("batch") in self.axes @dataclass(frozen=True) -class _Percentile(MeasureBase): +class _Percentile: n: float axes: Optional[Tuple[AxisId, ...]] = None @@ -89,21 +100,22 @@ def __post_init__(self): @dataclass(frozen=True) -class SamplePercentile(_Percentile, SampleMeasureBase): +class SamplePercentile(SampleMeasureBase, _Percentile): def compute(self, sample: Sample) -> MeasureValue: - return sample[self.tensor_id].tensor.quantile(self.n / 100.0, dim=self.axes) + return sample.data[self.tensor_id].quantile(self.n / 100.0, dim=self.axes) + + def __post_init__(self): + super().__post_init__() + assert self.axes is None or AxisId("batch") not in self.axes @dataclass(frozen=True) -class DatasetPercentile(_Percentile, DatasetMeasureBase): - pass +class DatasetPercentile(DatasetMeasureBase, _Percentile): + def __post_init__(self): + super().__post_init__() + assert self.axes is None or AxisId("batch") in self.axes SampleMeasure = Union[SampleMean, SampleStd, SampleVar, SamplePercentile] DatasetMeasure = Union[DatasetMean, DatasetStd, DatasetVar, DatasetPercentile] Measure = Union[SampleMeasure, DatasetMeasure] - -# MeasureVar = TypeVar("MeasureVar", bound=MeasureBase) -# SampleMeasureVar = TypeVar("SampleMeasureVar", bound=SampleMeasureBase) -# DatasetMeasureVar = TypeVar("DatasetMeasureVar", bound=DatasetMeasureBase) -# ModeVar = TypeVar("ModeVar", bound=Literal["per_sample", "per_dataset"]) diff --git a/bioimageio/core/stat_state.py b/bioimageio/core/stat_state.py deleted file mode 100644 index 24f062c9..00000000 --- a/bioimageio/core/stat_state.py +++ /dev/null @@ -1,100 +0,0 @@ -from dataclasses import dataclass, field -from typing import Dict, Iterable, Literal, Optional, Union - -from tqdm import tqdm - -from bioimageio.core.common import PER_DATASET, PER_SAMPLE, RequiredMeasure, Sample, TensorId -from bioimageio.core.stat_calculators import MeasureGroups, MeasureValue, get_measure_calculators -from bioimageio.core.stat_measures import MeasureBase, MeasureValue - - -@dataclass -class StatsState: - """class to compute, hold and update dataset and sample statistics""" - - required_measures: Iterable[RequiredMeasure] - - -def compute_statistics(): - dataset: Iterable[Sample] - update_dataset_stats_after_n_samples: Optional[int] = None - update_dataset_stats_for_n_samples: Union[int, float] = float("inf") - -def - """iterates over dataset to compute dataset statistics (if required). The resulting dataset statistics are further updated with each new sample. A sample in this context may be a mini-batch. - - Args: - required_measures: measures to be computed - dataset: (partial) dataset to initialize dataset statistics with - update_dataset_stats_after_n_samples: Update dataset statistics for new samples S_i if i > n. - (default: len(dataset)) - This parameter allows to avoid weighting the first n processed - samples to count twice if they make up the given 'dataset'. - update_dataset_stats_for_n_samples: stop updating dataset statistics with new samples S_i if - i > for_n_samples (+ update_dataset_stats_after_n_samples) - """ - sample_count: int = field(init=False) - last_sample: Optional[Sample] = field(init=False) - measure_groups: MeasureGroups = field(init=False) - _n_start: Union[int, float] = field(init=False) - _n_stop: Union[int, float] = field(init=False) - _final_dataset_stats: Optional[Dict[RequiredMeasure, MeasureValue]] = field(init=False) - - def __init__( - self, - *, - ): - super().__init__() - self.required_measures = required_measures - self.update_dataset_stats_after_n_samples = update_dataset_stats_after_n_samples - self.update_dataset_stats_for_n_samples = update_dataset_stats_for_n_samples - self.reset(dataset) - - def reset(self, dataset: Iterable[Sample]): - self.sample_count = 0 - self.last_sample = None - self._final_dataset_stats = None - self.measure_groups = get_measure_calculators(self.required_measures) - - len_dataset = 0 - if self.measure_groups[PER_DATASET]: - for sample in tqdm(dataset, "computing dataset statistics"): - len_dataset += 1 - self._update_dataset_measure_groups(sample) - - if self.update_dataset_stats_after_n_samples is None: - self._n_start = len_dataset - else: - self._n_start = self.update_dataset_stats_after_n_samples - - self._n_stop = self._n_start + self.update_dataset_stats_for_n_samples - - def update_with_sample(self, sample: Sample): - self.last_sample = sample - self.sample_count += 1 - if self._n_start < self.sample_count <= self._n_stop: - self._update_dataset_measure_groups(sample) - - def _update_dataset_measure_groups(self, sample: Sample): - for mg in self.measure_groups[PER_DATASET]: - mg.update_with_sample(sample) - - def compute_measures(self) -> ComputedMeasures: - ret = {PER_SAMPLE: {}, PER_DATASET: {}} - if self.last_sample is not None: - for mg in self.measure_groups[PER_SAMPLE]: - ret[PER_SAMPLE].update(mg.compute(self.last_sample)) - - if self._final_dataset_stats is None: - dataset_stats = {} - for mg in self.measure_groups[PER_DATASET]: - dataset_stats.update(mg.finalize()) - - if self.sample_count > self._n_stop: - # stop recomputing final dataset statistics - self._final_dataset_stats = dataset_stats - else: - dataset_stats = self._final_dataset_stats - - ret[PER_DATASET] = dataset_stats - return ret diff --git a/setup.py b/setup.py index 82b05008..3a89a597 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ ], include_package_data=True, extras_require={ - "test": ["pytest", "black[jupyter]", "onnxruntime", "torch>=1.6", "torchvision"], + "test": ["pytest", "black[jupyter]", "onnxruntime", "torch>=1.6", "torchvision", "crick"], "dev": ["pre-commit"], "pytorch": ["torch>=1.6", "torchvision"], "tensorflow": ["tensorflow"], From 324be48d1a9bf05b21c46cc56c96ff7912529047 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 15 Dec 2023 22:24:10 +0100 Subject: [PATCH 080/244] remove add_weights --- bioimageio/core/build_spec/add_weights.py | 81 ----------------------- 1 file changed, 81 deletions(-) delete mode 100644 bioimageio/core/build_spec/add_weights.py diff --git a/bioimageio/core/build_spec/add_weights.py b/bioimageio/core/build_spec/add_weights.py deleted file mode 100644 index 7dbafe82..00000000 --- a/bioimageio/core/build_spec/add_weights.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -from pathlib import Path -from shutil import copyfile -from typing import Dict, List, Optional, Union - -from pydantic import DirectoryPath, FilePath - -from bioimageio.core import export_resource_package -from bioimageio.core.io import FileSource, download, read_description, write_package_as_folder -from bioimageio.spec.model import AnyModel, v0_5 - -from .build_model import _get_weights - - -def add_weights( - model: Union[AnyModel, FileSource], - weight_file: FileSource, - output_path: DirectoryPath, - *, - weight_type: Optional[v0_5.WeightsFormat] = None, - architecture: Optional[str] = None, - model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None, - tensorflow_version: Optional[str] = None, - opset_version: Optional[str] = None, - pytorch_version: Optional[str] = None, - attachments: Optional[Dict[str, Union[str, List[str]]]] = None, -): - """Add weight entry to bioimage.io model. - - Args: - model: the resource description of the model to which the weight format is added - weight_file: the weight file to be added - output_path: where to serialize the new model with additional weight format - weight_type: the format of the weights to be added - architecture: the file with the source code for the model architecture and the corresponding class. - Only required for models with pytorch_state_dict weight format. - model_kwargs: the keyword arguments for the model class. - Only required for models with pytorch_state_dict weight format. - tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights. - opset_version: the opset version for this model. Only for onnx weights. - pytorch_version: the pytorch version for this model. Only for pytoch_state_dict or torchscript weights. - attachments: extra weight specific attachments. - """ - downloaded_weight_file = download(weight_file) - output_path = write_package_as_folder(model, output_path=output_path) - - # copy the weight path to the input model's root, otherwise it will - # not be found when packaging the new model - weight_out: FilePath = output_path / downloaded_weight_file.original_file_name # noqa: F821 - _ = copyfile(downloaded_weight_file.path, weight_out) - - new_weights, tmp_arch = _get_weights( - weight_out, - weight_type, - root=output_path, - architecture=architecture, - model_kwargs=model_kwargs, - tensorflow_version=tensorflow_version, - opset_version=opset_version, - pytorch_version=pytorch_version, - attachments=attachments, - ) - model.weights.update(new_weights) - - try: - model_package = export_resource_package(model, output_path=output_path) - model = read_description(model_package) - except Exception as e: - raise e - finally: - # clean up tmp files - if Path(weight_out).absolute() != Path(weight_file).absolute(): - os.remove(weight_out) - if tmp_arch is not None: - os.remove(tmp_arch) - # for some reason the weights are also copied to the cwd. - # not sure why this happens, but it needs to be cleaned up, unless these are the input weigths - weights_cwd = Path(os.path.split(weight_file)[1]) - if weights_cwd.exists() and weights_cwd.absolute() != Path(weight_file).absolute(): - os.remove(weights_cwd) - return model From f1af46fcff887e2777bffe09bba780aa4a5b81e2 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 23 Dec 2023 01:18:18 +0100 Subject: [PATCH 081/244] improve typing --- .../core/weight_converter/keras/tensorflow.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/bioimageio/core/weight_converter/keras/tensorflow.py b/bioimageio/core/weight_converter/keras/tensorflow.py index 38af1ea2..0836172d 100644 --- a/bioimageio/core/weight_converter/keras/tensorflow.py +++ b/bioimageio/core/weight_converter/keras/tensorflow.py @@ -1,18 +1,18 @@ import os import shutil from pathlib import Path -from typing import Union +from typing import Union, no_type_check from zipfile import ZipFile import tensorflow from tensorflow import saved_model -from bioimageio.spec import AnyModel, load_description from bioimageio.spec._internal.io_utils import download +from bioimageio.spec.model.v0_5 import ModelDescr def _zip_model_bundle(model_bundle_folder: Path): - zipped_model_bundle = f"{model_bundle_folder}.zip" + zipped_model_bundle = model_bundle_folder.with_suffix(".zip") with ZipFile(zipped_model_bundle, "w") as zip_obj: for root, _, files in os.walk(model_bundle_folder): @@ -33,12 +33,13 @@ def _zip_model_bundle(model_bundle_folder: Path): def _convert_tf1(keras_weight_path: Path, output_path: Path, input_name: str, output_name: str, zip_weights: bool): try: # try to build the tf model with the keras import from tensorflow - from tensorflow import keras + from tensorflow import keras # type: ignore except Exception: # if the above fails try to export with the standalone keras import keras + @no_type_check def build_tf_model(): keras_model = keras.models.load_model(keras_weight_path) @@ -63,7 +64,7 @@ def build_tf_model(): return 0 -def _convert_tf2(keras_weight_path, output_path, zip_weights): +def _convert_tf2(keras_weight_path: Path, output_path: Path, zip_weights: bool): try: # try to build the tf model with the keras import from tensorflow from tensorflow import keras @@ -81,32 +82,30 @@ def _convert_tf2(keras_weight_path, output_path, zip_weights): return 0 -def convert_weights_to_tensorflow_saved_model_bundle( - model_spec: Union[str, Path, AnyModel], output_path: Union[str, Path] -): +def convert_weights_to_tensorflow_saved_model_bundle(model: ModelDescr, output_path: Path): """Convert model weights from format 'keras_hdf5' to 'tensorflow_saved_model_bundle'. Adapted from https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py Args: - model_spec: location of the resource for the input bioimageio model + model: The bioimageio model description output_path: where to save the tensorflow weights. This path must not exist yet. """ tf_major_ver = int(tensorflow.__version__.split(".")[0]) - path_ = Path(output_path) - if path_.suffix == ".zip": - path_ = Path(os.path.splitext(path_)[0]) + if output_path.suffix == ".zip": + output_path = output_path.with_suffix("") zip_weights = True else: zip_weights = False - if path_.exists(): - raise ValueError(f"The ouptut directory at {path_} must not exist.") + if output_path.exists(): + raise ValueError(f"The ouptut directory at {output_path} must not exist.") + + if model.weights.keras_hdf5 is None: + raise ValueError("Missing Keras Hdf5 weights to convert from.") - model = load_description(model_spec) - model.weights.keras_hdf5 is not None weight_spec = model.weights.keras_hdf5 weight_path = download(weight_spec.source).path @@ -120,6 +119,6 @@ def convert_weights_to_tensorflow_saved_model_bundle( raise NotImplementedError( "Weight conversion for models with multiple inputs or outputs is not yet implemented." ) - return _convert_tf1(weight_path, str(path_), model.inputs[0].name, model.outputs[0].name, zip_weights) + return _convert_tf1(weight_path, output_path, model.inputs[0].id, model.outputs[0].id, zip_weights) else: - return _convert_tf2(weight_path, str(path_), zip_weights) + return _convert_tf2(weight_path, output_path, zip_weights) From ca02570828265eae2097793aab05530e66502428 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 12 Feb 2024 13:01:17 +0100 Subject: [PATCH 082/244] ignore reportUnnecessaryIsInstance --- .../_torchscript_model_adapter.py | 11 +++++------ bioimageio/core/proc_ops.py | 6 +++--- bioimageio/core/proc_setup.py | 2 +- bioimageio/core/stat_calculators.py | 2 +- .../core/weight_converter/keras/tensorflow.py | 2 +- .../weight_converter/torch/torchscript.py | 2 +- pyproject.toml | 19 ++++++++++--------- 7 files changed, 22 insertions(+), 22 deletions(-) diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index 804c8503..7aef4fee 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -7,15 +7,17 @@ import xarray as xr from numpy.typing import NDArray -from bioimageio.spec.utils import download from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import RelativeFilePath +from bioimageio.spec.utils import download from ._model_adapter import ModelAdapter class TorchscriptModelAdapter(ModelAdapter): - def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None): + def __init__( + self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None + ): super().__init__() if model_description.weights.torchscript is None: raise ValueError(f"No torchscript weights found for model {model_description.name}") @@ -50,10 +52,7 @@ def forward(self, *batch: xr.DataArray) -> List[xr.DataArray]: else: result = [_result] - result = [ - r.cpu().numpy() if not isinstance(r, np.ndarray) else r # pyright: ignore[reportUnnecessaryIsInstance] - for r in result - ] + result = [r.cpu().numpy() if not isinstance(r, np.ndarray) else r for r in result] assert len(result) == len(self._internal_output_axes) return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 1faf1be2..3f949d94 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -263,7 +263,7 @@ def _get_axes( axes = None elif isinstance(kwargs.axes, str): axes = convert_axis_ids(kwargs.axes, kwargs["mode"]) - elif isinstance(kwargs.axes, collections.abc.Sequence): # pyright: ignore[reportUnnecessaryIsInstance] + elif isinstance(kwargs.axes, collections.abc.Sequence): axes = tuple(kwargs.axes) else: assert_never(kwargs.axes) @@ -477,7 +477,7 @@ def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: # elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) and proc_spec.kwargs.mode == "fixed": # return FixedZeroMeanUnitVarianceImpl, v0_5.FixedZeroMeanUnitVarianceKwargs # elif isinstance( -# proc_spec, # pyright: ignore[reportUnnecessaryIsInstance +# proc_spec, # (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), # ): # return ZeroMeanUnitVarianceImpl, v0_5.ZeroMeanUnitVarianceKwargs @@ -517,7 +517,7 @@ def get_proc_class(proc_spec: ProcDescr) -> Type[Processing]: elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) and proc_spec.kwargs.mode == "fixed": return FixedZeroMeanUnitVariance elif isinstance( - proc_spec, # pyright: ignore[reportUnnecessaryIsInstance] + proc_spec, (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), ): return ZeroMeanUnitVariance diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index 762e3385..a138dd7e 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -43,7 +43,7 @@ def prepare_procs(tensor_descrs: Sequence[TensorDescr]): if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): proc_descrs = t_descr.preprocessing elif isinstance( - t_descr, # pyright: ignore[reportUnnecessaryIsInstance] + t_descr, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr), ): proc_descrs = t_descr.postprocessing diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 8c12d29e..af563b14 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -402,7 +402,7 @@ def get_measure_calculators( assert rm in required_dataset_mean_var_std elif isinstance(rm, SamplePercentile): required_sample_percentiles.setdefault((rm.tensor_id, rm.axes), set()).add(rm.n) - elif isinstance(rm, DatasetPercentile): # pyright: ignore[reportUnnecessaryIsInstance] + elif isinstance(rm, DatasetPercentile): required_dataset_percentiles.setdefault((rm.tensor_id, rm.axes), set()).add(rm.n) else: assert_never(rm) diff --git a/bioimageio/core/weight_converter/keras/tensorflow.py b/bioimageio/core/weight_converter/keras/tensorflow.py index 0836172d..5eed3797 100644 --- a/bioimageio/core/weight_converter/keras/tensorflow.py +++ b/bioimageio/core/weight_converter/keras/tensorflow.py @@ -1,7 +1,7 @@ import os import shutil from pathlib import Path -from typing import Union, no_type_check +from typing import no_type_check from zipfile import ZipFile import tensorflow diff --git a/bioimageio/core/weight_converter/torch/torchscript.py b/bioimageio/core/weight_converter/torch/torchscript.py index e01ac34f..451fcb3e 100644 --- a/bioimageio/core/weight_converter/torch/torchscript.py +++ b/bioimageio/core/weight_converter/torch/torchscript.py @@ -59,7 +59,7 @@ def _check(input_: Sequence[torch.Tensor]) -> None: step.append(0) elif isinstance(axis.size, (v0_5.AxisId, v0_5.TensorAxisId, type(None))): raise NotImplementedError(f"Can't verify inputs that don't specify their shape fully: {axis}") - elif isinstance(axis.size, v0_5.SizeReference): # pyright: ignore [reportUnnecessaryIsInstance] + elif isinstance(axis.size, v0_5.SizeReference): raise NotImplementedError(f"Can't handle axes like '{axis}' yet") else: assert_never(axis.size) diff --git a/pyproject.toml b/pyproject.toml index e0296a1f..be6c4a92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,21 +3,22 @@ line-length = 120 target-version = ["py38", "py39", "py310", "py311"] [tool.pyright] -include = ["bioimageio", "scripts", "tests"] exclude = ["**/node_modules", "**/__pycache__", "tests/old_*"] -typeCheckingMode = "strict" -reportMissingSuperCall = "error" -reportUnnecessaryTypeIgnoreComment = "error" -reportUninitializedInstanceVariable = "error" -reportUnknownMemberType = false +include = ["bioimageio", "scripts", "tests"] +pythonPlatform = "All" +pythonVersion = "3.8" reportIncompatibleMethodOverride = true +reportMissingSuperCall = "error" reportMissingTypeArgument = true reportMissingTypeStubs = "warning" -useLibraryCodeForTypes = true +reportUninitializedInstanceVariable = "error" +reportUnknownMemberType = false +reportUnnecessaryIsInstance = false +reportUnnecessaryTypeIgnoreComment = "error" reportUnusedCallResult = "error" reportUnusedVariable = "error" -pythonVersion = "3.8" -pythonPlatform = "All" +typeCheckingMode = "strict" +useLibraryCodeForTypes = true [tool.pytest.ini_options] addopts = "--capture=no --doctest-modules --failed-first" From ce80896ca22c774d212c9b0c1473031aa48af24d Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 13 Feb 2024 15:33:23 +0100 Subject: [PATCH 083/244] WIP update adapters --- .../core/model_adapters/_model_adapter.py | 2 +- .../model_adapters/_onnx_model_adapter.py | 4 +++- .../_tensorflow_model_adapter.py | 19 +++++++++---------- .../_torchscript_model_adapter.py | 7 ++----- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index a2809293..52a33771 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -42,7 +42,7 @@ class ModelAdapter(ABC): @classmethod def create( cls, - model_description: Union[v0_4.Model, v0_5.Model], + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], *, devices: Optional[Sequence[str]] = None, weight_format_priority_order: NotEmpty[Sequence[WeightsFormat]] = DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER, diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py index 14ed36d1..0a1caa36 100644 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -14,7 +14,9 @@ class ONNXModelAdapter(ModelAdapter): - def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None): + def __init__( + self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None + ): super().__init__() self._internal_output_axes = [ tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes) diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index 5ecd3cb3..96828016 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -6,10 +6,9 @@ import tensorflow as tf import xarray as xr -from bioimageio.spec.utils import download -from bioimageio.spec.generic.v0_3 import FileSource #FIXME: getre-export from somewhere? +from bioimageio.spec.common import FileSource, RelativeFilePath from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import RelativeFilePath +from bioimageio.spec.utils import download from ._model_adapter import ModelAdapter @@ -54,11 +53,7 @@ def __init__( if devices is not None: warnings.warn(f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}") - weight_file = self.require_unzipped( - weights.source.get_absolute(model_description.root) - if isinstance(weights.source, RelativeFilePath) - else weights.source - ) + weight_file = self.require_unzipped(weights.source) self._network = self._get_network(weight_file) self._internal_output_axes = [ tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes) @@ -149,7 +144,9 @@ def unload(self) -> None: class TensorflowModelAdapter(TensorflowModelAdapterBase): weight_format = "tensorflow_saved_model_bundle" - def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None): + def __init__( + self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None + ): if model_description.weights.tensorflow_saved_model_bundle is None: raise ValueError("missing tensorflow_saved_model_bundle weights") @@ -163,7 +160,9 @@ def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr] class KerasModelAdapter(TensorflowModelAdapterBase): weight_format = "keras_hdf5" - def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None): + def __init__( + self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None + ): if model_description.weights.keras_hdf5 is None: raise ValueError("missing keras_hdf5 weights") diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index 7aef4fee..3d7d046f 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -7,8 +7,8 @@ import xarray as xr from numpy.typing import NDArray +from bioimageio.spec.common import RelativeFilePath from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import RelativeFilePath from bioimageio.spec.utils import download from ._model_adapter import ModelAdapter @@ -22,10 +22,7 @@ def __init__( if model_description.weights.torchscript is None: raise ValueError(f"No torchscript weights found for model {model_description.name}") - src = model_description.weights.torchscript.source - weight_path = download( - src.get_absolute(model_description.root) if isinstance(src, RelativeFilePath) else src - ).path + weight_path = download(model_description.weights.torchscript.source).path if devices is None: self.devices = ["cuda" if torch.cuda.is_available() else "cpu"] else: From 8704abd21aa6d7f07f94c8fcb03418ba1b142b66 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 19 Feb 2024 23:49:34 +0100 Subject: [PATCH 084/244] switch to ruyaml --- dev/env.yaml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/env.yaml b/dev/env.yaml index 51c69875..7ee582f5 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -20,7 +20,7 @@ dependencies: - python-dateutil - python=3.8 - pytorch - - ruamel.yaml + - ruyaml - ruff - torchvision - tqdm diff --git a/setup.py b/setup.py index 3a89a597..f18a765b 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ "bioimageio.spec==0.4.9.*", "imageio>=2.5", "numpy", - "ruamel.yaml", + "ruyaml", "tifffile", "tqdm", "typer", From 8797bb00fc99d4af4a5a4d17e363a539ef972c04 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 09:55:29 +0100 Subject: [PATCH 085/244] improve vscode setup --- .vscode/settings.json | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index a328cfa1..64b5cecc 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,8 @@ { + "window.title": "bioimageio.core", + "python.analysis.extraPaths": [ + "../spec-bioimage-io", + ], "python.testing.unittestArgs": [ "-v", "-s", @@ -8,4 +12,4 @@ ], "python.testing.pytestEnabled": true, "python.testing.unittestEnabled": false, -} \ No newline at end of file +} From 5fddcb777bd6f7cf58417660c1f137103e9ffe56 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 13:14:47 +0100 Subject: [PATCH 086/244] WIP update to new spec --- bioimageio/core/__init__.py | 22 +- bioimageio/core/__main__.py | 119 +- bioimageio/core/image_helper.py | 319 ++---- bioimageio/core/io.py | 24 +- .../model_adapters/_keras_model_adapter.py | 10 +- .../core/model_adapters/_model_adapter.py | 4 +- .../model_adapters/_onnx_model_adapter.py | 6 +- .../model_adapters/_pytorch_model_adapter.py | 12 +- bioimageio/core/op_base.py | 6 +- bioimageio/core/prediction.py | 1017 +++++++++-------- bioimageio/core/prediction_pipeline.py | 266 ++--- bioimageio/core/proc_ops.py | 14 +- bioimageio/core/proc_setup.py | 70 +- bioimageio/core/resource_tests.py | 541 +++------ bioimageio/core/stat_calculators.py | 50 +- bioimageio/core/utils/__init__.py | 7 +- bioimageio/core/utils/node_visitor.py | 73 -- bioimageio/core/utils/testing.py | 3 +- .../core/weight_converter/torch/onnx.py | 11 +- .../core/weight_converter/torch/utils.py | 3 +- tests/build_spec/test_build_spec.py | 16 +- tests/test_resource_tests/test_test_model.py | 10 +- 22 files changed, 1113 insertions(+), 1490 deletions(-) delete mode 100644 bioimageio/core/utils/node_visitor.py diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 8fe9b81a..8268d6bf 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -1,20 +1,20 @@ import json -from bioimageio.core.io import build_description_and_validate as build_description_and_validate -from bioimageio.core.io import load_description as load_description -from bioimageio.core.io import load_description_and_validate as load_description_and_validate -from bioimageio.core.io import read_description as read_description -from bioimageio.core.io import resolve_source as resolve_source -from bioimageio.core.io import write_description as write_description -from bioimageio.core.io import write_package as write_package -from bioimageio.core.io import write_package_as_folder as write_package_as_folder from bioimageio.core.utils import files with files("bioimageio.core").joinpath("VERSION").open("r", encoding="utf-8") as f: __version__: str = json.load(f)["version"] assert isinstance(__version__, str) -# from .prediction import predict_image, predict_images, predict_with_padding, predict_with_tiling -from .prediction_pipeline import create_prediction_pipeline +from bioimageio.spec import build_description as build_description +from bioimageio.spec import dump_description as dump_description +from bioimageio.spec import load_description as load_description +from bioimageio.spec import load_description_and_validate_format_only as load_description_and_validate_format_only +from bioimageio.spec import save_bioimageio_package as save_bioimageio_package +from bioimageio.spec import save_bioimageio_package_as_folder as save_bioimageio_package_as_folder +from bioimageio.spec import save_bioimageio_yaml_only as save_bioimageio_yaml_only +from bioimageio.spec import validate_format as validate_format -# from .resource_tests import check_input_shape, check_output_shape, test_resource +from .prediction_pipeline import create_prediction_pipeline as create_prediction_pipeline +from .resource_tests import load_description_and_test as load_description_and_test +from .resource_tests import test_description as test_description diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index aabd2b05..651f5d20 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -5,16 +5,20 @@ import warnings from glob import glob from pathlib import Path -from pprint import pformat, pprint +from pprint import pformat from typing import List, Optional, get_args import typer +from typing_extensions import Annotated -from bioimageio.core import __version__, commands, load_raw_resource_description, prediction, resource_tests -from bioimageio.core.common import TestSummary -from bioimageio.spec.__main__ import app -from bioimageio.spec.__main__ import help_version as help_version_spec -from bioimageio.spec.model.raw_nodes import WeightsFormat +from bioimageio.core import __version__, prediction, resource_tests +from bioimageio.spec import load_description, save_bioimageio_package +from bioimageio.spec.collection import CollectionDescr +from bioimageio.spec.dataset import DatasetDescr +from bioimageio.spec.model import ModelDescr +from bioimageio.spec.model.v0_5 import WeightsFormat +from bioimageio.spec.notebook import NotebookDescr +from bioimageio.spec.summary import ValidationSummary try: with warnings.catch_warnings(): @@ -30,12 +34,20 @@ except ImportError: keras_converter = None +help_version = f"""bioimageio.core {__version__} +bioimageio.spec {__version__} +implementing: +\tcollection RDF {CollectionDescr.implemented_format_version} +\tdataset RDF {DatasetDescr.implemented_format_version} +\tmodel RDF {ModelDescr.implemented_format_version} +\tnotebook RDF {NotebookDescr.implemented_format_version}""" + -# extend help/version string by core version -help_version_core = f"bioimageio.core {__version__}" -help_version = f"{help_version_spec}\n{help_version_core}" # prevent rewrapping with \b\n: https://click.palletsprojects.com/en/7.x/documentation/#preventing-rewrapping -app.info.help = "\b\n" + help_version +app = typer.Typer( + help="\b\n" + help_version, + context_settings={"help_option_names": ["-h", "--help", "--version"]}, # make --version display help with version +) # https://typer.tiangolo.com/ @app.callback() @@ -43,42 +55,38 @@ def callback(): typer.echo(help_version) +# if we want to use something like "choice" for the weight formats, we need to use an enum, see: +# https://github.com/tiangolo/typer/issues/182 +WeightsFormatEnum = enum.Enum("WeightsFormatEnum", {wf: wf for wf in get_args(WeightsFormat)}) +# Enum with in values does not work with click.Choice: https://github.com/pallets/click/issues/784 +# so a simple Enum with auto int values is not an option: +# WeightsFormatEnum = enum.Enum("WeightsFormatEnum", get_args(WeightsFormat)) + + @app.command() def package( - rdf_source: str = typer.Argument(..., help="RDF source as relative file path or URI"), - path: Path = typer.Argument(Path() / "{src_name}-package.zip", help="Save package as"), - weights_priority_order: Optional[List[str]] = typer.Option( - None, - "--weights-priority-order", - "-wpo", - help="For model packages only. " - "If given only the first weights matching the given weight formats are included. " - "Defaults to include all weights present in source.", - show_default=False, - ), - verbose: bool = typer.Option(False, help="show traceback of exceptions"), + rdf_source: Annotated[str, typer.Argument(help="RDF source as relative file path or URI")], + path: Annotated[Path, typer.Argument(help="Save package as")] = Path() / "bioimageio-package.zip", + weights_priority_order: Annotated[ + Optional[List[WeightsFormatEnum]], + typer.Option( + "--weights-priority-order", + "-wpo", + help="For model packages only. " + "If given only the first weights matching the given weight formats are included. " + "Defaults to include all weights present in source.", + show_default=False, + ), + ] = None, + # verbose: Annotated[bool, typer.Option(help="show traceback of exceptions")] = False, ): # typer bug: typer returns empty tuple instead of None if weights_order_priority is not given - weights_priority_order = weights_priority_order or None - - ret_code = commands.package( - rdf_source=rdf_source, path=path, weights_priority_order=weights_priority_order, verbose=verbose - ) - sys.exit(ret_code) + weights_priority_order = weights_priority_order or None # TODO: check if this is still the case + _ = save_bioimageio_package(rdf_source, output_path=path, weights_priority_order=weights_priority_order) -package.__doc__ = commands.package.__doc__ - -# if we want to use something like "choice" for the weight formats, we need to use an enum, see: -# https://github.com/tiangolo/typer/issues/182 -WeightFormatEnum = enum.Enum("WeightFormatEnum", {wf: wf for wf in get_args(WeightsFormat)}) -# Enum with in values does not work with click.Choice: https://github.com/pallets/click/issues/784 -# so a simple Enum with auto int values is not an option: -# WeightFormatEnum = enum.Enum("WeightFormatEnum", get_args(WeightsFormat)) - - -def _log_test_summaries(summaries: List[TestSummary], msg: str): +def _log_test_summaries(summaries: List[ValidationSummary], msg: str): # todo: improve logging of multiple test summaries ret_code = 0 for summary in summaries: @@ -120,12 +128,12 @@ def show_part(part, show): @app.command() def test_model( - model_rdf: str = typer.Argument( - ..., help="Path or URL to the model resource description file (rdf.yaml) or zipped model." - ), - weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."), - devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."), - decimal: int = typer.Option(4, help="The test precision."), + model_rdf: Annotated[ + str, typer.Argument(help="Path or URL to the model resource description file (rdf.yaml) or zipped model.") + ], + weight_format: Annotated[Optional[WeightsFormatEnum], typer.Option(help="The weight format to use.")] = None, + devices: Annotated[Optional[List[str]], typer.Option(help="Devices for running the model.")] = None, + decimal: Annotated[int, typer.Option(help="The test precision.")] = 4, ): # this is a weird typer bug: default devices are empty tuple although they should be None devices = devices or None @@ -149,14 +157,14 @@ def test_resource( rdf: str = typer.Argument( ..., help="Path or URL to the resource description file (rdf.yaml) or zipped resource package." ), - weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="(for model only) The weight format to use."), + weight_format: Optional[WeightsFormatEnum] = typer.Option(None, help="(for model only) The weight format to use."), devices: Optional[List[str]] = typer.Option(None, help="(for model only) Devices for running the model."), decimal: int = typer.Option(4, help="(for model only) The test precision."), ): # this is a weird typer bug: default devices are empty tuple although they should be None if len(devices) == 0: devices = None - summaries = resource_tests.test_resource( + summaries = resource_tests.test_description( rdf, weight_format=None if weight_format is None else weight_format.value, devices=devices, decimal=decimal ) print(f"\ntesting {rdf}...") @@ -164,7 +172,7 @@ def test_resource( sys.exit(ret_code) -test_resource.__doc__ = resource_tests.test_resource.__doc__ +test_resource.__doc__ = resource_tests.test_description.__doc__ @app.command() @@ -183,7 +191,7 @@ def predict_image( # ), padding: Optional[bool] = typer.Option(None, help="Whether to pad the image to a size suited for the model."), tiling: Optional[bool] = typer.Option(None, help="Whether to run prediction in tiling mode."), - weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."), + weight_format: Optional[WeightsFormatEnum] = typer.Option(None, help="The weight format to use."), devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."), ): if isinstance(padding, str): @@ -221,7 +229,7 @@ def predict_images( # ), padding: Optional[bool] = typer.Option(None, help="Whether to pad the image to a size suited for the model."), tiling: Optional[bool] = typer.Option(None, help="Whether to run prediction in tiling mode."), - weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."), + weight_format: Optional[WeightsFormatEnum] = typer.Option(None, help="The weight format to use."), devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."), ): input_files = glob(input_pattern) @@ -290,12 +298,13 @@ def convert_torch_weights_to_torchscript( @app.command() def convert_keras_weights_to_tensorflow( - model_rdf: Path = typer.Argument( - ..., help="Path to the model resource description file (rdf.yaml) or zipped model." - ), - output_path: Path = typer.Argument(..., help="Where to save the tensorflow weights."), + model_rdf: Annotated[ + Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") + ], + output_path: Annotated[Path, typer.Argument(help="Where to save the tensorflow weights.")], ): - ret_code = keras_converter.convert_weights_to_tensorflow_saved_model_bundle(model_rdf, output_path) + rd = load_description(model_rdf) + ret_code = keras_converter.convert_weights_to_tensorflow_saved_model_bundle(rd, output_path) sys.exit(ret_code) convert_keras_weights_to_tensorflow.__doc__ = ( diff --git a/bioimageio/core/image_helper.py b/bioimageio/core/image_helper.py index 6fc3fd24..a3045588 100644 --- a/bioimageio/core/image_helper.py +++ b/bioimageio/core/image_helper.py @@ -1,193 +1,126 @@ -from __future__ import annotations - -import os -from copy import deepcopy -from typing import Dict, List, Optional, Sequence, Tuple, TypeVar, Union - -import imageio -import numpy as np -from numpy.typing import NDArray -from xarray import DataArray - -from bioimageio.spec._internal.io_utils import load_array -from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensor04 -from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensor04 -from bioimageio.spec.model.v0_5 import InputTensorDescr as InputTensor05 -from bioimageio.spec.model.v0_5 import OutputTensorDescr as OutputTensor05 - -InputTensor = Union[InputTensor04, InputTensor05] -OutputTensor = Union[OutputTensor04, OutputTensor05] - - -# -# helper functions to transform input images / output tensors to the required axes -# - - -DType = TypeVar("DType", bound=np.dtype) - - -def transpose_image(image: NDArray[DType], desired_axes: str, current_axes: Optional[str] = None) -> NDArray[DType]: - """Transform an image to match desired axes. - - Args: - image: the input image - desired_axes: the desired image axes - current_axes: the axes of the input image - """ - # if the image axes are not given deduce them from the required axes and image shape - if current_axes is None: - has_z_axis = "z" in desired_axes - ndim = image.ndim - if ndim == 2: - current_axes = "yx" - elif ndim == 3: - current_axes = "zyx" if has_z_axis else "cyx" - elif ndim == 4: - current_axes = "czyx" - elif ndim == 5: - current_axes = "bczyx" - else: - raise ValueError(f"Invalid number of image dimensions: {ndim}") - tensor = DataArray(image, dims=tuple(current_axes)) - # expand the missing image axes - missing_axes = tuple(set(desired_axes) - set(current_axes)) - tensor = tensor.expand_dims(dim=missing_axes) - # transpose to the correct axis order - tensor = tensor.transpose(*tuple(desired_axes)) - # return numpy array - return tensor.values - - -def _drop_axis_default(axis_name, axis_len): - # spatial axes: drop at middle coordnate - # other axes (channel or batch): drop at 0 coordinate - return axis_len // 2 if axis_name in "zyx" else 0 - - -def transform_output_tensor(tensor: NDArray, tensor_axes: str, output_axes: str, drop_function=_drop_axis_default): - """Transform output tensor into image with desired axes. - - Args: - tensor: the output tensor - tensor_axes: bioimageio model spec - output_axes: the desired output axes - drop_function: function that determines how to drop unwanted axes - """ - if len(tensor_axes) != tensor.ndim: - raise ValueError(f"Number of axes {len(tensor_axes)} and dimension of tensor {tensor.ndim} don't match") - shape = {ax_name: sh for ax_name, sh in zip(tensor_axes, tensor.shape)} - output = DataArray(tensor, dims=tuple(tensor_axes)) - # drop unwanted axes - drop_axis_names = tuple(set(tensor_axes) - set(output_axes)) - drop_axes = {ax_name: drop_function(ax_name, shape[ax_name]) for ax_name in drop_axis_names} - output = output[drop_axes] - # transpose to the desired axis order - output = output.transpose(*tuple(output_axes)) - # return numpy array - return output.values - - -def to_channel_last(image): - chan_id = image.dims.index("c") - if chan_id != image.ndim - 1: - target_axes = tuple(ax for ax in image.dims if ax != "c") + ("c",) - image = image.transpose(*target_axes) - return image - - -# -# helper functions for loading and saving images -# - - -def load_image(in_path, axes: Sequence[str]) -> DataArray: - ext = os.path.splitext(in_path)[1] - if ext == ".npy": - im = load_array(in_path) - else: - is_volume = "z" in axes - im = imageio.volread(in_path) if is_volume else imageio.imread(in_path) - im = transpose_image(im, axes) - return DataArray(im, dims=axes) - - -def load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]]) -> List[DataArray]: - return [load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)] - - -def save_image(out_path, image): - ext = os.path.splitext(out_path)[1] - if ext == ".npy": - np.save(out_path, image) - else: - is_volume = "z" in image.dims - - # squeeze batch or channel axes if they are singletons - squeeze = {ax: 0 if (ax in "bc" and sh == 1) else slice(None) for ax, sh in zip(image.dims, image.shape)} - image = image[squeeze] - - if "b" in image.dims: - raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {ext}-file") - if "c" in image.dims: # image formats need channel last - image = to_channel_last(image) - - save_function = imageio.volsave if is_volume else imageio.imsave - # most image formats only support channel dimensions of 1, 3 or 4; - # if not we need to save the channels separately - ndim = 3 if is_volume else 2 - save_as_single_image = image.ndim == ndim or (image.shape[-1] in (3, 4)) - - if save_as_single_image: - save_function(out_path, image) - else: - out_prefix, ext = os.path.splitext(out_path) - for c in range(image.shape[-1]): - chan_out_path = f"{out_prefix}-c{c}{ext}" - save_function(chan_out_path, image[..., c]) - - -# -# helper function for padding -# - - -def pad(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, Dict[str, slice]]: - assert image.ndim == len(axes), f"{image.ndim}, {len(axes)}" - - padding_ = deepcopy(padding) - mode = padding_.pop("mode", "dynamic") - assert mode in ("dynamic", "fixed") - - is_volume = "z" in axes - if is_volume: - assert len(padding_) == 3 - else: - assert len(padding_) == 2 - - if isinstance(pad_right, bool): - pad_right = len(axes) * [pad_right] - - pad_width = [] - crop = {} - for ax, dlen, pr in zip(axes, image.shape, pad_right): - if ax in "zyx": - pad_to = padding_[ax] - - if mode == "dynamic": - r = dlen % pad_to - pwidth = 0 if r == 0 else (pad_to - r) - else: - if pad_to < dlen: - msg = f"Padding for axis {ax} failed; pad shape {pad_to} is smaller than the image shape {dlen}." - raise RuntimeError(msg) - pwidth = pad_to - dlen - - pad_width.append([0, pwidth] if pr else [pwidth, 0]) - crop[ax] = slice(0, dlen) if pr else slice(pwidth, None) - else: - pad_width.append([0, 0]) - crop[ax] = slice(None) - - image = np.pad(image, pad_width, mode="symmetric") - return image, crop +# # TODO: update + +# from __future__ import annotations + +# import os +# from copy import deepcopy +# from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union + +# import imageio +# import numpy as np +# from numpy.typing import ArrayLike, NDArray +# from xarray import DataArray + +# from bioimageio.spec._internal.io_utils import load_array +# from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensor04 +# from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensor04 +# from bioimageio.spec.model.v0_5 import InputTensorDescr as InputTensor05 +# from bioimageio.spec.model.v0_5 import OutputTensorDescr as OutputTensor05 + +# InputTensor = Union[InputTensor04, InputTensor05] +# OutputTensor = Union[OutputTensor04, OutputTensor05] + + +# # +# # helper functions to transform input images / output tensors to the required axes +# # + + +# def transpose_image(image: NDArray[Any], desired_axes: str, current_axes: Optional[str] = None) -> NDArray[Any]: +# """Transform an image to match desired axes. + +# Args: +# image: the input image +# desired_axes: the desired image axes +# current_axes: the axes of the input image +# """ +# # if the image axes are not given deduce them from the required axes and image shape +# if current_axes is None: +# has_z_axis = "z" in desired_axes +# ndim = image.ndim +# if ndim == 2: +# current_axes = "yx" +# elif ndim == 3: +# current_axes = "zyx" if has_z_axis else "cyx" +# elif ndim == 4: +# current_axes = "czyx" +# elif ndim == 5: +# current_axes = "bczyx" +# else: +# raise ValueError(f"Invalid number of image dimensions: {ndim}") + +# tensor = DataArray(image, dims=tuple(current_axes)) +# # expand the missing image axes +# missing_axes = tuple(set(desired_axes) - set(current_axes)) +# tensor = tensor.expand_dims(dim=missing_axes) +# # transpose to the correct axis order +# tensor = tensor.transpose(*tuple(desired_axes)) +# # return numpy array +# ret: NDArray[Any] = tensor.values +# return ret + + +# # +# # helper functions for loading and saving images +# # + + +# def load_image(in_path, axes: Sequence[str]) -> DataArray: +# ext = os.path.splitext(in_path)[1] +# if ext == ".npy": +# im = load_array(in_path) +# else: +# is_volume = "z" in axes +# im = imageio.volread(in_path) if is_volume else imageio.imread(in_path) +# im = transpose_image(im, axes) +# return DataArray(im, dims=axes) + + +# def load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]]) -> List[DataArray]: +# return [load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)] + + +# # +# # helper function for padding +# # + + +# def pad(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, Dict[str, slice]]: +# assert image.ndim == len(axes), f"{image.ndim}, {len(axes)}" + +# padding_ = deepcopy(padding) +# mode = padding_.pop("mode", "dynamic") +# assert mode in ("dynamic", "fixed") + +# is_volume = "z" in axes +# if is_volume: +# assert len(padding_) == 3 +# else: +# assert len(padding_) == 2 + +# if isinstance(pad_right, bool): +# pad_right = len(axes) * [pad_right] + +# pad_width = [] +# crop = {} +# for ax, dlen, pr in zip(axes, image.shape, pad_right): +# if ax in "zyx": +# pad_to = padding_[ax] + +# if mode == "dynamic": +# r = dlen % pad_to +# pwidth = 0 if r == 0 else (pad_to - r) +# else: +# if pad_to < dlen: +# msg = f"Padding for axis {ax} failed; pad shape {pad_to} is smaller than the image shape {dlen}." +# raise RuntimeError(msg) +# pwidth = pad_to - dlen + +# pad_width.append([0, pwidth] if pr else [pwidth, 0]) +# crop[ax] = slice(0, dlen) if pr else slice(pwidth, None) +# else: +# pad_width.append([0, 0]) +# crop[ax] = slice(None) + +# image = np.pad(image, pad_width, mode="symmetric") +# return image, crop diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index 092e6d6a..e811961c 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -1,39 +1,39 @@ from __future__ import annotations -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union from bioimageio.spec import build_description from bioimageio.spec import load_description as load_description from bioimageio.spec._description import ResourceDescr from bioimageio.spec._internal.constants import DISCOVER -from bioimageio.spec._internal.validation_context import ValidationContext from bioimageio.spec._internal.io_utils import open_bioimageio_yaml -from bioimageio.spec.common import BioimageioYamlContent, FileSource, InvalidDescription +from bioimageio.spec._internal.validation_context import ValidationContext +from bioimageio.spec.common import BioimageioYamlContent, FileSource, InvalidDescr from bioimageio.spec.summary import ValidationSummary -def load_description_and_validate( +def load_description_and_test( source: FileSource, /, *, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, -) -> Union[ResourceDescr, InvalidDescription]: +) -> Union[ResourceDescr, InvalidDescr]: opened = open_bioimageio_yaml(source) - return build_description_and_validate( + return build_description_and_test( opened.content, context=ValidationContext(root=opened.original_root, file_name=opened.original_file_name), format_version=format_version, ) -def build_description_and_validate( +def build_description_and_test( data: BioimageioYamlContent, /, *, context: Optional[ValidationContext] = None, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, -) -> Union[ResourceDescr, InvalidDescription]: +) -> Union[ResourceDescr, InvalidDescr]: """load and validate a BioImage.IO description from the content of a resource description file (RDF)""" rd = build_description(data, context=context, format_version=format_version) # todo: add dynamic validation @@ -46,10 +46,10 @@ def validate( *, context: Optional[ValidationContext] = None, format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, -) -> List[ValidationSummary]: +) -> ValidationSummary: if isinstance(source, dict): - rd = build_description_and_validate(source, context=context, format_version=format_version) + rd = build_description_and_test(source, context=context, format_version=format_version) else: - rd = load_description_and_validate(source, format_version=format_version) + rd = load_description_and_test(source, format_version=format_version) - return rd.validation_summaries + return rd.validation_summary diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py index 9c6f842a..177a6a28 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -16,16 +16,15 @@ tf_version = None import xarray as xr -from bioimageio.core.io import download +from bioimageio.spec._internal.io_utils import download from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import RelativeFilePath from ._model_adapter import ModelAdapter class KerasModelAdapter(ModelAdapter): def __init__( - self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None + self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None ) -> None: super().__init__() if model_description.weights.keras_hdf5 is None: @@ -45,10 +44,7 @@ def __init__( if devices is not None: warnings.warn(f"Device management is not implemented for keras yet, ignoring the devices {devices}") - src = model_description.weights.keras_hdf5.source - weight_path = download( - src.get_absolute(model_description.root) if isinstance(src, RelativeFilePath) else src - ).path + weight_path = download(model_description.weights.keras_hdf5.source).path self._network = keras.models.load_model(weight_path) self._output_axes = [tuple(out.axes) for out in model_description.outputs] diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index 52a33771..09a346a4 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -45,7 +45,7 @@ def create( model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], *, devices: Optional[Sequence[str]] = None, - weight_format_priority_order: NotEmpty[Sequence[WeightsFormat]] = DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER, + weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None, ): """ Creates model adapter based on the passed spec @@ -54,7 +54,7 @@ def create( """ weights = model_description.weights errors: List[Exception] = [] - for wf in weight_format_priority_order: + for wf in weight_format_priority_order or DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None: try: from ._pytorch_model_adapter import PytorchModelAdapter diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py index 0a1caa36..0d947dc9 100644 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -35,9 +35,9 @@ def __init__( def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: assert len(input_tensors) == len(self._input_names) input_arrays = [ipt.data for ipt in input_tensors] - result: Union[ # pyright: ignore[reportUnknownVariableType] - Sequence[NDArray[Any]], NDArray[Any] - ] = self._session.run(None, dict(zip(self._input_names, input_arrays))) + result: Union[Sequence[NDArray[Any]], NDArray[Any]] = ( # pyright: ignore[reportUnknownVariableType] + self._session.run(None, dict(zip(self._input_names, input_arrays))) + ) if not isinstance(result, (list, tuple)): result = [] diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 3a9a3109..ba9210a9 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -51,12 +51,16 @@ def unload(self) -> None: torch.cuda.empty_cache() # release reserved memory @staticmethod - def get_network(weight_spec: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr]) -> torch.nn.Module: + def get_network( + weight_spec: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr] + ) -> torch.nn.Module: arch = import_callable( weight_spec.architecture, - sha256=weight_spec.architecture_sha256 - if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) - else weight_spec.sha256, + sha256=( + weight_spec.architecture_sha256 + if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) + else weight_spec.sha256 + ), ) model_kwargs = ( weight_spec.kwargs diff --git a/bioimageio/core/op_base.py b/bioimageio/core/op_base.py index 2e872a19..8392f8e5 100644 --- a/bioimageio/core/op_base.py +++ b/bioimageio/core/op_base.py @@ -9,10 +9,8 @@ @dataclass class Operator(ABC): @abstractmethod - def __call__(self, sample: Sample) -> None: - ... + def __call__(self, sample: Sample) -> None: ... @property @abstractmethod - def required_measures(self) -> Collection[Measure]: - ... + def required_measures(self) -> Collection[Measure]: ... diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index 68de04fe..49fb9d53 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -1,506 +1,511 @@ -import collections -import os -from fractions import Fraction -from itertools import product -from pathlib import Path -from typing import Any, Dict, Hashable, Iterator, List, NamedTuple, Optional, OrderedDict, Sequence, Tuple, Union - -import numpy as np -import xarray as xr -from bioimageio.spec import ResourceDescription -from bioimageio.spec.model.v0_5 import AxisType -from numpy.typing import NDArray -from pydantic import HttpUrl -from tqdm import tqdm - -from bioimageio.core import image_helper, load_resource_description -from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline -from bioimageio.core.resource_io.nodes import ImplicitOutputShape, Model, ResourceDescription - -Axis = Hashable - - -class TileDef(NamedTuple): - outer: Dict[Axis, slice] - inner: Dict[Axis, slice] - local: Dict[Axis, slice] - - -def get_tiling( - shape: Sequence[int], - tile_shape: Dict[Axis, int], - halo: Dict[Axis, int], - input_axes: Sequence[Axis], - axis_types: Dict[Axis, AxisType], - scaling: Dict[Axis, float], -) -> Iterator[TileDef]: - # outer_tile is the "input" tile, inner_tile is the "output" tile with the halo removed - # tile_shape is the shape of the outer_tile - assert len(shape) == len(input_axes) - scaling_fractions = {ax: Fraction(sc).limit_denominator() for ax, sc in scaling.items()} - - shape_ = [sh for sh, ax in zip(shape, input_axes) if axis_types[ax] == "space"] - spatial_axes = [ax for ax in input_axes if axis_types[ax] == "space"] - inner_tile_shape_ = [tile_shape[ax] - 2 * halo[ax] for ax in spatial_axes] - scaling_ = [scaling_fractions[ax] for ax in spatial_axes] - assert all([sh % fr.denominator == 0 for sh, fr in zip(shape_, scaling_)]) - assert all([ish % fr.denominator == 0 for ish, fr in zip(inner_tile_shape_, scaling_)]) - halo_ = [halo[ax] for ax in spatial_axes] - assert len(shape_) == len(inner_tile_shape_) == len(spatial_axes) == len(halo_) - - ranges = [range(sh // tsh if sh % tsh == 0 else sh // tsh + 1) for sh, tsh in zip(shape_, inner_tile_shape_)] - start_points = product(*ranges) - - for start_point in start_points: - positions = [sp * tsh for sp, tsh in zip(start_point, inner_tile_shape_)] - - inner_tile = { - ax: slice(int(pos * fr), int(min(pos + tsh, sh) * fr)) - for ax, pos, tsh, sh, fr in zip(spatial_axes, positions, inner_tile_shape_, shape_, scaling_) - } - # inner_tile["b"] = slice(None) - # inner_tile["c"] = slice(None) - - outer_tile = { - ax: slice(max(pos - ha, 0), min(pos + tsh + ha, sh)) - for ax, pos, tsh, sh, ha in zip(spatial_axes, positions, inner_tile_shape_, shape_, halo_) - } - # outer_tile["b"] = slice(None) - # outer_tile["c"] = slice(None) - - local_tile = { - ax: slice( - inner_tile[ax].start - int(outer_tile[ax].start * scaling[ax]), - -(int(outer_tile[ax].stop * scaling[ax]) - inner_tile[ax].stop) - if int(outer_tile[ax].stop * scaling[ax]) != inner_tile[ax].stop - else None, - ) - for ax in spatial_axes - } - # local_tile["b"] = slice(None) - # local_tile["c"] = slice(None) - - yield TileDef(outer_tile, inner_tile, local_tile) - - -def _predict_with_tiling_impl( - prediction_pipeline: PredictionPipeline, - inputs: Sequence[xr.DataArray], - outputs: Sequence[xr.DataArray], - tile_shapes: Sequence[Dict[str, int]], - halos: Sequence[Dict[str, int]], - scales: Sequence[Dict[str, Tuple[int, int]]], - verbose: bool = False, -): - if len(inputs) > 1: - raise NotImplementedError("Tiling with multiple inputs not implemented yet") - - if len(outputs) > 1: - raise NotImplementedError("Tiling with multiple outputs not implemented yet") - - assert len(tile_shapes) == len(outputs) - assert len(halos) == len(outputs) - - input_ = inputs[0] - output = outputs[0] - tile_shape = tile_shapes[0] - halo = halos[0] - scaling = scales[0] - - tiles = get_tiling(shape=input_.shape, tile_shape=tile_shape, halo=halo, input_axes=input_.dims, scaling=scaling) - - def load_tile(tile): - inp = input_[tile] - # whether to pad on the right or left of the dim for the spatial dims - # + placeholders for batch and axis dimension, where we don't pad - pad_right = [tile[ax].start == 0 if ax in "xyz" else None for ax in input_.dims] - return inp, pad_right - - if verbose: - shape = {ax: sh for ax, sh in zip(prediction_pipeline.input_specs[0].axes, input_.shape)} - n_tiles = int(np.prod([np.ceil(float(shape[ax]) / (tsh - 2 * halo[ax])) for ax, tsh in tile_shape.items()])) - tiles = tqdm(tiles, total=n_tiles, desc="prediction with tiling") - - # we need to use padded prediction for the individual tiles in case the - # border tiles don't match the requested tile shape - padding = {ax: tile_shape[ax] for ax in input_axes if ax in "xyz"} - padding["mode"] = "fixed" - for outer_tile, inner_tile, local_tile in tiles: - inp, pad_right = load_tile(outer_tile) - out = predict_with_padding(prediction_pipeline, inp, padding, pad_right) - assert len(out) == 1 - out = out[0] - output[inner_tile] = out[local_tile] - - -def predict( - prediction_pipeline: PredictionPipeline, - inputs: Union[ - xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray], NDArray[Any], List[NDArray[Any]], Tuple[NDArray[Any]] - ], -) -> List[xr.DataArray]: - """Run prediction for a single set of input(s) with a bioimage.io model - - Args: - prediction_pipeline: the prediction pipeline for the input model. - inputs: the input(s) for this model represented as xarray data or numpy nd array. - """ - if not isinstance(inputs, (tuple, list)): - inputs = [inputs] - - assert len(inputs) == len(prediction_pipeline.input_specs) - tagged_data = [ - ipt if isinstance(ipt, xr.DataArray) else xr.DataArray(ipt, dims=ipt_spec.axes) - for ipt, ipt_spec in zip(inputs, prediction_pipeline.input_specs) - ] - return prediction_pipeline.forward(*tagged_data) - - -def _parse_padding(padding, input_specs): - if padding is None: # no padding - return padding - if len(input_specs) > 1: - raise NotImplementedError("Padding for multiple inputs not yet implemented") - - input_spec = input_specs[0] - pad_keys = tuple(input_spec.axes) + ("mode",) - - def check_padding(padding): - assert all(k in pad_keys for k in padding.keys()) - - if isinstance(padding, dict): # pre-defined padding - check_padding(padding) - elif isinstance(padding, bool): # determine padding from spec - if padding: - axes = input_spec.axes - shape = input_spec.shape - if isinstance(shape, list): # fixed padding - padding = {ax: sh for ax, sh in zip(axes, shape) if ax in "xyz"} - padding["mode"] = "fixed" - else: # dynamic padding - step = shape.step - padding = {ax: st for ax, st in zip(axes, step) if ax in "xyz"} - padding["mode"] = "dynamic" - check_padding(padding) - else: # no padding - padding = None - else: - raise ValueError(f"Invalid argument for padding: {padding}") - return padding - - -def predict_with_padding( - prediction_pipeline: PredictionPipeline, - inputs: Union[xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray]], - padding: Union[bool, Dict[str, int]] = True, - pad_right: bool = True, -) -> List[xr.DataArray]: - """Run prediction with padding for a single set of input(s) with a bioimage.io model. - - Args: - prediction_pipeline: the prediction pipeline for the input model. - inputs: the input(s) for this model represented as xarray data. - padding: the padding settings. Pass True to derive from the model spec. - pad_right: whether to applying padding to the right or left of the input. - """ - if not padding: - raise ValueError - assert len(inputs) == len(prediction_pipeline.input_specs) - - output_spec = prediction_pipeline.output_specs[0] - if hasattr(output_spec.shape, "scale"): - scale = dict(zip(output_spec.axes, output_spec.shape.scale)) - offset = dict(zip(output_spec.axes, output_spec.shape.offset)) - network_resizes = any(sc != 1 for ax, sc in scale.items() if ax in "xyz") or any( - off != 0 for ax, off in offset.items() if ax in "xyz" - ) - else: - network_resizes = False - - padding = _parse_padding(padding, prediction_pipeline.input_specs) - if not isinstance(inputs, (tuple, list)): - inputs = [inputs] - if not isinstance(padding, (tuple, list)): - padding = [padding] - assert len(padding) == len(prediction_pipeline.input_specs) - inputs, crops = zip( - *[ - image_helper.pad(inp, spec.axes, p, pad_right=pad_right) - for inp, spec, p in zip(inputs, prediction_pipeline.input_specs, padding) - ] - ) - result = predict(prediction_pipeline, inputs) - if network_resizes: - crops = [ - { - ax: slice( - crp.start if crp.start is None else int(crp.start * scale[ax] + 2 * offset[ax]), - crp.stop if crp.stop is None else int(crp.stop * scale[ax] + 2 * offset[ax]), - ) - if ax in "xyz" - else crp - for ax, crp in crop.items() - } - for crop in crops - ] - return [res[crop] for res, crop in zip(result, crops)] - - -# simple heuristic to determine suitable shape from min and step -def _determine_shape(min_shape, step, axes): - is3d = "z" in axes - min_len = 64 if is3d else 256 - shape = [] - for ax, min_ax, step_ax in zip(axes, min_shape, step): - if ax in "zyx" and step_ax > 0: - len_ax = min_ax - while len_ax < min_len: - len_ax += step_ax - shape.append(len_ax) - else: - shape.append(min_ax) - return shape - - -def _parse_tiling(tiling, input_specs, output_specs): - if tiling is None: # no tiling - return tiling - if len(input_specs) > 1: - raise NotImplementedError("Tiling for multiple inputs not yet implemented") - if len(output_specs) > 1: - raise NotImplementedError("Tiling for multiple outputs not yet implemented") - - input_spec = input_specs[0] - output_spec = output_specs[0] - if isinstance(output_spec.shape, list): - assert isinstance(input_spec.shape, list) and input_spec.shape == output_spec.shape, ( - "When predicting with tiling, output_shape and input_shape must either be specified " - "explictly and must be identical, or output_shape must be" - "implicitly defined by input_shape, otherwise relationship between " - "input and output shapes per tile cannot be known." - ) - axes = input_spec.axes - - def check_tiling(tiling): - assert "halo" in tiling and "tile" in tiling - spatial_axes = [ax for ax in axes if ax in "xyz"] - halo = tiling["halo"] - tile = tiling["tile"] - scale = tiling.get("scale", dict()) - assert all(halo.get(ax, 0) >= 0 for ax in spatial_axes) - assert all(tile.get(ax, 0) > 0 for ax in spatial_axes) - assert all(scale.get(ax, 1) > 0 for ax in spatial_axes) - - if isinstance(tiling, dict) or (isinstance(tiling, bool) and tiling): - # NOTE we assume here that shape in input and output are the same - # for different input and output shapes, we should actually tile in the - # output space and then request the corresponding input tiles - # so we would need to apply the output scale and offset to the - # input shape to compute the tile size and halo here - shape = input_spec.shape - if not isinstance(shape, list): - shape = _determine_shape(shape.min, shape.step, axes) - assert isinstance(shape, list) - assert len(shape) == len(axes) - - scale = None - output_shape = output_spec.shape - scale = [1.0] * len(output_spec.shape) if isinstance(output_shape, list) else output_shape.scale - assert len(scale) == len(axes) - - halo = output_spec.halo - if not isinstance(halo, list): - halo = [0] * len(axes) - assert len(halo) == len(axes) - - default_tiling = { - "halo": {ax: ha for ax, ha in zip(axes, halo) if ax in "xyz"}, - "tile": {ax: sh for ax, sh in zip(axes, shape) if ax in "xyz"}, - "scale": {ax: sc for ax, sc in zip(axes, scale) if ax in "xyz"}, - } - - # override metadata defaults with provided dict - if isinstance(tiling, dict): - for key in ["halo", "tile", "scale"]: - default_tiling[key].update(tiling.get(key, dict())) - tiling = default_tiling - check_tiling(tiling) - - elif isinstance(tiling, bool) and not tiling: - raise NotImplementedError("Should be unreachable") - - else: - raise ValueError(f"Invalid argument for tiling: {tiling}") - - return tiling - - -def predict_with_tiling( - prediction_pipeline: PredictionPipeline, - inputs: Union[xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray]], - tiling: Union[bool, Dict[str, Dict[str, int]]] = True, - verbose: bool = False, -) -> List[xr.DataArray]: - """Run prediction with tiling for a single set of input(s) with a bioimage.io model. - - Args: - prediction_pipeline: the prediction pipeline for the input model. - inputs: the input(s) for this model represented as xarray data. - tiling: the tiling settings. Pass True to derive from the model spec. - verbose: whether to print the prediction progress. - """ - if not tiling: - raise ValueError("cannot call predict_with_tiling with tiling=False") - assert len(inputs) == len(prediction_pipeline.input_specs) - - tiling = _parse_tiling(tiling, prediction_pipeline.input_specs, prediction_pipeline.output_specs) - if not isinstance(inputs, (list, tuple)): - inputs = [inputs] - named_inputs: OrderedDict[str, xr.DataArray] = collections.OrderedDict( - **{ - ipt_spec.name: xr.DataArray(ipt_data, dims=tuple(ipt_spec.axes)) - for ipt_data, ipt_spec in zip(inputs, prediction_pipeline.input_specs) - } - ) - - outputs = [] - for output_spec in prediction_pipeline.output_specs: - if isinstance(output_spec.shape, ImplicitOutputShape): - scale = dict(zip(output_spec.axes, output_spec.shape.scale)) - offset = dict(zip(output_spec.axes, output_spec.shape.offset)) - - ref_input = named_inputs[output_spec.shape.reference_tensor] - ref_input_shape = dict(zip(ref_input.dims, ref_input.shape)) - output_shape = tuple(int(scale[ax] * ref_input_shape[ax] + 2 * offset[ax]) for ax in output_spec.axes) - else: - if len(inputs) > 1: - raise NotImplementedError - input_spec = prediction_pipeline.input_specs[0] - if input_spec.axes != output_spec.axes: - raise NotImplementedError("Tiling with a different output shape is not yet supported") - out_axes = output_spec.axes - fixed_shape = tuple(output_spec.shape) - if not all(fixed_shape[out_axes.index(ax)] == tile_shape for ax, tile_shape in tiling["tile"].items()): - raise NotImplementedError("Tiling with a different output shape is not yet supported") - - output_shape = list(inputs[0].shape) - chan_id = out_axes.index("c") - if fixed_shape[chan_id] != output_shape[chan_id]: - output_shape[chan_id] = fixed_shape[chan_id] - output_shape = tuple(output_shape) - - outputs.append(xr.DataArray(np.zeros(output_shape, dtype=output_spec.data_type), dims=tuple(output_spec.axes))) - - _predict_with_tiling_impl( - prediction_pipeline, - list(named_inputs.values()), - outputs, - tile_shapes=[tiling["tile"]], # todo: update tiling for multiple inputs/outputs - halos=[tiling["halo"]], - scales=[tiling["scale"]], - verbose=verbose, - ) - - return outputs - - -def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling): - if padding and tiling: - raise ValueError("Only one of padding or tiling is supported") - - input_data = image_helper.load_tensors(inputs, prediction_pipeline.input_specs) - if padding is not None: - result = predict_with_padding(prediction_pipeline, input_data, padding) - elif tiling is not None: - result = predict_with_tiling(prediction_pipeline, input_data, tiling) - else: - result = predict(prediction_pipeline, input_data) - - assert isinstance(result, list) - assert len(result) == len(outputs) - for res, out in zip(result, outputs): - image_helper.save_image(out, res) - - -def predict_image( - model_rdf: DescriptionSource, - inputs: Union[Tuple[Path, ...], List[Path], Path], - outputs: Union[Tuple[Path, ...], List[Path], Path], - padding: Optional[Union[bool, Dict[str, int]]] = None, - tiling: Optional[Union[bool, Dict[str, Dict[str, int]]]] = None, - weight_format: Optional[str] = None, - devices: Optional[List[str]] = None, - verbose: bool = False, -): - """Run prediction for a single set of input image(s) with a bioimage.io model. - - Args: - model_rdf: the bioimageio model. - inputs: the filepaths for the input images. - outputs: the filepaths for saving the input images. - padding: the padding settings for prediction. By default no padding is used. - tiling: the tiling settings for prediction. By default no tiling is used. - weight_format: the weight format to use for predictions. - devices: the devices to use for prediction. - verbose: run prediction in verbose mode. - """ - if not isinstance(inputs, (tuple, list)): - inputs = [inputs] - - if not isinstance(outputs, (tuple, list)): - outputs = [outputs] - - model = load_resource_description(model_rdf) - assert isinstance(model, Model) - if len(model.inputs) != len(inputs): - raise ValueError - if len(model.outputs) != len(outputs): - raise ValueError - - with create_prediction_pipeline( - bioimageio_model=model, weight_format=weight_format, devices=devices - ) as prediction_pipeline: - _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling) - - -def predict_images( - model_rdf: DescriptionSource, - inputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]], - outputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]], - padding: Optional[Union[bool, Dict[str, int]]] = None, - tiling: Optional[Union[bool, Dict[str, Dict[str, int]]]] = None, - weight_format: Optional[str] = None, - devices: Optional[List[str]] = None, - verbose: bool = False, -): - """Predict multiple input images with a bioimage.io model. - - Args: - model_rdf: the bioimageio model. - inputs: the filepaths for the input images. - outputs: the filepaths for saving the input images. - padding: the padding settings for prediction. By default no padding is used. - tiling: the tiling settings for prediction. By default no tiling is used. - weight_format: the weight format to use for predictions. - devices: the devices to use for prediction. - verbose: run prediction in verbose mode. - """ - - model = load_resource_description(model_rdf) - assert isinstance(model, Model) - - with create_prediction_pipeline( - bioimageio_model=model, weight_format=weight_format, devices=devices - ) as prediction_pipeline: - prog = zip(inputs, outputs) - if verbose: - prog = tqdm(prog, total=len(inputs)) - - for inp, outp in prog: - if not isinstance(inp, (tuple, list)): - inp = [inp] - - if not isinstance(outp, (tuple, list)): - outp = [outp] - - _predict_sample(prediction_pipeline, inp, outp, padding, tiling) +# TODO: update +# import collections +# import os +# from fractions import Fraction +# from itertools import product +# from pathlib import Path +# from typing import Any, Dict, Hashable, Iterator, List, NamedTuple, Optional, OrderedDict, Sequence, Tuple, Union + +# import numpy as np +# import xarray as xr +# from bioimageio.spec import ResourceDescr +# from bioimageio.spec.model.v0_5 import AxisType +# from numpy.typing import NDArray +# from pydantic import HttpUrl +# from tqdm import tqdm + +# from bioimageio.core import image_helper, load_resource_description +# from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline +# from bioimageio.core.resource_io.nodes import ImplicitOutputShape, Model, ResourceDescr + +# Axis = Hashable + + +# class TileDef(NamedTuple): +# outer: Dict[Axis, slice] +# inner: Dict[Axis, slice] +# local: Dict[Axis, slice] + + +# def get_tiling( +# shape: Sequence[int], +# tile_shape: Dict[Axis, int], +# halo: Dict[Axis, int], +# input_axes: Sequence[Axis], +# axis_types: Dict[Axis, AxisType], +# scaling: Dict[Axis, float], +# ) -> Iterator[TileDef]: +# # outer_tile is the "input" tile, inner_tile is the "output" tile with the halo removed +# # tile_shape is the shape of the outer_tile +# assert len(shape) == len(input_axes) +# scaling_fractions = {ax: Fraction(sc).limit_denominator() for ax, sc in scaling.items()} + +# shape_ = [sh for sh, ax in zip(shape, input_axes) if axis_types[ax] == "space"] +# spatial_axes = [ax for ax in input_axes if axis_types[ax] == "space"] +# inner_tile_shape_ = [tile_shape[ax] - 2 * halo[ax] for ax in spatial_axes] +# scaling_ = [scaling_fractions[ax] for ax in spatial_axes] +# assert all([sh % fr.denominator == 0 for sh, fr in zip(shape_, scaling_)]) +# assert all([ish % fr.denominator == 0 for ish, fr in zip(inner_tile_shape_, scaling_)]) +# halo_ = [halo[ax] for ax in spatial_axes] +# assert len(shape_) == len(inner_tile_shape_) == len(spatial_axes) == len(halo_) + +# ranges = [range(sh // tsh if sh % tsh == 0 else sh // tsh + 1) for sh, tsh in zip(shape_, inner_tile_shape_)] +# start_points = product(*ranges) + +# for start_point in start_points: +# positions = [sp * tsh for sp, tsh in zip(start_point, inner_tile_shape_)] + +# inner_tile = { +# ax: slice(int(pos * fr), int(min(pos + tsh, sh) * fr)) +# for ax, pos, tsh, sh, fr in zip(spatial_axes, positions, inner_tile_shape_, shape_, scaling_) +# } +# # inner_tile["b"] = slice(None) +# # inner_tile["c"] = slice(None) + +# outer_tile = { +# ax: slice(max(pos - ha, 0), min(pos + tsh + ha, sh)) +# for ax, pos, tsh, sh, ha in zip(spatial_axes, positions, inner_tile_shape_, shape_, halo_) +# } +# # outer_tile["b"] = slice(None) +# # outer_tile["c"] = slice(None) + +# local_tile = { +# ax: slice( +# inner_tile[ax].start - int(outer_tile[ax].start * scaling[ax]), +# ( +# -(int(outer_tile[ax].stop * scaling[ax]) - inner_tile[ax].stop) +# if int(outer_tile[ax].stop * scaling[ax]) != inner_tile[ax].stop +# else None +# ), +# ) +# for ax in spatial_axes +# } +# # local_tile["b"] = slice(None) +# # local_tile["c"] = slice(None) + +# yield TileDef(outer_tile, inner_tile, local_tile) + + +# def _predict_with_tiling_impl( +# prediction_pipeline: PredictionPipeline, +# inputs: Sequence[xr.DataArray], +# outputs: Sequence[xr.DataArray], +# tile_shapes: Sequence[Dict[str, int]], +# halos: Sequence[Dict[str, int]], +# scales: Sequence[Dict[str, Tuple[int, int]]], +# verbose: bool = False, +# ): +# if len(inputs) > 1: +# raise NotImplementedError("Tiling with multiple inputs not implemented yet") + +# if len(outputs) > 1: +# raise NotImplementedError("Tiling with multiple outputs not implemented yet") + +# assert len(tile_shapes) == len(outputs) +# assert len(halos) == len(outputs) + +# input_ = inputs[0] +# output = outputs[0] +# tile_shape = tile_shapes[0] +# halo = halos[0] +# scaling = scales[0] + +# tiles = get_tiling(shape=input_.shape, tile_shape=tile_shape, halo=halo, input_axes=input_.dims, scaling=scaling) + +# def load_tile(tile): +# inp = input_[tile] +# # whether to pad on the right or left of the dim for the spatial dims +# # + placeholders for batch and axis dimension, where we don't pad +# pad_right = [tile[ax].start == 0 if ax in "xyz" else None for ax in input_.dims] +# return inp, pad_right + +# if verbose: +# shape = {ax: sh for ax, sh in zip(prediction_pipeline.input_specs[0].axes, input_.shape)} +# n_tiles = int(np.prod([np.ceil(float(shape[ax]) / (tsh - 2 * halo[ax])) for ax, tsh in tile_shape.items()])) +# tiles = tqdm(tiles, total=n_tiles, desc="prediction with tiling") + +# # we need to use padded prediction for the individual tiles in case the +# # border tiles don't match the requested tile shape +# padding = {ax: tile_shape[ax] for ax in input_axes if ax in "xyz"} +# padding["mode"] = "fixed" +# for outer_tile, inner_tile, local_tile in tiles: +# inp, pad_right = load_tile(outer_tile) +# out = predict_with_padding(prediction_pipeline, inp, padding, pad_right) +# assert len(out) == 1 +# out = out[0] +# output[inner_tile] = out[local_tile] + + +# def predict( +# prediction_pipeline: PredictionPipeline, +# inputs: Union[ +# xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray], NDArray[Any], List[NDArray[Any]], Tuple[NDArray[Any]] +# ], +# ) -> List[xr.DataArray]: +# """Run prediction for a single set of input(s) with a bioimage.io model + +# Args: +# prediction_pipeline: the prediction pipeline for the input model. +# inputs: the input(s) for this model represented as xarray data or numpy nd array. +# """ +# if not isinstance(inputs, (tuple, list)): +# inputs = [inputs] + +# assert len(inputs) == len(prediction_pipeline.input_specs) +# tagged_data = [ +# ipt if isinstance(ipt, xr.DataArray) else xr.DataArray(ipt, dims=ipt_spec.axes) +# for ipt, ipt_spec in zip(inputs, prediction_pipeline.input_specs) +# ] +# return prediction_pipeline.forward(*tagged_data) + + +# def _parse_padding(padding, input_specs): +# if padding is None: # no padding +# return padding +# if len(input_specs) > 1: +# raise NotImplementedError("Padding for multiple inputs not yet implemented") + +# input_spec = input_specs[0] +# pad_keys = tuple(input_spec.axes) + ("mode",) + +# def check_padding(padding): +# assert all(k in pad_keys for k in padding.keys()) + +# if isinstance(padding, dict): # pre-defined padding +# check_padding(padding) +# elif isinstance(padding, bool): # determine padding from spec +# if padding: +# axes = input_spec.axes +# shape = input_spec.shape +# if isinstance(shape, list): # fixed padding +# padding = {ax: sh for ax, sh in zip(axes, shape) if ax in "xyz"} +# padding["mode"] = "fixed" +# else: # dynamic padding +# step = shape.step +# padding = {ax: st for ax, st in zip(axes, step) if ax in "xyz"} +# padding["mode"] = "dynamic" +# check_padding(padding) +# else: # no padding +# padding = None +# else: +# raise ValueError(f"Invalid argument for padding: {padding}") +# return padding + + +# def predict_with_padding( +# prediction_pipeline: PredictionPipeline, +# inputs: Union[xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray]], +# padding: Union[bool, Dict[str, int]] = True, +# pad_right: bool = True, +# ) -> List[xr.DataArray]: +# """Run prediction with padding for a single set of input(s) with a bioimage.io model. + +# Args: +# prediction_pipeline: the prediction pipeline for the input model. +# inputs: the input(s) for this model represented as xarray data. +# padding: the padding settings. Pass True to derive from the model spec. +# pad_right: whether to applying padding to the right or left of the input. +# """ +# if not padding: +# raise ValueError +# assert len(inputs) == len(prediction_pipeline.input_specs) + +# output_spec = prediction_pipeline.output_specs[0] +# if hasattr(output_spec.shape, "scale"): +# scale = dict(zip(output_spec.axes, output_spec.shape.scale)) +# offset = dict(zip(output_spec.axes, output_spec.shape.offset)) +# network_resizes = any(sc != 1 for ax, sc in scale.items() if ax in "xyz") or any( +# off != 0 for ax, off in offset.items() if ax in "xyz" +# ) +# else: +# network_resizes = False + +# padding = _parse_padding(padding, prediction_pipeline.input_specs) +# if not isinstance(inputs, (tuple, list)): +# inputs = [inputs] +# if not isinstance(padding, (tuple, list)): +# padding = [padding] +# assert len(padding) == len(prediction_pipeline.input_specs) +# inputs, crops = zip( +# *[ +# image_helper.pad(inp, spec.axes, p, pad_right=pad_right) +# for inp, spec, p in zip(inputs, prediction_pipeline.input_specs, padding) +# ] +# ) +# result = predict(prediction_pipeline, inputs) +# if network_resizes: +# crops = [ +# { +# ax: ( +# slice( +# crp.start if crp.start is None else int(crp.start * scale[ax] + 2 * offset[ax]), +# crp.stop if crp.stop is None else int(crp.stop * scale[ax] + 2 * offset[ax]), +# ) +# if ax in "xyz" +# else crp +# ) +# for ax, crp in crop.items() +# } +# for crop in crops +# ] +# return [res[crop] for res, crop in zip(result, crops)] + + +# # simple heuristic to determine suitable shape from min and step +# def _determine_shape(min_shape, step, axes): +# is3d = "z" in axes +# min_len = 64 if is3d else 256 +# shape = [] +# for ax, min_ax, step_ax in zip(axes, min_shape, step): +# if ax in "zyx" and step_ax > 0: +# len_ax = min_ax +# while len_ax < min_len: +# len_ax += step_ax +# shape.append(len_ax) +# else: +# shape.append(min_ax) +# return shape + + +# def _parse_tiling(tiling, input_specs, output_specs): +# if tiling is None: # no tiling +# return tiling +# if len(input_specs) > 1: +# raise NotImplementedError("Tiling for multiple inputs not yet implemented") +# if len(output_specs) > 1: +# raise NotImplementedError("Tiling for multiple outputs not yet implemented") + +# input_spec = input_specs[0] +# output_spec = output_specs[0] +# if isinstance(output_spec.shape, list): +# assert isinstance(input_spec.shape, list) and input_spec.shape == output_spec.shape, ( +# "When predicting with tiling, output_shape and input_shape must either be specified " +# "explictly and must be identical, or output_shape must be" +# "implicitly defined by input_shape, otherwise relationship between " +# "input and output shapes per tile cannot be known." +# ) +# axes = input_spec.axes + +# def check_tiling(tiling): +# assert "halo" in tiling and "tile" in tiling +# spatial_axes = [ax for ax in axes if ax in "xyz"] +# halo = tiling["halo"] +# tile = tiling["tile"] +# scale = tiling.get("scale", dict()) +# assert all(halo.get(ax, 0) >= 0 for ax in spatial_axes) +# assert all(tile.get(ax, 0) > 0 for ax in spatial_axes) +# assert all(scale.get(ax, 1) > 0 for ax in spatial_axes) + +# if isinstance(tiling, dict) or (isinstance(tiling, bool) and tiling): +# # NOTE we assume here that shape in input and output are the same +# # for different input and output shapes, we should actually tile in the +# # output space and then request the corresponding input tiles +# # so we would need to apply the output scale and offset to the +# # input shape to compute the tile size and halo here +# shape = input_spec.shape +# if not isinstance(shape, list): +# shape = _determine_shape(shape.min, shape.step, axes) +# assert isinstance(shape, list) +# assert len(shape) == len(axes) + +# scale = None +# output_shape = output_spec.shape +# scale = [1.0] * len(output_spec.shape) if isinstance(output_shape, list) else output_shape.scale +# assert len(scale) == len(axes) + +# halo = output_spec.halo +# if not isinstance(halo, list): +# halo = [0] * len(axes) +# assert len(halo) == len(axes) + +# default_tiling = { +# "halo": {ax: ha for ax, ha in zip(axes, halo) if ax in "xyz"}, +# "tile": {ax: sh for ax, sh in zip(axes, shape) if ax in "xyz"}, +# "scale": {ax: sc for ax, sc in zip(axes, scale) if ax in "xyz"}, +# } + +# # override metadata defaults with provided dict +# if isinstance(tiling, dict): +# for key in ["halo", "tile", "scale"]: +# default_tiling[key].update(tiling.get(key, dict())) +# tiling = default_tiling +# check_tiling(tiling) + +# elif isinstance(tiling, bool) and not tiling: +# raise NotImplementedError("Should be unreachable") + +# else: +# raise ValueError(f"Invalid argument for tiling: {tiling}") + +# return tiling + + +# def predict_with_tiling( +# prediction_pipeline: PredictionPipeline, +# inputs: Union[xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray]], +# tiling: Union[bool, Dict[str, Dict[str, int]]] = True, +# verbose: bool = False, +# ) -> List[xr.DataArray]: +# """Run prediction with tiling for a single set of input(s) with a bioimage.io model. + +# Args: +# prediction_pipeline: the prediction pipeline for the input model. +# inputs: the input(s) for this model represented as xarray data. +# tiling: the tiling settings. Pass True to derive from the model spec. +# verbose: whether to print the prediction progress. +# """ +# if not tiling: +# raise ValueError("cannot call predict_with_tiling with tiling=False") +# assert len(inputs) == len(prediction_pipeline.input_specs) + +# tiling = _parse_tiling(tiling, prediction_pipeline.input_specs, prediction_pipeline.output_specs) +# if not isinstance(inputs, (list, tuple)): +# inputs = [inputs] +# named_inputs: OrderedDict[str, xr.DataArray] = collections.OrderedDict( +# **{ +# ipt_spec.name: xr.DataArray(ipt_data, dims=tuple(ipt_spec.axes)) +# for ipt_data, ipt_spec in zip(inputs, prediction_pipeline.input_specs) +# } +# ) + +# outputs = [] +# for output_spec in prediction_pipeline.output_specs: +# if isinstance(output_spec.shape, ImplicitOutputShape): +# scale = dict(zip(output_spec.axes, output_spec.shape.scale)) +# offset = dict(zip(output_spec.axes, output_spec.shape.offset)) + +# ref_input = named_inputs[output_spec.shape.reference_tensor] +# ref_input_shape = dict(zip(ref_input.dims, ref_input.shape)) +# output_shape = tuple(int(scale[ax] * ref_input_shape[ax] + 2 * offset[ax]) for ax in output_spec.axes) +# else: +# if len(inputs) > 1: +# raise NotImplementedError +# input_spec = prediction_pipeline.input_specs[0] +# if input_spec.axes != output_spec.axes: +# raise NotImplementedError("Tiling with a different output shape is not yet supported") +# out_axes = output_spec.axes +# fixed_shape = tuple(output_spec.shape) +# if not all(fixed_shape[out_axes.index(ax)] == tile_shape for ax, tile_shape in tiling["tile"].items()): +# raise NotImplementedError("Tiling with a different output shape is not yet supported") + +# output_shape = list(inputs[0].shape) +# chan_id = out_axes.index("c") +# if fixed_shape[chan_id] != output_shape[chan_id]: +# output_shape[chan_id] = fixed_shape[chan_id] +# output_shape = tuple(output_shape) + +# outputs.append(xr.DataArray(np.zeros(output_shape, dtype=output_spec.data_type), dims=tuple(output_spec.axes))) + +# _predict_with_tiling_impl( +# prediction_pipeline, +# list(named_inputs.values()), +# outputs, +# tile_shapes=[tiling["tile"]], # todo: update tiling for multiple inputs/outputs +# halos=[tiling["halo"]], +# scales=[tiling["scale"]], +# verbose=verbose, +# ) + +# return outputs + + +# def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling): +# if padding and tiling: +# raise ValueError("Only one of padding or tiling is supported") + +# input_data = image_helper.load_tensors(inputs, prediction_pipeline.input_specs) +# if padding is not None: +# result = predict_with_padding(prediction_pipeline, input_data, padding) +# elif tiling is not None: +# result = predict_with_tiling(prediction_pipeline, input_data, tiling) +# else: +# result = predict(prediction_pipeline, input_data) + +# assert isinstance(result, list) +# assert len(result) == len(outputs) +# for res, out in zip(result, outputs): +# image_helper.save_image(out, res) + + +# def predict_image( +# model_rdf: DescriptionSource, +# inputs: Union[Tuple[Path, ...], List[Path], Path], +# outputs: Union[Tuple[Path, ...], List[Path], Path], +# padding: Optional[Union[bool, Dict[str, int]]] = None, +# tiling: Optional[Union[bool, Dict[str, Dict[str, int]]]] = None, +# weight_format: Optional[str] = None, +# devices: Optional[List[str]] = None, +# verbose: bool = False, +# ): +# """Run prediction for a single set of input image(s) with a bioimage.io model. + +# Args: +# model_rdf: the bioimageio model. +# inputs: the filepaths for the input images. +# outputs: the filepaths for saving the input images. +# padding: the padding settings for prediction. By default no padding is used. +# tiling: the tiling settings for prediction. By default no tiling is used. +# weight_format: the weight format to use for predictions. +# devices: the devices to use for prediction. +# verbose: run prediction in verbose mode. +# """ +# if not isinstance(inputs, (tuple, list)): +# inputs = [inputs] + +# if not isinstance(outputs, (tuple, list)): +# outputs = [outputs] + +# model = load_resource_description(model_rdf) +# assert isinstance(model, Model) +# if len(model.inputs) != len(inputs): +# raise ValueError +# if len(model.outputs) != len(outputs): +# raise ValueError + +# with create_prediction_pipeline( +# bioimageio_model=model, weight_format=weight_format, devices=devices +# ) as prediction_pipeline: +# _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling) + + +# def predict_images( +# model_rdf: DescriptionSource, +# inputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]], +# outputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]], +# padding: Optional[Union[bool, Dict[str, int]]] = None, +# tiling: Optional[Union[bool, Dict[str, Dict[str, int]]]] = None, +# weight_format: Optional[str] = None, +# devices: Optional[List[str]] = None, +# verbose: bool = False, +# ): +# """Predict multiple input images with a bioimage.io model. + +# Args: +# model_rdf: the bioimageio model. +# inputs: the filepaths for the input images. +# outputs: the filepaths for saving the input images. +# padding: the padding settings for prediction. By default no padding is used. +# tiling: the tiling settings for prediction. By default no tiling is used. +# weight_format: the weight format to use for predictions. +# devices: the devices to use for prediction. +# verbose: run prediction in verbose mode. +# """ + +# model = load_resource_description(model_rdf) +# assert isinstance(model, Model) + +# with create_prediction_pipeline( +# bioimageio_model=model, weight_format=weight_format, devices=devices +# ) as prediction_pipeline: +# prog = zip(inputs, outputs) +# if verbose: +# prog = tqdm(prog, total=len(inputs)) + +# for inp, outp in prog: +# if not isinstance(inp, (tuple, list)): +# inp = [inp] + +# if not isinstance(outp, (tuple, list)): +# outp = [outp] + +# _predict_sample(prediction_pipeline, inp, outp, padding, tiling) diff --git a/bioimageio/core/prediction_pipeline.py b/bioimageio/core/prediction_pipeline.py index fbe10e74..d01e0274 100644 --- a/bioimageio/core/prediction_pipeline.py +++ b/bioimageio/core/prediction_pipeline.py @@ -1,192 +1,128 @@ -import abc import warnings -from dataclasses import dataclass -from typing import Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence import xarray as xr +from bioimageio.core.common import Sample, TensorId from bioimageio.core.model_adapters import ModelAdapter, create_model_adapter from bioimageio.core.model_adapters import get_weight_formats as get_weight_formats -from bioimageio.core.utils.node_visitor import resolve_raw_node -from bioimageio.spec.model import AnyModel, raw_nodes +from bioimageio.core.proc_ops import Processing +from bioimageio.core.proc_setup import setup_pre_and_postprocessing +from bioimageio.core.stat_calculators import StatsCalculator +from bioimageio.spec.model import AnyModelDescr, v0_4 +from bioimageio.spec.model.v0_5 import WeightsFormat -from ._combined_processing import CombinedProcessing -from ._utils import ComputedMeasures, Sample, TensorName -from .stat_state import StatsState - -@dataclass -class NamedImplicitOutputShape: - reference_input: TensorName - scale: List[Tuple[str, float]] - offset: List[Tuple[str, int]] - - def __len__(self): - return len(self.scale) - - -class PredictionPipeline(abc.ABC): +class PredictionPipeline: """ Represents model computation including preprocessing and postprocessing Note: Ideally use the PredictionPipeline as a context manager """ - @abc.abstractmethod - def __enter__(self): - ... - - @abc.abstractmethod - def __exit__(self, exc_type, exc_val, exc_tb): - ... - - @abc.abstractmethod - def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - """ - Compute predictions - """ - ... - - @property - @abc.abstractmethod - def name(self) -> str: - """ - Name of the pipeline - """ - ... - - @property - @abc.abstractmethod - def input_specs(self) -> List[nodes.InputTensor]: - """ - specs of inputs - """ - ... - - @property - @abc.abstractmethod - def output_specs(self) -> List[nodes.OutputTensor]: - """ - specs of outputs - """ - ... - - @abc.abstractmethod - def load(self) -> None: - """ - optional step: load model onto devices before calling forward if not using it as context manager - """ - ... - - @abc.abstractmethod - def unload(self) -> None: - """ - free any device memory in use - """ - ... - - -class _PredictionPipelineImpl(PredictionPipeline): def __init__( self, *, name: str, - bioimageio_model: AnyModel, - preprocessing: CombinedProcessing, - postprocessing: CombinedProcessing, - ipt_stats: StatsState, - out_stats: StatsState, + bioimageio_model: AnyModelDescr, + preprocessing: List[Processing], + postprocessing: List[Processing], + ipt_stats: StatsCalculator, + out_stats: StatsCalculator, model: ModelAdapter, ) -> None: + super().__init__() if bioimageio_model.run_mode: warnings.warn(f"Not yet implemented inference for run mode '{bioimageio_model.run_mode.name}'") - self._name = name - self._input_specs = bioimageio_model.inputs - self._output_specs = bioimageio_model.outputs - + self.name = name self._preprocessing = preprocessing self._postprocessing = postprocessing self._ipt_stats = ipt_stats self._out_stats = out_stats - self._model: ModelAdapter = model + if isinstance(bioimageio_model, v0_4.ModelDescr): + self._input_ids = [TensorId(d.name) for d in bioimageio_model.inputs] + self._output_ids = [TensorId(d.name) for d in bioimageio_model.outputs] + else: + self._input_ids = [d.id for d in bioimageio_model.inputs] + self._output_ids = [d.id for d in bioimageio_model.outputs] + + self._adapter: ModelAdapter = model - def __call__(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - return self.forward(*input_tensors) + def __call__(self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray) -> List[xr.DataArray]: + return self.forward(*input_tensors, **named_input_tensors) def __enter__(self): self.load() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore self.unload() return False - @property - def name(self): - return self._name - - @property - def input_specs(self): - return self._input_specs - - @property - def output_specs(self): - return self._output_specs - - def predict(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: + def predict(self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray) -> List[xr.DataArray]: """Predict input_tensor with the model without applying pre/postprocessing.""" - return self._model.forward(*input_tensors) - - def apply_preprocessing(self, sample: Sample, computed_measures: ComputedMeasures) -> None: - """apply preprocessing in-place, also updates given computed_measures""" - self._ipt_stats.update_with_sample(sample) - for mode, stats in self._ipt_stats.compute_measures().items(): - if mode not in computed_measures: - computed_measures[mode] = {} - computed_measures[mode].update(stats) - - self._preprocessing.apply(sample, computed_measures) - - def apply_postprocessing(self, sample: Sample, computed_measures: ComputedMeasures) -> None: - """apply postprocessing in-place, also updates given computed_measures""" - self._out_stats.update_with_sample(sample) - for mode, stats in self._out_stats.compute_measures().items(): - if mode not in computed_measures: - computed_measures[mode] = {} - computed_measures[mode].update(stats) - - self._postprocessing.apply(sample, computed_measures) - - def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - """Apply preprocessing, run prediction and apply postprocessing. - Note: The preprocessing might change input_tensors in-pace. - """ - input_sample = dict(zip([ipt.name for ipt in self.input_specs], input_tensors)) - computed_measures = {} - self.apply_preprocessing(input_sample, computed_measures) - - prediction_tensors = self.predict(*list(input_sample.values())) - prediction = dict(zip([out.name for out in self.output_specs], prediction_tensors)) - self.apply_postprocessing(prediction, computed_measures) - - return [prediction[tn] for tn in [out.name for out in self.output_specs]] + named_tensors = [named_input_tensors[k] for k in self._input_ids[len(input_tensors) :]] + return self._adapter.forward(*input_tensors, *named_tensors) + + def apply_preprocessing(self, sample: Sample) -> None: + """apply preprocessing in-place, also updates sample stats""" + sample.stat.update(self._ipt_stats.update_and_get_all(sample)) + for op in self._preprocessing: + op(sample) + + def apply_postprocessing(self, sample: Sample) -> None: + """apply postprocessing in-place, also updates samples stats""" + sample.stat.update(self._out_stats.update_and_get_all(sample)) + for op in self._postprocessing: + op(sample) + + def forward_sample(self, input_sample: Sample): + """Apply preprocessing, run prediction and apply postprocessing.""" + self.apply_preprocessing(input_sample) + + prediction_tensors = self.predict(**input_sample.data) + prediction = Sample(data=dict(zip(self._output_ids, prediction_tensors)), stat=input_sample.stat) + self.apply_postprocessing(prediction) + return prediction + + def forward_named( + self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray + ) -> Dict[TensorId, xr.DataArray]: + """Apply preprocessing, run prediction and apply postprocessing.""" + input_sample = Sample( + data={ + **dict(zip(self._input_ids, input_tensors)), + **{TensorId(k): v for k, v in named_input_tensors.items()}, + } + ) + return self.forward_sample(input_sample).data + + def forward(self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray) -> List[xr.DataArray]: + """Apply preprocessing, run prediction and apply postprocessing.""" + named_outputs = self.forward_named(*input_tensors, **named_input_tensors) + return [named_outputs[x] for x in self._output_ids] def load(self): - self._model.load() + """ + optional step: load model onto devices before calling forward if not using it as context manager + """ + self._adapter.load() def unload(self): - self._model.unload() + """ + free any device memory in use + """ + self._adapter.unload() def create_prediction_pipeline( - bioimageio_model: AnyModel, + bioimageio_model: AnyModelDescr, *, devices: Optional[Sequence[str]] = None, - weight_format: Optional[str] = None, + weight_format: Optional[WeightsFormat] = None, dataset_for_initial_statistics: Iterable[Sequence[xr.DataArray]] = tuple(), - update_dataset_stats_after_n_samples: Optional[int] = None, - update_dataset_stats_for_n_samples: int = float("inf"), model_adapter: Optional[ModelAdapter] = None, + **deprecated_kwargs: Any, ) -> PredictionPipeline: """ Creates prediction pipeline which includes: @@ -196,39 +132,29 @@ def create_prediction_pipeline( * computation of output statistics * postprocessing """ - model_adapter: ModelAdapter = model_adapter or create_model_adapter( - bioimageio_model=bioimageio_model, devices=devices, weight_format=weight_format + if deprecated_kwargs: + warnings.warn(f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}") + + model_adapter = model_adapter or create_model_adapter( + model_description=bioimageio_model, + devices=devices, + weight_format_priority_order=weight_format and (weight_format,), ) - if isinstance(bioimageio_model, nodes.Model): - ipts = bioimageio_model.inputs - outs = bioimageio_model.outputs + if isinstance(bioimageio_model, v0_4.ModelDescr): + input_ids = [TensorId(ipt.name) for ipt in bioimageio_model.inputs] else: - assert isinstance(bioimageio_model, raw_nodes.Model) - ipts = [resolve_raw_node(s, nodes) for s in bioimageio_model.inputs] - outs = [resolve_raw_node(s, nodes) for s in bioimageio_model.outputs] - - preprocessing = CombinedProcessing.from_tensor_specs(ipts) - - def sample_dataset(): - for tensors in dataset_for_initial_statistics: - yield dict(zip([ipt.name for ipt in bioimageio_model.inputs], tensors)) + input_ids = [ipt.id for ipt in bioimageio_model.inputs] - ipt_stats = StatsState( - preprocessing.required_measures, - dataset=sample_dataset(), - update_dataset_stats_after_n_samples=update_dataset_stats_after_n_samples, - update_dataset_stats_for_n_samples=update_dataset_stats_for_n_samples, - ) - postprocessing = CombinedProcessing.from_tensor_specs(outs) - out_stats = StatsState( - postprocessing.required_measures, - dataset=tuple(), - update_dataset_stats_after_n_samples=0, - update_dataset_stats_for_n_samples=ipt_stats.sample_count + update_dataset_stats_for_n_samples, - ) + preprocessing, postprocessing, pre_req_meas, post_req_meas = setup_pre_and_postprocessing(bioimageio_model) + ipt_stats = StatsCalculator(pre_req_meas) + out_stats = StatsCalculator(post_req_meas) + for tensors in dataset_for_initial_statistics: + sample = Sample(data=dict(zip(input_ids, tensors))) + ipt_stats.update(sample) + out_stats.update(sample) - return _PredictionPipelineImpl( + return PredictionPipeline( name=bioimageio_model.name, bioimageio_model=bioimageio_model, model=model_adapter, diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 3f949d94..535f085c 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -32,7 +32,6 @@ DatasetPercentile, DatasetStd, Measure, - MeasureValue, SampleMean, SamplePercentile, SampleStd, @@ -79,8 +78,7 @@ def __call__(self, sample: Sample) -> None: sample.data[self.output] = self._apply(sample.data[self.input], sample.stat) @abstractmethod - def _apply(self, input: Tensor, stat: Stat) -> Tensor: - ... + def _apply(self, input: Tensor, stat: Stat) -> Tensor: ... @dataclass @@ -342,16 +340,20 @@ def get_descr(self): @dataclass -class Sigmoid: +class Sigmoid(_SimpleOperator): """1 / (1 + e^(-input)).""" - def _apply(self, input: xr.DataArray) -> xr.DataArray: + def _apply(self, input: Tensor, stat: Stat) -> Tensor: return 1.0 / (1.0 + np.exp(-input)) # type: ignore + @property + def required_measures(self) -> Collection[Measure]: + return {} + @classmethod def from_proc_descr(cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], tensor_id: TensorId) -> Self: assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) - return cls() + return cls(input=tensor_id, output=tensor_id) def get_descr(self): return v0_5.SigmoidDescr() diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index a138dd7e..e77673de 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -1,44 +1,42 @@ from typing import ( - Any, - Iterator, List, NamedTuple, Sequence, Set, - Tuple, - Type, Union, cast, ) from typing_extensions import assert_never -from bioimageio.core.common import ProcessingKwargs, Sample -from bioimageio.core.proc_ops import ( - Processing, - get_proc_class, -) -from bioimageio.core.stat_calculators import compute_measures +from bioimageio.core.proc_ops import Processing, get_proc_class from bioimageio.core.stat_measures import Measure -from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.model.v0_5 import TensorId -ModelDescr = Union[v0_4.ModelDescr, v0_5.ModelDescr] TensorDescr = Union[v0_4.InputTensorDescr, v0_4.OutputTensorDescr, v0_5.InputTensorDescr, v0_5.OutputTensorDescr] class _SetupProcessing(NamedTuple): preprocessing: List[Processing] postprocessing: List[Processing] + preprocessing_req_measures: Set[Measure] + postprocessing_req_measures: Set[Measure] -def setup_pre_and_postprocessing(model: ModelDescr, dataset: Iterator[Sample]) -> _SetupProcessing: - Prepared = List[Tuple[Type[Processing], ProcessingKwargs, TensorId]] +def setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing: + pre_measures: Set[Measure] = set() + post_measures: Set[Measure] = set() - required_measures: Set[Measure] = set() + if isinstance(model, v0_4.ModelDescr): + input_ids = {TensorId(d.name) for d in model.inputs} + output_ids = {TensorId(d.name) for d in model.outputs} + else: + input_ids = {d.id for d in model.inputs} + output_ids = {d.id for d in model.outputs} def prepare_procs(tensor_descrs: Sequence[TensorDescr]): - prepared: Prepared = [] + procs: List[Processing] = [] for t_descr in tensor_descrs: if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): proc_descrs = t_descr.preprocessing @@ -53,23 +51,23 @@ def prepare_procs(tensor_descrs: Sequence[TensorDescr]): for proc_d in proc_descrs: proc_class = get_proc_class(proc_d) tensor_id = cast(TensorId, t_descr.name) if isinstance(t_descr, v0_4.TensorDescrBase) else t_descr.id - req = proc_class.from_proc_descr(proc_d, tensor_id) - required_measures.update(req.get_set()) - prepared.append((proc_class, proc_d.kwargs, tensor_id)) - - return prepared - - prepared_preps = prepare_procs(model.inputs) - prepared_posts = prepare_procs(model.outputs) - - computed_measures = compute_measures(required_measures, dataset=dataset) - - def init_procs(prepared: Prepared): - initialized: List[ProcessingImpl] = [] - for impl_class, kwargs, tensor_id in prepared: - impl = impl_class(tensor_id=tensor_id, kwargs=kwargs, computed_measures=computed_measures) - initialized.append(impl) - - return initialized - - return _SetupProcessing(init_procs(prepared_preps), init_procs(prepared_posts)) + req = proc_class.from_proc_descr(proc_d, tensor_id) # pyright: ignore[reportArgumentType] + for m in req.required_measures: + if m.tensor_id in input_ids: + pre_measures.add(m) + elif m.tensor_id in output_ids: + post_measures.add(m) + else: + raise ValueError("When to raise ") + procs.append(req) + return procs + + pre_procs = prepare_procs(model.inputs) + post_procs = prepare_procs(model.outputs) + + return _SetupProcessing( + preprocessing=pre_procs, + postprocessing=post_procs, + preprocessing_req_measures=pre_measures, + postprocessing_req_measures=post_measures, + ) diff --git a/bioimageio/core/resource_tests.py b/bioimageio/core/resource_tests.py index 51d3dd53..dd899089 100644 --- a/bioimageio/core/resource_tests.py +++ b/bioimageio/core/resource_tests.py @@ -1,404 +1,225 @@ -import os -import re import traceback import warnings -from copy import deepcopy -from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Union -import numpy import numpy as np -import xarray as xr from bioimageio.core import __version__ as bioimageio_core_version -from bioimageio.core import load_raw_resource_description, load_resource_description -from bioimageio.core._internal.validation_visitors import Sha256NodeChecker, SourceNodeChecker -from bioimageio.core.common import TestSummary from bioimageio.core.prediction import predict from bioimageio.core.prediction_pipeline import create_prediction_pipeline -from bioimageio.core.resource_io.nodes import ( - URI, - ImplicitOutputShape, - Model, - ParametrizedInputShape, - ResourceDescription, -) -from bioimageio.spec import __version__ as bioimageio_spec_version +from bioimageio.spec import InvalidDescr, ResourceDescr, build_description, dump_description, load_description +from bioimageio.spec._internal.base_nodes import ResourceDescrBase from bioimageio.spec._internal.io_utils import load_array -from bioimageio.spec.model.raw_nodes import WeightsFormat -from bioimageio.spec.shared import resolve_source -from bioimageio.spec.shared.common import ValidationWarning -from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription +from bioimageio.spec._internal.validation_context import validation_context_var +from bioimageio.spec.common import BioimageioYamlContent, FileSource +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import WeightsFormat +from bioimageio.spec.summary import ErrorEntry, InstalledPackage, ValidationDetail, ValidationSummary def test_model( - model_rdf: Union[URI, Path, str], + source: FileSource, weight_format: Optional[WeightsFormat] = None, devices: Optional[List[str]] = None, decimal: int = 4, -) -> List[TestSummary]: +) -> ValidationSummary: """Test whether the test output(s) of a model can be reproduced.""" - return test_resource( - model_rdf, weight_format=weight_format, devices=devices, decimal=decimal, expected_type="model" + return test_description( + source, weight_format=weight_format, devices=devices, decimal=decimal, expected_type="model" ) -def check_input_shape(shape: Tuple[int, ...], shape_spec) -> bool: - if isinstance(shape_spec, list): - if shape != tuple(shape_spec): - return False - elif isinstance(shape_spec, ParametrizedInputShape): - assert len(shape_spec.min) == len(shape_spec.step) - if len(shape) != len(shape_spec.min): - return False - min_shape = shape_spec.min - step = shape_spec.step - # check if the shape is valid for all dimension by seeing if it can be reached with an integer number of steps - # NOTE we allow that the valid shape is reached using a different number of steps for each axis here - # this is usually valid because dimensions are independent in neural networks - is_valid = [(sh - minsh) % st == 0 if st > 0 else sh == minsh for sh, st, minsh in zip(shape, step, min_shape)] - return all(is_valid) - else: - raise TypeError(f"Encountered unexpected shape description of type {type(shape_spec)}") - - return True - - -def check_output_shape(shape: Tuple[int, ...], shape_spec, input_shapes) -> bool: - if isinstance(shape_spec, list): - return shape == tuple(shape_spec) - elif isinstance(shape_spec, ImplicitOutputShape): - ref_tensor = shape_spec.reference_tensor - if ref_tensor not in input_shapes: - raise ValidationError(f"The reference tensor name {ref_tensor} is not in {input_shapes}") - ipt_shape = numpy.array(input_shapes[ref_tensor]) - scale = numpy.array([0.0 if sc is None else sc for sc in shape_spec.scale]) - offset = numpy.array(shape_spec.offset) - exp_shape = numpy.round_(ipt_shape * scale) + 2 * offset - - return shape == tuple(exp_shape) - else: - raise TypeError(f"Encountered unexpected shape description of type {type(shape_spec)}") - - -def _test_resource_urls(rd: RawResourceDescription) -> TestSummary: - assert isinstance(rd, RawResourceDescription), type(rd) - with warnings.catch_warnings(record=True) as all_warnings: - try: - SourceNodeChecker(root_path=rd.root_path).visit(rd) - except FileNotFoundError as e: - error = str(e) - tb = traceback.format_tb(e.__traceback__) - else: - error = None - tb = None - - return dict( - name="All URLs and paths available", - status="passed" if error is None else "failed", - error=error, - traceback=tb, - bioimageio_spec_version=bioimageio_spec_version, - bioimageio_core_version=bioimageio_core_version, - nested_errors=None, - source_name=rd.id or rd.id or rd.name if hasattr(rd, "id") else rd.name, - warnings={"SourceNodeChecker": [str(w.message) for w in all_warnings]} if all_warnings else {}, - ) - - -def _test_resource_integrity(rd: RawResourceDescription) -> TestSummary: - assert isinstance(rd, RawResourceDescription) - with warnings.catch_warnings(record=True) as all_warnings: - if isinstance(rd, ResourceDescription): - warnings.warn("Testing source file integrity of an already loaded resource!") - - try: - Sha256NodeChecker(root_path=rd.root_path).visit(rd) - except FileNotFoundError as e: - error = str(e) - tb = traceback.format_tb(e.__traceback__) - else: - error = None - tb = None - - return dict( - name="Integrity of source files", - status="passed" if error is None else "failed", - error=error, - traceback=tb, - bioimageio_spec_version=bioimageio_spec_version, - bioimageio_core_version=bioimageio_core_version, - nested_errors=None, - source_name=rd.id or rd.id or rd.name if hasattr(rd, "id") else rd.name, - warnings={"Sha256NodeChecker": [str(w.message) for w in all_warnings]} if all_warnings else {}, - ) - - -def _test_model_documentation(rd: ResourceDescription) -> TestSummary: - assert isinstance(rd, Model) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - doc_path: Path = resolve_source(rd.documentation, root_path=rd.root_path) - doc = doc_path.read_text() - wrn = "" - if not re.match("#.*[vV]alidation", doc): - wrn = "No '# Validation' (sub)section found." - - return dict( - name="Test documentation completeness.", - status="passed", - error=None, - traceback=None, - bioimageio_spec_version=bioimageio_spec_version, - bioimageio_core_version=bioimageio_core_version, - source_name=rd.id or rd.name if hasattr(rd, "id") else rd.name, - warnings={"documentation": wrn} if wrn else {}, - ) - - -def _test_model_inference(model: Model, weight_format: str, devices: Optional[List[str]], decimal: int) -> TestSummary: +def _test_model_inference( + model: Union[v0_4.ModelDescr, v0_5.ModelDescr], + weight_format: Optional[WeightsFormat], + devices: Optional[List[str]], + decimal: int, +) -> None: error: Optional[str] = None - tb: Optional = None - with warnings.catch_warnings(record=True) as all_warnings: - try: - inputs = [load_array(str(in_path)) for in_path in model.test_inputs] - expected = [load_array(str(out_path)) for out_path in model.test_outputs] - - assert len(inputs) == len(model.inputs) # should be checked by validation - input_shapes = {} - for idx, (ipt, ipt_spec) in enumerate(zip(inputs, model.inputs)): - if not check_input_shape(tuple(ipt.shape), ipt_spec.shape): - raise ValidationError( - f"Shape {tuple(ipt.shape)} of test input {idx} '{ipt_spec.name}' does not match " - f"input shape description: {ipt_spec.shape}." - ) - input_shapes[ipt_spec.name] = ipt.shape - - assert len(expected) == len(model.outputs) # should be checked by validation - for idx, (out, out_spec) in enumerate(zip(expected, model.outputs)): - if not check_output_shape(tuple(out.shape), out_spec.shape, input_shapes): - error = (error or "") + ( - f"Shape {tuple(out.shape)} of test output {idx} '{out_spec.name}' does not match " - f"output shape description: {out_spec.shape}." - ) - - with create_prediction_pipeline( - bioimageio_model=model, devices=devices, weight_format=weight_format - ) as prediction_pipeline: - results = predict(prediction_pipeline, inputs) - - if len(results) != len(expected): - error = (error or "") + ( - f"Number of outputs and number of expected outputs disagree: {len(results)} != {len(expected)}" - ) - else: - for res, exp in zip(results, expected): - try: - np.testing.assert_array_almost_equal(res, exp, decimal=decimal) - except AssertionError as e: - error = (error or "") + f"Output and expected output disagree:\n {e}" - except Exception as e: - error = str(e) - tb = traceback.format_tb(e.__traceback__) - - return dict( - name=f"Reproduce test outputs from test inputs (bioimageio.core {bioimageio_core_version})", - status="passed" if error is None else "failed", - error=error, - traceback=tb, - bioimageio_spec_version=bioimageio_spec_version, - bioimageio_core_version=bioimageio_core_version, - warnings=ValidationWarning.get_warning_summary(all_warnings), - source_name=model.id or model.name, - ) - - -def _test_load_raw_resource( - rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str] -) -> Tuple[Optional[ResourceDescription], TestSummary]: - if isinstance(rdf, (URI, os.PathLike)): - source_name = str(rdf) - elif isinstance(rdf, str): - source_name = rdf[:120] - else: - source_name = rdf.id if hasattr(rdf, "id") else rdf.name - - main_test_warnings = [] + tb: List[str] = [] try: - with warnings.catch_warnings(record=True) as all_warnings: - rd: Optional[ResourceDescription] = load_raw_resource_description(rdf) - - main_test_warnings += list(all_warnings) - except Exception as e: - rd = None - error: Optional[str] = str(e) - tb: Optional = traceback.format_tb(e.__traceback__) - else: - error = None - tb = None - - load_summary = TestSummary( - name="Load raw resource description", - status="passed" if error is None else "failed", - error=error, - nested_errors=None, - traceback=tb, - bioimageio_spec_version=bioimageio_spec_version, - bioimageio_core_version=bioimageio_core_version, - warnings={}, - source_name=source_name, - ) - - return rd, load_summary - + if isinstance(model, v0_4.ModelDescr): + inputs = [load_array(in_path) for in_path in model.test_inputs] + expected = [load_array(out_path) for out_path in model.test_outputs] + else: + inputs = [load_array(ipt.test_tensor.download().path) for ipt in model.inputs] + expected = [load_array(out.test_tensor.download().path) for out in model.outputs] -def _test_load_resource( - raw_rd: RawResourceDescription, - weight_format: Optional[WeightsFormat] = None, -) -> Tuple[Optional[ResourceDescription], TestSummary]: - source_name = getattr(raw_rd, "rdf_source", getattr(raw_rd, "id", raw_rd.name)) + with create_prediction_pipeline( + bioimageio_model=model, devices=devices, weight_format=weight_format + ) as prediction_pipeline: + results = predict(prediction_pipeline, inputs) - main_test_warnings = [] - try: - with warnings.catch_warnings(record=True) as all_warnings: - rd: Optional[ResourceDescription] = load_resource_description( - raw_rd, weights_priority_order=None if weight_format is None else [weight_format] + if len(results) != len(expected): + error = (error or "") + ( + f"Number of outputs and number of expected outputs disagree: {len(results)} != {len(expected)}" ) - - main_test_warnings += list(all_warnings) + else: + for res, exp in zip(results, expected): + try: + np.testing.assert_array_almost_equal(res, exp, decimal=decimal) + except AssertionError as e: + error = (error or "") + f"Output and expected output disagree:\n {e}" except Exception as e: - rd = None - error: Optional[str] = str(e) - tb: Optional = traceback.format_tb(e.__traceback__) - else: - error = None - tb = None - - load_summary = TestSummary( - name="Load resource description", - status="passed" if error is None else "failed", - error=error, - nested_errors=None, - traceback=tb, - bioimageio_spec_version=bioimageio_spec_version, - bioimageio_core_version=bioimageio_core_version, - warnings={}, - source_name=source_name, + error = str(e) + tb = traceback.format_tb(e.__traceback__) + + model.validation_summary.add_detail( + ValidationDetail( + name="Reproduce test outputs from test inputs", + status="passed" if error is None else "failed", + errors=( + [] + if error is None + else [ + ErrorEntry( + loc=("weights",) if weight_format is None else ("weights", weight_format), + msg=error, + type="bioimageio.core", + traceback=tb, + ) + ] + ), + ) ) - return rd, load_summary - -def _test_expected_resource_type(rd: RawResourceDescription, expected_type: str) -> TestSummary: +def _test_expected_resource_type(rd: Union[InvalidDescr, ResourceDescr], expected_type: str): has_expected_type = rd.type == expected_type - return dict( - name="Has expected resource type", - status="passed" if has_expected_type else "failed", - error=None if has_expected_type else f"expected type {expected_type}, found {rd.type}", - traceback=None, - source_name=rd.id or rd.name if hasattr(rd, "id") else rd.name, + rd.validation_summary.details.append( + ValidationDetail( + name="Has expected resource type", + status="passed" if has_expected_type else "failed", + errors=( + [] + if has_expected_type + else [ErrorEntry(loc=("type",), type="type", msg=f"expected type {expected_type}, found {rd.type}")] + ), + ) ) -def test_resource( - rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str], +def test_description( + source: Union[ResourceDescr, FileSource, BioimageioYamlContent], *, + format_version: Union[Literal["discover", "latest"], str] = "discover", weight_format: Optional[WeightsFormat] = None, devices: Optional[List[str]] = None, decimal: int = 4, expected_type: Optional[str] = None, -) -> List[TestSummary]: - """Test RDF dynamically - - Returns: summary dict with keys: name, status, error, traceback, bioimageio_spec_version, bioimageio_core_version - """ - raw_rd, load_test = _test_load_raw_resource(rdf) - tests: List[TestSummary] = [load_test] - if raw_rd is None: - return tests - - if expected_type is not None: - tests.append(_test_expected_resource_type(raw_rd, expected_type)) - - tests.append(_test_resource_urls(raw_rd)) - if tests[-1]["status"] == "passed": - tests.append(_test_resource_integrity(raw_rd)) - - if tests[-1]["status"] != "passed": - return tests # stop testing if resource availability/integrity is an issue - - rd = _test_load_resource(raw_rd, weight_format) - if isinstance(rd, Model): - tests.append(_test_model_documentation(rd)) - tests.append(_test_model_inference(rd, weight_format, devices, decimal)) - - return tests +) -> ValidationSummary: + """Test RDF dynamically, e.g. model inference of test inputs""" + rd = load_description_and_test( + source, + format_version=format_version, + weight_format=weight_format, + devices=devices, + decimal=decimal, + expected_type=expected_type, + ) + return rd.validation_summary -def debug_model( - model_rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str], +def load_description_and_test( + source: Union[ResourceDescr, FileSource, BioimageioYamlContent], *, + format_version: Union[Literal["discover", "latest"], str] = "discover", weight_format: Optional[WeightsFormat] = None, devices: Optional[List[str]] = None, -): - """Run the model test and return dict with inputs, results, expected results and intermediates. - - Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". - """ - inputs_raw: Optional = None - inputs_processed: Optional = None - outputs_raw: Optional = None - outputs: Optional = None - expected: Optional = None - diff: Optional = None - - model = load_resource_description( - model_rdf, weights_priority_order=None if weight_format is None else [weight_format] - ) - if not isinstance(model, Model): - raise ValueError(f"Not a bioimageio.model: {model_rdf}") - - prediction_pipeline = create_prediction_pipeline( - bioimageio_model=model, devices=devices, weight_format=weight_format - ) - inputs = [ - xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) - for in_path, input_spec in zip(model.test_inputs, model.inputs) - ] - input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} - - # keep track of the non-processed inputs - inputs_raw = [deepcopy(input) for input in inputs] - - computed_measures = {} - - prediction_pipeline.apply_preprocessing(input_dict, computed_measures) - inputs_processed = list(input_dict.values()) - outputs_raw = prediction_pipeline.predict(*inputs_processed) - output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} - prediction_pipeline.apply_postprocessing(output_dict, computed_measures) - outputs = list(output_dict.values()) - - if isinstance(outputs, (np.ndarray, xr.DataArray)): - outputs = [outputs] - - expected = [ - xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) - for out_path, output_spec in zip(model.test_outputs, model.outputs) - ] - if len(outputs) != len(expected): - error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" - print(error) + decimal: int = 4, + expected_type: Optional[str] = None, +) -> Union[ResourceDescr, InvalidDescr]: + """Test RDF dynamically, e.g. model inference of test inputs""" + if ( + isinstance(source, ResourceDescrBase) + and format_version != "discover" + and source.format_version != format_version + ): + warnings.warn(f"deserializing source to ensure we validate and test using format {format_version}") + source = dump_description(source) + + if isinstance(source, ResourceDescrBase): + rd = source + elif isinstance(source, dict): + rd = build_description(source, format_version=format_version) else: - diff = [] - for res, exp in zip(outputs, expected): - diff.append(res - exp) + rd = load_description(source, format_version=format_version) - return { - "inputs": inputs_raw, - "inputs_processed": inputs_processed, - "outputs_raw": outputs_raw, - "outputs": outputs, - "expected": expected, - "diff": diff, - } + rd.validation_summary.env.append(InstalledPackage(name="bioimageio.core", version=bioimageio_core_version)) + + if expected_type is not None: + _test_expected_resource_type(rd, expected_type) + + if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): + _test_model_inference(rd, weight_format, devices, decimal) + + return rd + + +# def debug_model( +# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], +# *, +# weight_format: Optional[WeightsFormat] = None, +# devices: Optional[List[str]] = None, +# ): +# """Run the model test and return dict with inputs, results, expected results and intermediates. + +# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". +# """ +# inputs_raw: Optional = None +# inputs_processed: Optional = None +# outputs_raw: Optional = None +# outputs: Optional = None +# expected: Optional = None +# diff: Optional = None + +# model = load_resource_description( +# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] +# ) +# if not isinstance(model, Model): +# raise ValueError(f"Not a bioimageio.model: {model_rdf}") + +# prediction_pipeline = create_prediction_pipeline( +# bioimageio_model=model, devices=devices, weight_format=weight_format +# ) +# inputs = [ +# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) +# for in_path, input_spec in zip(model.test_inputs, model.inputs) +# ] +# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} + +# # keep track of the non-processed inputs +# inputs_raw = [deepcopy(input) for input in inputs] + +# computed_measures = {} + +# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) +# inputs_processed = list(input_dict.values()) +# outputs_raw = prediction_pipeline.predict(*inputs_processed) +# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} +# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) +# outputs = list(output_dict.values()) + +# if isinstance(outputs, (np.ndarray, xr.DataArray)): +# outputs = [outputs] + +# expected = [ +# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) +# for out_path, output_spec in zip(model.test_outputs, model.outputs) +# ] +# if len(outputs) != len(expected): +# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" +# print(error) +# else: +# diff = [] +# for res, exp in zip(outputs, expected): +# diff.append(res - exp) + +# return { +# "inputs": inputs_raw, +# "inputs_processed": inputs_processed, +# "outputs_raw": outputs_raw, +# "outputs": outputs, +# "expected": expected, +# "diff": diff, +# } diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index af563b14..42a4fdc8 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -19,7 +19,6 @@ Tuple, Type, Union, - cast, ) import numpy as np @@ -93,7 +92,7 @@ def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: def _update_impl(self, tensor: xr.DataArray, tensor_mean: xr.DataArray): assert tensor_mean.dtype == np.float64 # reduced voxel count - n_b = np.prod(tensor.shape) / np.prod(tensor_mean.shape) # type: ignore + n_b = int(np.prod(tensor.shape) / np.prod(tensor_mean.shape)) if self._mean is None: assert self._n == 0 @@ -130,7 +129,7 @@ def compute(self, sample: Sample) -> Dict[Union[SampleMean, SampleVar, SampleStd if self._axes is None: n = tensor.size else: - n = int(np.prod([tensor.sizes[d] for d in self._axes])) # type: ignore # FIXME: type annotation + n = int(np.prod([tensor.sizes[d] for d in self._axes])) var: xr.DataArray = xr.dot(c, c, dims=self._axes) / n assert isinstance(var, xr.DataArray) @@ -147,7 +146,7 @@ def update(self, sample: Sample): mean_b = tensor.mean(dim=self._axes) assert mean_b.dtype == np.float64 # reduced voxel count - n_b = int(np.prod(tensor.shape) / np.prod(mean_b.shape)) # type: ignore + n_b = int(np.prod(tensor.shape) / np.prod(mean_b.shape)) m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) assert m2_b.dtype == np.float64 if self._mean is None: @@ -191,7 +190,7 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Co def compute(self, sample: Sample) -> Dict[SamplePercentile, MeasureValue]: tensor = sample.data[self._tensor_id] - ps = tensor.quantile(self._qs, dim=self._axes) # type: ignore + ps = tensor.quantile(self._qs, dim=self._axes) return {SamplePercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): p for n, p in zip(self.ns, ps)} @@ -211,7 +210,7 @@ def update(self, sample: Sample): sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(np.float64, copy=False) # reduced voxel count - n = int(np.prod(tensor.shape) / np.prod(sample_estimates.shape[1:])) # type: ignore + n = int(np.prod(tensor.shape) / np.prod(sample_estimates.shape[1:])) if self._estimates is None: assert self._n == 0 @@ -318,11 +317,7 @@ def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: class StatsCalculator: """Estimates dataset statistics and computes sample statistics efficiently""" - def __init__( - self, - *, - measures: Iterable[Measure], - ): + def __init__(self, measures: Iterable[Measure]): super().__init__() self.sample_count = 0 self.sample_calculators, self.dataset_calculators = get_measure_calculators(measures) @@ -336,15 +331,20 @@ def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: return ret - def _update(self, sample: Sample): + def update(self, sample: Union[Sample, Iterable[Sample]]) -> None: + _ = self._update(sample) + + def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: self.sample_count += 1 - for calc in self.dataset_calculators: - calc.update(sample) - self._current_dataset_measures = None + samples = [sample] if isinstance(sample, Sample) else sample + last_sample = None + for s in samples: + last_sample = s + for calc in self.dataset_calculators: + calc.update(s) - def _compute_and_update(self, sample: Sample): - self._update(sample) - return self._compute(sample) + self._current_dataset_measures = None + return last_sample def _finalize(self) -> Dict[DatasetMeasure, MeasureValue]: """returns aggregated dataset statistics""" @@ -356,17 +356,17 @@ def _finalize(self) -> Dict[DatasetMeasure, MeasureValue]: return self._current_dataset_measures - def update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: + def update_and_get_all(self, sample: Union[Sample, Iterable[Sample]]) -> Dict[Measure, MeasureValue]: """Returns sample as well as updated dataset statistics""" - ret = cast(Dict[Measure, MeasureValue], self._compute_and_update(sample)) - ret.update(self._finalize().items()) - return ret + last_sample = self._update(sample) + if last_sample is None: + raise ValueError("`sample` was not a `Sample`, nor did it yield any.") + + return {**self._compute(last_sample), **self._finalize()} def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: """Returns sample as well as previously computed dataset statistics""" - ret = cast(Dict[Measure, MeasureValue], self._compute(sample)) - ret.update(self._finalize().items()) - return ret + return {**self._compute(sample), **self._finalize()} def get_measure_calculators( diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py index eb1dbbfc..31875605 100644 --- a/bioimageio/core/utils/__init__.py +++ b/bioimageio/core/utils/__init__.py @@ -16,9 +16,8 @@ from typing_extensions import Unpack from bioimageio.core.io import FileSource, HashKwargs, download -from bioimageio.spec.model.v0_4 import CallableFromDepencency -from bioimageio.spec.model.v0_4 import CallableFromFile as CallableFromFile04 -from bioimageio.spec.model.v0_5 import CallableFromFile as CallableFromFile05 +from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile +from bioimageio.spec.model.v0_5 import ArchitectureFromFileDescr, ArchitectureFromLibraryDescr if sys.version_info < (3, 9): @@ -66,7 +65,7 @@ def import_from_dependency(node: CallableFromDepencency) -> Callable[..., Any]: @import_callable.register -def import_from_file04(node: CallableFromFile04, **kwargs: Unpack[HashKwargs]): +def import_from_file04(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): return _import_from_file_impl(node.file, node.callable_name, **kwargs) diff --git a/bioimageio/core/utils/node_visitor.py b/bioimageio/core/utils/node_visitor.py deleted file mode 100644 index dd523a7f..00000000 --- a/bioimageio/core/utils/node_visitor.py +++ /dev/null @@ -1,73 +0,0 @@ -from abc import ABC, abstractmethod -from dataclasses import dataclass, replace -from functools import singledispatchmethod -from pathlib import Path, PurePath -from typing import Any, List, Optional, Tuple, Union - -import requests -from pydantic import AnyUrl, DirectoryPath -from pydantic.fields import FieldInfo - -from bioimageio.core.utils import get_sha256 -from bioimageio.spec._internal.base_nodes import Node -from bioimageio.spec._internal.constants import IN_PACKAGE_MESSAGE, KW_ONLY, SLOTS -from bioimageio.spec._internal.types import Sha256 -from bioimageio.spec.summary import ErrorEntry, Loc, WarningEntry - - -@dataclass(frozen=True, **SLOTS, **KW_ONLY) -class Memo: - loc: Loc = () - info: Optional[FieldInfo] = None - parent_nodes: Tuple[Node, ...] = () - - -class NodeVisitor: - def visit(self, obj: Any, /, memo: Memo = Memo()): - self._traverse(obj, memo=memo) - - @singledispatchmethod - def _traverse(self, obj: type, /, memo: Memo): - pass - - @_traverse.register - def _traverse_node(self, node: Node, memo: Memo): - for k, v in node: - self.visit( - v, - replace(memo, loc=memo.loc + (k,), info=node.model_fields[k], parent_nodes=memo.parent_nodes + (node,)), - ) - - @_traverse.register - def _traverse_list(self, lst: list, memo: Memo): # type: ignore - e: Any - for i, e in enumerate(lst): # type: ignore - self.visit(e, replace(memo, loc=memo.loc + (i,))) - - @_traverse.register - def _traverse_tuple(self, tup: tuple, memo: Memo): # type: ignore - e: Any - for i, e in enumerate(tup): # type: ignore - self.visit(e, replace(memo, loc=memo.loc + (i,))) - - @_traverse.register - def _traverse_dict(self, dict_: dict, memo: Memo): # type: ignore - v: Any - for k, v in dict_.items(): # type: ignore - self.visit(v, replace(memo, loc=memo.loc + (k,))) - - -class ValidationVisitor(NodeVisitor, ABC): - def __init__(self) -> None: - super().__init__() - self.errors: List[ErrorEntry] = [] - self.warnings: List[WarningEntry] = [] - - def visit(self, obj: Any, /, memo: Memo = Memo()): - self.validate(obj, memo=memo) - return super().visit(obj, memo) - - @singledispatchmethod - @abstractmethod - def validate(self, obj: type, /, memo: Memo): - ... diff --git a/bioimageio/core/utils/testing.py b/bioimageio/core/utils/testing.py index c61fa62f..2659a2e7 100644 --- a/bioimageio/core/utils/testing.py +++ b/bioimageio/core/utils/testing.py @@ -3,8 +3,7 @@ class test_func(Protocol): - def __call__(*args: Any, **kwargs: Any): - ... + def __call__(*args: Any, **kwargs: Any): ... def skip_on(exception: Type[Exception], reason: str): diff --git a/bioimageio/core/weight_converter/torch/onnx.py b/bioimageio/core/weight_converter/torch/onnx.py index acdecc41..394a4825 100644 --- a/bioimageio/core/weight_converter/torch/onnx.py +++ b/bioimageio/core/weight_converter/torch/onnx.py @@ -12,6 +12,7 @@ from bioimageio.spec.common import InvalidDescription from bioimageio.spec.utils import download + def add_onnx_weights( model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr", *, @@ -34,7 +35,9 @@ def add_onnx_weights( if isinstance(loaded_spec, InvalidDescription): raise ValueError(f"Bad resource description: {loaded_spec}") if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)): - raise TypeError(f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr") + raise TypeError( + f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr" + ) model_spec = loaded_spec state_dict_weights_descr = model_spec.weights.pytorch_state_dict @@ -69,7 +72,7 @@ def add_onnx_weights( raise NotImplementedError try: - import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs] + import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs] except ImportError: msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked." warnings.warn(msg) @@ -77,11 +80,11 @@ def add_onnx_weights( # check the onnx model sess = rt.InferenceSession(str(output_path)) - onnx_input_node_args = cast(List[Any], sess.get_inputs()) # fixme: remove cast, try using rt.NodeArg instead of Any + onnx_input_node_args = cast(List[Any], sess.get_inputs()) # fixme: remove cast, try using rt.NodeArg instead of Any onnx_inputs: Dict[str, np.ndarray[Any, Any]] = { input_name.name: inp for input_name, inp in zip(onnx_input_node_args, input_data) } - outputs = cast(Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs)) #FIXME: remove cast + outputs = cast(Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs)) # FIXME: remove cast try: for exp, out in zip(expected_outputs, outputs): diff --git a/bioimageio/core/weight_converter/torch/utils.py b/bioimageio/core/weight_converter/torch/utils.py index 14e08514..413ba629 100644 --- a/bioimageio/core/weight_converter/torch/utils.py +++ b/bioimageio/core/weight_converter/torch/utils.py @@ -5,11 +5,10 @@ from bioimageio.spec.utils import download - # additional convenience for pytorch state dict, eventually we want this in python-bioimageio too # and for each weight format def load_model(node: "v0_4.PytorchStateDictWeightsDescr | v0_5.PytorchStateDictWeightsDescr"): model = PytorchModelAdapter.get_network(node) state = torch.load(download(node.source).path, map_location="cpu") - _ = model.load_state_dict(state) #FIXME: check incompatible keys? + _ = model.load_state_dict(state) # FIXME: check incompatible keys? return model.eval() diff --git a/tests/build_spec/test_build_spec.py b/tests/build_spec/test_build_spec.py index 7d842509..ad281ac1 100644 --- a/tests/build_spec/test_build_spec.py +++ b/tests/build_spec/test_build_spec.py @@ -67,15 +67,19 @@ def _test_build_spec( input_axes = [input_.axes for input_ in model_spec.inputs] output_axes = [output.axes for output in model_spec.outputs] preprocessing = [ - None - if input_.preprocessing is missing - else [{"name": preproc.name, "kwargs": preproc.kwargs} for preproc in input_.preprocessing] + ( + None + if input_.preprocessing is missing + else [{"name": preproc.name, "kwargs": preproc.kwargs} for preproc in input_.preprocessing] + ) for input_ in model_spec.inputs ] postprocessing = [ - None - if output.postprocessing is missing - else [{"name": preproc.name, "kwargs": preproc.kwargs} for preproc in output.preprocessing] + ( + None + if output.postprocessing is missing + else [{"name": preproc.name, "kwargs": preproc.kwargs} for preproc in output.preprocessing] + ) for output in model_spec.outputs ] diff --git a/tests/test_resource_tests/test_test_model.py b/tests/test_resource_tests/test_test_model.py index 1d168472..c5f3cf5c 100644 --- a/tests/test_resource_tests/test_test_model.py +++ b/tests/test_resource_tests/test_test_model.py @@ -34,19 +34,19 @@ def test_test_model(any_model): def test_test_resource(any_model): - from bioimageio.core.resource_tests import test_resource + from bioimageio.core.resource_tests import test_description - summary = test_resource(any_model) + summary = test_description(any_model) assert all([s["status"] for s in summary]) def test_validation_section_warning(unet2d_nuclei_broad_model, tmp_path: pathlib.Path): - from bioimageio.core.resource_tests import test_resource from bioimageio.core import load_resource_description + from bioimageio.core.resource_tests import test_description model = load_resource_description(unet2d_nuclei_broad_model) - summary = test_resource(model)[2] + summary = test_description(model)[2] assert summary["name"] == "Test documentation completeness." assert summary["warnings"] == {"documentation": "No '# Validation' (sub)section found."} assert summary["status"] == "passed" @@ -54,7 +54,7 @@ def test_validation_section_warning(unet2d_nuclei_broad_model, tmp_path: pathlib doc_with_validation = tmp_path / "doc.md" doc_with_validation.write_text("# Validation\nThis is a section about how to validate the model on new data") model.documentation = doc_with_validation - summary = test_resource(model)[2] + summary = test_description(model)[2] assert summary["name"] == "Test documentation completeness." assert summary["warnings"] == {} assert summary["status"] == "passed" From fcb6bcd674e211c61161105eaf3552cf3cb804ec Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 13:56:38 +0100 Subject: [PATCH 087/244] add test_description alias --- bioimageio/core/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 8268d6bf..29116eaa 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -18,3 +18,5 @@ from .prediction_pipeline import create_prediction_pipeline as create_prediction_pipeline from .resource_tests import load_description_and_test as load_description_and_test from .resource_tests import test_description as test_description + +test_resource = test_description From 154a7d6f3c585045bb23a5a0fcfc6f418355bc1a Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 14:14:09 +0100 Subject: [PATCH 088/244] update version and spec pinning --- bioimageio/core/VERSION | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bioimageio/core/VERSION b/bioimageio/core/VERSION index 167d3d30..424d6096 100644 --- a/bioimageio/core/VERSION +++ b/bioimageio/core/VERSION @@ -1,3 +1,3 @@ { - "version": "0.5.11" + "version": "0.6.0" } diff --git a/setup.py b/setup.py index f18a765b..cd612047 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ ], packages=find_namespace_packages(exclude=["tests"]), install_requires=[ - "bioimageio.spec==0.4.9.*", + "bioimageio.spec==0.5.0.*", "imageio>=2.5", "numpy", "ruyaml", From a9493768bffbf96f359cd44540acd350954b11e3 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 14:40:51 +0100 Subject: [PATCH 089/244] WIP cleanup utils --- bioimageio/core/utils/__init__.py | 87 +++++++++---------------------- 1 file changed, 26 insertions(+), 61 deletions(-) diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py index 31875605..20dbd866 100644 --- a/bioimageio/core/utils/__init__.py +++ b/bioimageio/core/utils/__init__.py @@ -1,23 +1,7 @@ -# todo: cleanup __init__: move stuff to util submodules or elsewhere from __future__ import annotations -import hashlib -import importlib.util -import os import sys -from contextlib import AbstractContextManager -from functools import singledispatch from pathlib import Path -from types import TracebackType -from typing import Any, Callable -from urllib.parse import urlsplit, urlunsplit - -from pydantic import AnyUrl, HttpUrl -from typing_extensions import Unpack - -from bioimageio.core.io import FileSource, HashKwargs, download -from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile -from bioimageio.spec.model.v0_5 import ArchitectureFromFileDescr, ArchitectureFromLibraryDescr if sys.version_info < (3, 9): @@ -29,58 +13,39 @@ def files(package_name: str): from importlib.resources import files as files -class TemporaryInsertionIntoPythonPath(AbstractContextManager[None]): - def __init__(self, path: Path): - super().__init__() - self.path = str(path) - - def __enter__(self): - super().__enter__() - sys.path.insert(0, self.path) - - def __exit__( - self, - __exc_type: "type[BaseException] | None", - __exc_value: "BaseException | None", - __traceback: "TracebackType | None", - ) -> "bool | None": - assert sys.path[0] == self.path - _ = sys.path.pop(0) - return super().__exit__(__exc_type, __exc_value, __traceback) - - -@singledispatch -def import_callable(node: type, /) -> Callable[..., Any]: - raise TypeError(type(node)) +# TODO: import helpers +# @singledispatch +# def import_callable(node: type, /) -> Callable[..., Any]: +# raise TypeError(type(node)) -@import_callable.register -def import_from_dependency(node: CallableFromDepencency) -> Callable[..., Any]: - module = importlib.import_module(node.module_name) - c = getattr(module, node.callable_name) - if not callable(c): - raise ValueError(f"{node} (imported: {c}) is not callable") +# @import_callable.register +# def import_from_dependency(node: CallableFromDepencency) -> Callable[..., Any]: +# module = importlib.import_module(node.module_name) +# c = getattr(module, node.callable_name) +# if not callable(c): +# raise ValueError(f"{node} (imported: {c}) is not callable") - return c +# return c -@import_callable.register -def import_from_file04(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): - return _import_from_file_impl(node.file, node.callable_name, **kwargs) +# @import_callable.register +# def import_from_file04(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): +# return _import_from_file_impl(node.file, node.callable_name, **kwargs) -@import_callable.register -def import_from_file05(node: CallableFromFile05, **kwargs: Unpack[HashKwargs]): - return _import_from_file_impl(node.source_file, node.callable_name, **kwargs) +# @import_callable.register +# def import_from_file05(node: CallableFromFile05, **kwargs: Unpack[HashKwargs]): +# return _import_from_file_impl(node.source_file, node.callable_name, **kwargs) -def _import_from_file_impl(source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]): - local_file = download(source, **kwargs) - module_name = local_file.path.stem - importlib_spec = importlib.util.spec_from_file_location(module_name, local_file.path) - if importlib_spec is None: - raise ImportError(f"Failed to import {module_name} from {source}.") +# def _import_from_file_impl(source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]): +# local_file = download(source, **kwargs) +# module_name = local_file.path.stem +# importlib_spec = importlib.util.spec_from_file_location(module_name, local_file.path) +# if importlib_spec is None: +# raise ImportError(f"Failed to import {module_name} from {source}.") - dep = importlib.util.module_from_spec(importlib_spec) - importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? - return getattr(dep, callable_name) +# dep = importlib.util.module_from_spec(importlib_spec) +# importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? +# return getattr(dep, callable_name) From 546df6601f2dc2c2eace86ff09f288249e935d7d Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 15:50:26 +0100 Subject: [PATCH 090/244] fix errors on import and more utils cleanup --- bioimageio/core/common.py | 11 +-- .../model_adapters/_keras_model_adapter.py | 12 ++-- bioimageio/core/resource_tests.py | 19 +++-- bioimageio/core/stat_calculators.py | 8 +-- bioimageio/core/stat_measures.py | 16 ++--- bioimageio/core/utils.py | 0 bioimageio/core/utils/__init__.py | 71 ++++++++++++------- 7 files changed, 83 insertions(+), 54 deletions(-) delete mode 100644 bioimageio/core/utils.py diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index db9b73fd..1f0bcd84 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,19 +1,20 @@ -from dataclasses import field -from typing import Dict, Union +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Dict, Union import xarray as xr -from attr import dataclass -from bioimageio.core.stat_measures import Measure, MeasureValue from bioimageio.spec.model import v0_4, v0_5 +if TYPE_CHECKING: + from bioimageio.core.stat_measures import Measure, MeasureValue + TensorId = v0_5.TensorId AxisId = v0_5.AxisId Tensor = xr.DataArray Data = Dict[TensorId, Tensor] -Stat = Dict[Measure, MeasureValue] +Stat = Dict["Measure", "MeasureValue"] @dataclass diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py index 177a6a28..e353df17 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -11,7 +11,10 @@ tf_version = Version(tf.__version__) except Exception: - import keras + try: + import keras + except Exception: + keras = None tf_version = None import xarray as xr @@ -26,6 +29,7 @@ class KerasModelAdapter(ModelAdapter): def __init__( self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None ) -> None: + assert keras is not None super().__init__() if model_description.weights.keras_hdf5 is None: raise ValueError("model has not keras_hdf5 weights specified") @@ -50,9 +54,9 @@ def __init__( self._output_axes = [tuple(out.axes) for out in model_description.outputs] def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - _result: Union[ # pyright: ignore[reportUnknownVariableType] - Sequence[NDArray[Any]], NDArray[Any] - ] = self._network.predict(*input_tensors) + _result: Union[Sequence[NDArray[Any]], NDArray[Any]] = ( # pyright: ignore[reportUnknownVariableType] + self._network.predict(*input_tensors) + ) if isinstance(_result, (tuple, list)): result: Sequence[NDArray[Any]] = _result else: diff --git a/bioimageio/core/resource_tests.py b/bioimageio/core/resource_tests.py index dd899089..8acf1406 100644 --- a/bioimageio/core/resource_tests.py +++ b/bioimageio/core/resource_tests.py @@ -3,14 +3,13 @@ from typing import List, Literal, Optional, Union import numpy as np +import xarray as xr from bioimageio.core import __version__ as bioimageio_core_version -from bioimageio.core.prediction import predict from bioimageio.core.prediction_pipeline import create_prediction_pipeline from bioimageio.spec import InvalidDescr, ResourceDescr, build_description, dump_description, load_description from bioimageio.spec._internal.base_nodes import ResourceDescrBase from bioimageio.spec._internal.io_utils import load_array -from bioimageio.spec._internal.validation_context import validation_context_var from bioimageio.spec.common import BioimageioYamlContent, FileSource from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import WeightsFormat @@ -39,16 +38,22 @@ def _test_model_inference( tb: List[str] = [] try: if isinstance(model, v0_4.ModelDescr): - inputs = [load_array(in_path) for in_path in model.test_inputs] - expected = [load_array(out_path) for out_path in model.test_outputs] + inputs = [xr.DataArray(load_array(src), dims=d.axes) for src, d in zip(model.test_inputs, model.inputs)] + expected = [xr.DataArray(load_array(src), dims=d.axes) for src, d in zip(model.test_outputs, model.outputs)] else: - inputs = [load_array(ipt.test_tensor.download().path) for ipt in model.inputs] - expected = [load_array(out.test_tensor.download().path) for out in model.outputs] + inputs = [ + xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(a.id for a in d.axes)) + for d in model.inputs + ] + expected = [ + xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(a.id for a in d.axes)) + for d in model.outputs + ] with create_prediction_pipeline( bioimageio_model=model, devices=devices, weight_format=weight_format ) as prediction_pipeline: - results = predict(prediction_pipeline, inputs) + results = prediction_pipeline.forward(*inputs) if len(results) != len(expected): error = (error or "") + ( diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 42a4fdc8..54e601d0 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -49,7 +49,7 @@ try: import crick -except ImportError: +except Exception: crick = None class TDigest: @@ -289,9 +289,9 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: if crick is None: - DatasetPercentilesCalculator: Type[ - Union[MeanPercentilesCalculator, CrickPercentilesCalculator] - ] = MeanPercentilesCalculator + DatasetPercentilesCalculator: Type[Union[MeanPercentilesCalculator, CrickPercentilesCalculator]] = ( + MeanPercentilesCalculator + ) else: DatasetPercentilesCalculator = CrickPercentilesCalculator diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index 96de3a9c..4acb8f31 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -35,7 +35,7 @@ class _Mean: @dataclass(frozen=True) -class SampleMean(SampleMeasureBase, _Mean): +class SampleMean(_Mean, SampleMeasureBase): def compute(self, sample: Sample) -> MeasureValue: return sample.data[self.tensor_id].mean(dim=self.axes) @@ -44,7 +44,7 @@ def __post_init__(self): @dataclass(frozen=True) -class DatasetMean(DatasetMeasureBase, _Mean): +class DatasetMean(_Mean, DatasetMeasureBase): def __post_init__(self): assert self.axes is None or AxisId("batch") in self.axes @@ -55,7 +55,7 @@ class _Std: @dataclass(frozen=True) -class SampleStd(SampleMeasureBase, _Std): +class SampleStd(_Std, SampleMeasureBase): def compute(self, sample: Sample) -> MeasureValue: return sample.data[self.tensor_id].std(dim=self.axes) @@ -64,7 +64,7 @@ def __post_init__(self): @dataclass(frozen=True) -class DatasetStd(DatasetMeasureBase, _Std): +class DatasetStd(_Std, DatasetMeasureBase): def __post_init__(self): assert self.axes is None or AxisId("batch") in self.axes @@ -75,7 +75,7 @@ class _Var: @dataclass(frozen=True) -class SampleVar(SampleMeasureBase, _Var): +class SampleVar(_Var, SampleMeasureBase): def compute(self, sample: Sample) -> MeasureValue: return sample.data[self.tensor_id].var(dim=self.axes) @@ -84,7 +84,7 @@ def __post_init__(self): @dataclass(frozen=True) -class DatasetVar(DatasetMeasureBase, _Var): +class DatasetVar(_Var, DatasetMeasureBase): def __post_init__(self): assert self.axes is None or AxisId("batch") in self.axes @@ -100,7 +100,7 @@ def __post_init__(self): @dataclass(frozen=True) -class SamplePercentile(SampleMeasureBase, _Percentile): +class SamplePercentile(_Percentile, SampleMeasureBase): def compute(self, sample: Sample) -> MeasureValue: return sample.data[self.tensor_id].quantile(self.n / 100.0, dim=self.axes) @@ -110,7 +110,7 @@ def __post_init__(self): @dataclass(frozen=True) -class DatasetPercentile(DatasetMeasureBase, _Percentile): +class DatasetPercentile(_Percentile, DatasetMeasureBase): def __post_init__(self): super().__post_init__() assert self.axes is None or AxisId("batch") in self.axes diff --git a/bioimageio/core/utils.py b/bioimageio/core/utils.py deleted file mode 100644 index e69de29b..00000000 diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py index 20dbd866..bcb713e2 100644 --- a/bioimageio/core/utils/__init__.py +++ b/bioimageio/core/utils/__init__.py @@ -1,7 +1,17 @@ from __future__ import annotations +import importlib.util import sys +from functools import singledispatch from pathlib import Path +from typing import Any, Callable + +from typing_extensions import Unpack + +from bioimageio.spec._internal.io_utils import HashKwargs, download +from bioimageio.spec.common import FileSource +from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile +from bioimageio.spec.model.v0_5 import ArchitectureFromFileDescr, ArchitectureFromLibraryDescr if sys.version_info < (3, 9): @@ -13,39 +23,48 @@ def files(package_name: str): from importlib.resources import files as files -# TODO: import helpers -# @singledispatch -# def import_callable(node: type, /) -> Callable[..., Any]: -# raise TypeError(type(node)) +@singledispatch +def import_callable(node: type, /) -> Callable[..., Any]: + raise TypeError(type(node)) + + +@import_callable.register +def import_from_dependency04(node: CallableFromDepencency) -> Callable[..., Any]: + module = importlib.import_module(node.module_name) + c = getattr(module, node.callable_name) + if not callable(c): + raise ValueError(f"{node} (imported: {c}) is not callable") + + return c -# @import_callable.register -# def import_from_dependency(node: CallableFromDepencency) -> Callable[..., Any]: -# module = importlib.import_module(node.module_name) -# c = getattr(module, node.callable_name) -# if not callable(c): -# raise ValueError(f"{node} (imported: {c}) is not callable") +@import_callable.register +def import_from_dependency05(node: ArchitectureFromLibraryDescr) -> Callable[..., Any]: + module = importlib.import_module(node.import_from) + c = getattr(module, node.callable) + if not callable(c): + raise ValueError(f"{node} (imported: {c}) is not callable") -# return c + return c -# @import_callable.register -# def import_from_file04(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): -# return _import_from_file_impl(node.file, node.callable_name, **kwargs) +@import_callable.register +def import_from_file04(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): + return _import_from_file_impl(node.file, node.callable_name, **kwargs) -# @import_callable.register -# def import_from_file05(node: CallableFromFile05, **kwargs: Unpack[HashKwargs]): -# return _import_from_file_impl(node.source_file, node.callable_name, **kwargs) +@import_callable.register +def import_from_file05(node: ArchitectureFromFileDescr, **kwargs: Unpack[HashKwargs]): + return _import_from_file_impl(node.source, node.callable, sha256=node.sha256) -# def _import_from_file_impl(source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]): -# local_file = download(source, **kwargs) -# module_name = local_file.path.stem -# importlib_spec = importlib.util.spec_from_file_location(module_name, local_file.path) -# if importlib_spec is None: -# raise ImportError(f"Failed to import {module_name} from {source}.") +def _import_from_file_impl(source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]): + local_file = download(source, **kwargs) + module_name = local_file.path.stem + importlib_spec = importlib.util.spec_from_file_location(module_name, local_file.path) + if importlib_spec is None: + raise ImportError(f"Failed to import {module_name} from {source}.") -# dep = importlib.util.module_from_spec(importlib_spec) -# importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? -# return getattr(dep, callable_name) + dep = importlib.util.module_from_spec(importlib_spec) + importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? + return getattr(dep, callable_name) From 462b4fe94be4f1198f08b96c172ebeb7caabb65a Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 16:06:57 +0100 Subject: [PATCH 091/244] import additional deps more carefully --- .../core/model_adapters/_model_adapter.py | 7 +++---- .../model_adapters/_onnx_model_adapter.py | 7 ++++++- .../model_adapters/_pytorch_model_adapter.py | 7 ++++++- .../_tensorflow_model_adapter.py | 9 +++++++-- .../_torchscript_model_adapter.py | 8 ++++++-- .../core/weight_converter/keras/tensorflow.py | 19 ++++++++++++------- .../core/weight_converter/torch/__init__.py | 2 +- .../core/weight_converter/torch/onnx.py | 6 +++--- .../weight_converter/torch/torchscript.py | 3 --- 9 files changed, 44 insertions(+), 24 deletions(-) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index 09a346a4..dabaff5f 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -4,7 +4,6 @@ import xarray as xr -from bioimageio.spec._internal.types import NotEmpty from bioimageio.spec.model import v0_4, v0_5 WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat] @@ -90,9 +89,9 @@ def create( # we try to first import the keras model adapter using the separate package and, # if it is not available, try to load the one using tf try: - try: - from ._keras_model_adapter import KerasModelAdapter - except ImportError: + from ._keras_model_adapter import KerasModelAdapter, keras + + if keras is None: from ._tensorflow_model_adapter import KerasModelAdapter return KerasModelAdapter(model_description=model_description, devices=devices) diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py index 0d947dc9..d1f51946 100644 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -2,7 +2,6 @@ import warnings from typing import Any, List, Optional, Sequence, Union -import onnxruntime as rt import xarray as xr from numpy.typing import NDArray @@ -10,6 +9,11 @@ from ._model_adapter import ModelAdapter +try: + import onnxruntime as rt +except Exception: + rt = None + logger = logging.getLogger(__name__) @@ -17,6 +21,7 @@ class ONNXModelAdapter(ModelAdapter): def __init__( self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None ): + assert rt is not None super().__init__() self._internal_output_axes = [ tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes) diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index ba9210a9..a9b7701b 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -2,7 +2,6 @@ import warnings from typing import Any, List, Optional, Sequence, Tuple, Union -import torch import xarray as xr from bioimageio.core.utils import import_callable @@ -11,6 +10,11 @@ from ._model_adapter import ModelAdapter +try: + import torch +except Exception: + torch = None + class PytorchModelAdapter(ModelAdapter): def __init__( @@ -20,6 +24,7 @@ def __init__( weights: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr], devices: Optional[Sequence[str]] = None, ): + assert torch is not None super().__init__() self.output_dims = [tuple(a if isinstance(a, str) else a.id for a in out.axes) for out in outputs] self._network = self.get_network(weights) diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index 96828016..a845f380 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -3,15 +3,19 @@ from typing import List, Literal, Optional, Sequence, Union import numpy as np -import tensorflow as tf import xarray as xr -from bioimageio.spec.common import FileSource, RelativeFilePath +from bioimageio.spec.common import FileSource from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download from ._model_adapter import ModelAdapter +try: + import tensorflow as tf +except Exception: + tf = None + class TensorflowModelAdapterBase(ModelAdapter): weight_format: Literal["keras_hdf5", "tensorflow_saved_model_bundle"] @@ -28,6 +32,7 @@ def __init__( ], model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], ): + assert tf is not None super().__init__() self.model_description = model_description tf_version = v0_5.Version(tf.__version__) diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index 3d7d046f..876136b8 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -3,21 +3,25 @@ from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np -import torch import xarray as xr from numpy.typing import NDArray -from bioimageio.spec.common import RelativeFilePath from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download from ._model_adapter import ModelAdapter +try: + import torch +except Exception: + torch = None + class TorchscriptModelAdapter(ModelAdapter): def __init__( self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None ): + assert torch is not None super().__init__() if model_description.weights.torchscript is None: raise ValueError(f"No torchscript weights found for model {model_description.name}") diff --git a/bioimageio/core/weight_converter/keras/tensorflow.py b/bioimageio/core/weight_converter/keras/tensorflow.py index 5eed3797..e6476a46 100644 --- a/bioimageio/core/weight_converter/keras/tensorflow.py +++ b/bioimageio/core/weight_converter/keras/tensorflow.py @@ -4,8 +4,10 @@ from typing import no_type_check from zipfile import ZipFile -import tensorflow -from tensorflow import saved_model +try: + import tensorflow.saved_model +except Exception: + tensorflow = None from bioimageio.spec._internal.io_utils import download from bioimageio.spec.model.v0_5 import ModelDescr @@ -42,16 +44,18 @@ def _convert_tf1(keras_weight_path: Path, output_path: Path, input_name: str, ou @no_type_check def build_tf_model(): keras_model = keras.models.load_model(keras_weight_path) - - builder = saved_model.builder.SavedModelBuilder(output_path) - signature = saved_model.signature_def_utils.predict_signature_def( + assert tensorflow is not None + builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path) + signature = tensorflow.saved_model.signature_def_utils.predict_signature_def( inputs={input_name: keras_model.input}, outputs={output_name: keras_model.output} ) - signature_def_map = {saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature} + signature_def_map = {tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature} builder.add_meta_graph_and_variables( - keras.backend.get_session(), [saved_model.tag_constants.SERVING], signature_def_map=signature_def_map + keras.backend.get_session(), + [tensorflow.saved_model.tag_constants.SERVING], + signature_def_map=signature_def_map, ) builder.save() @@ -92,6 +96,7 @@ def convert_weights_to_tensorflow_saved_model_bundle(model: ModelDescr, output_p model: The bioimageio model description output_path: where to save the tensorflow weights. This path must not exist yet. """ + assert tensorflow is not None tf_major_ver = int(tensorflow.__version__.split(".")[0]) if output_path.suffix == ".zip": diff --git a/bioimageio/core/weight_converter/torch/__init__.py b/bioimageio/core/weight_converter/torch/__init__.py index 27b20c99..c7bda015 100644 --- a/bioimageio/core/weight_converter/torch/__init__.py +++ b/bioimageio/core/weight_converter/torch/__init__.py @@ -1,2 +1,2 @@ -from .onnx import convert_weights_to_onnx +from .onnx import add_onnx_weights from .torchscript import convert_weights_to_torchscript diff --git a/bioimageio/core/weight_converter/torch/onnx.py b/bioimageio/core/weight_converter/torch/onnx.py index 394a4825..3606cd74 100644 --- a/bioimageio/core/weight_converter/torch/onnx.py +++ b/bioimageio/core/weight_converter/torch/onnx.py @@ -6,10 +6,10 @@ import torch from numpy.testing import assert_array_almost_equal +from bioimageio.core.weight_converter.torch.utils import load_model from bioimageio.spec import load_description +from bioimageio.spec.common import InvalidDescr from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.core.weight_converter.torch.utils import load_model -from bioimageio.spec.common import InvalidDescription from bioimageio.spec.utils import download @@ -32,7 +32,7 @@ def add_onnx_weights( """ if isinstance(model_spec, (str, Path)): loaded_spec = load_description(Path(model_spec)) - if isinstance(loaded_spec, InvalidDescription): + if isinstance(loaded_spec, InvalidDescr): raise ValueError(f"Bad resource description: {loaded_spec}") if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)): raise TypeError( diff --git a/bioimageio/core/weight_converter/torch/torchscript.py b/bioimageio/core/weight_converter/torch/torchscript.py index 451fcb3e..a517e17b 100644 --- a/bioimageio/core/weight_converter/torch/torchscript.py +++ b/bioimageio/core/weight_converter/torch/torchscript.py @@ -6,11 +6,8 @@ from numpy.testing import assert_array_almost_equal from typing_extensions import Any, assert_never -from bioimageio.spec import load_description -from bioimageio.spec.common import InvalidDescription from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import Version -from bioimageio.spec.utils import download from .utils import load_model From d7eabfb9dc65a897ad43b939406950462c92c3d3 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 16:10:41 +0100 Subject: [PATCH 092/244] guard main --- scripts/setup_dev_env.py | 21 +++++++++++---------- scripts/show_diff.py | 27 ++++++++++++++------------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/scripts/setup_dev_env.py b/scripts/setup_dev_env.py index ed4502cc..315306a2 100644 --- a/scripts/setup_dev_env.py +++ b/scripts/setup_dev_env.py @@ -8,13 +8,14 @@ def run(prompt: str): _ = subprocess.run(prompt, check=True, capture_output=True) -repo_dir = Path(__file__).parent.parent.parent -cur_dir = Path().resolve() -chdir(str(repo_dir)) -try: - run("mamba env create --file core-bioimage-io/dev/env.yaml") - run("pip install --no-deps --config-settings editable_mode=compat -e spec-bioimage-io") - run("pip install --no-deps --config-settings editable_mode=compat -e core-bioimage-io") -except Exception: - chdir(cur_dir) - raise +if __name__ == "__main__": + repo_dir = Path(__file__).parent.parent.parent + cur_dir = Path().resolve() + chdir(str(repo_dir)) + try: + run("mamba env create --file core-bioimage-io/dev/env.yaml") + run("pip install --no-deps --config-settings editable_mode=compat -e spec-bioimage-io") + run("pip install --no-deps --config-settings editable_mode=compat -e core-bioimage-io") + except Exception: + chdir(cur_dir) + raise diff --git a/scripts/show_diff.py b/scripts/show_diff.py index f0fb20d8..7cdbeaa4 100644 --- a/scripts/show_diff.py +++ b/scripts/show_diff.py @@ -6,19 +6,20 @@ from bioimageio.core import load_description, write_description -rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/pydantic_axes/example_specs/models/unet2d_nuclei_broad/rdf_v0_4_9.yaml" +if __name__ == "__main__": + rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/pydantic_axes/example_specs/models/unet2d_nuclei_broad/rdf_v0_4_9.yaml" -local_source = Path(pooch.retrieve(rdf_source, None)) # type: ignore -model_as_is, summary_as_is = load_description(rdf_source, format_version="discover") -assert model_as_is is not None, summary_as_is -model_latest, summary_latest = load_description(rdf_source, format_version="latest") -print(summary_latest) -assert model_latest is not None + local_source = Path(pooch.retrieve(rdf_source, None)) # type: ignore + model_as_is, summary_as_is = load_description(rdf_source, format_version="discover") + assert model_as_is is not None, summary_as_is + model_latest, summary_latest = load_description(rdf_source, format_version="latest") + print(summary_latest) + assert model_latest is not None -with TemporaryDirectory() as tmp: - as_is = Path(tmp) / "as_is.bioimageio.yaml" - write_description(model_as_is, as_is) # write out as is to avoid sorting diff - latest = Path(tmp) / "latest.bioimageio.yaml" - write_description(model_latest, latest) + with TemporaryDirectory() as tmp: + as_is = Path(tmp) / "as_is.bioimageio.yaml" + write_description(model_as_is, as_is) # write out as is to avoid sorting diff + latest = Path(tmp) / "latest.bioimageio.yaml" + write_description(model_latest, latest) - _ = subprocess.run(f"git diff --no-index --ignore-all-space {as_is} {latest}") + _ = subprocess.run(f"git diff --no-index --ignore-all-space {as_is} {latest}") From b9f67b260811c3c3c85ef808d6f9e0c9b48813ba Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 16:16:37 +0100 Subject: [PATCH 093/244] update show_diff.py --- scripts/show_diff.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/scripts/show_diff.py b/scripts/show_diff.py index 7cdbeaa4..77623343 100644 --- a/scripts/show_diff.py +++ b/scripts/show_diff.py @@ -4,22 +4,21 @@ import pooch -from bioimageio.core import load_description, write_description +from bioimageio.core import load_description, save_bioimageio_yaml_only if __name__ == "__main__": rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/pydantic_axes/example_specs/models/unet2d_nuclei_broad/rdf_v0_4_9.yaml" local_source = Path(pooch.retrieve(rdf_source, None)) # type: ignore - model_as_is, summary_as_is = load_description(rdf_source, format_version="discover") - assert model_as_is is not None, summary_as_is - model_latest, summary_latest = load_description(rdf_source, format_version="latest") - print(summary_latest) - assert model_latest is not None + model_as_is = load_description(rdf_source, format_version="discover") + model_latest = load_description(rdf_source, format_version="latest") + print(model_latest.validation_summary) with TemporaryDirectory() as tmp: as_is = Path(tmp) / "as_is.bioimageio.yaml" - write_description(model_as_is, as_is) # write out as is to avoid sorting diff + + save_bioimageio_yaml_only(model_as_is, file=as_is) # write out as is to avoid sorting diff latest = Path(tmp) / "latest.bioimageio.yaml" - write_description(model_latest, latest) + save_bioimageio_yaml_only(model_latest, file=latest) _ = subprocess.run(f"git diff --no-index --ignore-all-space {as_is} {latest}") From ce0a83305e5e2a3aeefc81df3f44af63879fa1ae Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 16:21:29 +0100 Subject: [PATCH 094/244] fix skipping of test_bioimageio_spec_version --- tests/test_bioimageio_spec_version.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_bioimageio_spec_version.py b/tests/test_bioimageio_spec_version.py index 87cbaed6..444bac31 100644 --- a/tests/test_bioimageio_spec_version.py +++ b/tests/test_bioimageio_spec_version.py @@ -1,14 +1,17 @@ import json import subprocess import sys +from typing import Optional import pytest from packaging.version import Version @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python 3.8") -@pytest.mark.skipif(pytest.mamba_cmd is None, reason="requires mamba") -def test_bioimageio_spec_version(): +def test_bioimageio_spec_version(mamba_cmd: Optional[str]): + if mamba_cmd is None: + pytest.skip("requires mamba") + from importlib.metadata import metadata # get latest released bioimageio.spec version From 2fc6f3aba2a847388bbf80a6a391efb05d0ae1e6 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 16:24:02 +0100 Subject: [PATCH 095/244] load_resource_description -> load_description --- bioimageio/core/prediction.py | 6 +-- bioimageio/core/resource_tests.py | 2 +- tests/build_spec/test_add_weights.py | 7 ++-- tests/build_spec/test_build_spec.py | 6 +-- .../test_device_management.py | 4 +- .../test_prediction_pipeline.py | 4 +- tests/resource_io/test_load_rdf.py | 40 +++++++++---------- tests/test_cli.py | 4 +- tests/test_prediction.py | 10 ++--- tests/test_resource_tests/test_test_model.py | 6 +-- 10 files changed, 45 insertions(+), 44 deletions(-) diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index 49fb9d53..228bfa63 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -14,7 +14,7 @@ # from pydantic import HttpUrl # from tqdm import tqdm -# from bioimageio.core import image_helper, load_resource_description +# from bioimageio.core import image_helper, load_description # from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline # from bioimageio.core.resource_io.nodes import ImplicitOutputShape, Model, ResourceDescr @@ -455,7 +455,7 @@ # if not isinstance(outputs, (tuple, list)): # outputs = [outputs] -# model = load_resource_description(model_rdf) +# model = load_description(model_rdf) # assert isinstance(model, Model) # if len(model.inputs) != len(inputs): # raise ValueError @@ -491,7 +491,7 @@ # verbose: run prediction in verbose mode. # """ -# model = load_resource_description(model_rdf) +# model = load_description(model_rdf) # assert isinstance(model, Model) # with create_prediction_pipeline( diff --git a/bioimageio/core/resource_tests.py b/bioimageio/core/resource_tests.py index 8acf1406..2ac15c71 100644 --- a/bioimageio/core/resource_tests.py +++ b/bioimageio/core/resource_tests.py @@ -178,7 +178,7 @@ def load_description_and_test( # expected: Optional = None # diff: Optional = None -# model = load_resource_description( +# model = load_description( # model_rdf, weights_priority_order=None if weight_format is None else [weight_format] # ) # if not isinstance(model, Model): diff --git a/tests/build_spec/test_add_weights.py b/tests/build_spec/test_add_weights.py index 2f8300b0..4bba5f87 100644 --- a/tests/build_spec/test_add_weights.py +++ b/tests/build_spec/test_add_weights.py @@ -1,5 +1,6 @@ import os -from bioimageio.core import export_resource_package, load_raw_resource_description, load_resource_description + +from bioimageio.core import export_resource_package, load_description, load_raw_resource_description from bioimageio.core.resource_tests import test_model as _test_model @@ -10,7 +11,7 @@ def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs): assert base_weights in rdf.weights assert added_weights in rdf.weights - weight_path = load_resource_description(model).weights[added_weights].source + weight_path = load_description(model).weights[added_weights].source assert weight_path.exists() drop_weights = set(rdf.weights.keys()) - {base_weights} @@ -25,7 +26,7 @@ def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs): add_weights(in_path, weight_path, weight_type=added_weights, output_path=out_path, **kwargs) assert out_path.exists() - new_rdf = load_resource_description(out_path) + new_rdf = load_description(out_path) assert set(new_rdf.weights.keys()) == {base_weights, added_weights} for weight in new_rdf.weights.values(): assert weight.source.exists() diff --git a/tests/build_spec/test_build_spec.py b/tests/build_spec/test_build_spec.py index ad281ac1..669eeb8a 100644 --- a/tests/build_spec/test_build_spec.py +++ b/tests/build_spec/test_build_spec.py @@ -1,9 +1,9 @@ from typing import Optional -import bioimageio.spec as spec from marshmallow import missing -from bioimageio.core import load_raw_resource_description, load_resource_description +import bioimageio.spec as spec +from bioimageio.core import load_description, load_raw_resource_description from bioimageio.core._internal.validation_visitors import resolve_source from bioimageio.core.resource_io import nodes from bioimageio.core.resource_tests import test_model as _test_model @@ -131,7 +131,7 @@ def _test_build_spec( build_model(**kwargs) assert out_path.exists() - loaded_model = load_resource_description(out_path) + loaded_model = load_description(out_path) assert isinstance(loaded_model, nodes.Model) if add_deepimagej_config: loaded_config = loaded_model.config diff --git a/tests/prediction_pipeline/test_device_management.py b/tests/prediction_pipeline/test_device_management.py index bbe907ad..fb2d9d35 100644 --- a/tests/prediction_pipeline/test_device_management.py +++ b/tests/prediction_pipeline/test_device_management.py @@ -3,7 +3,7 @@ import xarray as xr from numpy.testing import assert_array_almost_equal -from bioimageio.core import load_resource_description +from bioimageio.core import load_description from bioimageio.core._internal.pytest_utils import skip_on from bioimageio.core.resource_io.nodes import Model @@ -20,7 +20,7 @@ def _test_device_management(model_package, weight_format): from bioimageio.core.prediction_pipeline import create_prediction_pipeline - bio_model = load_resource_description(model_package) + bio_model = load_description(model_package) assert isinstance(bio_model, Model) pred_pipe = create_prediction_pipeline(bioimageio_model=bio_model, weight_format=weight_format, devices=["cuda:0"]) diff --git a/tests/prediction_pipeline/test_prediction_pipeline.py b/tests/prediction_pipeline/test_prediction_pipeline.py index ac3c6a65..3a2c57aa 100644 --- a/tests/prediction_pipeline/test_prediction_pipeline.py +++ b/tests/prediction_pipeline/test_prediction_pipeline.py @@ -2,14 +2,14 @@ import xarray as xr from numpy.testing import assert_array_almost_equal -from bioimageio.core import load_resource_description +from bioimageio.core import load_description from bioimageio.core.resource_io.nodes import Model def _test_prediction_pipeline(model_package, weight_format): from bioimageio.core.prediction_pipeline import create_prediction_pipeline - bio_model = load_resource_description(model_package) + bio_model = load_description(model_package) assert isinstance(bio_model, Model) pp = create_prediction_pipeline(bioimageio_model=bio_model, weight_format=weight_format) diff --git a/tests/resource_io/test_load_rdf.py b/tests/resource_io/test_load_rdf.py index a9ea2441..873d9b26 100644 --- a/tests/resource_io/test_load_rdf.py +++ b/tests/resource_io/test_load_rdf.py @@ -8,12 +8,12 @@ def test_load_non_existing_rdf(): - from bioimageio.core import load_resource_description + from bioimageio.core import load_description spec_path = Path("some/none/existing/path/to/spec.model.yaml") with pytest.raises(FileNotFoundError): - load_resource_description(spec_path) + load_description(spec_path) def test_load_raw_model(any_model): @@ -24,73 +24,73 @@ def test_load_raw_model(any_model): def test_load_model(any_model): - from bioimageio.core import load_resource_description + from bioimageio.core import load_description - model = load_resource_description(any_model) + model = load_description(any_model) assert model def test_load_model_with_abs_path_source(unet2d_nuclei_broad_model): - from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description + from bioimageio.core.resource_io import load_description, load_raw_resource_description raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) path_source = (raw_rd.root_path / "rdf.yaml").absolute() assert path_source.is_absolute() - model = load_resource_description(path_source) + model = load_description(path_source) assert model def test_load_model_with_rel_path_source(unet2d_nuclei_broad_model): - from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description + from bioimageio.core.resource_io import load_description, load_raw_resource_description raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) path_source = pathlib.Path(os.path.relpath(raw_rd.root_path / "rdf.yaml", os.curdir)) assert not path_source.is_absolute() - model = load_resource_description(path_source) + model = load_description(path_source) assert model def test_load_model_with_abs_str_source(unet2d_nuclei_broad_model): - from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description + from bioimageio.core.resource_io import load_description, load_raw_resource_description raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) path_source = (raw_rd.root_path / "rdf.yaml").absolute() assert path_source.is_absolute() - model = load_resource_description(str(path_source)) + model = load_description(str(path_source)) assert model def test_load_model_with_rel_str_source(unet2d_nuclei_broad_model): - from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description + from bioimageio.core.resource_io import load_description, load_raw_resource_description raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) path_source = pathlib.Path(os.path.relpath(raw_rd.root_path / "rdf.yaml", os.curdir)) assert not path_source.is_absolute() - model = load_resource_description(str(path_source)) + model = load_description(str(path_source)) assert model -@pytest.mark.skipif(pytest.skip_torch, reason="remote model is a pytorch model") -def test_load_remote_rdf(): - from bioimageio.core import load_resource_description +def test_load_remote_rdf(unet2d_nuclei_broad_model): + # remote model is a pytorch model, needing unet2d_nuclei_broad_model skips the test when needed + _ = unet2d_nuclei_broad_model + from bioimageio.core import load_description from bioimageio.core.resource_io import nodes remote_rdf = "https://zenodo.org/api/files/63b44f05-a187-4fc9-81c8-c4568535531b/rdf.yaml" - model = load_resource_description(remote_rdf) + model = load_description(remote_rdf) assert isinstance(model, nodes.Model) @pytest.mark.skipif(True, reason="No suitable test model available yet") def test_load_remote_rdf_with_folders(): - from bioimageio.spec.model import raw_nodes - - from bioimageio.core import load_raw_resource_description, load_resource_description + from bioimageio.core import load_description, load_raw_resource_description from bioimageio.core.resource_io import nodes + from bioimageio.spec.model import raw_nodes rdf_doi = "" raw_model = load_raw_resource_description(rdf_doi, update_to_format="latest") assert isinstance(raw_model, raw_nodes.Model) - model = load_resource_description(rdf_doi) + model = load_description(rdf_doi) assert isinstance(model, nodes.Model) # test for field value with folder, e.g. diff --git a/tests/test_cli.py b/tests/test_cli.py index c0de99d4..384f5120 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from bioimageio.core import load_resource_description +from bioimageio.core import load_description def run_subprocess(commands: Sequence[str], **kwargs) -> subprocess.CompletedProcess: @@ -61,7 +61,7 @@ def test_cli_test_resource_with_weight_format(unet2d_nuclei_broad_model): def _test_cli_predict_image(model, tmp_path, extra_kwargs=None): - spec = load_resource_description(model) + spec = load_description(model) in_path = spec.test_inputs[0] out_path = tmp_path.with_suffix(".npy") cmd = ["bioimageio", "predict-image", model, "--inputs", str(in_path), "--outputs", str(out_path)] diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 7992f9f5..b73d0f82 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -4,14 +4,14 @@ import numpy as np from numpy.testing import assert_array_almost_equal -from bioimageio.core import load_resource_description +from bioimageio.core import load_description from bioimageio.core.resource_io.nodes import Model def test_predict_image(any_model, tmpdir): from bioimageio.core.prediction import predict_image - spec = load_resource_description(any_model) + spec = load_description(any_model) assert isinstance(spec, Model) inputs = spec.test_inputs @@ -29,7 +29,7 @@ def test_predict_image(any_model, tmpdir): def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not, tmpdir): from bioimageio.core.prediction import predict_image - spec = load_resource_description(unet2d_fixed_shape_or_not) + spec = load_description(unet2d_fixed_shape_or_not) assert isinstance(spec, Model) inputs = spec.test_inputs @@ -47,7 +47,7 @@ def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not, tmpdir): def _test_predict_with_padding(model, tmp_path): from bioimageio.core.prediction import predict_image - spec = load_resource_description(model) + spec = load_description(model) assert isinstance(spec, Model) input_spec, output_spec = spec.inputs[0], spec.outputs[0] @@ -136,7 +136,7 @@ def test_predict_image_with_padding_channel_last(stardist, tmp_path): def _test_predict_image_with_tiling(model, tmp_path: Path, exp_mean_deviation): from bioimageio.core.prediction import predict_image - spec = load_resource_description(model) + spec = load_description(model) assert isinstance(spec, Model) inputs = spec.test_inputs assert len(inputs) == 1 diff --git a/tests/test_resource_tests/test_test_model.py b/tests/test_resource_tests/test_test_model.py index c5f3cf5c..f1d5894f 100644 --- a/tests/test_resource_tests/test_test_model.py +++ b/tests/test_resource_tests/test_test_model.py @@ -41,10 +41,10 @@ def test_test_resource(any_model): def test_validation_section_warning(unet2d_nuclei_broad_model, tmp_path: pathlib.Path): - from bioimageio.core import load_resource_description + from bioimageio.core import load_description from bioimageio.core.resource_tests import test_description - model = load_resource_description(unet2d_nuclei_broad_model) + model = load_description(unet2d_nuclei_broad_model) summary = test_description(model)[2] assert summary["name"] == "Test documentation completeness." @@ -67,6 +67,6 @@ def test_issue289(): from bioimageio.core.resource_tests import test_model doi = "10.5281/zenodo.6287342" - model_resource = bioimageio.core.load_resource_description(doi) + model_resource = bioimageio.core.load_description(doi) test_result = test_model(model_resource) assert all([t["status"] == "passed" for t in test_result]) From 371a1fc91aa401d8740d8a6a682005187e369cb3 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 16:32:22 +0100 Subject: [PATCH 096/244] avoid pytest.mark.skipif(pytest.X ... --- tests/test_cli.py | 5 ++--- tests/test_resource_tests/test_test_model.py | 6 ++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 384f5120..473a1b1a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -126,10 +126,9 @@ def test_torch_to_torchscript(unet2d_nuclei_broad_model, tmp_path): assert out_path.exists() -@pytest.mark.skipif(pytest.skip_onnx, reason="requires torch and onnx") -def test_torch_to_onnx(unet2d_nuclei_broad_model, tmp_path): +def test_torch_to_onnx(convert_to_onnx, tmp_path): out_path = tmp_path.with_suffix(".onnx") - ret = run_subprocess(["bioimageio", "convert-torch-weights-to-onnx", str(unet2d_nuclei_broad_model), str(out_path)]) + ret = run_subprocess(["bioimageio", "convert-torch-weights-to-onnx", str(convert_to_onnx), str(out_path)]) assert ret.returncode == 0, ret.stdout assert out_path.exists() diff --git a/tests/test_resource_tests/test_test_model.py b/tests/test_resource_tests/test_test_model.py index f1d5894f..9498e8ab 100644 --- a/tests/test_resource_tests/test_test_model.py +++ b/tests/test_resource_tests/test_test_model.py @@ -60,9 +60,11 @@ def test_validation_section_warning(unet2d_nuclei_broad_model, tmp_path: pathlib assert summary["status"] == "passed" -@pytest.mark.skipif(pytest.skip_torch, reason="requires torch") -def test_issue289(): +def test_issue289(unet2d_nuclei_broad_model): """test for failure case from https://github.com/bioimage-io/core-bioimage-io-python/issues/289""" + # remote model is a pytorch model, needing unet2d_nuclei_broad_model skips the test when needed + _ = unet2d_nuclei_broad_model + import bioimageio.core from bioimageio.core.resource_tests import test_model From 3a7875b5debc2d52b2fc87f6579afe217e1c7280 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 21 Feb 2024 16:46:21 +0100 Subject: [PATCH 097/244] use test_description in build_description_and_test --- bioimageio/core/io.py | 13 ++++++++++--- tests/prediction_pipeline/test_device_management.py | 3 --- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index e811961c..1f27b60e 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -1,8 +1,9 @@ from __future__ import annotations +from contextlib import nullcontext from typing import Literal, Optional, Union -from bioimageio.spec import build_description +from bioimageio.core.resource_tests import test_description from bioimageio.spec import load_description as load_description from bioimageio.spec._description import ResourceDescr from bioimageio.spec._internal.constants import DISCOVER @@ -35,8 +36,14 @@ def build_description_and_test( format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, ) -> Union[ResourceDescr, InvalidDescr]: """load and validate a BioImage.IO description from the content of a resource description file (RDF)""" - rd = build_description(data, context=context, format_version=format_version) - # todo: add dynamic validation + if context is None: + val_context = nullcontext() + else: + val_context = context + + with val_context: + rd = test_description(data, format_version=format_version) + return rd diff --git a/tests/prediction_pipeline/test_device_management.py b/tests/prediction_pipeline/test_device_management.py index fb2d9d35..d8f65eca 100644 --- a/tests/prediction_pipeline/test_device_management.py +++ b/tests/prediction_pipeline/test_device_management.py @@ -61,19 +61,16 @@ def test_device_management_torchscript(any_torchscript_model): _test_device_management(any_torchscript_model, "torchscript") -@pytest.mark.skipif(pytest.skip_torch, reason="requires torch for device discovery") @skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_onnx(any_onnx_model): _test_device_management(any_onnx_model, "onnx") -@pytest.mark.skipif(pytest.skip_torch, reason="requires torch for device discovery") @skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_tensorflow(any_tensorflow_model): _test_device_management(any_tensorflow_model, "tensorflow_saved_model_bundle") -@pytest.mark.skipif(pytest.skip_torch, reason="requires torch for device discovery") @skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_keras(any_keras_model): _test_device_management(any_keras_model, "keras_hdf5") From 9fd124893f0b3bc647559e5f0ffba75d826e08b0 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 26 Feb 2024 16:06:40 +0100 Subject: [PATCH 098/244] WIP update tests --- bioimageio/core/proc_ops.py | 15 +- bioimageio/core/stat_calculators.py | 34 +- bioimageio/core/stat_measures.py | 11 +- tests/build_spec/test_add_weights.py | 3 - tests/build_spec/test_build_spec.py | 11 +- tests/conftest.py | 2 +- .../test_combined_processing.py | 13 +- .../test_device_management.py | 28 +- tests/prediction_pipeline/test_measures.py | 103 ------ .../test_postprocessing.py | 70 ----- .../test_prediction_pipeline.py | 4 +- .../prediction_pipeline/test_preprocessing.py | 212 ------------- tests/prediction_pipeline/test_processing.py | 55 ---- tests/resource_io/test_load_rdf.py | 41 +-- tests/resource_io/test_utils.py | 11 +- tests/test_cli.py | 1 - tests/test_export_package.py | 60 ---- .../test_internal/test_validation_visitors.py | 39 --- tests/test_prediction.py | 8 +- tests/test_proc_ops.py | 295 ++++++++++++++++++ tests/test_resource_tests/test_test_model.py | 54 ++-- tests/test_stat_measures.py | 39 +++ .../weight_converter/keras/test_tensorflow.py | 16 +- tests/weight_converter/torch/test_onnx.py | 4 +- 24 files changed, 478 insertions(+), 651 deletions(-) delete mode 100644 tests/prediction_pipeline/test_measures.py delete mode 100644 tests/prediction_pipeline/test_postprocessing.py delete mode 100644 tests/prediction_pipeline/test_preprocessing.py delete mode 100644 tests/prediction_pipeline/test_processing.py delete mode 100644 tests/test_export_package.py delete mode 100644 tests/test_internal/test_validation_visitors.py create mode 100644 tests/test_proc_ops.py create mode 100644 tests/test_stat_measures.py diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 535f085c..fd3ef2ee 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -3,6 +3,7 @@ from dataclasses import InitVar, dataclass, field from typing import ( Collection, + Generic, Hashable, Literal, Optional, @@ -31,10 +32,12 @@ DatasetMean, DatasetPercentile, DatasetStd, + MeanMeasure, Measure, SampleMean, SamplePercentile, SampleStd, + StdMeasure, ) from bioimageio.spec.model import v0_4, v0_5 @@ -166,7 +169,7 @@ class ScaleLinear(_SimpleOperator): offset: Union[float, xr.DataArray] = 0.0 """additive term""" - def apply(self, input: Tensor, stat: Stat) -> Tensor: + def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input * self.gain + self.offset # @classmethod @@ -315,8 +318,8 @@ def from_proc_descr(cls, descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr return cls( input=tensor_id, output=tensor_id, - lower_percentile=Percentile(kwargs.min_percentile, axes=axes, tensor_id=ref_tensor), - upper_percentile=Percentile(kwargs.max_percentile, axes=axes, tensor_id=ref_tensor), + lower_percentile=Percentile(n=kwargs.min_percentile, axes=axes, tensor_id=ref_tensor), + upper_percentile=Percentile(n=kwargs.max_percentile, axes=axes, tensor_id=ref_tensor), ) def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: @@ -363,8 +366,8 @@ def get_descr(self): class ZeroMeanUnitVariance(_SimpleOperator): """normalize to zero mean, unit variance.""" - mean: Union[SampleMean, DatasetMean] - std: Union[SampleStd, DatasetStd] + mean: MeanMeasure + std: StdMeasure eps: float = 1e-6 @@ -372,7 +375,7 @@ def __post_init__(self): assert self.mean.axes == self.std.axes @property - def required_measures(self) -> Collection[Measure]: + def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: return {self.mean, self.std} @classmethod diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 54e601d0..e3ccdc16 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -437,7 +437,7 @@ def get_measure_calculators( def compute_dataset_measures( - *, measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] + measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] ) -> Dict[DatasetMeasure, MeasureValue]: """compute all dataset `measures` for the given `dataset`""" sample_calculators, calculators = get_measure_calculators(measures) @@ -453,3 +453,35 @@ def compute_dataset_measures( ret.update(calc.finalize().items()) return ret + +def compute_sample_measures(measures: Iterable[SampleMeasure], sample: Sample) -> Dict[SampleMeasure, MeasureValue]: + """compute all sample `measures` for the given `sample`""" + calculators, dataset_calculators = get_measure_calculators(measures) + assert not dataset_calculators + ret: Dict[SampleMeasure, MeasureValue] = {} + + for calc in calculators: + ret.update(calc.compute(sample).items()) + + return ret + + +def compute_measures(measures: Iterable[Measure], dataset: Iterable[Sample]) -> Dict[Measure, MeasureValue]: + """compute all `measures` for the given `dataset` + sample measures are computed for the last sample in `dataset`""" + sample_calculators, dataset_calculators = get_measure_calculators(measures) + ret: Dict[Measure, MeasureValue] = {} + sample = None + for sample in dataset: + for calc in dataset_calculators: + calc.update(sample) + if sample is None: + raise ValueError("empty dataset") + + for calc in dataset_calculators: + ret.update(calc.finalize().items()) + + for calc in sample_calculators: + ret.update(calc.compute(sample).items()) + + return ret diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index 4acb8f31..726c90cf 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, TypeVar, Union import xarray as xr @@ -119,3 +119,12 @@ def __post_init__(self): SampleMeasure = Union[SampleMean, SampleStd, SampleVar, SamplePercentile] DatasetMeasure = Union[DatasetMean, DatasetStd, DatasetVar, DatasetPercentile] Measure = Union[SampleMeasure, DatasetMeasure] + +MeanMeasure = Union[SampleMean, DatasetMean] +StdMeasure = Union[SampleStd, DatasetStd] +VarMeasure = Union[SampleVar, DatasetVar] +PercentileMeasure = Union[SamplePercentile, DatasetPercentile] +MeanMeasureT = TypeVar("MeanMeasureT", bound=MeanMeasure) +StdMeasureT = TypeVar("StdMeasureT", bound=StdMeasure) +VarMeasureT = TypeVar("VarMeasureT", bound=VarMeasure) +PercentileMeasureT = TypeVar("PercentileMeasureT", bound=PercentileMeasure) diff --git a/tests/build_spec/test_add_weights.py b/tests/build_spec/test_add_weights.py index 4bba5f87..e3df4b80 100644 --- a/tests/build_spec/test_add_weights.py +++ b/tests/build_spec/test_add_weights.py @@ -1,8 +1,5 @@ import os -from bioimageio.core import export_resource_package, load_description, load_raw_resource_description -from bioimageio.core.resource_tests import test_model as _test_model - def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs): from bioimageio.core.build_spec import add_weights diff --git a/tests/build_spec/test_build_spec.py b/tests/build_spec/test_build_spec.py index 669eeb8a..b1fa85ab 100644 --- a/tests/build_spec/test_build_spec.py +++ b/tests/build_spec/test_build_spec.py @@ -1,12 +1,11 @@ from typing import Optional -from marshmallow import missing - import bioimageio.spec as spec -from bioimageio.core import load_description, load_raw_resource_description -from bioimageio.core._internal.validation_visitors import resolve_source -from bioimageio.core.resource_io import nodes -from bioimageio.core.resource_tests import test_model as _test_model + +# from bioimageio.core import load_description, load_raw_resource_description +# from bioimageio.core._internal.validation_visitors import resolve_source +# from bioimageio.core.resource_io import nodes +# from bioimageio.core.resource_tests import test_model as _test_model try: import tensorflow diff --git a/tests/conftest.py b/tests/conftest.py index 6a8366ae..0f44a94d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -132,7 +132,7 @@ @fixture(scope="session") -def model_packages(): +def model_packages() -> MappingProxyType[str, FilePath]: return MappingProxyType({name: save_bioimageio_package(MODEL_SOURCES[name]) for name in load_model_packages}) diff --git a/tests/prediction_pipeline/test_combined_processing.py b/tests/prediction_pipeline/test_combined_processing.py index 744e236c..7a590991 100644 --- a/tests/prediction_pipeline/test_combined_processing.py +++ b/tests/prediction_pipeline/test_combined_processing.py @@ -1,11 +1,12 @@ import numpy as np import xarray as xr -from bioimageio.core.resource_io import nodes +def test_postprocessing_dtype(): # TODO: remove? + from bioimageio.core.common import TensorId + from bioimageio.spec.model.v0_5 import BinarizeDescr, BinarizeKwargs, OutputTensorDescr -def test_postprocessing_dtype(): - from bioimageio.core.prediction_pipeline._combined_processing import CombinedProcessing + # from bioimageio.core.prediction_pipeline._combined_processing import CombinedProcessing shape = [3, 32, 32] axes = ("c", "y", "x") @@ -17,12 +18,12 @@ def test_postprocessing_dtype(): for dtype in ("float32", "float64", "uint8", "uint16"): outputs = [ - nodes.OutputTensor( - "out1", + OutputTensorDescr( + id=TensorId("out1"), data_type=dtype, axes=axes, shape=shape, - postprocessing=[nodes.Postprocessing("binarize", dict(threshold=threshold))], + postprocessing=[BinarizeDescr(kwargs=BinarizeKwargs(threshold=threshold))], ) ] com_proc = CombinedProcessing.from_tensor_specs(outputs) diff --git a/tests/prediction_pipeline/test_device_management.py b/tests/prediction_pipeline/test_device_management.py index d8f65eca..bda4af08 100644 --- a/tests/prediction_pipeline/test_device_management.py +++ b/tests/prediction_pipeline/test_device_management.py @@ -1,18 +1,21 @@ +from pathlib import Path + import numpy as np -import pytest import xarray as xr from numpy.testing import assert_array_almost_equal -from bioimageio.core import load_description -from bioimageio.core._internal.pytest_utils import skip_on -from bioimageio.core.resource_io.nodes import Model +from bioimageio.core.utils.testing import skip_on +from bioimageio.spec import load_description +from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04 +from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat +from bioimageio.spec.utils import load_array class TooFewDevicesException(Exception): pass -def _test_device_management(model_package, weight_format): +def _test_device_management(model_package: Path, weight_format: WeightsFormat): import torch if torch.cuda.device_count() == 0: @@ -21,13 +24,18 @@ def _test_device_management(model_package, weight_format): from bioimageio.core.prediction_pipeline import create_prediction_pipeline bio_model = load_description(model_package) - assert isinstance(bio_model, Model) + assert isinstance(bio_model, (ModelDescr, ModelDescr04)) pred_pipe = create_prediction_pipeline(bioimageio_model=bio_model, weight_format=weight_format, devices=["cuda:0"]) - inputs = [ - xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) - for test_tensor, spec in zip(bio_model.test_inputs, bio_model.inputs) - ] + if isinstance(bio_model, ModelDescr04): + inputs = [ + xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) + for test_tensor, spec in zip(bio_model.test_inputs, bio_model.inputs) + ] + else: + inputs = [ + xr.DataArray(load_array(ipt.test_tensor), dims=tuple(a.id for a in ipt.axes)) for ipt in bio_model.inputs + ] with pred_pipe as pp: outputs = pp.forward(*inputs) diff --git a/tests/prediction_pipeline/test_measures.py b/tests/prediction_pipeline/test_measures.py deleted file mode 100644 index 37916424..00000000 --- a/tests/prediction_pipeline/test_measures.py +++ /dev/null @@ -1,103 +0,0 @@ -import dataclasses -from itertools import product - -import numpy as np -import numpy.testing -import pytest -import xarray as xr - -from bioimageio.core import stat_measures -from bioimageio.core.prediction_pipeline._measure_groups import get_measure_groups -from bioimageio.core.prediction_pipeline._utils import PER_DATASET, PER_SAMPLE -from bioimageio.core.stat_measures import Mean, Percentile, Std, Var - - -@pytest.mark.parametrize("name_axes", product(["mean", "var", "std"], [None, ("x", "y")])) -def test_individual_normal_measure(name_axes): - name, axes = name_axes - measure = getattr(stat_measures, name.title())(axes=axes) - data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) - - expected = getattr(data, name)(dim=axes) - actual = measure.compute(data) - xr.testing.assert_allclose(expected, actual) - - -@pytest.mark.parametrize("axes_n", product([None, ("x", "y")], [0, 10, 50, 100])) -def test_individual_percentile_measure(axes_n): - axes, n = axes_n - measure = stat_measures.Percentile(axes=axes, n=n) - data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) - - expected = data.quantile(q=n / 100, dim=axes) - actual = measure.compute(data) - xr.testing.assert_allclose(expected, actual) - - -@pytest.mark.parametrize( - "measures_mode", - product( - [ - {"t1": {Mean()}, "t2": {Mean(), Std()}}, - {"t1": {Mean(), Var(), Std()}, "t2": {Std(axes=("x", "y"))}}, - {"t1": {Mean(axes=("x", "y"))}, "t2": {Mean(), Std(axes=("x", "y"))}}, - { - "t1": {Percentile(n=10), Percentile(n=35), Percentile(n=10, axes=("x", "y"))}, - "t2": {Percentile(n=10, axes=("x", "y")), Percentile(n=35, axes=("x", "y")), Percentile(n=10)}, - }, - ], - [PER_SAMPLE, PER_DATASET], - ), -) -def test_measure_groups(measures_mode): - measures, mode = measures_mode - - def get_sample(): - return { - "t1": xr.DataArray(np.random.random((2, 500, 600, 3)), dims=("b", "x", "y", "c")), - "t2": xr.DataArray(np.random.random((1, 500, 600)), dims=("c", "x", "y")), - } - - sample = get_sample() - dataset_seq = [sample, get_sample()] - dataset_full = {tn: xr.concat([s[tn] for s in dataset_seq], dim="dataset") for tn in sample.keys()} - - # compute independently - expected = {} - for tn, ms in measures.items(): - for m in ms: - if mode == PER_SAMPLE: - expected[(tn, m)] = m.compute(sample[tn]) - elif mode == PER_DATASET: - if m.axes is None: - m_d = m - else: - m_d = dataclasses.replace(m, axes=("dataset",) + m.axes) - - expected[(tn, m)] = m_d.compute(dataset_full[tn]) - else: - raise NotImplementedError(mode) - - groups = get_measure_groups({mode: measures})[mode] - actual = {} - for g in groups: - if mode == PER_SAMPLE: - res = g.compute(sample) - elif mode == PER_DATASET: - for s in dataset_seq: - g.update_with_sample(s) - - res = g.finalize() - else: - raise NotImplementedError(mode) - - for tn, vs in res.items(): - for m, v in vs.items(): - actual[(tn, m)] = v - - # discard additionally computed measures by groups - actual = {k: v for k, v in actual.items() if k in expected} - - for k in expected.keys(): - assert k in actual - numpy.testing.assert_array_almost_equal(expected[k].data, actual[k].data, decimal=2) diff --git a/tests/prediction_pipeline/test_postprocessing.py b/tests/prediction_pipeline/test_postprocessing.py deleted file mode 100644 index 52c3e151..00000000 --- a/tests/prediction_pipeline/test_postprocessing.py +++ /dev/null @@ -1,70 +0,0 @@ -import numpy as np -import pytest -import xarray as xr - -from bioimageio.core.prediction_pipeline._measure_groups import compute_measures - - -def test_binarize(): - from bioimageio.core.prediction_pipeline._processing import Binarize - - shape = (3, 32, 32) - axes = ("c", "y", "x") - np_data = np.random.rand(*shape) - data = xr.DataArray(np_data, dims=axes) - - threshold = 0.5 - exp = xr.DataArray(np_data > threshold, dims=axes) - - binarize = Binarize("data_name", threshold=threshold) - res = binarize(data) - xr.testing.assert_allclose(res, exp) - - -@pytest.mark.parametrize("axes", [None, tuple("cy"), tuple("cyx"), tuple("x")]) -def test_scale_mean_variance(axes): - from bioimageio.core.prediction_pipeline._processing import ScaleMeanVariance - - shape = (3, 32, 46) - ipt_axes = ("c", "y", "x") - np_data = np.random.rand(*shape) - ipt_data = xr.DataArray(np_data, dims=ipt_axes) - ref_data = xr.DataArray((np_data * 2) + 3, dims=ipt_axes) - - scale_mean_variance = ScaleMeanVariance("data_name", reference_tensor="ref_name", axes=axes) - required = scale_mean_variance.get_required_measures() - computed = compute_measures(required, sample={"data_name": ipt_data, "ref_name": ref_data}) - scale_mean_variance.set_computed_measures(computed) - - res = scale_mean_variance(ipt_data) - xr.testing.assert_allclose(res, ref_data) - - -@pytest.mark.parametrize("axes", [None, tuple("cy"), tuple("y"), tuple("yx")]) -def test_scale_mean_variance_per_channel(axes): - from bioimageio.core.prediction_pipeline._processing import ScaleMeanVariance - - shape = (3, 32, 46) - ipt_axes = ("c", "y", "x") - np_data = np.random.rand(*shape) - ipt_data = xr.DataArray(np_data, dims=ipt_axes) - - # set different mean, std per channel - np_ref_data = np.stack([d * i + i for i, d in enumerate(np_data, start=2)]) - print(np_ref_data.shape) - ref_data = xr.DataArray(np_ref_data, dims=ipt_axes) - - scale_mean_variance = ScaleMeanVariance("data_name", reference_tensor="ref_name", axes=axes) - required = scale_mean_variance.get_required_measures() - computed = compute_measures(required, sample={"data_name": ipt_data, "ref_name": ref_data}) - scale_mean_variance.set_computed_measures(computed) - - res = scale_mean_variance(ipt_data) - - if axes is not None and "c" not in axes: - # mean,std per channel should match exactly - xr.testing.assert_allclose(res, ref_data) - else: - # mean,std across channels should not match - with pytest.raises(AssertionError): - xr.testing.assert_allclose(res, ref_data) diff --git a/tests/prediction_pipeline/test_prediction_pipeline.py b/tests/prediction_pipeline/test_prediction_pipeline.py index 3a2c57aa..2c196401 100644 --- a/tests/prediction_pipeline/test_prediction_pipeline.py +++ b/tests/prediction_pipeline/test_prediction_pipeline.py @@ -2,8 +2,8 @@ import xarray as xr from numpy.testing import assert_array_almost_equal -from bioimageio.core import load_description -from bioimageio.core.resource_io.nodes import Model +# from bioimageio.core import load_description +# from bioimageio.core.resource_io.nodes import Model def _test_prediction_pipeline(model_package, weight_format): diff --git a/tests/prediction_pipeline/test_preprocessing.py b/tests/prediction_pipeline/test_preprocessing.py deleted file mode 100644 index fb8efa06..00000000 --- a/tests/prediction_pipeline/test_preprocessing.py +++ /dev/null @@ -1,212 +0,0 @@ -import numpy as np -import xarray as xr - -from bioimageio.core.prediction_pipeline._measure_groups import compute_measures -from bioimageio.core.prediction_pipeline._utils import PER_SAMPLE - - -def test_scale_linear(): - from bioimageio.core.prediction_pipeline._processing import ScaleLinear - - preprocessing = ScaleLinear("data_name", offset=[1, 2, 42], gain=[1, 2, 3], axes="yx") - data = xr.DataArray(np.arange(6).reshape((1, 2, 3)), dims=("x", "y", "c")) - expected = xr.DataArray(np.array([[[1, 4, 48], [4, 10, 57]]]), dims=("x", "y", "c")) - result = preprocessing.apply(data) - xr.testing.assert_allclose(expected, result) - - -def test_scale_linear_no_channel(): - from bioimageio.core.prediction_pipeline._processing import ScaleLinear - - preprocessing = ScaleLinear("data_name", offset=1, gain=2, axes="yx") - data = xr.DataArray(np.arange(6).reshape(2, 3), dims=("x", "y")) - expected = xr.DataArray(np.array([[1, 3, 5], [7, 9, 11]]), dims=("x", "y")) - result = preprocessing.apply(data) - xr.testing.assert_allclose(expected, result) - - -def test_zero_mean_unit_variance_preprocessing(): - from bioimageio.core.prediction_pipeline._processing import ZeroMeanUnitVariance - - data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) - - preprocessing = ZeroMeanUnitVariance("data_name", mode=PER_SAMPLE) - required = preprocessing.get_required_measures() - computed = compute_measures(required, sample={"data_name": data}) - preprocessing.set_computed_measures(computed) - - expected = xr.DataArray( - np.array( - [ - [-1.54919274, -1.16189455, -0.77459637], - [-0.38729818, 0.0, 0.38729818], - [0.77459637, 1.16189455, 1.54919274], - ] - ), - dims=("x", "y"), - ) - result = preprocessing(data) - xr.testing.assert_allclose(expected, result) - - -def test_zero_mean_unit_variance_preprocessing_fixed(): - from bioimageio.core.prediction_pipeline._processing import ZeroMeanUnitVariance - - preprocessing = ZeroMeanUnitVariance( - "data_name", mode="fixed", axes=["y"], mean=[1, 4, 7], std=[0.81650, 0.81650, 0.81650] - ) - data = xr.DataArray(np.arange(9).reshape((1, 1, 3, 3)), dims=("b", "c", "x", "y")) - expected = xr.DataArray( - np.array([[-1.224743, 0.0, 1.224743], [-1.224743, 0.0, 1.224743], [-1.224743, 0.0, 1.224743]])[None, None], - dims=("b", "c", "x", "y"), - ) - result = preprocessing(data) - xr.testing.assert_allclose(expected, result) - - -def test_zero_mean_unit_across_axes(): - from bioimageio.core.prediction_pipeline._processing import ZeroMeanUnitVariance - - data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) - - axes = ("x", "y") - preprocessing = ZeroMeanUnitVariance("data_name", axes=axes, mode=PER_SAMPLE) - required = preprocessing.get_required_measures() - computed = compute_measures(required, sample={"data_name": data}) - preprocessing.set_computed_measures(computed) - - expected = xr.DataArray( - np.array( - [ - [-1.54919274, -1.16189455, -0.77459637], - [-0.38729818, 0.0, 0.38729818], - [0.77459637, 1.16189455, 1.54919274], - ] - ), - dims=("x", "y"), - ) - result = preprocessing(data) - xr.testing.assert_allclose(expected, result[dict(c=0)]) - - -def test_zero_mean_unit_variance_fixed(): - from bioimageio.core.prediction_pipeline._processing import ZeroMeanUnitVariance - - np_data = np.arange(9).reshape(3, 3) - mean = np_data.mean() - std = np_data.mean() - eps = 1.0e-7 - preprocessing = ZeroMeanUnitVariance("data_name", mode="fixed", mean=mean, std=std, eps=eps) - - data = xr.DataArray(np_data, dims=("x", "y")) - expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y")) - result = preprocessing(data) - xr.testing.assert_allclose(expected, result) - - -def test_binarize(): - from bioimageio.core.prediction_pipeline._processing import Binarize - - preprocessing = Binarize("data_name", threshold=14) - data = xr.DataArray(np.arange(30).reshape((2, 3, 5)), dims=("x", "y", "c")) - expected = xr.zeros_like(data) - expected[{"x": slice(1, None)}] = 1 - result = preprocessing(data) - xr.testing.assert_allclose(expected, result) - - -def test_clip_preprocessing(): - from bioimageio.core.prediction_pipeline._processing import Clip - - preprocessing = Clip("data_name", min=3, max=5) - data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) - expected = xr.DataArray(np.array([[3, 3, 3], [3, 4, 5], [5, 5, 5]]), dims=("x", "y")) - result = preprocessing(data) - xr.testing.assert_equal(expected, result) - - -def test_combination_of_preprocessing_steps_with_dims_specified(): - from bioimageio.core.prediction_pipeline._processing import ZeroMeanUnitVariance - - data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) - axes = ("x", "y") - preprocessing = ZeroMeanUnitVariance("data_name", axes=axes, mode=PER_SAMPLE) - required = preprocessing.get_required_measures() - computed = compute_measures(required, sample={"data_name": data}) - preprocessing.set_computed_measures(computed) - - expected = xr.DataArray( - np.array( - [ - [-1.54919274, -1.16189455, -0.77459637], - [-0.38729818, 0.0, 0.38729818], - [0.77459637, 1.16189455, 1.54919274], - ] - ), - dims=("x", "y"), - ) - - result = preprocessing(data) - xr.testing.assert_allclose(expected, result[dict(c=0)]) - - -def test_scale_range(): - from bioimageio.core.prediction_pipeline._processing import ScaleRange - - preprocessing = ScaleRange("data_name") - np_data = np.arange(9).reshape(3, 3).astype("float32") - data = xr.DataArray(np_data, dims=("x", "y")) - required = preprocessing.get_required_measures() - computed = compute_measures(required, sample={"data_name": data}) - preprocessing.set_computed_measures(computed) - - eps = 1.0e-6 - mi, ma = np_data.min(), np_data.max() - exp_data = (np_data - mi) / (ma - mi + eps) - expected = xr.DataArray(exp_data, dims=("x", "y")) - - result = preprocessing(data) - # NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct - np.testing.assert_allclose(expected, result) - - -def test_scale_range_axes(): - from bioimageio.core.prediction_pipeline._processing import ScaleRange - - min_percentile = 1.0 - max_percentile = 99.0 - preprocessing = ScaleRange( - "data_name", axes=("x", "y"), min_percentile=min_percentile, max_percentile=max_percentile - ) - - np_data = np.arange(18).reshape((2, 3, 3)).astype("float32") - data = xr.DataArray(np_data, dims=("c", "x", "y")) - - required = preprocessing.get_required_measures() - computed = compute_measures(required, sample={"data_name": data}) - preprocessing.set_computed_measures(computed) - - eps = 1.0e-6 - p_low = np.percentile(np_data, min_percentile, axis=(1, 2), keepdims=True) - p_up = np.percentile(np_data, max_percentile, axis=(1, 2), keepdims=True) - exp_data = (np_data - p_low) / (p_up - p_low + eps) - expected = xr.DataArray(exp_data, dims=("c", "x", "y")) - - result = preprocessing(data) - # NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct - np.testing.assert_allclose(expected, result) - - -def test_sigmoid(): - from bioimageio.core.prediction_pipeline._processing import Sigmoid - - shape = (3, 32, 32) - axes = ("c", "y", "x") - np_data = np.random.rand(*shape) - data = xr.DataArray(np_data, dims=axes) - - sigmoid = Sigmoid("data_name") - res = sigmoid(data) - - exp = xr.DataArray(1.0 / (1 + np.exp(-np_data)), dims=axes) - xr.testing.assert_allclose(res, exp) diff --git a/tests/prediction_pipeline/test_processing.py b/tests/prediction_pipeline/test_processing.py deleted file mode 100644 index 819982a2..00000000 --- a/tests/prediction_pipeline/test_processing.py +++ /dev/null @@ -1,55 +0,0 @@ -import dataclasses - -import numpy as np -import pytest -import xarray as xr - -from bioimageio.core.prediction_pipeline._processing import IMPLEMENTED_PROCESSING -from bioimageio.core.prediction_pipeline._utils import FIXED - -try: - from typing import get_args -except ImportError: - from typing_extensions import get_args # type: ignore - - -def test_assert_dtype(): - from bioimageio.core.prediction_pipeline._processing import AssertDtype - - proc = AssertDtype("test_tensor", dtype="uint8") - tensor = xr.DataArray(np.zeros((1,), dtype="uint8"), dims=tuple("c")) - out = proc(tensor) - assert out is tensor - - tensor = tensor.astype("uint16") - with pytest.raises(AssertionError): - out = proc(tensor) - assert out is tensor - - -@pytest.mark.parametrize( - "proc", - list(IMPLEMENTED_PROCESSING["pre"].values()) + list(IMPLEMENTED_PROCESSING["post"].values()), -) -def test_no_req_measures_for_mode_fixed(proc): - # check if mode=fixed is valid for this proc - for f in dataclasses.fields(proc): - if f.name == "mode": - break - else: - raise AttributeError("Processing is missing mode attribute") - # mode is always annotated as literals (or literals of literals) - valid_modes = get_args(f.type) - for inner in get_args(f.type): - valid_modes += get_args(inner) - - if FIXED not in valid_modes: - return - - # we might be missing required kwargs. These have marshmallow.missing value as default - # and raise a TypeError is in __post_init__() - proc.__post_init__ = lambda self: None # ignore missing kwargs - - proc_instance = proc(tensor_name="tensor_name", mode=FIXED) - req_measures = proc_instance.get_required_measures() - assert not req_measures diff --git a/tests/resource_io/test_load_rdf.py b/tests/resource_io/test_load_rdf.py index 873d9b26..0d86d1a2 100644 --- a/tests/resource_io/test_load_rdf.py +++ b/tests/resource_io/test_load_rdf.py @@ -4,43 +4,18 @@ import pytest -from bioimageio.core._internal.validation_visitors import resolve_source +def test_load_model_with_abs_path_source(unet2d_nuclei_broad_model: Path): + from bioimageio.spec import load_description -def test_load_non_existing_rdf(): - from bioimageio.core import load_description - - spec_path = Path("some/none/existing/path/to/spec.model.yaml") - - with pytest.raises(FileNotFoundError): - load_description(spec_path) - - -def test_load_raw_model(any_model): - from bioimageio.core import load_raw_resource_description - - raw_model = load_raw_resource_description(any_model) - assert raw_model - - -def test_load_model(any_model): - from bioimageio.core import load_description - - model = load_description(any_model) - assert model - - -def test_load_model_with_abs_path_source(unet2d_nuclei_broad_model): - from bioimageio.core.resource_io import load_description, load_raw_resource_description - - raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) - path_source = (raw_rd.root_path / "rdf.yaml").absolute() + raw_rd = load_description(unet2d_nuclei_broad_model) + path_source = (raw_rd.root / "rdf.yaml").absolute() assert path_source.is_absolute() model = load_description(path_source) assert model -def test_load_model_with_rel_path_source(unet2d_nuclei_broad_model): +def test_load_model_with_rel_path_source(unet2d_nuclei_broad_model: Path): from bioimageio.core.resource_io import load_description, load_raw_resource_description raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) @@ -50,7 +25,7 @@ def test_load_model_with_rel_path_source(unet2d_nuclei_broad_model): assert model -def test_load_model_with_abs_str_source(unet2d_nuclei_broad_model): +def test_load_model_with_abs_str_source(unet2d_nuclei_broad_model: Path): from bioimageio.core.resource_io import load_description, load_raw_resource_description raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) @@ -60,7 +35,7 @@ def test_load_model_with_abs_str_source(unet2d_nuclei_broad_model): assert model -def test_load_model_with_rel_str_source(unet2d_nuclei_broad_model): +def test_load_model_with_rel_str_source(unet2d_nuclei_broad_model: Path): from bioimageio.core.resource_io import load_description, load_raw_resource_description raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) @@ -70,7 +45,7 @@ def test_load_model_with_rel_str_source(unet2d_nuclei_broad_model): assert model -def test_load_remote_rdf(unet2d_nuclei_broad_model): +def test_load_remote_rdf(unet2d_nuclei_broad_model: Path): # remote model is a pytorch model, needing unet2d_nuclei_broad_model skips the test when needed _ = unet2d_nuclei_broad_model from bioimageio.core import load_description diff --git a/tests/resource_io/test_utils.py b/tests/resource_io/test_utils.py index d1e570cc..ff834edc 100644 --- a/tests/resource_io/test_utils.py +++ b/tests/resource_io/test_utils.py @@ -2,12 +2,13 @@ from pathlib import Path import pytest -from bioimageio.spec.shared import raw_nodes -from bioimageio.spec.shared.raw_nodes import RawNode -from bioimageio.core._internal import validation_visitors -from bioimageio.core._internal.validation_visitors import Sha256NodeChecker -from bioimageio.core.resource_io import nodes +# from bioimageio.spec.shared import raw_nodes +# from bioimageio.spec.shared.raw_nodes import RawNode + +# from bioimageio.core._internal import validation_visitors +# from bioimageio.core._internal.validation_visitors import Sha256NodeChecker +# from bioimageio.core.resource_io import nodes def test_resolve_import_path(tmpdir): diff --git a/tests/test_cli.py b/tests/test_cli.py index 473a1b1a..2ed9f894 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,6 @@ from typing import Sequence import numpy as np -import pytest from bioimageio.core import load_description diff --git a/tests/test_export_package.py b/tests/test_export_package.py deleted file mode 100644 index a04aa994..00000000 --- a/tests/test_export_package.py +++ /dev/null @@ -1,60 +0,0 @@ -import shutil -from pathlib import Path -from tempfile import TemporaryDirectory -from zipfile import ZipFile - -from marshmallow import missing - -from bioimageio.spec.model import raw_nodes - - -def test_export_package(any_onnx_model): - from bioimageio.core import export_resource_package, load_raw_resource_description - - package_path = export_resource_package(any_onnx_model, weights_priority_order=["onnx"]) - assert isinstance(package_path, Path), package_path - assert package_path.exists(), package_path - - raw_model = load_raw_resource_description(package_path) - assert isinstance(raw_model, raw_nodes.Model) - - -def test_package_with_folder(unet2d_nuclei_broad_model): - from bioimageio.core import export_resource_package, load_raw_resource_description - - with TemporaryDirectory() as tmp_dir: - tmp_dir = Path(tmp_dir) - - # extract package (to not cache to BIOIMAGEIO_CACHE) - package_folder = tmp_dir / "package" - with ZipFile(unet2d_nuclei_broad_model) as zf: - zf.extractall(package_folder) - - # load package - model = load_raw_resource_description(package_folder / "rdf.yaml") - assert isinstance(model, raw_nodes.Model) - - # alter package to have its documentation in a nested folder - doc = model.documentation - assert doc is not missing - doc = doc.relative_to(model.root_path) - assert not doc.is_absolute() - new_doc = Path("nested") / "folder" / doc - (package_folder / new_doc).parent.mkdir(parents=True) - shutil.move(package_folder / doc, package_folder / new_doc) - model.documentation = new_doc - - # export altered package - altered_package = tmp_dir / "altered_package.zip" - altered_package = export_resource_package(model, output_path=altered_package, weights_priority_order=["onnx"]) - - # extract altered package (to not cache to BIOIMAGEIO_CACHE) - altered_package_folder = tmp_dir / "altered_package" - with ZipFile(altered_package) as zf: - zf.extractall(altered_package_folder) - - # load altered package - reloaded_model = load_raw_resource_description(altered_package_folder / "rdf.yaml") - assert isinstance(reloaded_model, raw_nodes.Model) - assert reloaded_model.documentation.as_posix().endswith(new_doc.as_posix()) - assert reloaded_model.documentation.exists() diff --git a/tests/test_internal/test_validation_visitors.py b/tests/test_internal/test_validation_visitors.py deleted file mode 100644 index 7988f658..00000000 --- a/tests/test_internal/test_validation_visitors.py +++ /dev/null @@ -1,39 +0,0 @@ -from functools import singledispatchmethod - -from bioimageio.core._internal.validation_visitors import Note, ValidationVisitor -from bioimageio.spec._internal.base_nodes import Node -from bioimageio.spec.summary import ErrorEntry - - -def test_traversing_nodes(): - class MyVisitor(ValidationVisitor): - @singledispatchmethod - def visit(self, obj: type, note: Note = Note()): - super().visit(obj, note) - - @visit.register - def _visit_int(self, nr: int, note: Note = Note()): - super().visit(nr, note) - self.errors.append(ErrorEntry(loc=note.loc, msg=f"nr: {nr}", type="got-int")) - - class NestedNode(Node): - leaf: int - - class MyNode(Node): - nested: NestedNode - - tree = { - "a": MyNode(nested=NestedNode(leaf=1)), - "b": [NestedNode(leaf=2), NestedNode(leaf=3)], - "c": (NestedNode(leaf=4),), - "d": {"deep": MyNode(nested=NestedNode(leaf=5))}, - } - visitor = MyVisitor() - visitor.visit(tree) - assert len(visitor.errors) == [ - ErrorEntry(loc=("a", "nested", "leaf"), msg="nr: 1", type="got-int"), - ErrorEntry(loc=("b", 0, "leaf"), msg="nr: 2", type="got-int"), - ErrorEntry(loc=("b", 1, "leaf"), msg="nr: 3", type="got-int"), - ErrorEntry(loc=("c", 0, "leaf"), msg="nr: 4", type="got-int"), - ErrorEntry(loc=("d", "deep", "nested", "leaf"), msg="nr: 5", type="got-int"), - ] diff --git a/tests/test_prediction.py b/tests/test_prediction.py index b73d0f82..a0e34b08 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -4,15 +4,15 @@ import numpy as np from numpy.testing import assert_array_almost_equal -from bioimageio.core import load_description -from bioimageio.core.resource_io.nodes import Model +from bioimageio.spec import load_description +from bioimageio.spec.model.v0_5 import ModelDescr -def test_predict_image(any_model, tmpdir): +def test_predict_image(any_model: Path, tmpdir: Path): from bioimageio.core.prediction import predict_image spec = load_description(any_model) - assert isinstance(spec, Model) + assert isinstance(spec, ModelDescr) inputs = spec.test_inputs outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))] diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py new file mode 100644 index 00000000..f029517e --- /dev/null +++ b/tests/test_proc_ops.py @@ -0,0 +1,295 @@ +from typing import Iterable, Optional, Tuple, Type, TypeVar + +import numpy as np +import pytest +import xarray as xr +from typing_extensions import TypeGuard + +from bioimageio.core.common import AxisId, Sample, TensorId +from bioimageio.core.stat_calculators import compute_measures +from bioimageio.core.stat_measures import SampleMean, SamplePercentile, SampleStd + + +@pytest.fixture(scope="module") +def tid(): + return TensorId("data123") + + +def test_scale_linear(tid: TensorId): + from bioimageio.core.proc_ops import ScaleLinear + + offset = xr.DataArray([1, 2, 42], dims=("c")) + gain = xr.DataArray([1, 2, 3], dims=("c")) + data = xr.DataArray(np.arange(6).reshape((1, 2, 3)), dims=("x", "y", "c")) + sample = Sample(data={tid: data}) + + op = ScaleLinear(input=tid, output=tid, offset=offset, gain=gain) + op(sample) + + expected = xr.DataArray(np.array([[[1, 4, 48], [4, 10, 57]]]), dims=("x", "y", "c")) + xr.testing.assert_allclose(expected, sample.data[tid]) + + +def test_scale_linear_no_channel(tid: TensorId): + from bioimageio.core.proc_ops import ScaleLinear + + op = ScaleLinear(tid, tid, offset=1, gain=2) + data = xr.DataArray(np.arange(6).reshape(2, 3), dims=("x", "y")) + sample = Sample(data={tid: data}) + op(sample) + + expected = xr.DataArray(np.array([[1, 3, 5], [7, 9, 11]]), dims=("x", "y")) + xr.testing.assert_allclose(expected, sample.data[tid]) + + +T = TypeVar("T") + + +def is_iterable(val: Iterable[T], inner: Type[T]) -> TypeGuard[Iterable[T]]: + """Determines whether all objects in the list are strings""" + return all(isinstance(x, inner) for x in val) + + +def test_zero_mean_unit_variance(tid: TensorId): + from bioimageio.core.proc_ops import ZeroMeanUnitVariance + + data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) + sample = Sample(data={tid: data}) + m = SampleMean(tid) + std = SampleStd(tid) + op = ZeroMeanUnitVariance(tid, tid, m, std) + req = op.required_measures + sample.stat = compute_measures(req, [sample]) + op(sample) + + expected = xr.DataArray( + np.array( + [ + [-1.54919274, -1.16189455, -0.77459637], + [-0.38729818, 0.0, 0.38729818], + [0.77459637, 1.16189455, 1.54919274], + ] + ), + dims=("x", "y"), + ) + xr.testing.assert_allclose(expected, sample.data[tid]) + + +def test_zero_mean_unit_variance_fixed(tid: TensorId): + from bioimageio.core.proc_ops import FixedZeroMeanUnitVariance + + op = FixedZeroMeanUnitVariance( + tid, tid, mean=xr.DataArray([1, 4, 7], dims=("y")), std=xr.DataArray([0.81650, 0.81650, 0.81650], dims=("y")) + ) + data = xr.DataArray(np.arange(9).reshape((1, 1, 3, 3)), dims=("b", "c", "x", "y")) + expected = xr.DataArray( + np.array([[-1.224743, 0.0, 1.224743], [-1.224743, 0.0, 1.224743], [-1.224743, 0.0, 1.224743]])[None, None], + dims=("b", "c", "x", "y"), + ) + sample = Sample(data={tid: data}) + op(sample) + xr.testing.assert_allclose(expected, sample.data[tid]) + + +def test_zero_mean_unit_across_axes(tid: TensorId): + from bioimageio.core.proc_ops import ZeroMeanUnitVariance + + data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) + + op = ZeroMeanUnitVariance(tid, tid, SampleMean(tid, (AxisId("c"),)), SampleStd(tid, (AxisId("c"),))) + sample = Sample(data={tid: data}) + sample.stat = compute_measures(op.required_measures, [sample]) + + expected = xr.DataArray( + np.array( + [ + [-1.54919274, -1.16189455, -0.77459637], + [-0.38729818, 0.0, 0.38729818], + [0.77459637, 1.16189455, 1.54919274], + ] + ), + dims=("x", "y"), + ) + op(sample) + xr.testing.assert_allclose(expected, sample.data[tid]) + + +def test_zero_mean_unit_variance_fixed2(tid: TensorId): + from bioimageio.core.proc_ops import FixedZeroMeanUnitVariance + + np_data = np.arange(9).reshape(3, 3) + mean = float(np_data.mean()) + std = float(np_data.mean()) + eps = 1.0e-7 + op = FixedZeroMeanUnitVariance(tid, tid, mean=mean, std=std, eps=eps) + + data = xr.DataArray(np_data, dims=("x", "y")) + sample = Sample(data={tid: data}) + expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y")) + op(sample) + xr.testing.assert_allclose(expected, sample.data[tid]) + + +def test_binarize(tid: TensorId): + from bioimageio.core.proc_ops import Binarize + + op = Binarize(tid, tid, threshold=14) + data = xr.DataArray(np.arange(30).reshape((2, 3, 5)), dims=("x", "y", "c")) + sample = Sample(data={tid: data}) + expected = xr.zeros_like(data) + expected[{"x": slice(1, None)}] = 1 + op(sample) + xr.testing.assert_allclose(expected, sample.data[tid]) + + +def test_binarize2(tid: TensorId): + from bioimageio.core.proc_ops import Binarize + + shape = (3, 32, 32) + axes = ("c", "y", "x") + np_data = np.random.rand(*shape) + data = xr.DataArray(np_data, dims=axes) + + threshold = 0.5 + exp = xr.DataArray(np_data > threshold, dims=axes) + + sample = Sample(data={tid: data}) + binarize = Binarize(tid, tid, threshold=threshold) + binarize(sample) + xr.testing.assert_allclose(exp, sample.data[tid]) + + +def test_clip(tid: TensorId): + from bioimageio.core.proc_ops import Clip + + op = Clip(tid, tid, min=3, max=5) + data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) + sample = Sample(data={tid: data}) + + expected = xr.DataArray(np.array([[3, 3, 3], [3, 4, 5], [5, 5, 5]]), dims=("x", "y")) + op(sample) + xr.testing.assert_equal(expected, sample.data[tid]) + + +def test_combination_of_op_steps_with_dims_specified(tid: TensorId): + from bioimageio.core.proc_ops import ZeroMeanUnitVariance + + data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) + sample = Sample(data={tid: data}) + op = ZeroMeanUnitVariance(tid, tid, SampleMean(tid, (AxisId("c"),)), SampleStd(tid, (AxisId("c"),))) + sample.stat = compute_measures(op.required_measures, [sample]) + + expected = xr.DataArray( + np.array( + [ + [-1.54919274, -1.16189455, -0.77459637], + [-0.38729818, 0.0, 0.38729818], + [0.77459637, 1.16189455, 1.54919274], + ] + ), + dims=("x", "y"), + ) + + op(sample) + xr.testing.assert_allclose(expected, sample.data[tid]) + + +@pytest.mark.parametrize("axes", [None, tuple(map(AxisId, "cy")), tuple(map(AxisId, "cyx")), tuple(map(AxisId, "x"))]) +def test_scale_mean_variance(tid: TensorId, axes: Optional[Tuple[AxisId, ...]]): + from bioimageio.core.proc_ops import ScaleMeanVariance + + shape = (3, 32, 46) + ipt_axes = ("c", "y", "x") + np_data = np.random.rand(*shape) + ipt_data = xr.DataArray(np_data, dims=ipt_axes) + ref_data = xr.DataArray((np_data * 2) + 3, dims=ipt_axes) + + op = ScaleMeanVariance(tid, tid, reference_tensor=TensorId("ref_name"), axes=axes) + sample = Sample(data={tid: ipt_data, TensorId("ref_name"): ref_data}) + sample.stat = compute_measures(op.required_measures, [sample]) + op(sample) + xr.testing.assert_allclose(ref_data, sample.data[tid]) + + +@pytest.mark.parametrize("axes", [None, tuple(map(AxisId, "cy")), tuple(map(AxisId, "y")), tuple(map(AxisId, "yx"))]) +def test_scale_mean_variance_per_channel(tid: TensorId, axes: Optional[Tuple[AxisId, ...]]): + from bioimageio.core.proc_ops import ScaleMeanVariance + + shape = (3, 32, 46) + ipt_axes = ("c", "y", "x") + np_data = np.random.rand(*shape) + ipt_data = xr.DataArray(np_data, dims=ipt_axes) + + # set different mean, std per channel + np_ref_data = np.stack([d * i + i for i, d in enumerate(np_data, start=2)]) + ref_data = xr.DataArray(np_ref_data, dims=ipt_axes) + + op = ScaleMeanVariance(tid, tid, reference_tensor=TensorId("ref_name"), axes=axes) + sample = Sample(data={tid: ipt_data, TensorId("ref_name"): ref_data}) + sample.stat = compute_measures(op.required_measures, [sample]) + op(sample) + + if axes is not None and AxisId("c") not in axes: + # mean,std per channel should match exactly + xr.testing.assert_allclose(ref_data, sample.data[tid]) + else: + # mean,std across channels should not match + with pytest.raises(AssertionError): + xr.testing.assert_allclose(ref_data, sample.data[tid]) + + +def test_scale_range(tid: TensorId): + from bioimageio.core.proc_ops import ScaleRange + + op = ScaleRange(tid, tid) + np_data = np.arange(9).reshape(3, 3).astype("float32") + data = xr.DataArray(np_data, dims=("x", "y")) + sample = Sample(data={tid: data}) + sample.stat = compute_measures(op.required_measures, [sample]) + + eps = 1.0e-6 + mi, ma = np_data.min(), np_data.max() + exp_data = (np_data - mi) / (ma - mi + eps) + expected = xr.DataArray(exp_data, dims=("x", "y")) + + op(sample) + # NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct + np.testing.assert_allclose(expected, sample.data[tid]) + + +def test_scale_range_axes(tid: TensorId): + from bioimageio.core.proc_ops import ScaleRange + + lower_percentile = SamplePercentile(tid, 1, axes=(AxisId("c"),)) + upper_percentile = SamplePercentile(tid, 100, axes=(AxisId("c"),)) + op = ScaleRange(tid, tid, lower_percentile, upper_percentile) + + np_data = np.arange(18).reshape((2, 3, 3)).astype("float32") + data = xr.DataArray(np_data, dims=("c", "x", "y")) + sample = Sample(data={tid: data}) + sample.stat = compute_measures(op.required_measures, [sample]) + + eps = 1.0e-6 + p_low = np.percentile(np_data, lower_percentile.n, axis=(1, 2), keepdims=True) + p_up = np.percentile(np_data, upper_percentile.n, axis=(1, 2), keepdims=True) + exp_data = (np_data - p_low) / (p_up - p_low + eps) + expected = xr.DataArray(exp_data, dims=("c", "x", "y")) + + op(sample) + # NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct + np.testing.assert_allclose(expected, sample.data[tid]) + + +def test_sigmoid(tid: TensorId): + from bioimageio.core.proc_ops import Sigmoid + + shape = (3, 32, 32) + axes = ("c", "y", "x") + np_data = np.random.rand(*shape) + data = xr.DataArray(np_data, dims=axes) + sample = Sample(data={tid: data}) + sigmoid = Sigmoid(tid, tid) + sigmoid(sample) + + exp = xr.DataArray(1.0 / (1 + np.exp(-np_data)), dims=axes) + xr.testing.assert_allclose(exp, sample.data[tid]) diff --git a/tests/test_resource_tests/test_test_model.py b/tests/test_resource_tests/test_test_model.py index 9498e8ab..f83baf52 100644 --- a/tests/test_resource_tests/test_test_model.py +++ b/tests/test_resource_tests/test_test_model.py @@ -1,74 +1,72 @@ -import pathlib +from pathlib import Path -import pytest +from bioimageio.spec import InvalidDescr -def test_error_for_wrong_shape(stardist_wrong_shape): +def test_error_for_wrong_shape(stardist_wrong_shape: Path): from bioimageio.core.resource_tests import test_model - summary = test_model(stardist_wrong_shape)[-1] + summary = test_model(stardist_wrong_shape) expected_error_message = ( "Shape (1, 512, 512, 33) of test output 0 'output' does not match output shape description: " "ImplicitOutputShape(reference_tensor='input', " "scale=[1.0, 1.0, 1.0, 0.0], offset=[1.0, 1.0, 1.0, 33.0])." ) - assert summary["error"] == expected_error_message + assert summary.details[0].errors[0].msg == expected_error_message -def test_error_for_wrong_shape2(stardist_wrong_shape2): +def test_error_for_wrong_shape2(stardist_wrong_shape2: Path): from bioimageio.core.resource_tests import test_model - summary = test_model(stardist_wrong_shape2)[-1] + summary = test_model(stardist_wrong_shape2) expected_error_message = ( "Shape (1, 512, 512, 1) of test input 0 'input' does not match input shape description: " "ParametrizedInputShape(min=[1, 80, 80, 1], step=[0, 17, 17, 0])." ) - assert summary["error"] == expected_error_message + assert summary.details[0].errors[0].msg == expected_error_message -def test_test_model(any_model): +def test_test_model(any_model: Path): from bioimageio.core.resource_tests import test_model summary = test_model(any_model) - assert all([s["status"] for s in summary]) + assert summary.status == "passed" -def test_test_resource(any_model): +def test_test_resource(any_model: Path): from bioimageio.core.resource_tests import test_description summary = test_description(any_model) - assert all([s["status"] for s in summary]) + assert summary.status == "passed" -def test_validation_section_warning(unet2d_nuclei_broad_model, tmp_path: pathlib.Path): +def test_validation_section_warning(unet2d_nuclei_broad_model: str, tmp_path: Path): from bioimageio.core import load_description from bioimageio.core.resource_tests import test_description model = load_description(unet2d_nuclei_broad_model) - - summary = test_description(model)[2] - assert summary["name"] == "Test documentation completeness." - assert summary["warnings"] == {"documentation": "No '# Validation' (sub)section found."} - assert summary["status"] == "passed" + assert not isinstance(model, InvalidDescr) + summary = test_description(model) + assert summary.name == "Test documentation completeness." + assert summary.warnings == {"documentation": "No '# Validation' (sub)section found."} + assert summary.status == "passed" doc_with_validation = tmp_path / "doc.md" - doc_with_validation.write_text("# Validation\nThis is a section about how to validate the model on new data") + _ = doc_with_validation.write_text("# Validation\nThis is a section about how to validate the model on new data") model.documentation = doc_with_validation - summary = test_description(model)[2] - assert summary["name"] == "Test documentation completeness." - assert summary["warnings"] == {} - assert summary["status"] == "passed" + summary = test_description(model) + assert summary.name == "Test documentation completeness." + assert summary.warnings == {} + assert summary.status == "passed" -def test_issue289(unet2d_nuclei_broad_model): +def test_issue289(unet2d_nuclei_broad_model: str): """test for failure case from https://github.com/bioimage-io/core-bioimage-io-python/issues/289""" # remote model is a pytorch model, needing unet2d_nuclei_broad_model skips the test when needed _ = unet2d_nuclei_broad_model - import bioimageio.core from bioimageio.core.resource_tests import test_model doi = "10.5281/zenodo.6287342" - model_resource = bioimageio.core.load_description(doi) - test_result = test_model(model_resource) - assert all([t["status"] == "passed" for t in test_result]) + summary = test_model(doi) + assert summary.status == "passed" diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py new file mode 100644 index 00000000..7845da89 --- /dev/null +++ b/tests/test_stat_measures.py @@ -0,0 +1,39 @@ +from itertools import product +from typing import Optional, Tuple + +import numpy as np +import pytest +import xarray as xr + +from bioimageio.core import stat_measures +from bioimageio.core.common import AxisId, Sample, TensorId +from bioimageio.core.stat_calculators import SamplePercentilesCalculator, get_measure_calculators +from bioimageio.core.stat_measures import SamplePercentile + + +@pytest.mark.parametrize("name, axes", product(["mean", "var", "std"], [None, (AxisId("x"), AxisId("y"))])) +def test_individual_normal_measure(name: str, axes: Optional[Tuple[AxisId, AxisId]]): + measure = getattr(stat_measures, name.title() + "Measure")(axes=axes) + data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) + + expected = getattr(data, name)(dim=axes) + actual = measure.compute(data) + xr.testing.assert_allclose(expected, actual) + + +@pytest.mark.parametrize("axes", [None, (AxisId("x"), AxisId("y"))]) +def test_individual_percentile_measure(axes: Optional[Tuple[AxisId, ...]]): + ns = [0, 10, 50, 100] + tid = TensorId("tensor") + + measures = [SamplePercentile(tensor_id=tid, axes=axes, n=n) for n in ns] + calcs, _ = get_measure_calculators(measures) + assert len(calcs) == 1 + calc = calcs[0] + assert isinstance(calc, SamplePercentilesCalculator) + + data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) + actual = calc.compute(Sample(data={tid: data})) + for m in measures: + expected = data.quantile(q=m.n / 100, dim=m.axes) + xr.testing.assert_allclose(expected, actual[m]) diff --git a/tests/weight_converter/keras/test_tensorflow.py b/tests/weight_converter/keras/test_tensorflow.py index 712263fa..5cc7f297 100644 --- a/tests/weight_converter/keras/test_tensorflow.py +++ b/tests/weight_converter/keras/test_tensorflow.py @@ -1,22 +1,30 @@ import zipfile +from pathlib import Path +from bioimageio.spec import load_description +from bioimageio.spec.model.v0_5 import ModelDescr -def test_tensorflow_converter(any_keras_model, tmp_path): + +def test_tensorflow_converter(any_keras_model: Path, tmp_path: Path): from bioimageio.core.weight_converter.keras import convert_weights_to_tensorflow_saved_model_bundle out_path = tmp_path / "weights" - ret_val = convert_weights_to_tensorflow_saved_model_bundle(any_keras_model, out_path) + model = load_description(any_keras_model) + assert isinstance(model, ModelDescr), model.validation_summary.format() + ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path) assert out_path.exists() assert (out_path / "variables").exists() assert (out_path / "saved_model.pb").exists() assert ret_val == 0 # check for correctness is done in converter and returns 0 if it passes -def test_tensorflow_converter_zipped(any_keras_model, tmp_path): +def test_tensorflow_converter_zipped(any_keras_model: Path, tmp_path: Path): from bioimageio.core.weight_converter.keras import convert_weights_to_tensorflow_saved_model_bundle out_path = tmp_path / "weights.zip" - ret_val = convert_weights_to_tensorflow_saved_model_bundle(any_keras_model, out_path) + model = load_description(any_keras_model) + assert isinstance(model, ModelDescr), model.validation_summary.format() + ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path) assert out_path.exists() assert ret_val == 0 # check for correctness is done in converter and returns 0 if it passes diff --git a/tests/weight_converter/torch/test_onnx.py b/tests/weight_converter/torch/test_onnx.py index 5a26c916..bc757806 100644 --- a/tests/weight_converter/torch/test_onnx.py +++ b/tests/weight_converter/torch/test_onnx.py @@ -1,9 +1,11 @@ import os +from pathlib import Path + import pytest # todo: test with 'any_torch_model' -def test_onnx_converter(convert_to_onnx, tmp_path): +def test_onnx_converter(convert_to_onnx: Path, tmp_path, Path): from bioimageio.core.weight_converter.torch.onnx import convert_weights_to_onnx out_path = tmp_path / "weights.onnx" From 15d0cdd84e0cb8d30275f3fb4f77cce1c0ba74c1 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 27 Feb 2024 13:19:31 +0100 Subject: [PATCH 099/244] add AddKnownDatasetStats and UpdateStats --- .../model_adapters/_pytorch_model_adapter.py | 4 +- bioimageio/core/prediction_pipeline.py | 46 ++-- bioimageio/core/proc_ops.py | 112 +++++--- bioimageio/core/proc_setup.py | 70 ++++- bioimageio/core/stat_calculators.py | 70 +++-- bioimageio/core/utils/__init__.py | 61 +---- bioimageio/core/utils/_import_callable.py | 59 +++++ bioimageio/core/utils/_tensor_io.py | 20 ++ .../core/weight_converter/torch/onnx.py | 23 +- .../weight_converter/torch/torchscript.py | 8 +- .../core/weight_converter/torch/utils.py | 2 +- tests/build_spec/test_build_spec.py | 241 ------------------ .../test_combined_processing.py | 35 --- .../test_prediction_pipeline.py | 23 +- ..._prediction_pipeline_device_management.py} | 0 tests/test_resource_tests/test_test_model.py | 2 +- .../test_add_weights.py | 0 17 files changed, 308 insertions(+), 468 deletions(-) create mode 100644 bioimageio/core/utils/_import_callable.py create mode 100644 bioimageio/core/utils/_tensor_io.py delete mode 100644 tests/build_spec/test_build_spec.py delete mode 100644 tests/prediction_pipeline/test_combined_processing.py rename tests/{prediction_pipeline => }/test_prediction_pipeline.py (68%) rename tests/{prediction_pipeline/test_device_management.py => test_prediction_pipeline_device_management.py} (100%) rename tests/{build_spec => weight_converter}/test_add_weights.py (100%) diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index a9b7701b..95f3de50 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -58,7 +58,7 @@ def unload(self) -> None: @staticmethod def get_network( weight_spec: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr] - ) -> torch.nn.Module: + ) -> "torch.nn.Module": arch = import_callable( weight_spec.architecture, sha256=( @@ -79,7 +79,7 @@ def get_network( return network @staticmethod - def get_devices(devices: Optional[Sequence[str]] = None) -> List[torch.device]: + def get_devices(devices: Optional[Sequence[str]] = None) -> List["torch.device"]: if not devices: torch_devices = [torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")] else: diff --git a/bioimageio/core/prediction_pipeline.py b/bioimageio/core/prediction_pipeline.py index d01e0274..4f7db9e2 100644 --- a/bioimageio/core/prediction_pipeline.py +++ b/bioimageio/core/prediction_pipeline.py @@ -1,5 +1,6 @@ import warnings -from typing import Any, Dict, Iterable, List, Optional, Sequence +from types import MappingProxyType +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union import xarray as xr @@ -8,7 +9,7 @@ from bioimageio.core.model_adapters import get_weight_formats as get_weight_formats from bioimageio.core.proc_ops import Processing from bioimageio.core.proc_setup import setup_pre_and_postprocessing -from bioimageio.core.stat_calculators import StatsCalculator +from bioimageio.core.stat_measures import DatasetMeasure, MeasureValue from bioimageio.spec.model import AnyModelDescr, v0_4 from bioimageio.spec.model.v0_5 import WeightsFormat @@ -26,8 +27,6 @@ def __init__( bioimageio_model: AnyModelDescr, preprocessing: List[Processing], postprocessing: List[Processing], - ipt_stats: StatsCalculator, - out_stats: StatsCalculator, model: ModelAdapter, ) -> None: super().__init__() @@ -37,8 +36,6 @@ def __init__( self.name = name self._preprocessing = preprocessing self._postprocessing = postprocessing - self._ipt_stats = ipt_stats - self._out_stats = out_stats if isinstance(bioimageio_model, v0_4.ModelDescr): self._input_ids = [TensorId(d.name) for d in bioimageio_model.inputs] self._output_ids = [TensorId(d.name) for d in bioimageio_model.outputs] @@ -66,13 +63,11 @@ def predict(self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataAr def apply_preprocessing(self, sample: Sample) -> None: """apply preprocessing in-place, also updates sample stats""" - sample.stat.update(self._ipt_stats.update_and_get_all(sample)) for op in self._preprocessing: op(sample) def apply_postprocessing(self, sample: Sample) -> None: """apply postprocessing in-place, also updates samples stats""" - sample.stat.update(self._out_stats.update_and_get_all(sample)) for op in self._postprocessing: op(sample) @@ -85,7 +80,7 @@ def forward_sample(self, input_sample: Sample): self.apply_postprocessing(prediction) return prediction - def forward_named( + def forward_tensors( self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray ) -> Dict[TensorId, xr.DataArray]: """Apply preprocessing, run prediction and apply postprocessing.""" @@ -99,7 +94,7 @@ def forward_named( def forward(self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray) -> List[xr.DataArray]: """Apply preprocessing, run prediction and apply postprocessing.""" - named_outputs = self.forward_named(*input_tensors, **named_input_tensors) + named_outputs = self.forward_tensors(*input_tensors, **named_input_tensors) return [named_outputs[x] for x in self._output_ids] def load(self): @@ -120,7 +115,10 @@ def create_prediction_pipeline( *, devices: Optional[Sequence[str]] = None, weight_format: Optional[WeightsFormat] = None, - dataset_for_initial_statistics: Iterable[Sequence[xr.DataArray]] = tuple(), + weights_format: Optional[WeightsFormat] = None, + dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[xr.DataArray]]] = tuple(), + keep_updating_initial_dataset_statistics: bool = False, + fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType({}), model_adapter: Optional[ModelAdapter] = None, **deprecated_kwargs: Any, ) -> PredictionPipeline: @@ -132,13 +130,15 @@ def create_prediction_pipeline( * computation of output statistics * postprocessing """ + weights_format = weight_format or weights_format + del weight_format if deprecated_kwargs: warnings.warn(f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}") model_adapter = model_adapter or create_model_adapter( model_description=bioimageio_model, devices=devices, - weight_format_priority_order=weight_format and (weight_format,), + weight_format_priority_order=weights_format and (weights_format,), ) if isinstance(bioimageio_model, v0_4.ModelDescr): @@ -146,13 +146,19 @@ def create_prediction_pipeline( else: input_ids = [ipt.id for ipt in bioimageio_model.inputs] - preprocessing, postprocessing, pre_req_meas, post_req_meas = setup_pre_and_postprocessing(bioimageio_model) - ipt_stats = StatsCalculator(pre_req_meas) - out_stats = StatsCalculator(post_req_meas) - for tensors in dataset_for_initial_statistics: - sample = Sample(data=dict(zip(input_ids, tensors))) - ipt_stats.update(sample) - out_stats.update(sample) + def dataset(): + for x in dataset_for_initial_statistics: + if isinstance(x, Sample): + yield x + else: + yield Sample(data=dict(zip(input_ids, x))) + + preprocessing, postprocessing = setup_pre_and_postprocessing( + bioimageio_model, + dataset(), + keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics, + fixed_dataset_stats=fixed_dataset_statistics, + ) return PredictionPipeline( name=bioimageio_model.name, @@ -160,6 +166,4 @@ def create_prediction_pipeline( model=model_adapter, preprocessing=preprocessing, postprocessing=postprocessing, - ipt_stats=ipt_stats, - out_stats=out_stats, ) diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index fd3ef2ee..d055c059 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -3,14 +3,13 @@ from dataclasses import InitVar, dataclass, field from typing import ( Collection, - Generic, Hashable, Literal, + Mapping, Optional, Sequence, Set, Tuple, - Type, Union, cast, ) @@ -28,12 +27,15 @@ TensorId, ) from bioimageio.core.op_base import Operator +from bioimageio.core.stat_calculators import StatsCalculator from bioimageio.core.stat_measures import ( DatasetMean, + DatasetMeasure, DatasetPercentile, DatasetStd, MeanMeasure, Measure, + MeasureValue, SampleMean, SamplePercentile, SampleStd, @@ -85,24 +87,80 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: ... @dataclass -class Dataset(Operator): +class AddKnownDatasetStats(Operator): + dataset_stats: Mapping[DatasetMeasure, MeasureValue] + @property def required_measures(self) -> Set[Measure]: return set() + def __call__(self, sample: Sample) -> None: + sample.stat.update(self.dataset_stats.items()) + # @dataclass -# class AssertDtype(Operator): -# tensor: TensorId -# dtype: Union[Type[DTypeLike], Tuple[Type[DTypeLike], ...]] +# class UpdateStats(Operator): +# """Calculates sample and/or dataset measures""" + +# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] +# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed +# dataset measure values may be given, see `keep_updating_dataset_stats` for details. +# """ +# keep_updating_dataset_stats: Optional[bool] = None +# """indicates if operator calls should keep updating dataset statistics or not + +# default (None): if `measures` is a `Mapping` (i.e. initial measure values are +# given) no further updates to dataset statistics is conducted, otherwise (w.o. +# initial measure values) dataset statistics are updated by each processed sample. +# """ +# _keep_updating_dataset_stats: bool = field(init=False) +# _stats_calculator: StatsCalculator = field(init=False) # @property # def required_measures(self) -> Set[Measure]: # return set() -# def apply(self, tensor: Tensor) -> Tensor: -# assert isinstance(tensor.dtype, self.dtype) -# return tensor +# def __post_init__(self): +# self._stats_calculator = StatsCalculator(self.measures) +# if self.keep_updating_dataset_stats is None: +# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) +# else: +# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats + +# def __call__(self, sample: Sample) -> None: +# if self._keep_updating_dataset_stats: +# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) +# else: +# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) + + +@dataclass +class UpdateStats(Operator): + """Calculates sample and/or dataset measures""" + + stats_calculator: StatsCalculator + """`StatsCalculator` to be used by this operator.""" + keep_updating_initial_dataset_stats: bool = False + """indicates if operator calls should keep updating initial dataset statistics or not; + if the `stats_calculator` was not provided with any initial dataset statistics, + these are always updated with every new sample. + """ + _keep_updating_dataset_stats: bool = field(init=False) + + @property + def required_measures(self) -> Set[Measure]: + return set() + + def __post_init__(self): + self._keep_updating_initial_dataset_stats = ( + self.keep_updating_initial_dataset_stats or not self.stats_calculator.has_dataset_measures + ) + + def __call__(self, sample: Sample) -> None: + if self._keep_updating_dataset_stats: + sample.stat.update(self.stats_calculator.update_and_get_all(sample)) + else: + sample.stat.update(self.stats_calculator.skip_update_and_get_all(sample)) @dataclass @@ -457,39 +515,8 @@ def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: ProcDescr = Union[v0_4.PreprocessingDescr, v0_4.PostprocessingDescr, v0_5.PreprocessingDescr, v0_5.PostprocessingDescr] -# get_impl_class which also returns the kwargs class -# def get_impl_class(proc_spec: ProcDescr): -# if isinstance(proc_spec, AssertDtype): -# return AssertDtypeImpl, AssertDtypeKwargs -# elif isinstance(proc_spec, v0_4.BinarizeDescr): -# return BinarizeImpl, v0_4.BinarizeKwargs -# elif isinstance(proc_spec, v0_5.BinarizeDescr): -# return BinarizeImpl, v0_5.BinarizeKwargs -# elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): -# return ClipImpl, v0_5.ClipKwargs -# elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): -# return EnsureDtypeImpl, v0_5.EnsureDtypeKwargs -# elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): -# return FixedZeroMeanUnitVarianceImpl, v0_5.FixedZeroMeanUnitVarianceKwargs -# elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): -# return ScaleLinearImpl, v0_5.ScaleLinearKwargs -# elif isinstance(proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)): -# return ScaleMeanVarianceImpl, v0_5.ScaleMeanVarianceKwargs -# elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): -# return ScaleRangeImpl, v0_5.ScaleRangeKwargs -# elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): -# return SigmoidImpl, v0_5.ProcessingKwargs -# elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) and proc_spec.kwargs.mode == "fixed": -# return FixedZeroMeanUnitVarianceImpl, v0_5.FixedZeroMeanUnitVarianceKwargs -# elif isinstance( -# proc_spec, -# (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), -# ): -# return ZeroMeanUnitVarianceImpl, v0_5.ZeroMeanUnitVarianceKwargs -# else: -# assert_never(proc_spec) - Processing = Union[ + AddKnownDatasetStats, Binarize, Clip, EnsureDtype, @@ -498,11 +525,12 @@ def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: ScaleMeanVariance, ScaleRange, Sigmoid, + UpdateStats, ZeroMeanUnitVariance, ] -def get_proc_class(proc_spec: ProcDescr) -> Type[Processing]: +def get_proc_class(proc_spec: ProcDescr): if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): return Binarize elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index e77673de..4c504681 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -1,5 +1,8 @@ +from types import MappingProxyType from typing import ( + Iterable, List, + Mapping, NamedTuple, Sequence, Set, @@ -9,22 +12,64 @@ from typing_extensions import assert_never -from bioimageio.core.proc_ops import Processing, get_proc_class -from bioimageio.core.stat_measures import Measure +from bioimageio.core.common import Sample +from bioimageio.core.proc_ops import AddKnownDatasetStats, Processing, UpdateStats, get_proc_class +from bioimageio.core.stat_calculators import StatsCalculator +from bioimageio.core.stat_measures import DatasetMeasure, Measure, MeasureValue from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.model.v0_5 import TensorId TensorDescr = Union[v0_4.InputTensorDescr, v0_4.OutputTensorDescr, v0_5.InputTensorDescr, v0_5.OutputTensorDescr] +class PreAndPostprocessing(NamedTuple): + pre: List[Processing] + post: List[Processing] + + class _SetupProcessing(NamedTuple): - preprocessing: List[Processing] - postprocessing: List[Processing] - preprocessing_req_measures: Set[Measure] - postprocessing_req_measures: Set[Measure] + pre: List[Processing] + post: List[Processing] + pre_measures: Set[Measure] + post_measures: Set[Measure] + + +def setup_pre_and_postprocessing( + model: AnyModelDescr, + dataset_for_initial_statistics: Iterable[Sample], + keep_updating_initial_dataset_stats: bool = False, + fixed_dataset_stats: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType({}), +) -> PreAndPostprocessing: + prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model) + + missing_dataset_stats = {m for m in prep_meas | post_meas if m not in fixed_dataset_stats} + initial_stats_calc = StatsCalculator(missing_dataset_stats) + for sample in dataset_for_initial_statistics: + initial_stats_calc.update(sample) + initial_stats = initial_stats_calc.finalize() + prep.insert( + 0, + UpdateStats( + StatsCalculator(prep_meas, initial_stats), + keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, + ), + ) + post.insert( + 0, + UpdateStats( + StatsCalculator(post_meas, initial_stats), + keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, + ), + ) + if fixed_dataset_stats: + prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats)) + post.insert(0, AddKnownDatasetStats(fixed_dataset_stats)) + + return PreAndPostprocessing(prep, post) -def setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing: + +def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing: pre_measures: Set[Measure] = set() post_measures: Set[Measure] = set() @@ -62,12 +107,9 @@ def prepare_procs(tensor_descrs: Sequence[TensorDescr]): procs.append(req) return procs - pre_procs = prepare_procs(model.inputs) - post_procs = prepare_procs(model.outputs) - return _SetupProcessing( - preprocessing=pre_procs, - postprocessing=post_procs, - preprocessing_req_measures=pre_measures, - postprocessing_req_measures=post_measures, + pre=prepare_procs(model.inputs), + post=prepare_procs(model.outputs), + pre_measures=pre_measures, + post_measures=post_measures, ) diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index e3ccdc16..3b0045da 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -1,6 +1,6 @@ from __future__ import annotations -import collections +import collections.abc import warnings from itertools import product from typing import ( @@ -34,6 +34,7 @@ from bioimageio.core.stat_measures import ( DatasetMean, DatasetMeasure, + DatasetMeasureBase, DatasetPercentile, DatasetStd, DatasetVar, @@ -317,36 +318,34 @@ def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: class StatsCalculator: """Estimates dataset statistics and computes sample statistics efficiently""" - def __init__(self, measures: Iterable[Measure]): + def __init__( + self, + measures: Collection[Measure], + initial_dataset_measures: Optional[Mapping[DatasetMeasure, MeasureValue]] = None, + ): super().__init__() self.sample_count = 0 self.sample_calculators, self.dataset_calculators = get_measure_calculators(measures) - self._current_dataset_measures: Optional[Dict[DatasetMeasure, MeasureValue]] = None - - def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: - ret: Dict[SampleMeasure, MeasureValue] = {} - for calc in self.sample_calculators: - values = calc.compute(sample) - ret.update(values.items()) + if initial_dataset_measures is None: + self._current_dataset_measures: Optional[Dict[DatasetMeasure, MeasureValue]] = None + else: + missing_dataset_meas = { + m for m in measures if isinstance(m, DatasetMeasureBase) and m not in initial_dataset_measures + } + if missing_dataset_meas: + warnings.warn(f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}") + self._current_dataset_measures = None + else: + self._current_dataset_measures = dict(initial_dataset_measures) - return ret + @property + def has_dataset_measures(self): + return self._current_dataset_measures is not None def update(self, sample: Union[Sample, Iterable[Sample]]) -> None: _ = self._update(sample) - def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: - self.sample_count += 1 - samples = [sample] if isinstance(sample, Sample) else sample - last_sample = None - for s in samples: - last_sample = s - for calc in self.dataset_calculators: - calc.update(s) - - self._current_dataset_measures = None - return last_sample - - def _finalize(self) -> Dict[DatasetMeasure, MeasureValue]: + def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: """returns aggregated dataset statistics""" if self._current_dataset_measures is None: self._current_dataset_measures = {} @@ -362,11 +361,31 @@ def update_and_get_all(self, sample: Union[Sample, Iterable[Sample]]) -> Dict[Me if last_sample is None: raise ValueError("`sample` was not a `Sample`, nor did it yield any.") - return {**self._compute(last_sample), **self._finalize()} + return {**self._compute(last_sample), **self.finalize()} def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: """Returns sample as well as previously computed dataset statistics""" - return {**self._compute(sample), **self._finalize()} + return {**self._compute(sample), **self.finalize()} + + def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: + ret: Dict[SampleMeasure, MeasureValue] = {} + for calc in self.sample_calculators: + values = calc.compute(sample) + ret.update(values.items()) + + return ret + + def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: + self.sample_count += 1 + samples = [sample] if isinstance(sample, Sample) else sample + last_sample = None + for s in samples: + last_sample = s + for calc in self.dataset_calculators: + calc.update(s) + + self._current_dataset_measures = None + return last_sample def get_measure_calculators( @@ -454,6 +473,7 @@ def compute_dataset_measures( return ret + def compute_sample_measures(measures: Iterable[SampleMeasure], sample: Sample) -> Dict[SampleMeasure, MeasureValue]: """compute all sample `measures` for the given `sample`""" calculators, dataset_calculators = get_measure_calculators(measures) diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py index bcb713e2..426c8591 100644 --- a/bioimageio/core/utils/__init__.py +++ b/bioimageio/core/utils/__init__.py @@ -1,17 +1,9 @@ -from __future__ import annotations - -import importlib.util import sys -from functools import singledispatch from pathlib import Path -from typing import Any, Callable - -from typing_extensions import Unpack -from bioimageio.spec._internal.io_utils import HashKwargs, download -from bioimageio.spec.common import FileSource -from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile -from bioimageio.spec.model.v0_5 import ArchitectureFromFileDescr, ArchitectureFromLibraryDescr +from ._import_callable import import_callable as import_callable +from ._tensor_io import get_test_inputs as get_test_inputs +from ._tensor_io import get_test_outputs as get_test_outputs if sys.version_info < (3, 9): @@ -21,50 +13,3 @@ def files(package_name: str): else: from importlib.resources import files as files - - -@singledispatch -def import_callable(node: type, /) -> Callable[..., Any]: - raise TypeError(type(node)) - - -@import_callable.register -def import_from_dependency04(node: CallableFromDepencency) -> Callable[..., Any]: - module = importlib.import_module(node.module_name) - c = getattr(module, node.callable_name) - if not callable(c): - raise ValueError(f"{node} (imported: {c}) is not callable") - - return c - - -@import_callable.register -def import_from_dependency05(node: ArchitectureFromLibraryDescr) -> Callable[..., Any]: - module = importlib.import_module(node.import_from) - c = getattr(module, node.callable) - if not callable(c): - raise ValueError(f"{node} (imported: {c}) is not callable") - - return c - - -@import_callable.register -def import_from_file04(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): - return _import_from_file_impl(node.file, node.callable_name, **kwargs) - - -@import_callable.register -def import_from_file05(node: ArchitectureFromFileDescr, **kwargs: Unpack[HashKwargs]): - return _import_from_file_impl(node.source, node.callable, sha256=node.sha256) - - -def _import_from_file_impl(source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]): - local_file = download(source, **kwargs) - module_name = local_file.path.stem - importlib_spec = importlib.util.spec_from_file_location(module_name, local_file.path) - if importlib_spec is None: - raise ImportError(f"Failed to import {module_name} from {source}.") - - dep = importlib.util.module_from_spec(importlib_spec) - importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? - return getattr(dep, callable_name) diff --git a/bioimageio/core/utils/_import_callable.py b/bioimageio/core/utils/_import_callable.py new file mode 100644 index 00000000..40ff1c45 --- /dev/null +++ b/bioimageio/core/utils/_import_callable.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import importlib.util +from functools import singledispatch +from typing import Any, Callable + +from typing_extensions import Unpack + +from bioimageio.spec._internal.io_utils import HashKwargs, download +from bioimageio.spec.common import FileSource +from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile +from bioimageio.spec.model.v0_5 import ArchitectureFromFileDescr, ArchitectureFromLibraryDescr + + +@singledispatch +def import_callable(node: type, /) -> Callable[..., Any]: + raise TypeError(type(node)) + + +@import_callable.register +def import_from_dependency04(node: CallableFromDepencency) -> Callable[..., Any]: + module = importlib.import_module(node.module_name) + c = getattr(module, node.callable_name) + if not callable(c): + raise ValueError(f"{node} (imported: {c}) is not callable") + + return c + + +@import_callable.register +def import_from_dependency05(node: ArchitectureFromLibraryDescr) -> Callable[..., Any]: + module = importlib.import_module(node.import_from) + c = getattr(module, node.callable) + if not callable(c): + raise ValueError(f"{node} (imported: {c}) is not callable") + + return c + + +@import_callable.register +def import_from_file04(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): + return _import_from_file_impl(node.file, node.callable_name, **kwargs) + + +@import_callable.register +def import_from_file05(node: ArchitectureFromFileDescr, **kwargs: Unpack[HashKwargs]): + return _import_from_file_impl(node.source, node.callable, sha256=node.sha256) + + +def _import_from_file_impl(source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]): + local_file = download(source, **kwargs) + module_name = local_file.path.stem + importlib_spec = importlib.util.spec_from_file_location(module_name, local_file.path) + if importlib_spec is None: + raise ImportError(f"Failed to import {module_name} from {source}.") + + dep = importlib.util.module_from_spec(importlib_spec) + importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? + return getattr(dep, callable_name) diff --git a/bioimageio/core/utils/_tensor_io.py b/bioimageio/core/utils/_tensor_io.py new file mode 100644 index 00000000..ad41789f --- /dev/null +++ b/bioimageio/core/utils/_tensor_io.py @@ -0,0 +1,20 @@ +from typing import List + +import xarray as xr + +from bioimageio.spec.model import AnyModelDescr, v0_4 +from bioimageio.spec.utils import load_array + + +def get_test_inputs(model: AnyModelDescr) -> List[xr.DataArray]: + if isinstance(model, v0_4.ModelDescr): + return [xr.DataArray(load_array(tt), dims=tuple(d.axes)) for d, tt in zip(model.inputs, model.test_inputs)] + else: + return [xr.DataArray(load_array(d.test_tensor), dims=tuple(a.id for a in d.axes)) for d in model.inputs] + + +def get_test_outputs(model: AnyModelDescr) -> List[xr.DataArray]: + if isinstance(model, v0_4.ModelDescr): + return [xr.DataArray(load_array(tt), dims=tuple(d.axes)) for d, tt in zip(model.outputs, model.test_outputs)] + else: + return [xr.DataArray(load_array(d.test_tensor), dims=tuple(a.id for a in d.axes)) for d in model.outputs] diff --git a/bioimageio/core/weight_converter/torch/onnx.py b/bioimageio/core/weight_converter/torch/onnx.py index 3606cd74..9fa90de1 100644 --- a/bioimageio/core/weight_converter/torch/onnx.py +++ b/bioimageio/core/weight_converter/torch/onnx.py @@ -1,16 +1,16 @@ import warnings from pathlib import Path -from typing import Any, Dict, List, Sequence, cast +from typing import Any, List, Sequence, cast import numpy as np import torch from numpy.testing import assert_array_almost_equal -from bioimageio.core.weight_converter.torch.utils import load_model +from bioimageio.core.utils import get_test_inputs +from bioimageio.core.weight_converter.torch.utils import load_torch_model from bioimageio.spec import load_description from bioimageio.spec.common import InvalidDescr from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.utils import download def add_onnx_weights( @@ -42,18 +42,13 @@ def add_onnx_weights( state_dict_weights_descr = model_spec.weights.pytorch_state_dict if state_dict_weights_descr is None: - raise ValueError(f"The provided model does not have weights in the pytorch state dict format") + raise ValueError("The provided model does not have weights in the pytorch state dict format") with torch.no_grad(): - if isinstance(model_spec, v0_4.ModelDescr): - downloaded_test_inputs = [download(inp) for inp in model_spec.test_inputs] - else: - downloaded_test_inputs = [inp.test_tensor.download() for inp in model_spec.inputs] - - input_data: List[np.ndarray[Any, Any]] = [np.load(dl.path).astype("float32") for dl in downloaded_test_inputs] - input_tensors = [torch.from_numpy(inp) for inp in input_data] - model = load_model(state_dict_weights_descr) + input_data = [t.data for t in get_test_inputs(model_spec)] + input_tensors = [torch.from_numpy(d) for d in input_data] + model = load_torch_model(state_dict_weights_descr) expected_tensors = model(*input_tensors) if isinstance(expected_tensors, torch.Tensor): @@ -81,9 +76,7 @@ def add_onnx_weights( # check the onnx model sess = rt.InferenceSession(str(output_path)) onnx_input_node_args = cast(List[Any], sess.get_inputs()) # fixme: remove cast, try using rt.NodeArg instead of Any - onnx_inputs: Dict[str, np.ndarray[Any, Any]] = { - input_name.name: inp for input_name, inp in zip(onnx_input_node_args, input_data) - } + onnx_inputs = {input_name.name: inp for input_name, inp in zip(onnx_input_node_args, input_data)} outputs = cast(Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs)) # FIXME: remove cast try: diff --git a/bioimageio/core/weight_converter/torch/torchscript.py b/bioimageio/core/weight_converter/torch/torchscript.py index a517e17b..0dd23442 100644 --- a/bioimageio/core/weight_converter/torch/torchscript.py +++ b/bioimageio/core/weight_converter/torch/torchscript.py @@ -9,7 +9,7 @@ from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import Version -from .utils import load_model +from .utils import load_torch_model # FIXME: remove Any @@ -40,7 +40,7 @@ def _check(input_: Sequence[torch.Tensor]) -> None: input_descr = model_spec.inputs[0] if isinstance(input_descr, v0_4.InputTensorDescr): - if not isinstance(input_descr.shape, v0_4.ParametrizedInputShape): + if not isinstance(input_descr.shape, v0_4.ParameterizedInputShape): return min_shape = input_descr.shape.min step = input_descr.shape.step @@ -54,7 +54,7 @@ def _check(input_: Sequence[torch.Tensor]) -> None: elif isinstance(axis.size, int): min_shape.append(axis.size) step.append(0) - elif isinstance(axis.size, (v0_5.AxisId, v0_5.TensorAxisId, type(None))): + elif axis.size is None: raise NotImplementedError(f"Can't verify inputs that don't specify their shape fully: {axis}") elif isinstance(axis.size, v0_5.SizeReference): raise NotImplementedError(f"Can't handle axes like '{axis}' yet") @@ -94,7 +94,7 @@ def convert_weights_to_torchscript( with torch.no_grad(): input_data = [torch.from_numpy(inp.astype("float32")) for inp in input_data] - model = load_model(state_dict_weights_descr) + model = load_torch_model(state_dict_weights_descr) # FIXME: remove Any if use_tracing: diff --git a/bioimageio/core/weight_converter/torch/utils.py b/bioimageio/core/weight_converter/torch/utils.py index 413ba629..4b5debad 100644 --- a/bioimageio/core/weight_converter/torch/utils.py +++ b/bioimageio/core/weight_converter/torch/utils.py @@ -7,7 +7,7 @@ # additional convenience for pytorch state dict, eventually we want this in python-bioimageio too # and for each weight format -def load_model(node: "v0_4.PytorchStateDictWeightsDescr | v0_5.PytorchStateDictWeightsDescr"): +def load_torch_model(node: "v0_4.PytorchStateDictWeightsDescr | v0_5.PytorchStateDictWeightsDescr"): model = PytorchModelAdapter.get_network(node) state = torch.load(download(node.source).path, map_location="cpu") _ = model.load_state_dict(state) # FIXME: check incompatible keys? diff --git a/tests/build_spec/test_build_spec.py b/tests/build_spec/test_build_spec.py deleted file mode 100644 index b1fa85ab..00000000 --- a/tests/build_spec/test_build_spec.py +++ /dev/null @@ -1,241 +0,0 @@ -from typing import Optional - -import bioimageio.spec as spec - -# from bioimageio.core import load_description, load_raw_resource_description -# from bioimageio.core._internal.validation_visitors import resolve_source -# from bioimageio.core.resource_io import nodes -# from bioimageio.core.resource_tests import test_model as _test_model - -try: - import tensorflow -except ImportError: - tf_version = None -else: - tf_version: Optional[str] = ".".join(tensorflow.__version__.split(".")[:2]) - - -def _test_build_spec( - spec_path, - out_path, - weight_type, - tensorflow_version=None, - opset_version=None, - use_implicit_output_shape=False, - add_deepimagej_config=False, - use_original_covers=False, - training_data=None, - parent=None, -): - from bioimageio.core.build_spec import build_model - - model_spec = load_raw_resource_description(spec_path, update_to_format="latest") - root = model_spec.root_path - assert isinstance(model_spec, spec.model.raw_nodes.Model) - weight_source = model_spec.weights[weight_type].source - - cite = [] - for entry in model_spec.cite: - entry_ = {"text": entry.text} - has_url = entry.url is not missing - has_doi = entry.doi is not missing - assert has_url != has_doi - if has_doi: - entry_["doi"] = entry.doi - else: - entry_["url"] = entry.url - cite.append(entry_) - - weight_spec = model_spec.weights[weight_type] - dep_file = None if weight_spec.dependencies is missing else resolve_source(weight_spec.dependencies.file, root) - if weight_type == "pytorch_state_dict": - model_kwargs = None if weight_spec.kwargs is missing else weight_spec.kwargs - architecture = str(weight_spec.architecture) - weight_type_ = None # the weight type can be auto-detected - elif weight_type == "torchscript": - architecture = None - model_kwargs = None - weight_type_ = "torchscript" # the weight type CANNOT be auto-detected - else: - architecture = None - model_kwargs = None - weight_type_ = None # the weight type can be auto-detected - - authors = [{"name": auth.name, "affiliation": auth.affiliation} for auth in model_spec.authors] - - input_axes = [input_.axes for input_ in model_spec.inputs] - output_axes = [output.axes for output in model_spec.outputs] - preprocessing = [ - ( - None - if input_.preprocessing is missing - else [{"name": preproc.name, "kwargs": preproc.kwargs} for preproc in input_.preprocessing] - ) - for input_ in model_spec.inputs - ] - postprocessing = [ - ( - None - if output.postprocessing is missing - else [{"name": preproc.name, "kwargs": preproc.kwargs} for preproc in output.preprocessing] - ) - for output in model_spec.outputs - ] - - kwargs = dict( - weight_uri=weight_source, - test_inputs=resolve_source(model_spec.test_inputs, root), - test_outputs=resolve_source(model_spec.test_outputs, root), - name=model_spec.name, - description=model_spec.description, - authors=authors, - tags=model_spec.tags, - license=model_spec.license, - documentation=model_spec.documentation, - dependencies=dep_file, - cite=cite, - root=model_spec.root_path, - weight_type=weight_type_, - input_axes=input_axes, - output_axes=output_axes, - preprocessing=preprocessing, - postprocessing=postprocessing, - output_path=out_path, - add_deepimagej_config=add_deepimagej_config, - maintainers=[{"github_user": "jane_doe"}], - input_names=[inp.name for inp in model_spec.inputs], - output_names=[out.name for out in model_spec.outputs], - ) - if architecture is not None: - kwargs["architecture"] = architecture - if model_kwargs is not None: - kwargs["model_kwargs"] = model_kwargs - if tensorflow_version is not None: - kwargs["tensorflow_version"] = tensorflow_version - if opset_version is not None: - kwargs["opset_version"] = opset_version - if use_implicit_output_shape: - kwargs["input_names"] = ["input"] - kwargs["output_reference"] = ["input"] - kwargs["output_scale"] = [[1.0, 1.0, 1.0, 1.0]] - kwargs["output_offset"] = [[0.0, 0.0, 0.0, 0.0]] - if add_deepimagej_config: - kwargs["pixel_sizes"] = [{"x": 5.0, "y": 5.0}] - if use_original_covers: - kwargs["covers"] = resolve_source(model_spec.covers, root) - if training_data is not None: - kwargs["training_data"] = training_data - if parent is not None: - kwargs["parent"] = parent - - build_model(**kwargs) - assert out_path.exists() - loaded_model = load_description(out_path) - assert isinstance(loaded_model, nodes.Model) - if add_deepimagej_config: - loaded_config = loaded_model.config - assert "deepimagej" in loaded_config - - if loaded_model.sample_inputs is not missing: - for sample in loaded_model.sample_inputs: - assert sample.exists() - if loaded_model.sample_outputs is not missing: - for sample in loaded_model.sample_outputs: - assert sample.exists() - - assert loaded_model.maintainers[0].github_user == "jane_doe" - - attachments = loaded_model.attachments - if attachments is not missing and attachments.files is not missing: - for attached_file in attachments.files: - assert attached_file.exists() - - # make sure there is one attachment per pre/post-processing - if add_deepimagej_config: - preproc, postproc = preprocessing[0], postprocessing[0] - n_processing = 0 - if preproc is not None: - n_processing += len(preproc) - if postproc is not None: - n_processing += len(postproc) - if n_processing > 0: - assert attachments.files is not missing - assert n_processing == len(attachments.files) - - # test inference for the model to ensure that the weights were written correctly - test_res = _test_model(out_path) - assert all([s["status"] == "passed" for s in test_res]) - - -def test_build_spec_pytorch(any_torch_model, tmp_path): - _test_build_spec(any_torch_model, tmp_path / "model.zip", "pytorch_state_dict") - - -def test_build_spec_implicit_output_shape(unet2d_nuclei_broad_model, tmp_path): - _test_build_spec( - unet2d_nuclei_broad_model, tmp_path / "model.zip", "pytorch_state_dict", use_implicit_output_shape=True - ) - - -def test_build_spec_torchscript(any_torchscript_model, tmp_path): - _test_build_spec(any_torchscript_model, tmp_path / "model.zip", "torchscript") - - -def test_build_spec_onnx(any_onnx_model, tmp_path): - _test_build_spec(any_onnx_model, tmp_path / "model.zip", "onnx", opset_version=12) - - -def test_build_spec_keras(any_keras_model, tmp_path): - _test_build_spec( - any_keras_model, tmp_path / "model.zip", "keras_hdf5", tensorflow_version=tf_version - ) # todo: keras for tf 2?? - - -def test_build_spec_tf(any_tensorflow_model, tmp_path): - _test_build_spec( - any_tensorflow_model, tmp_path / "model.zip", "tensorflow_saved_model_bundle", tensorflow_version=tf_version - ) # check tf version - - -def test_build_spec_tfjs(any_tensorflow_js_model, tmp_path): - _test_build_spec(any_tensorflow_js_model, tmp_path / "model.zip", "tensorflow_js", tensorflow_version=tf_version) - - -def test_build_spec_deepimagej(unet2d_nuclei_broad_model, tmp_path): - _test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", add_deepimagej_config=True) - - -def test_build_spec_training_data1(unet2d_nuclei_broad_model, tmp_path): - training_data = {"id": "ilastik/stradist_dsb_training_data"} - _test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", training_data=training_data) - - -def test_build_spec_training_data2(unet2d_nuclei_broad_model, tmp_path): - training_data = { - "type": "dataset", - "name": "nucleus-training-data", - "description": "stardist nucleus training data", - "source": "https://github.com/stardist/stardist/releases/download/0.1.0/dsb2018.zip", - } - _test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", training_data=training_data) - - -def test_build_spec_parent1(unet2d_nuclei_broad_model, tmp_path): - parent = {"uri": "https://doi.org/10.5281/zenodo.5764892"} - _test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", parent=parent) - - -def test_build_spec_parent2(unet2d_nuclei_broad_model, tmp_path): - parent = {"id": "10.5281/zenodo.5764892"} - _test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", parent=parent) - - -def test_build_spec_deepimagej_keras(unet2d_keras, tmp_path): - _test_build_spec( - unet2d_keras, tmp_path / "model.zip", "keras_hdf5", add_deepimagej_config=True, tensorflow_version=tf_version - ) - - -# test with original covers -def test_build_spec_with_original_covers(unet2d_nuclei_broad_model, tmp_path): - _test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", use_original_covers=True) diff --git a/tests/prediction_pipeline/test_combined_processing.py b/tests/prediction_pipeline/test_combined_processing.py deleted file mode 100644 index 7a590991..00000000 --- a/tests/prediction_pipeline/test_combined_processing.py +++ /dev/null @@ -1,35 +0,0 @@ -import numpy as np -import xarray as xr - - -def test_postprocessing_dtype(): # TODO: remove? - from bioimageio.core.common import TensorId - from bioimageio.spec.model.v0_5 import BinarizeDescr, BinarizeKwargs, OutputTensorDescr - - # from bioimageio.core.prediction_pipeline._combined_processing import CombinedProcessing - - shape = [3, 32, 32] - axes = ("c", "y", "x") - np_data = np.random.rand(*shape) - data = xr.DataArray(np_data, dims=axes) - - threshold = 0.5 - exp = xr.DataArray(np_data > threshold, dims=axes) - - for dtype in ("float32", "float64", "uint8", "uint16"): - outputs = [ - OutputTensorDescr( - id=TensorId("out1"), - data_type=dtype, - axes=axes, - shape=shape, - postprocessing=[BinarizeDescr(kwargs=BinarizeKwargs(threshold=threshold))], - ) - ] - com_proc = CombinedProcessing.from_tensor_specs(outputs) - - sample = {"out1": data} - com_proc.apply(sample, {}) - res = sample["out1"] - assert np.dtype(res.dtype) == np.dtype(dtype) - xr.testing.assert_allclose(res, exp.astype(dtype)) diff --git a/tests/prediction_pipeline/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py similarity index 68% rename from tests/prediction_pipeline/test_prediction_pipeline.py rename to tests/test_prediction_pipeline.py index 2c196401..b569c517 100644 --- a/tests/prediction_pipeline/test_prediction_pipeline.py +++ b/tests/test_prediction_pipeline.py @@ -1,22 +1,27 @@ +from pathlib import Path import numpy as np import xarray as xr from numpy.testing import assert_array_almost_equal -# from bioimageio.core import load_description -# from bioimageio.core.resource_io.nodes import Model +from bioimageio.spec import load_description +from bioimageio.spec.model.v0_5 import WeightsFormat, ModelDescr +from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04 -def _test_prediction_pipeline(model_package, weight_format): +def _test_prediction_pipeline(model_package: Path, weights_format: WeightsFormat): from bioimageio.core.prediction_pipeline import create_prediction_pipeline bio_model = load_description(model_package) - assert isinstance(bio_model, Model) - pp = create_prediction_pipeline(bioimageio_model=bio_model, weight_format=weight_format) + assert isinstance(bio_model, (ModelDescr, ModelDescr04)) + pp = create_prediction_pipeline(bioimageio_model=bio_model, weight_format=weights_format) + + if isinstance(bio_model, ModelDescr04): + inputs = [ + xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) + for test_tensor, spec in zip(bio_model.test_inputs, bio_model.inputs) + ] + else: - inputs = [ - xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) - for test_tensor, spec in zip(bio_model.test_inputs, bio_model.inputs) - ] outputs = pp.forward(*inputs) assert isinstance(outputs, list) diff --git a/tests/prediction_pipeline/test_device_management.py b/tests/test_prediction_pipeline_device_management.py similarity index 100% rename from tests/prediction_pipeline/test_device_management.py rename to tests/test_prediction_pipeline_device_management.py diff --git a/tests/test_resource_tests/test_test_model.py b/tests/test_resource_tests/test_test_model.py index f83baf52..970bf2e2 100644 --- a/tests/test_resource_tests/test_test_model.py +++ b/tests/test_resource_tests/test_test_model.py @@ -21,7 +21,7 @@ def test_error_for_wrong_shape2(stardist_wrong_shape2: Path): summary = test_model(stardist_wrong_shape2) expected_error_message = ( "Shape (1, 512, 512, 1) of test input 0 'input' does not match input shape description: " - "ParametrizedInputShape(min=[1, 80, 80, 1], step=[0, 17, 17, 0])." + "ParameterizedInputShape(min=[1, 80, 80, 1], step=[0, 17, 17, 0])." ) assert summary.details[0].errors[0].msg == expected_error_message diff --git a/tests/build_spec/test_add_weights.py b/tests/weight_converter/test_add_weights.py similarity index 100% rename from tests/build_spec/test_add_weights.py rename to tests/weight_converter/test_add_weights.py From 81531d8b1d31128647cc709ce38520ef22d1fed0 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 27 Feb 2024 14:30:07 +0100 Subject: [PATCH 100/244] udpate tests and cli --- bioimageio/core/__main__.py | 148 ++++++++---------- tests/conftest.py | 43 ++--- tests/resource_io/test_load_rdf.py | 72 --------- tests/resource_io/test_utils.py | 85 ---------- tests/test_bioimageio_spec_version.py | 2 - tests/test_cli.py | 54 ++++--- tests/test_prediction_pipeline.py | 29 ++-- ...t_test_model.py => test_resource_tests.py} | 0 8 files changed, 121 insertions(+), 312 deletions(-) delete mode 100644 tests/resource_io/test_load_rdf.py delete mode 100644 tests/resource_io/test_utils.py rename tests/{test_resource_tests/test_test_model.py => test_resource_tests.py} (100%) diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index 651f5d20..54ab9425 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -5,8 +5,7 @@ import warnings from glob import glob from pathlib import Path -from pprint import pformat -from typing import List, Optional, get_args +from typing import List, Optional import typer from typing_extensions import Annotated @@ -16,9 +15,7 @@ from bioimageio.spec.collection import CollectionDescr from bioimageio.spec.dataset import DatasetDescr from bioimageio.spec.model import ModelDescr -from bioimageio.spec.model.v0_5 import WeightsFormat from bioimageio.spec.notebook import NotebookDescr -from bioimageio.spec.summary import ValidationSummary try: with warnings.catch_warnings(): @@ -57,73 +54,41 @@ def callback(): # if we want to use something like "choice" for the weight formats, we need to use an enum, see: # https://github.com/tiangolo/typer/issues/182 -WeightsFormatEnum = enum.Enum("WeightsFormatEnum", {wf: wf for wf in get_args(WeightsFormat)}) -# Enum with in values does not work with click.Choice: https://github.com/pallets/click/issues/784 -# so a simple Enum with auto int values is not an option: -# WeightsFormatEnum = enum.Enum("WeightsFormatEnum", get_args(WeightsFormat)) + + +class WeightsFormatEnum(enum.Enum): + keras_hdf5 = "keras_hdf5" + onnx = "onnx" + pytorch_state_dict = "pytorch_state_dict" + tensorflow_js = "tensorflow_js" + tensorflow_saved_model_bundle = "tensorflow_saved_model_bundle" + torchscript = "torchscript" + + +# Enum with int values does not work with click.Choice: https://github.com/pallets/click/issues/784 +# so a simple Enum with auto int values is not an option. @app.command() def package( - rdf_source: Annotated[str, typer.Argument(help="RDF source as relative file path or URI")], - path: Annotated[Path, typer.Argument(help="Save package as")] = Path() / "bioimageio-package.zip", + source: Annotated[str, typer.Argument(help="path or url to a bioimageio RDF")], + path: Annotated[Path, typer.Argument(help="Save package as")] = Path("bioimageio-package.zip"), weights_priority_order: Annotated[ Optional[List[WeightsFormatEnum]], typer.Option( "--weights-priority-order", "-wpo", help="For model packages only. " - "If given only the first weights matching the given weight formats are included. " - "Defaults to include all weights present in source.", + "If given, only the first matching weights entry is included. " + "Defaults to including all weights present in source.", show_default=False, ), ] = None, - # verbose: Annotated[bool, typer.Option(help="show traceback of exceptions")] = False, ): # typer bug: typer returns empty tuple instead of None if weights_order_priority is not given weights_priority_order = weights_priority_order or None # TODO: check if this is still the case - _ = save_bioimageio_package(rdf_source, output_path=path, weights_priority_order=weights_priority_order) - - -def _log_test_summaries(summaries: List[ValidationSummary], msg: str): - # todo: improve logging of multiple test summaries - ret_code = 0 - for summary in summaries: - print(f"{summary['name']}: {summary['status']}") - if summary["status"] != "passed": - s = { - k: v - for k, v in summary.items() - if k not in ("name", "status", "bioimageio_spec_version", "bioimageio_core_version") - } - tb = s.pop("traceback") - if tb: - print("traceback:") - print("".join(tb)) - - def show_part(part, show): - if show: - line = f"{part}: " - print(line + pformat(show, width=min(80, 120 - len(line))).replace("\n", " " * len(line) + "\n")) - - for part in ["error", "warnings", "source_name"]: - show_part(part, s.pop(part, None)) - - for part in sorted(s.keys()): - show_part(part, s[part]) - - ret_code = 1 - - if ret_code: - result = "FAILED!" - icon = "❌" - else: - result = "passed." - icon = "✔️" - - print(msg.format(icon=icon, result=result)) - return ret_code + _ = save_bioimageio_package(source, output_path=path, weights_priority_order=weights_priority_order) @app.command() @@ -138,15 +103,15 @@ def test_model( # this is a weird typer bug: default devices are empty tuple although they should be None devices = devices or None - summaries = resource_tests.test_model( + summary = resource_tests.test_model( model_rdf, weight_format=None if weight_format is None else weight_format.value, devices=devices, decimal=decimal, ) print(f"\ntesting model {model_rdf}...") - ret_code = _log_test_summaries(summaries, f"\n{{icon}} Model {model_rdf} {{result}}") - sys.exit(ret_code) + print(summary.format()) + sys.exit(0 if summary.status == "passed" else 1) test_model.__doc__ = resource_tests.test_model.__doc__ @@ -154,22 +119,26 @@ def test_model( @app.command() def test_resource( - rdf: str = typer.Argument( - ..., help="Path or URL to the resource description file (rdf.yaml) or zipped resource package." - ), - weight_format: Optional[WeightsFormatEnum] = typer.Option(None, help="(for model only) The weight format to use."), - devices: Optional[List[str]] = typer.Option(None, help="(for model only) Devices for running the model."), - decimal: int = typer.Option(4, help="(for model only) The test precision."), + rdf: Annotated[ + str, typer.Argument(help="Path or URL to the resource description file (rdf.yaml) or zipped resource package.") + ], + weight_format: Annotated[ + Optional[WeightsFormatEnum], typer.Option(help="(for model only) The weight format to use.") + ] = None, + devices: Annotated[ + Optional[List[str]], typer.Option(help="(for model only) Devices for running the model.") + ] = None, + decimal: Annotated[int, typer.Option(help="(for model only) The test precision.")] = 4, ): # this is a weird typer bug: default devices are empty tuple although they should be None if len(devices) == 0: devices = None - summaries = resource_tests.test_description( + print(f"\ntesting {rdf}...") + summary = resource_tests.test_description( rdf, weight_format=None if weight_format is None else weight_format.value, devices=devices, decimal=decimal ) - print(f"\ntesting {rdf}...") - ret_code = _log_test_summaries(summaries, f"{{icon}} Resource test for {rdf} has {{result}}") - sys.exit(ret_code) + print(summary.format()) + sys.exit(0 if summary.status == "passed" else 1) test_resource.__doc__ = resource_tests.test_description.__doc__ @@ -177,11 +146,11 @@ def test_resource( @app.command() def predict_image( - model_rdf: Path = typer.Argument( - ..., help="Path to the model resource description file (rdf.yaml) or zipped model." - ), - inputs: List[Path] = typer.Option(..., help="Path(s) to the model input(s)."), - outputs: List[Path] = typer.Option(..., help="Path(s) for saveing the model output(s)."), + model_rdf: Annotated[ + Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") + ], + inputs: Annotated[List[Path], typer.Option(help="Path(s) to the model input(s).")], + outputs: Annotated[List[Path], typer.Option(help="Path(s) for saveing the model output(s).")], # NOTE: typer currently doesn't support union types, so we only support boolean here # padding: Optional[Union[str, bool]] = typer.Argument( # None, help="Padding to apply in each dimension passed as json encoded string." @@ -189,10 +158,12 @@ def predict_image( # tiling: Optional[Union[str, bool]] = typer.Argument( # None, help="Padding to apply in each dimension passed as json encoded string." # ), - padding: Optional[bool] = typer.Option(None, help="Whether to pad the image to a size suited for the model."), - tiling: Optional[bool] = typer.Option(None, help="Whether to run prediction in tiling mode."), - weight_format: Optional[WeightsFormatEnum] = typer.Option(None, help="The weight format to use."), - devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."), + padding: Annotated[ + Optional[bool], typer.Option(help="Whether to pad the image to a size suited for the model.") + ] = None, + tiling: Annotated[Optional[bool], typer.Option(help="Whether to run prediction in tiling mode.")] = None, + weight_format: Annotated[Optional[WeightsFormatEnum], typer.Option(help="The weight format to use.")] = None, + devices: Annotated[Optional[List[str]], typer.Option(help="Devices for running the model.")] = None, ): if isinstance(padding, str): padding = json.loads(padding.replace("'", '"')) @@ -202,8 +173,9 @@ def predict_image( assert isinstance(tiling, dict) # this is a weird typer bug: default devices are empty tuple although they should be None - if len(devices) == 0: + if devices is None or len(devices) == 0: devices = None + prediction.predict_image( model_rdf, inputs, outputs, padding, tiling, None if weight_format is None else weight_format.value, devices ) @@ -214,12 +186,12 @@ def predict_image( @app.command() def predict_images( - model_rdf: Path = typer.Argument( - ..., help="Path to the model resource description file (rdf.yaml) or zipped model." - ), - input_pattern: str = typer.Argument(..., help="Glob pattern for the input images."), - output_folder: str = typer.Argument(..., help="Folder to save the outputs."), - output_extension: Optional[str] = typer.Argument(None, help="Optional output extension."), + model_rdf: Annotated[ + Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") + ], + input_pattern: Annotated[str, typer.Argument(help="Glob pattern for the input images.")], + output_folder: Annotated[str, typer.Argument(help="Folder to save the outputs.")], + output_extension: Annotated[Optional[str], typer.Argument(help="Optional output extension.")] = None, # NOTE: typer currently doesn't support union types, so we only support boolean here # padding: Optional[Union[str, bool]] = typer.Argument( # None, help="Padding to apply in each dimension passed as json encoded string." @@ -227,10 +199,12 @@ def predict_images( # tiling: Optional[Union[str, bool]] = typer.Argument( # None, help="Padding to apply in each dimension passed as json encoded string." # ), - padding: Optional[bool] = typer.Option(None, help="Whether to pad the image to a size suited for the model."), - tiling: Optional[bool] = typer.Option(None, help="Whether to run prediction in tiling mode."), - weight_format: Optional[WeightsFormatEnum] = typer.Option(None, help="The weight format to use."), - devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."), + padding: Annotated[ + Optional[bool], typer.Option(help="Whether to pad the image to a size suited for the model.") + ] = None, + tiling: Annotated[Optional[bool], typer.Option(help="Whether to run prediction in tiling mode.")] = None, + weight_format: Annotated[Optional[WeightsFormatEnum], typer.Option(help="The weight format to use.")] = None, + devices: Annotated[Optional[List[str]], typer.Option(help="Devices for running the model.")] = None, ): input_files = glob(input_pattern) input_names = [os.path.split(infile)[1] for infile in input_files] diff --git a/tests/conftest.py b/tests/conftest.py index 0f44a94d..dcf8e8d5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import subprocess import warnings from types import MappingProxyType -from typing import Set +from typing import List, Set from pydantic import FilePath from pytest import FixtureRequest, fixture @@ -31,7 +31,7 @@ TENSORFLOW2_MODELS = ["unet2d_keras_tf2"] KERAS_TF1_MODELS = ["unet2d_keras"] KERAS_TF2_MODELS = ["unet2d_keras_tf2"] -TENSORFLOW_JS_MODELS = [] +TENSORFLOW_JS_MODELS: List[str] = [] MODEL_SOURCES = { @@ -108,31 +108,32 @@ tensorflow = None tf_major_version = None + skip_tensorflow = tensorflow is None skip_tensorflow_js = True # TODO: add a tensorflow_js example model -# load all model packages we need for testing -load_model_packages: Set[str] = set() -if not skip_torch: - load_model_packages |= set(TORCH_MODELS + TORCHSCRIPT_MODELS) - -if not skip_onnx: - load_model_packages |= set(ONNX_MODELS) - -if not skip_tensorflow: - load_model_packages |= set(TENSORFLOW_JS_MODELS) - if tf_major_version == 1: - load_model_packages |= set(KERAS_TF1_MODELS) - load_model_packages |= set(TENSORFLOW1_MODELS) - load_model_packages.add("stardist_wrong_shape") - load_model_packages.add("stardist_wrong_shape2") - elif tf_major_version == 2: - load_model_packages |= set(KERAS_TF2_MODELS) - load_model_packages |= set(TENSORFLOW2_MODELS) - @fixture(scope="session") def model_packages() -> MappingProxyType[str, FilePath]: + # load all model packages we need for testing + load_model_packages: Set[str] = set() + if not skip_torch: + load_model_packages |= set(TORCH_MODELS + TORCHSCRIPT_MODELS) + + if not skip_onnx: + load_model_packages |= set(ONNX_MODELS) + + if not skip_tensorflow: + load_model_packages |= set(TENSORFLOW_JS_MODELS) + if tf_major_version == 1: + load_model_packages |= set(KERAS_TF1_MODELS) + load_model_packages |= set(TENSORFLOW1_MODELS) + load_model_packages.add("stardist_wrong_shape") + load_model_packages.add("stardist_wrong_shape2") + elif tf_major_version == 2: + load_model_packages |= set(KERAS_TF2_MODELS) + load_model_packages |= set(TENSORFLOW2_MODELS) + return MappingProxyType({name: save_bioimageio_package(MODEL_SOURCES[name]) for name in load_model_packages}) diff --git a/tests/resource_io/test_load_rdf.py b/tests/resource_io/test_load_rdf.py deleted file mode 100644 index 0d86d1a2..00000000 --- a/tests/resource_io/test_load_rdf.py +++ /dev/null @@ -1,72 +0,0 @@ -import os.path -import pathlib -from pathlib import Path - -import pytest - - -def test_load_model_with_abs_path_source(unet2d_nuclei_broad_model: Path): - from bioimageio.spec import load_description - - raw_rd = load_description(unet2d_nuclei_broad_model) - path_source = (raw_rd.root / "rdf.yaml").absolute() - assert path_source.is_absolute() - model = load_description(path_source) - assert model - - -def test_load_model_with_rel_path_source(unet2d_nuclei_broad_model: Path): - from bioimageio.core.resource_io import load_description, load_raw_resource_description - - raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) - path_source = pathlib.Path(os.path.relpath(raw_rd.root_path / "rdf.yaml", os.curdir)) - assert not path_source.is_absolute() - model = load_description(path_source) - assert model - - -def test_load_model_with_abs_str_source(unet2d_nuclei_broad_model: Path): - from bioimageio.core.resource_io import load_description, load_raw_resource_description - - raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) - path_source = (raw_rd.root_path / "rdf.yaml").absolute() - assert path_source.is_absolute() - model = load_description(str(path_source)) - assert model - - -def test_load_model_with_rel_str_source(unet2d_nuclei_broad_model: Path): - from bioimageio.core.resource_io import load_description, load_raw_resource_description - - raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) - path_source = pathlib.Path(os.path.relpath(raw_rd.root_path / "rdf.yaml", os.curdir)) - assert not path_source.is_absolute() - model = load_description(str(path_source)) - assert model - - -def test_load_remote_rdf(unet2d_nuclei_broad_model: Path): - # remote model is a pytorch model, needing unet2d_nuclei_broad_model skips the test when needed - _ = unet2d_nuclei_broad_model - from bioimageio.core import load_description - from bioimageio.core.resource_io import nodes - - remote_rdf = "https://zenodo.org/api/files/63b44f05-a187-4fc9-81c8-c4568535531b/rdf.yaml" - model = load_description(remote_rdf) - assert isinstance(model, nodes.Model) - - -@pytest.mark.skipif(True, reason="No suitable test model available yet") -def test_load_remote_rdf_with_folders(): - from bioimageio.core import load_description, load_raw_resource_description - from bioimageio.core.resource_io import nodes - from bioimageio.spec.model import raw_nodes - - rdf_doi = "" - raw_model = load_raw_resource_description(rdf_doi, update_to_format="latest") - assert isinstance(raw_model, raw_nodes.Model) - model = load_description(rdf_doi) - assert isinstance(model, nodes.Model) - - # test for field value with folder, e.g. - assert resolve_source(raw_model.documentation) == model.documentation diff --git a/tests/resource_io/test_utils.py b/tests/resource_io/test_utils.py deleted file mode 100644 index ff834edc..00000000 --- a/tests/resource_io/test_utils.py +++ /dev/null @@ -1,85 +0,0 @@ -import dataclasses -from pathlib import Path - -import pytest - -# from bioimageio.spec.shared import raw_nodes -# from bioimageio.spec.shared.raw_nodes import RawNode - -# from bioimageio.core._internal import validation_visitors -# from bioimageio.core._internal.validation_visitors import Sha256NodeChecker -# from bioimageio.core.resource_io import nodes - - -def test_resolve_import_path(tmpdir): - tmpdir = Path(tmpdir) - manifest_path = tmpdir / "manifest.yaml" - manifest_path.touch() - source_file = Path("my_mod.py") - (tmpdir / str(source_file)).write_text("class Foo: pass", encoding="utf8") - node = raw_nodes.ImportableSourceFile(source_file=source_file, callable_name="Foo") - uri_transformed = validation_visitors.UriNodeTransformer(root_path=tmpdir).transform(node) - source_transformed = validation_visitors.SourceNodeTransformer().transform(uri_transformed) - assert isinstance(source_transformed, nodes.ImportedSource), type(source_transformed) - Foo = source_transformed.factory - assert Foo.__name__ == "Foo", Foo.__name__ - assert isinstance(Foo, type), type(Foo) - - -def test_resolve_directory_uri(tmpdir): - node = raw_nodes.URI(Path(tmpdir).as_uri()) - uri_transformed = validation_visitors.UriNodeTransformer(root_path=Path(tmpdir)).transform(node) - assert uri_transformed == Path(tmpdir) - - -def test_uri_available(): - pass # todo - - -def test_all_uris_available(): - from bioimageio.core._internal.validation_visitors import all_sources_available - - not_available = { - "uri": raw_nodes.URI(scheme="file", path="non_existing_file_in/non_existing_dir/ftw"), - "uri_exists": raw_nodes.URI(scheme="file", path="."), - } - assert not all_sources_available(not_available) - - -def test_uri_node_transformer_is_ok_with_abs_path(): - from bioimageio.core._internal.validation_visitors import UriNodeTransformer - - # note: the call of .absolute() is required to add the drive letter for windows paths, which are relative otherwise - tree = {"rel_path": Path("something/relative"), "abs_path": Path("/something/absolute").absolute()} - assert not tree["rel_path"].is_absolute() - assert tree["abs_path"].is_absolute() - - root = Path("/root").absolute() - print(root) - - tree = UriNodeTransformer(root_path=root).transform(tree) - assert tree["rel_path"].is_absolute() - assert tree["rel_path"] == Path("/root/something/relative").absolute() - assert tree["abs_path"].is_absolute() - assert tree["abs_path"] == Path("/something/absolute").absolute() - - -def test_sha256_checker(tmpdir): - root = Path(tmpdir) - src1 = root / "meh.txt" - src2 = root / "muh.txt" - src1.write_text(src1.stem, encoding="utf-8") - src2.write_text(src2.stem, encoding="utf-8") - - @dataclasses.dataclass - class TestNode(RawNode): - source: Path = src1 - sha256: str = "f65255094d7773ed8dd417badc9fc045c1f80fdc5b2d25172b031ce6933e039a" - my_src: Path = src2 - my_src_sha256: str = "8cf5844c38045aa19aae00d689002549d308de07a777c2ea34355d65283255ac" - - checker = Sha256NodeChecker(root_path=root) - checker.visit(TestNode()) - - with pytest.raises(ValueError): - checker.visit(TestNode(my_src_sha256="nope")) diff --git a/tests/test_bioimageio_spec_version.py b/tests/test_bioimageio_spec_version.py index 444bac31..af47dfea 100644 --- a/tests/test_bioimageio_spec_version.py +++ b/tests/test_bioimageio_spec_version.py @@ -1,13 +1,11 @@ import json import subprocess -import sys from typing import Optional import pytest from packaging.version import Version -@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python 3.8") def test_bioimageio_spec_version(mamba_cmd: Optional[str]): if mamba_cmd is None: pytest.skip("requires mamba") diff --git a/tests/test_cli.py b/tests/test_cli.py index 2ed9f894..f821e0c9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,85 +1,87 @@ import os import subprocess -from typing import Sequence +from pathlib import Path +from typing import Any, List, Optional, Sequence import numpy as np from bioimageio.core import load_description -def run_subprocess(commands: Sequence[str], **kwargs) -> subprocess.CompletedProcess: +def run_subprocess(commands: Sequence[str], **kwargs: Any) -> "subprocess.CompletedProcess[str]": return subprocess.run(commands, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8", **kwargs) -def test_validate_model(unet2d_nuclei_broad_model): +def test_validate_model(unet2d_nuclei_broad_model: Path): ret = run_subprocess(["bioimageio", "validate", unet2d_nuclei_broad_model]) assert ret.returncode == 0, ret.stdout -def test_cli_package(unet2d_nuclei_broad_model, tmp_path): +def test_cli_package(unet2d_nuclei_broad_model: Path, tmp_path: Path): out_path = tmp_path / "model.zip" ret = run_subprocess(["bioimageio", "package", unet2d_nuclei_broad_model, str(out_path)]) assert ret.returncode == 0, ret.stdout assert out_path.exists() -def test_cli_package_wo_cache(unet2d_nuclei_broad_model): +def test_cli_package_wo_cache(unet2d_nuclei_broad_model: Path): env = os.environ.copy() env["BIOIMAGEIO_USE_CACHE"] = "false" ret = run_subprocess(["bioimageio", "package", unet2d_nuclei_broad_model], env=env) assert ret.returncode == 0, ret.stdout -def test_cli_test_model(unet2d_nuclei_broad_model): +def test_cli_test_model(unet2d_nuclei_broad_model: Path): ret = run_subprocess(["bioimageio", "test-model", unet2d_nuclei_broad_model]) assert ret.returncode == 0, ret.stdout -def test_cli_test_model_fail(stardist_wrong_shape): +def test_cli_test_model_fail(stardist_wrong_shape: Path): ret = run_subprocess(["bioimageio", "test-model", stardist_wrong_shape]) assert ret.returncode == 1 -def test_cli_test_model_with_weight_format(unet2d_nuclei_broad_model): +def test_cli_test_model_with_weight_format(unet2d_nuclei_broad_model: Path): ret = run_subprocess( ["bioimageio", "test-model", unet2d_nuclei_broad_model, "--weight-format", "pytorch_state_dict"] ) assert ret.returncode == 0, ret.stdout -def test_cli_test_resource(unet2d_nuclei_broad_model): +def test_cli_test_resource(unet2d_nuclei_broad_model: Path): ret = run_subprocess(["bioimageio", "test-resource", unet2d_nuclei_broad_model]) assert ret.returncode == 0, ret.stdout -def test_cli_test_resource_with_weight_format(unet2d_nuclei_broad_model): +def test_cli_test_resource_with_weight_format(unet2d_nuclei_broad_model: Path): ret = run_subprocess( ["bioimageio", "test-resource", unet2d_nuclei_broad_model, "--weight-format", "pytorch_state_dict"] ) assert ret.returncode == 0, ret.stdout -def _test_cli_predict_image(model, tmp_path, extra_kwargs=None): +def _test_cli_predict_image(model: Path, tmp_path: Path, extra_cmd_args: Optional[List[str]] = None): spec = load_description(model) in_path = spec.test_inputs[0] + out_path = tmp_path.with_suffix(".npy") - cmd = ["bioimageio", "predict-image", model, "--inputs", str(in_path), "--outputs", str(out_path)] - if extra_kwargs is not None: - cmd.extend(extra_kwargs) + cmd = ["bioimageio", "predict-image", model, "--input", str(in_path), "--output", str(out_path)] + if extra_cmd_args is not None: + cmd.extend(extra_cmd_args) ret = run_subprocess(cmd) assert ret.returncode == 0, ret.stdout assert out_path.exists() -def test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path): +def test_cli_predict_image(unet2d_nuclei_broad_model: Path, tmp_path: Path): _test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path) -def test_cli_predict_image_with_weight_format(unet2d_nuclei_broad_model, tmp_path): +def test_cli_predict_image_with_weight_format(unet2d_nuclei_broad_model: Path, tmp_path: Path): _test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path, ["--weight-format", "pytorch_state_dict"]) -def _test_cli_predict_images(model, tmp_path, extra_kwargs=None): +def _test_cli_predict_images(model: Path, tmp_path: Path, extra_cmd_args: Optional[List[str]] = None): n_images = 3 shape = (1, 1, 128, 128) expected_shape = (1, 1, 128, 128) @@ -89,7 +91,7 @@ def _test_cli_predict_images(model, tmp_path, extra_kwargs=None): out_folder = tmp_path / "outputs" out_folder.mkdir() - expected_outputs = [] + expected_outputs: List[Path] = [] for i in range(n_images): path = in_folder / f"im-{i}.npy" im = np.random.randint(0, 255, size=shape).astype("uint8") @@ -97,9 +99,9 @@ def _test_cli_predict_images(model, tmp_path, extra_kwargs=None): expected_outputs.append(out_folder / f"im-{i}.npy") input_pattern = str(in_folder / "*.npy") - cmd = ["bioimageio", "predict-images", model, input_pattern, str(out_folder)] - if extra_kwargs is not None: - cmd.extend(extra_kwargs) + cmd = ["bioimageio", "predict-images", str(model), input_pattern, str(out_folder)] + if extra_cmd_args is not None: + cmd.extend(extra_cmd_args) ret = run_subprocess(cmd) assert ret.returncode == 0, ret.stdout @@ -108,15 +110,15 @@ def _test_cli_predict_images(model, tmp_path, extra_kwargs=None): assert np.load(out_path).shape == expected_shape -def test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path): +def test_cli_predict_images(unet2d_nuclei_broad_model: Path, tmp_path: Path): _test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path) -def test_cli_predict_images_with_weight_format(unet2d_nuclei_broad_model, tmp_path): +def test_cli_predict_images_with_weight_format(unet2d_nuclei_broad_model: Path, tmp_path: Path): _test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path, ["--weight-format", "pytorch_state_dict"]) -def test_torch_to_torchscript(unet2d_nuclei_broad_model, tmp_path): +def test_torch_to_torchscript(unet2d_nuclei_broad_model: Path, tmp_path: Path): out_path = tmp_path.with_suffix(".pt") ret = run_subprocess( ["bioimageio", "convert-torch-weights-to-torchscript", str(unet2d_nuclei_broad_model), str(out_path)] @@ -125,14 +127,14 @@ def test_torch_to_torchscript(unet2d_nuclei_broad_model, tmp_path): assert out_path.exists() -def test_torch_to_onnx(convert_to_onnx, tmp_path): +def test_torch_to_onnx(convert_to_onnx: Path, tmp_path: Path): out_path = tmp_path.with_suffix(".onnx") ret = run_subprocess(["bioimageio", "convert-torch-weights-to-onnx", str(convert_to_onnx), str(out_path)]) assert ret.returncode == 0, ret.stdout assert out_path.exists() -def test_keras_to_tf(unet2d_keras, tmp_path): +def test_keras_to_tf(unet2d_keras: Path, tmp_path: Path): out_path = tmp_path / "weights.zip" ret = run_subprocess(["bioimageio", "convert-keras-weights-to-tensorflow", str(unet2d_keras), str(out_path)]) assert ret.returncode == 0, ret.stdout diff --git a/tests/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py index b569c517..d4b16c12 100644 --- a/tests/test_prediction_pipeline.py +++ b/tests/test_prediction_pipeline.py @@ -1,11 +1,11 @@ from pathlib import Path -import numpy as np -import xarray as xr + from numpy.testing import assert_array_almost_equal +from bioimageio.core.utils import get_test_inputs, get_test_outputs from bioimageio.spec import load_description -from bioimageio.spec.model.v0_5 import WeightsFormat, ModelDescr from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04 +from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat def _test_prediction_pipeline(model_package: Path, weights_format: WeightsFormat): @@ -15,41 +15,32 @@ def _test_prediction_pipeline(model_package: Path, weights_format: WeightsFormat assert isinstance(bio_model, (ModelDescr, ModelDescr04)) pp = create_prediction_pipeline(bioimageio_model=bio_model, weight_format=weights_format) - if isinstance(bio_model, ModelDescr04): - inputs = [ - xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) - for test_tensor, spec in zip(bio_model.test_inputs, bio_model.inputs) - ] - else: - + inputs = get_test_inputs(bio_model) outputs = pp.forward(*inputs) assert isinstance(outputs, list) - expected_outputs = [ - xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) - for test_tensor, spec in zip(bio_model.test_outputs, bio_model.outputs) - ] + expected_outputs = get_test_outputs(bio_model) assert len(outputs) == len(expected_outputs) for out, exp in zip(outputs, expected_outputs): assert_array_almost_equal(out, exp, decimal=4) -def test_prediction_pipeline_torch(any_torch_model): +def test_prediction_pipeline_torch(any_torch_model: Path): _test_prediction_pipeline(any_torch_model, "pytorch_state_dict") -def test_prediction_pipeline_torchscript(any_torchscript_model): +def test_prediction_pipeline_torchscript(any_torchscript_model: Path): _test_prediction_pipeline(any_torchscript_model, "torchscript") -def test_prediction_pipeline_onnx(any_onnx_model): +def test_prediction_pipeline_onnx(any_onnx_model: Path): _test_prediction_pipeline(any_onnx_model, "onnx") -def test_prediction_pipeline_tensorflow(any_tensorflow_model): +def test_prediction_pipeline_tensorflow(any_tensorflow_model: Path): _test_prediction_pipeline(any_tensorflow_model, "tensorflow_saved_model_bundle") -def test_prediction_pipeline_keras(any_keras_model): +def test_prediction_pipeline_keras(any_keras_model: Path): _test_prediction_pipeline(any_keras_model, "keras_hdf5") diff --git a/tests/test_resource_tests/test_test_model.py b/tests/test_resource_tests.py similarity index 100% rename from tests/test_resource_tests/test_test_model.py rename to tests/test_resource_tests.py From d50d68a098ab042ab8dcc4cbb9e0e30ce04ad280 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 27 Feb 2024 14:46:01 +0100 Subject: [PATCH 101/244] add docstrings and make stuff private --- bioimageio/core/__init__.py | 7 +- bioimageio/core/__main__.py | 327 +++++++++--------- bioimageio/core/{op_base.py => _op_base.py} | 0 ...on_pipeline.py => _prediction_pipeline.py} | 0 .../{resource_tests.py => _resource_tests.py} | 116 +++---- bioimageio/core/common.py | 13 +- bioimageio/core/io.py | 62 ---- bioimageio/core/prediction.py | 2 + bioimageio/core/proc_ops.py | 2 +- bioimageio/core/proc_setup.py | 3 + bioimageio/core/stat_calculators.py | 15 +- bioimageio/core/stat_measures.py | 20 ++ bioimageio/core/utils/__init__.py | 4 +- .../utils/{_tensor_io.py => _digest_spec.py} | 0 bioimageio/core/{ => utils}/image_helper.py | 0 bioimageio/core/weight_converter/__init__.py | 1 + .../core/weight_converter/keras/__init__.py | 2 +- .../keras/{tensorflow.py => _tensorflow.py} | 20 +- .../core/weight_converter/torch/__init__.py | 3 +- .../torch/{onnx.py => _onnx.py} | 2 +- .../torch/{torchscript.py => _torchscript.py} | 2 +- .../torch/{utils.py => _utils.py} | 0 tests/conftest.py | 37 +- tests/test_cli.py | 179 +++++----- tests/test_prediction_pipeline.py | 2 +- ...t_prediction_pipeline_device_management.py | 2 +- tests/test_resource_tests.py | 12 +- tests/{ => utils}/test_image_helper.py | 4 +- .../weight_converter/keras/test_tensorflow.py | 4 + tests/weight_converter/test_add_weights.py | 67 ++-- tests/weight_converter/torch/test_onnx.py | 6 +- .../torch/test_torchscript.py | 4 + 32 files changed, 435 insertions(+), 483 deletions(-) rename bioimageio/core/{op_base.py => _op_base.py} (100%) rename bioimageio/core/{prediction_pipeline.py => _prediction_pipeline.py} (100%) rename bioimageio/core/{resource_tests.py => _resource_tests.py} (99%) delete mode 100644 bioimageio/core/io.py rename bioimageio/core/utils/{_tensor_io.py => _digest_spec.py} (100%) rename bioimageio/core/{ => utils}/image_helper.py (100%) rename bioimageio/core/weight_converter/keras/{tensorflow.py => _tensorflow.py} (85%) rename bioimageio/core/weight_converter/torch/{onnx.py => _onnx.py} (97%) rename bioimageio/core/weight_converter/torch/{torchscript.py => _torchscript.py} (99%) rename bioimageio/core/weight_converter/torch/{utils.py => _utils.py} (100%) rename tests/{ => utils}/test_image_helper.py (84%) diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 29116eaa..926c61d7 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -15,8 +15,9 @@ from bioimageio.spec import save_bioimageio_yaml_only as save_bioimageio_yaml_only from bioimageio.spec import validate_format as validate_format -from .prediction_pipeline import create_prediction_pipeline as create_prediction_pipeline -from .resource_tests import load_description_and_test as load_description_and_test -from .resource_tests import test_description as test_description +from ._prediction_pipeline import create_prediction_pipeline as create_prediction_pipeline +from ._resource_tests import load_description_and_test as load_description_and_test +from ._resource_tests import test_description as test_description +from ._resource_tests import test_model as test_model test_resource = test_description diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index 54ab9425..9e767ef1 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -1,36 +1,20 @@ import enum -import json -import os import sys -import warnings -from glob import glob from pathlib import Path from typing import List, Optional import typer from typing_extensions import Annotated -from bioimageio.core import __version__, prediction, resource_tests -from bioimageio.spec import load_description, save_bioimageio_package +from bioimageio.core import __version__ +from bioimageio.core import test_description as _test_description +from bioimageio.core import test_model as _test_model +from bioimageio.spec import save_bioimageio_package from bioimageio.spec.collection import CollectionDescr from bioimageio.spec.dataset import DatasetDescr from bioimageio.spec.model import ModelDescr from bioimageio.spec.notebook import NotebookDescr -try: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - from bioimageio.core.weight_converter import torch as torch_converter -except ImportError: - torch_converter = None - -try: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - from bioimageio.core.weight_converter import keras as keras_converter -except ImportError: - keras_converter = None - help_version = f"""bioimageio.core {__version__} bioimageio.spec {__version__} implementing: @@ -88,7 +72,11 @@ def package( # typer bug: typer returns empty tuple instead of None if weights_order_priority is not given weights_priority_order = weights_priority_order or None # TODO: check if this is still the case - _ = save_bioimageio_package(source, output_path=path, weights_priority_order=weights_priority_order) + _ = save_bioimageio_package( + source, + output_path=path, + weights_priority_order=None if weights_priority_order is None else [wpo.name for wpo in weights_priority_order], + ) @app.command() @@ -103,7 +91,7 @@ def test_model( # this is a weird typer bug: default devices are empty tuple although they should be None devices = devices or None - summary = resource_tests.test_model( + summary = _test_model( model_rdf, weight_format=None if weight_format is None else weight_format.value, devices=devices, @@ -114,7 +102,7 @@ def test_model( sys.exit(0 if summary.status == "passed" else 1) -test_model.__doc__ = resource_tests.test_model.__doc__ +test_model.__doc__ = _test_model.__doc__ @app.command() @@ -131,159 +119,160 @@ def test_resource( decimal: Annotated[int, typer.Option(help="(for model only) The test precision.")] = 4, ): # this is a weird typer bug: default devices are empty tuple although they should be None - if len(devices) == 0: + if devices is None or len(devices) == 0: devices = None - print(f"\ntesting {rdf}...") - summary = resource_tests.test_description( + + summary = _test_description( rdf, weight_format=None if weight_format is None else weight_format.value, devices=devices, decimal=decimal ) print(summary.format()) sys.exit(0 if summary.status == "passed" else 1) -test_resource.__doc__ = resource_tests.test_description.__doc__ - - -@app.command() -def predict_image( - model_rdf: Annotated[ - Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") - ], - inputs: Annotated[List[Path], typer.Option(help="Path(s) to the model input(s).")], - outputs: Annotated[List[Path], typer.Option(help="Path(s) for saveing the model output(s).")], - # NOTE: typer currently doesn't support union types, so we only support boolean here - # padding: Optional[Union[str, bool]] = typer.Argument( - # None, help="Padding to apply in each dimension passed as json encoded string." - # ), - # tiling: Optional[Union[str, bool]] = typer.Argument( - # None, help="Padding to apply in each dimension passed as json encoded string." - # ), - padding: Annotated[ - Optional[bool], typer.Option(help="Whether to pad the image to a size suited for the model.") - ] = None, - tiling: Annotated[Optional[bool], typer.Option(help="Whether to run prediction in tiling mode.")] = None, - weight_format: Annotated[Optional[WeightsFormatEnum], typer.Option(help="The weight format to use.")] = None, - devices: Annotated[Optional[List[str]], typer.Option(help="Devices for running the model.")] = None, -): - if isinstance(padding, str): - padding = json.loads(padding.replace("'", '"')) - assert isinstance(padding, dict) - if isinstance(tiling, str): - tiling = json.loads(tiling.replace("'", '"')) - assert isinstance(tiling, dict) - - # this is a weird typer bug: default devices are empty tuple although they should be None - if devices is None or len(devices) == 0: - devices = None - - prediction.predict_image( - model_rdf, inputs, outputs, padding, tiling, None if weight_format is None else weight_format.value, devices - ) - - -predict_image.__doc__ = prediction.predict_image.__doc__ - - -@app.command() -def predict_images( - model_rdf: Annotated[ - Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") - ], - input_pattern: Annotated[str, typer.Argument(help="Glob pattern for the input images.")], - output_folder: Annotated[str, typer.Argument(help="Folder to save the outputs.")], - output_extension: Annotated[Optional[str], typer.Argument(help="Optional output extension.")] = None, - # NOTE: typer currently doesn't support union types, so we only support boolean here - # padding: Optional[Union[str, bool]] = typer.Argument( - # None, help="Padding to apply in each dimension passed as json encoded string." - # ), - # tiling: Optional[Union[str, bool]] = typer.Argument( - # None, help="Padding to apply in each dimension passed as json encoded string." - # ), - padding: Annotated[ - Optional[bool], typer.Option(help="Whether to pad the image to a size suited for the model.") - ] = None, - tiling: Annotated[Optional[bool], typer.Option(help="Whether to run prediction in tiling mode.")] = None, - weight_format: Annotated[Optional[WeightsFormatEnum], typer.Option(help="The weight format to use.")] = None, - devices: Annotated[Optional[List[str]], typer.Option(help="Devices for running the model.")] = None, -): - input_files = glob(input_pattern) - input_names = [os.path.split(infile)[1] for infile in input_files] - output_files = [os.path.join(output_folder, fname) for fname in input_names] - if output_extension is not None: - output_files = [f"{os.path.splitext(outfile)[0]}{output_extension}" for outfile in output_files] - - if isinstance(padding, str): - padding = json.loads(padding.replace("'", '"')) - assert isinstance(padding, dict) - if isinstance(tiling, str): - tiling = json.loads(tiling.replace("'", '"')) - assert isinstance(tiling, dict) - - # this is a weird typer bug: default devices are empty tuple although they should be None - if len(devices) == 0: - devices = None - prediction.predict_images( - model_rdf, - input_files, - output_files, - padding=padding, - tiling=tiling, - weight_format=None if weight_format is None else weight_format.value, - devices=devices, - verbose=True, - ) - - -predict_images.__doc__ = prediction.predict_images.__doc__ - - -if torch_converter is not None: - - @app.command() - def convert_torch_weights_to_onnx( - model_rdf: Path = typer.Argument( - ..., help="Path to the model resource description file (rdf.yaml) or zipped model." - ), - output_path: Path = typer.Argument(..., help="Where to save the onnx weights."), - opset_version: Optional[int] = typer.Argument(12, help="Onnx opset version."), - use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."), - verbose: bool = typer.Option(True, help="Verbosity"), - ): - ret_code = torch_converter.convert_weights_to_onnx(model_rdf, output_path, opset_version, use_tracing, verbose) - sys.exit(ret_code) - - convert_torch_weights_to_onnx.__doc__ = torch_converter.convert_weights_to_onnx.__doc__ - - @app.command() - def convert_torch_weights_to_torchscript( - model_rdf: Path = typer.Argument( - ..., help="Path to the model resource description file (rdf.yaml) or zipped model." - ), - output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."), - use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."), - ): - torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing) - sys.exit(0) - - convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_torchscript.__doc__ - - -if keras_converter is not None: - - @app.command() - def convert_keras_weights_to_tensorflow( - model_rdf: Annotated[ - Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") - ], - output_path: Annotated[Path, typer.Argument(help="Where to save the tensorflow weights.")], - ): - rd = load_description(model_rdf) - ret_code = keras_converter.convert_weights_to_tensorflow_saved_model_bundle(rd, output_path) - sys.exit(ret_code) - - convert_keras_weights_to_tensorflow.__doc__ = ( - keras_converter.convert_weights_to_tensorflow_saved_model_bundle.__doc__ - ) +test_resource.__doc__ = _test_description.__doc__ + + +# TODO: add predict commands +# @app.command() +# def predict_image( +# model_rdf: Annotated[ +# Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") +# ], +# inputs: Annotated[List[Path], typer.Option(help="Path(s) to the model input(s).")], +# outputs: Annotated[List[Path], typer.Option(help="Path(s) for saveing the model output(s).")], +# # NOTE: typer currently doesn't support union types, so we only support boolean here +# # padding: Optional[Union[str, bool]] = typer.Argument( +# # None, help="Padding to apply in each dimension passed as json encoded string." +# # ), +# # tiling: Optional[Union[str, bool]] = typer.Argument( +# # None, help="Padding to apply in each dimension passed as json encoded string." +# # ), +# padding: Annotated[ +# Optional[bool], typer.Option(help="Whether to pad the image to a size suited for the model.") +# ] = None, +# tiling: Annotated[Optional[bool], typer.Option(help="Whether to run prediction in tiling mode.")] = None, +# weight_format: Annotated[Optional[WeightsFormatEnum], typer.Option(help="The weight format to use.")] = None, +# devices: Annotated[Optional[List[str]], typer.Option(help="Devices for running the model.")] = None, +# ): +# if isinstance(padding, str): +# padding = json.loads(padding.replace("'", '"')) +# assert isinstance(padding, dict) +# if isinstance(tiling, str): +# tiling = json.loads(tiling.replace("'", '"')) +# assert isinstance(tiling, dict) + +# # this is a weird typer bug: default devices are empty tuple although they should be None +# if devices is None or len(devices) == 0: +# devices = None + +# prediction.predict_image( +# model_rdf, inputs, outputs, padding, tiling, None if weight_format is None else weight_format.value, devices +# ) + + +# predict_image.__doc__ = prediction.predict_image.__doc__ + + +# @app.command() +# def predict_images( +# model_rdf: Annotated[ +# Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") +# ], +# input_pattern: Annotated[str, typer.Argument(help="Glob pattern for the input images.")], +# output_folder: Annotated[str, typer.Argument(help="Folder to save the outputs.")], +# output_extension: Annotated[Optional[str], typer.Argument(help="Optional output extension.")] = None, +# # NOTE: typer currently doesn't support union types, so we only support boolean here +# # padding: Optional[Union[str, bool]] = typer.Argument( +# # None, help="Padding to apply in each dimension passed as json encoded string." +# # ), +# # tiling: Optional[Union[str, bool]] = typer.Argument( +# # None, help="Padding to apply in each dimension passed as json encoded string." +# # ), +# padding: Annotated[ +# Optional[bool], typer.Option(help="Whether to pad the image to a size suited for the model.") +# ] = None, +# tiling: Annotated[Optional[bool], typer.Option(help="Whether to run prediction in tiling mode.")] = None, +# weight_format: Annotated[Optional[WeightsFormatEnum], typer.Option(help="The weight format to use.")] = None, +# devices: Annotated[Optional[List[str]], typer.Option(help="Devices for running the model.")] = None, +# ): +# input_files = glob(input_pattern) +# input_names = [os.path.split(infile)[1] for infile in input_files] +# output_files = [os.path.join(output_folder, fname) for fname in input_names] +# if output_extension is not None: +# output_files = [f"{os.path.splitext(outfile)[0]}{output_extension}" for outfile in output_files] + +# if isinstance(padding, str): +# padding = json.loads(padding.replace("'", '"')) +# assert isinstance(padding, dict) +# if isinstance(tiling, str): +# tiling = json.loads(tiling.replace("'", '"')) +# assert isinstance(tiling, dict) + +# # this is a weird typer bug: default devices are empty tuple although they should be None +# if len(devices) == 0: +# devices = None +# prediction.predict_images( +# model_rdf, +# input_files, +# output_files, +# padding=padding, +# tiling=tiling, +# weight_format=None if weight_format is None else weight_format.value, +# devices=devices, +# verbose=True, +# ) + + +# predict_images.__doc__ = prediction.predict_images.__doc__ + + +# if torch_converter is not None: + +# @app.command() +# def convert_torch_weights_to_onnx( +# model_rdf: Path = typer.Argument( +# ..., help="Path to the model resource description file (rdf.yaml) or zipped model." +# ), +# output_path: Path = typer.Argument(..., help="Where to save the onnx weights."), +# opset_version: Optional[int] = typer.Argument(12, help="Onnx opset version."), +# use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."), +# verbose: bool = typer.Option(True, help="Verbosity"), +# ): +# ret_code = torch_converter.convert_weights_to_onnx(model_rdf, output_path, opset_version, use_tracing, verbose) +# sys.exit(ret_code) + +# convert_torch_weights_to_onnx.__doc__ = torch_converter.convert_weights_to_onnx.__doc__ + +# @app.command() +# def convert_torch_weights_to_torchscript( +# model_rdf: Path = typer.Argument( +# ..., help="Path to the model resource description file (rdf.yaml) or zipped model." +# ), +# output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."), +# use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."), +# ): +# torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing) +# sys.exit(0) + +# convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_torchscript.__doc__ + + +# if keras_converter is not None: + +# @app.command() +# def convert_keras_weights_to_tensorflow( +# model_rdf: Annotated[ +# Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") +# ], +# output_path: Annotated[Path, typer.Argument(help="Where to save the tensorflow weights.")], +# ): +# rd = load_description(model_rdf) +# ret_code = keras_converter.convert_weights_to_tensorflow_saved_model_bundle(rd, output_path) +# sys.exit(ret_code) + +# convert_keras_weights_to_tensorflow.__doc__ = ( +# keras_converter.convert_weights_to_tensorflow_saved_model_bundle.__doc__ +# ) if __name__ == "__main__": diff --git a/bioimageio/core/op_base.py b/bioimageio/core/_op_base.py similarity index 100% rename from bioimageio/core/op_base.py rename to bioimageio/core/_op_base.py diff --git a/bioimageio/core/prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py similarity index 100% rename from bioimageio/core/prediction_pipeline.py rename to bioimageio/core/_prediction_pipeline.py diff --git a/bioimageio/core/resource_tests.py b/bioimageio/core/_resource_tests.py similarity index 99% rename from bioimageio/core/resource_tests.py rename to bioimageio/core/_resource_tests.py index 2ac15c71..0f30b550 100644 --- a/bioimageio/core/resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -6,7 +6,7 @@ import xarray as xr from bioimageio.core import __version__ as bioimageio_core_version -from bioimageio.core.prediction_pipeline import create_prediction_pipeline +from bioimageio.core._prediction_pipeline import create_prediction_pipeline from bioimageio.spec import InvalidDescr, ResourceDescr, build_description, dump_description, load_description from bioimageio.spec._internal.base_nodes import ResourceDescrBase from bioimageio.spec._internal.io_utils import load_array @@ -28,6 +28,63 @@ def test_model( ) +def test_description( + source: Union[ResourceDescr, FileSource, BioimageioYamlContent], + *, + format_version: Union[Literal["discover", "latest"], str] = "discover", + weight_format: Optional[WeightsFormat] = None, + devices: Optional[List[str]] = None, + decimal: int = 4, + expected_type: Optional[str] = None, +) -> ValidationSummary: + """Test RDF dynamically, e.g. model inference of test inputs""" + rd = load_description_and_test( + source, + format_version=format_version, + weight_format=weight_format, + devices=devices, + decimal=decimal, + expected_type=expected_type, + ) + return rd.validation_summary + + +def load_description_and_test( + source: Union[ResourceDescr, FileSource, BioimageioYamlContent], + *, + format_version: Union[Literal["discover", "latest"], str] = "discover", + weight_format: Optional[WeightsFormat] = None, + devices: Optional[List[str]] = None, + decimal: int = 4, + expected_type: Optional[str] = None, +) -> Union[ResourceDescr, InvalidDescr]: + """Test RDF dynamically, e.g. model inference of test inputs""" + if ( + isinstance(source, ResourceDescrBase) + and format_version != "discover" + and source.format_version != format_version + ): + warnings.warn(f"deserializing source to ensure we validate and test using format {format_version}") + source = dump_description(source) + + if isinstance(source, ResourceDescrBase): + rd = source + elif isinstance(source, dict): + rd = build_description(source, format_version=format_version) + else: + rd = load_description(source, format_version=format_version) + + rd.validation_summary.env.append(InstalledPackage(name="bioimageio.core", version=bioimageio_core_version)) + + if expected_type is not None: + _test_expected_resource_type(rd, expected_type) + + if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): + _test_model_inference(rd, weight_format, devices, decimal) + + return rd + + def _test_model_inference( model: Union[v0_4.ModelDescr, v0_5.ModelDescr], weight_format: Optional[WeightsFormat], @@ -104,63 +161,6 @@ def _test_expected_resource_type(rd: Union[InvalidDescr, ResourceDescr], expecte ) -def test_description( - source: Union[ResourceDescr, FileSource, BioimageioYamlContent], - *, - format_version: Union[Literal["discover", "latest"], str] = "discover", - weight_format: Optional[WeightsFormat] = None, - devices: Optional[List[str]] = None, - decimal: int = 4, - expected_type: Optional[str] = None, -) -> ValidationSummary: - """Test RDF dynamically, e.g. model inference of test inputs""" - rd = load_description_and_test( - source, - format_version=format_version, - weight_format=weight_format, - devices=devices, - decimal=decimal, - expected_type=expected_type, - ) - return rd.validation_summary - - -def load_description_and_test( - source: Union[ResourceDescr, FileSource, BioimageioYamlContent], - *, - format_version: Union[Literal["discover", "latest"], str] = "discover", - weight_format: Optional[WeightsFormat] = None, - devices: Optional[List[str]] = None, - decimal: int = 4, - expected_type: Optional[str] = None, -) -> Union[ResourceDescr, InvalidDescr]: - """Test RDF dynamically, e.g. model inference of test inputs""" - if ( - isinstance(source, ResourceDescrBase) - and format_version != "discover" - and source.format_version != format_version - ): - warnings.warn(f"deserializing source to ensure we validate and test using format {format_version}") - source = dump_description(source) - - if isinstance(source, ResourceDescrBase): - rd = source - elif isinstance(source, dict): - rd = build_description(source, format_version=format_version) - else: - rd = load_description(source, format_version=format_version) - - rd.validation_summary.env.append(InstalledPackage(name="bioimageio.core", version=bioimageio_core_version)) - - if expected_type is not None: - _test_expected_resource_type(rd, expected_type) - - if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): - _test_model_inference(rd, weight_format, devices, decimal) - - return rd - - # def debug_model( # model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], # *, diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 1f0bcd84..1981f2c5 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,9 +1,9 @@ from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, Union +from typing import TYPE_CHECKING, Dict import xarray as xr -from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model import v0_5 if TYPE_CHECKING: from bioimageio.core.stat_measures import Measure, MeasureValue @@ -19,9 +19,10 @@ @dataclass class Sample: - data: Data = field(default_factory=dict) - stat: Stat = field(default_factory=dict) + """A (dataset) sample""" + data: Data = field(default_factory=dict) + """the samples tensors""" -ProcessingDescrBase = Union[v0_4.ProcessingDescrBase, v0_5.ProcessingDescrBase] -ProcessingKwargs = Union[v0_4.ProcessingKwargs, v0_5.ProcessingKwargs] + stat: Stat = field(default_factory=dict) + """sample and dataset statistics""" diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py deleted file mode 100644 index 1f27b60e..00000000 --- a/bioimageio/core/io.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from contextlib import nullcontext -from typing import Literal, Optional, Union - -from bioimageio.core.resource_tests import test_description -from bioimageio.spec import load_description as load_description -from bioimageio.spec._description import ResourceDescr -from bioimageio.spec._internal.constants import DISCOVER -from bioimageio.spec._internal.io_utils import open_bioimageio_yaml -from bioimageio.spec._internal.validation_context import ValidationContext -from bioimageio.spec.common import BioimageioYamlContent, FileSource, InvalidDescr -from bioimageio.spec.summary import ValidationSummary - - -def load_description_and_test( - source: FileSource, - /, - *, - format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, -) -> Union[ResourceDescr, InvalidDescr]: - opened = open_bioimageio_yaml(source) - - return build_description_and_test( - opened.content, - context=ValidationContext(root=opened.original_root, file_name=opened.original_file_name), - format_version=format_version, - ) - - -def build_description_and_test( - data: BioimageioYamlContent, - /, - *, - context: Optional[ValidationContext] = None, - format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, -) -> Union[ResourceDescr, InvalidDescr]: - """load and validate a BioImage.IO description from the content of a resource description file (RDF)""" - if context is None: - val_context = nullcontext() - else: - val_context = context - - with val_context: - rd = test_description(data, format_version=format_version) - - return rd - - -def validate( - source: "FileSource | BioimageioYamlContent", - /, - *, - context: Optional[ValidationContext] = None, - format_version: Union[Literal["discover"], Literal["latest"], str] = DISCOVER, -) -> ValidationSummary: - if isinstance(source, dict): - rd = build_description_and_test(source, context=context, format_version=format_version) - else: - rd = load_description_and_test(source, format_version=format_version) - - return rd.validation_summary diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index 228bfa63..e9ec7256 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -1,3 +1,5 @@ +"""coming soon""" + # TODO: update # import collections # import os diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index d055c059..a0a2a1f1 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -19,6 +19,7 @@ from numpy.typing import DTypeLike from typing_extensions import Self, assert_never +from bioimageio.core._op_base import Operator from bioimageio.core.common import ( AxisId, Sample, @@ -26,7 +27,6 @@ Tensor, TensorId, ) -from bioimageio.core.op_base import Operator from bioimageio.core.stat_calculators import StatsCalculator from bioimageio.core.stat_measures import ( DatasetMean, diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index 4c504681..a71ba023 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -40,6 +40,9 @@ def setup_pre_and_postprocessing( keep_updating_initial_dataset_stats: bool = False, fixed_dataset_stats: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType({}), ) -> PreAndPostprocessing: + """ + Get pre- and postprocessing operators for a `model` description. + userd in `bioimageio.core.create_prediction_pipeline""" prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model) missing_dataset_stats = {m for m in prep_meas | post_meas if m not in fixed_dataset_stats} diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 3b0045da..3cc4d67e 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -65,6 +65,8 @@ def quantile(self, q: Any) -> Any: class MeanCalculator: + """to calculate sample and dataset mean""" + def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): super().__init__() self._n: int = 0 @@ -115,6 +117,8 @@ def finalize(self) -> Dict[DatasetMean, MeasureValue]: class MeanVarStdCalculator: + """to calculate sample and dataset mean, variance or standard deviation""" + def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): super().__init__() self._axes = None if axes is None else tuple(axes) @@ -181,6 +185,8 @@ def finalize(self) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureVa class SamplePercentilesCalculator: + """to calculate sample percentiles""" + def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Collection[float]): super().__init__() assert all(0 <= n <= 100 for n in ns) @@ -196,6 +202,9 @@ def compute(self, sample: Sample) -> Dict[SamplePercentile, MeasureValue]: class MeanPercentilesCalculator: + """to calculate dataset percentiles heuristically by averaging across samples + **note**: the returned dataset percentiles are an estiamte and **not mathematically correct**""" + def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Collection[float]): super().__init__() assert all(0 <= n <= 100 for n in ns) @@ -234,6 +243,8 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: class CrickPercentilesCalculator: + """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)""" + def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Collection[float]): warnings.warn("Computing dataset percentiles with experimental 'crick' library.") super().__init__() @@ -297,7 +308,7 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: DatasetPercentilesCalculator = CrickPercentilesCalculator -class NaivSampleMeasureCalculator: +class NaiveSampleMeasureCalculator: """wrapper for measures to match interface of other sample measure calculators""" def __init__(self, tensor_id: TensorId, measure: SampleMeasure): @@ -310,7 +321,7 @@ def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: SampleMeasureCalculator = Union[ - MeanCalculator, MeanVarStdCalculator, SamplePercentilesCalculator, NaivSampleMeasureCalculator + MeanCalculator, MeanVarStdCalculator, SamplePercentilesCalculator, NaiveSampleMeasureCalculator ] DatasetMeasureCalculator = Union[MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator] diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index 726c90cf..99329498 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -32,10 +32,13 @@ class DatasetMeasureBase(MeasureBase, ABC): @dataclass(frozen=True) class _Mean: axes: Optional[Tuple[AxisId, ...]] = None + """`axes` to reduce""" @dataclass(frozen=True) class SampleMean(_Mean, SampleMeasureBase): + """The mean value of a single tensor""" + def compute(self, sample: Sample) -> MeasureValue: return sample.data[self.tensor_id].mean(dim=self.axes) @@ -45,6 +48,8 @@ def __post_init__(self): @dataclass(frozen=True) class DatasetMean(_Mean, DatasetMeasureBase): + """The mean value across multiple samples""" + def __post_init__(self): assert self.axes is None or AxisId("batch") in self.axes @@ -52,10 +57,13 @@ def __post_init__(self): @dataclass(frozen=True) class _Std: axes: Optional[Tuple[AxisId, ...]] = None + """`axes` to reduce""" @dataclass(frozen=True) class SampleStd(_Std, SampleMeasureBase): + """The standard deviation of a single tensor""" + def compute(self, sample: Sample) -> MeasureValue: return sample.data[self.tensor_id].std(dim=self.axes) @@ -65,6 +73,8 @@ def __post_init__(self): @dataclass(frozen=True) class DatasetStd(_Std, DatasetMeasureBase): + """The standard deviation across multiple samples""" + def __post_init__(self): assert self.axes is None or AxisId("batch") in self.axes @@ -72,10 +82,13 @@ def __post_init__(self): @dataclass(frozen=True) class _Var: axes: Optional[Tuple[AxisId, ...]] = None + """`axes` to reduce""" @dataclass(frozen=True) class SampleVar(_Var, SampleMeasureBase): + """The variance of a single tensor""" + def compute(self, sample: Sample) -> MeasureValue: return sample.data[self.tensor_id].var(dim=self.axes) @@ -85,6 +98,8 @@ def __post_init__(self): @dataclass(frozen=True) class DatasetVar(_Var, DatasetMeasureBase): + """The variance across multiple samples""" + def __post_init__(self): assert self.axes is None or AxisId("batch") in self.axes @@ -93,6 +108,7 @@ def __post_init__(self): class _Percentile: n: float axes: Optional[Tuple[AxisId, ...]] = None + """`axes` to reduce""" def __post_init__(self): assert self.n >= 0 @@ -101,6 +117,8 @@ def __post_init__(self): @dataclass(frozen=True) class SamplePercentile(_Percentile, SampleMeasureBase): + """The `n`th percentile of a single tensor""" + def compute(self, sample: Sample) -> MeasureValue: return sample.data[self.tensor_id].quantile(self.n / 100.0, dim=self.axes) @@ -111,6 +129,8 @@ def __post_init__(self): @dataclass(frozen=True) class DatasetPercentile(_Percentile, DatasetMeasureBase): + """The `n`th percentile across multiple samples""" + def __post_init__(self): super().__post_init__() assert self.axes is None or AxisId("batch") in self.axes diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py index 426c8591..7126bd75 100644 --- a/bioimageio/core/utils/__init__.py +++ b/bioimageio/core/utils/__init__.py @@ -1,9 +1,9 @@ import sys from pathlib import Path +from ._digest_spec import get_test_inputs as get_test_inputs +from ._digest_spec import get_test_outputs as get_test_outputs from ._import_callable import import_callable as import_callable -from ._tensor_io import get_test_inputs as get_test_inputs -from ._tensor_io import get_test_outputs as get_test_outputs if sys.version_info < (3, 9): diff --git a/bioimageio/core/utils/_tensor_io.py b/bioimageio/core/utils/_digest_spec.py similarity index 100% rename from bioimageio/core/utils/_tensor_io.py rename to bioimageio/core/utils/_digest_spec.py diff --git a/bioimageio/core/image_helper.py b/bioimageio/core/utils/image_helper.py similarity index 100% rename from bioimageio/core/image_helper.py rename to bioimageio/core/utils/image_helper.py diff --git a/bioimageio/core/weight_converter/__init__.py b/bioimageio/core/weight_converter/__init__.py index e69de29b..5f1674c9 100644 --- a/bioimageio/core/weight_converter/__init__.py +++ b/bioimageio/core/weight_converter/__init__.py @@ -0,0 +1 @@ +"""coming soon""" diff --git a/bioimageio/core/weight_converter/keras/__init__.py b/bioimageio/core/weight_converter/keras/__init__.py index 471713e2..195b42b8 100644 --- a/bioimageio/core/weight_converter/keras/__init__.py +++ b/bioimageio/core/weight_converter/keras/__init__.py @@ -1 +1 @@ -from .tensorflow import convert_weights_to_tensorflow_saved_model_bundle +# TODO: update keras weight converters diff --git a/bioimageio/core/weight_converter/keras/tensorflow.py b/bioimageio/core/weight_converter/keras/_tensorflow.py similarity index 85% rename from bioimageio/core/weight_converter/keras/tensorflow.py rename to bioimageio/core/weight_converter/keras/_tensorflow.py index e6476a46..5fa6be54 100644 --- a/bioimageio/core/weight_converter/keras/tensorflow.py +++ b/bioimageio/core/weight_converter/keras/_tensorflow.py @@ -7,7 +7,7 @@ try: import tensorflow.saved_model except Exception: - tensorflow = None + _tensorflow = None from bioimageio.spec._internal.io_utils import download from bioimageio.spec.model.v0_5 import ModelDescr @@ -35,7 +35,7 @@ def _zip_model_bundle(model_bundle_folder: Path): def _convert_tf1(keras_weight_path: Path, output_path: Path, input_name: str, output_name: str, zip_weights: bool): try: # try to build the tf model with the keras import from tensorflow - from tensorflow import keras # type: ignore + from bioimageio.core.weight_converter.keras._tensorflow import keras # type: ignore except Exception: # if the above fails try to export with the standalone keras @@ -44,17 +44,17 @@ def _convert_tf1(keras_weight_path: Path, output_path: Path, input_name: str, ou @no_type_check def build_tf_model(): keras_model = keras.models.load_model(keras_weight_path) - assert tensorflow is not None - builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path) - signature = tensorflow.saved_model.signature_def_utils.predict_signature_def( + assert _tensorflow is not None + builder = _tensorflow.saved_model.builder.SavedModelBuilder(output_path) + signature = _tensorflow.saved_model.signature_def_utils.predict_signature_def( inputs={input_name: keras_model.input}, outputs={output_name: keras_model.output} ) - signature_def_map = {tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature} + signature_def_map = {_tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature} builder.add_meta_graph_and_variables( keras.backend.get_session(), - [tensorflow.saved_model.tag_constants.SERVING], + [_tensorflow.saved_model.tag_constants.SERVING], signature_def_map=signature_def_map, ) builder.save() @@ -71,7 +71,7 @@ def build_tf_model(): def _convert_tf2(keras_weight_path: Path, output_path: Path, zip_weights: bool): try: # try to build the tf model with the keras import from tensorflow - from tensorflow import keras + from bioimageio.core.weight_converter.keras._tensorflow import keras except Exception: # if the above fails try to export with the standalone keras import keras @@ -96,8 +96,8 @@ def convert_weights_to_tensorflow_saved_model_bundle(model: ModelDescr, output_p model: The bioimageio model description output_path: where to save the tensorflow weights. This path must not exist yet. """ - assert tensorflow is not None - tf_major_ver = int(tensorflow.__version__.split(".")[0]) + assert _tensorflow is not None + tf_major_ver = int(_tensorflow.__version__.split(".")[0]) if output_path.suffix == ".zip": output_path = output_path.with_suffix("") diff --git a/bioimageio/core/weight_converter/torch/__init__.py b/bioimageio/core/weight_converter/torch/__init__.py index c7bda015..1b1ba526 100644 --- a/bioimageio/core/weight_converter/torch/__init__.py +++ b/bioimageio/core/weight_converter/torch/__init__.py @@ -1,2 +1 @@ -from .onnx import add_onnx_weights -from .torchscript import convert_weights_to_torchscript +# TODO: torch weight converters diff --git a/bioimageio/core/weight_converter/torch/onnx.py b/bioimageio/core/weight_converter/torch/_onnx.py similarity index 97% rename from bioimageio/core/weight_converter/torch/onnx.py rename to bioimageio/core/weight_converter/torch/_onnx.py index 9fa90de1..f9b66b9f 100644 --- a/bioimageio/core/weight_converter/torch/onnx.py +++ b/bioimageio/core/weight_converter/torch/_onnx.py @@ -7,7 +7,7 @@ from numpy.testing import assert_array_almost_equal from bioimageio.core.utils import get_test_inputs -from bioimageio.core.weight_converter.torch.utils import load_torch_model +from bioimageio.core.weight_converter.torch._utils import load_torch_model from bioimageio.spec import load_description from bioimageio.spec.common import InvalidDescr from bioimageio.spec.model import v0_4, v0_5 diff --git a/bioimageio/core/weight_converter/torch/torchscript.py b/bioimageio/core/weight_converter/torch/_torchscript.py similarity index 99% rename from bioimageio/core/weight_converter/torch/torchscript.py rename to bioimageio/core/weight_converter/torch/_torchscript.py index 0dd23442..e724dac2 100644 --- a/bioimageio/core/weight_converter/torch/torchscript.py +++ b/bioimageio/core/weight_converter/torch/_torchscript.py @@ -9,7 +9,7 @@ from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import Version -from .utils import load_torch_model +from ._utils import load_torch_model # FIXME: remove Any diff --git a/bioimageio/core/weight_converter/torch/utils.py b/bioimageio/core/weight_converter/torch/_utils.py similarity index 100% rename from bioimageio/core/weight_converter/torch/utils.py rename to bioimageio/core/weight_converter/torch/_utils.py diff --git a/tests/conftest.py b/tests/conftest.py index dcf8e8d5..2355c48c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -112,28 +112,27 @@ skip_tensorflow = tensorflow is None skip_tensorflow_js = True # TODO: add a tensorflow_js example model +# load all model packages we need for testing +load_model_packages: Set[str] = set() +if not skip_torch: + load_model_packages |= set(TORCH_MODELS + TORCHSCRIPT_MODELS) + +if not skip_onnx: + load_model_packages |= set(ONNX_MODELS) + +if not skip_tensorflow: + load_model_packages |= set(TENSORFLOW_JS_MODELS) + if tf_major_version == 1: + load_model_packages |= set(KERAS_TF1_MODELS) + load_model_packages |= set(TENSORFLOW1_MODELS) + load_model_packages.add("stardist_wrong_shape") + load_model_packages.add("stardist_wrong_shape2") + elif tf_major_version == 2: + load_model_packages |= set(KERAS_TF2_MODELS) + load_model_packages |= set(TENSORFLOW2_MODELS) @fixture(scope="session") def model_packages() -> MappingProxyType[str, FilePath]: - # load all model packages we need for testing - load_model_packages: Set[str] = set() - if not skip_torch: - load_model_packages |= set(TORCH_MODELS + TORCHSCRIPT_MODELS) - - if not skip_onnx: - load_model_packages |= set(ONNX_MODELS) - - if not skip_tensorflow: - load_model_packages |= set(TENSORFLOW_JS_MODELS) - if tf_major_version == 1: - load_model_packages |= set(KERAS_TF1_MODELS) - load_model_packages |= set(TENSORFLOW1_MODELS) - load_model_packages.add("stardist_wrong_shape") - load_model_packages.add("stardist_wrong_shape2") - elif tf_major_version == 2: - load_model_packages |= set(KERAS_TF2_MODELS) - load_model_packages |= set(TENSORFLOW2_MODELS) - return MappingProxyType({name: save_bioimageio_package(MODEL_SOURCES[name]) for name in load_model_packages}) diff --git a/tests/test_cli.py b/tests/test_cli.py index f821e0c9..d70f21d8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,9 +1,10 @@ import os import subprocess from pathlib import Path -from typing import Any, List, Optional, Sequence +from typing import Any, List, Optional, Sequence, Set import numpy as np +import pytest from bioimageio.core import load_description @@ -12,130 +13,108 @@ def run_subprocess(commands: Sequence[str], **kwargs: Any) -> "subprocess.Comple return subprocess.run(commands, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8", **kwargs) -def test_validate_model(unet2d_nuclei_broad_model: Path): - ret = run_subprocess(["bioimageio", "validate", unet2d_nuclei_broad_model]) - assert ret.returncode == 0, ret.stdout +FIXTURES = {"unet2d_nuclei_broad_model"} -def test_cli_package(unet2d_nuclei_broad_model: Path, tmp_path: Path): - out_path = tmp_path / "model.zip" - ret = run_subprocess(["bioimageio", "package", unet2d_nuclei_broad_model, str(out_path)]) +@pytest.mark.parametrize( + "args", + [ + ["package", "unet2d_nuclei_broad_model", "--weight-format", "pytorch_state_dict"], + ["package", "unet2d_nuclei_broad_model"], + ["test-model", "unet2d_nuclei_broad_model", "--weight-format", "pytorch_state_dict"], + ["test-model", "unet2d_nuclei_broad_model"], + ], +) +def test_cli(args: List[str], request: pytest.FixtureRequest): + resolved_args = [str(request.getfixturevalue(arg)) if arg in FIXTURES else arg for arg in args] + ret = run_subprocess(["bioimageio", *resolved_args]) assert ret.returncode == 0, ret.stdout - assert out_path.exists() -def test_cli_package_wo_cache(unet2d_nuclei_broad_model: Path): - env = os.environ.copy() - env["BIOIMAGEIO_USE_CACHE"] = "false" - ret = run_subprocess(["bioimageio", "package", unet2d_nuclei_broad_model], env=env) - assert ret.returncode == 0, ret.stdout +@pytest.mark.parametrize("args", [["test-model", "stardist_wrong_shape"]]) +def test_cli_fails(args: List[str], request: pytest.FixtureRequest): + resolved_args = [str(request.getfixturevalue(arg)) if arg in FIXTURES else arg for arg in args] + ret = run_subprocess(["bioimageio", *resolved_args]) + assert ret.returncode == 1, ret.stdout -def test_cli_test_model(unet2d_nuclei_broad_model: Path): - ret = run_subprocess(["bioimageio", "test-model", unet2d_nuclei_broad_model]) - assert ret.returncode == 0, ret.stdout +# TODO: update CLI test +# def _test_cli_predict_image(model: Path, tmp_path: Path, extra_cmd_args: Optional[List[str]] = None): +# spec = load_description(model) +# in_path = spec.test_inputs[0] +# out_path = tmp_path.with_suffix(".npy") +# cmd = ["bioimageio", "predict-image", model, "--input", str(in_path), "--output", str(out_path)] +# if extra_cmd_args is not None: +# cmd.extend(extra_cmd_args) +# ret = run_subprocess(cmd) +# assert ret.returncode == 0, ret.stdout +# assert out_path.exists() -def test_cli_test_model_fail(stardist_wrong_shape: Path): - ret = run_subprocess(["bioimageio", "test-model", stardist_wrong_shape]) - assert ret.returncode == 1 +# def test_cli_predict_image(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# _test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path) -def test_cli_test_model_with_weight_format(unet2d_nuclei_broad_model: Path): - ret = run_subprocess( - ["bioimageio", "test-model", unet2d_nuclei_broad_model, "--weight-format", "pytorch_state_dict"] - ) - assert ret.returncode == 0, ret.stdout +# def test_cli_predict_image_with_weight_format(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# _test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path, ["--weight-format", "pytorch_state_dict"]) -def test_cli_test_resource(unet2d_nuclei_broad_model: Path): - ret = run_subprocess(["bioimageio", "test-resource", unet2d_nuclei_broad_model]) - assert ret.returncode == 0, ret.stdout +# def _test_cli_predict_images(model: Path, tmp_path: Path, extra_cmd_args: Optional[List[str]] = None): +# n_images = 3 +# shape = (1, 1, 128, 128) +# expected_shape = (1, 1, 128, 128) -def test_cli_test_resource_with_weight_format(unet2d_nuclei_broad_model: Path): - ret = run_subprocess( - ["bioimageio", "test-resource", unet2d_nuclei_broad_model, "--weight-format", "pytorch_state_dict"] - ) - assert ret.returncode == 0, ret.stdout +# in_folder = tmp_path / "inputs" +# in_folder.mkdir() +# out_folder = tmp_path / "outputs" +# out_folder.mkdir() +# expected_outputs: List[Path] = [] +# for i in range(n_images): +# path = in_folder / f"im-{i}.npy" +# im = np.random.randint(0, 255, size=shape).astype("uint8") +# np.save(path, im) +# expected_outputs.append(out_folder / f"im-{i}.npy") -def _test_cli_predict_image(model: Path, tmp_path: Path, extra_cmd_args: Optional[List[str]] = None): - spec = load_description(model) - in_path = spec.test_inputs[0] +# input_pattern = str(in_folder / "*.npy") +# cmd = ["bioimageio", "predict-images", str(model), input_pattern, str(out_folder)] +# if extra_cmd_args is not None: +# cmd.extend(extra_cmd_args) +# ret = run_subprocess(cmd) +# assert ret.returncode == 0, ret.stdout - out_path = tmp_path.with_suffix(".npy") - cmd = ["bioimageio", "predict-image", model, "--input", str(in_path), "--output", str(out_path)] - if extra_cmd_args is not None: - cmd.extend(extra_cmd_args) - ret = run_subprocess(cmd) - assert ret.returncode == 0, ret.stdout - assert out_path.exists() - - -def test_cli_predict_image(unet2d_nuclei_broad_model: Path, tmp_path: Path): - _test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path) +# for out_path in expected_outputs: +# assert out_path.exists() +# assert np.load(out_path).shape == expected_shape -def test_cli_predict_image_with_weight_format(unet2d_nuclei_broad_model: Path, tmp_path: Path): - _test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path, ["--weight-format", "pytorch_state_dict"]) +# def test_cli_predict_images(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# _test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path) -def _test_cli_predict_images(model: Path, tmp_path: Path, extra_cmd_args: Optional[List[str]] = None): - n_images = 3 - shape = (1, 1, 128, 128) - expected_shape = (1, 1, 128, 128) +# def test_cli_predict_images_with_weight_format(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# _test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path, ["--weight-format", "pytorch_state_dict"]) - in_folder = tmp_path / "inputs" - in_folder.mkdir() - out_folder = tmp_path / "outputs" - out_folder.mkdir() - expected_outputs: List[Path] = [] - for i in range(n_images): - path = in_folder / f"im-{i}.npy" - im = np.random.randint(0, 255, size=shape).astype("uint8") - np.save(path, im) - expected_outputs.append(out_folder / f"im-{i}.npy") - - input_pattern = str(in_folder / "*.npy") - cmd = ["bioimageio", "predict-images", str(model), input_pattern, str(out_folder)] - if extra_cmd_args is not None: - cmd.extend(extra_cmd_args) - ret = run_subprocess(cmd) - assert ret.returncode == 0, ret.stdout +# def test_torch_to_torchscript(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# out_path = tmp_path.with_suffix(".pt") +# ret = run_subprocess( +# ["bioimageio", "convert-torch-weights-to-torchscript", str(unet2d_nuclei_broad_model), str(out_path)] +# ) +# assert ret.returncode == 0, ret.stdout +# assert out_path.exists() - for out_path in expected_outputs: - assert out_path.exists() - assert np.load(out_path).shape == expected_shape +# def test_torch_to_onnx(convert_to_onnx: Path, tmp_path: Path): +# out_path = tmp_path.with_suffix(".onnx") +# ret = run_subprocess(["bioimageio", "convert-torch-weights-to-onnx", str(convert_to_onnx), str(out_path)]) +# assert ret.returncode == 0, ret.stdout +# assert out_path.exists() -def test_cli_predict_images(unet2d_nuclei_broad_model: Path, tmp_path: Path): - _test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path) - -def test_cli_predict_images_with_weight_format(unet2d_nuclei_broad_model: Path, tmp_path: Path): - _test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path, ["--weight-format", "pytorch_state_dict"]) - - -def test_torch_to_torchscript(unet2d_nuclei_broad_model: Path, tmp_path: Path): - out_path = tmp_path.with_suffix(".pt") - ret = run_subprocess( - ["bioimageio", "convert-torch-weights-to-torchscript", str(unet2d_nuclei_broad_model), str(out_path)] - ) - assert ret.returncode == 0, ret.stdout - assert out_path.exists() - - -def test_torch_to_onnx(convert_to_onnx: Path, tmp_path: Path): - out_path = tmp_path.with_suffix(".onnx") - ret = run_subprocess(["bioimageio", "convert-torch-weights-to-onnx", str(convert_to_onnx), str(out_path)]) - assert ret.returncode == 0, ret.stdout - assert out_path.exists() - - -def test_keras_to_tf(unet2d_keras: Path, tmp_path: Path): - out_path = tmp_path / "weights.zip" - ret = run_subprocess(["bioimageio", "convert-keras-weights-to-tensorflow", str(unet2d_keras), str(out_path)]) - assert ret.returncode == 0, ret.stdout - assert out_path.exists() +# def test_keras_to_tf(unet2d_keras: Path, tmp_path: Path): +# out_path = tmp_path / "weights.zip" +# ret = run_subprocess(["bioimageio", "convert-keras-weights-to-tensorflow", str(unet2d_keras), str(out_path)]) +# assert ret.returncode == 0, ret.stdout +# assert out_path.exists() diff --git a/tests/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py index d4b16c12..5347380a 100644 --- a/tests/test_prediction_pipeline.py +++ b/tests/test_prediction_pipeline.py @@ -9,7 +9,7 @@ def _test_prediction_pipeline(model_package: Path, weights_format: WeightsFormat): - from bioimageio.core.prediction_pipeline import create_prediction_pipeline + from bioimageio.core._prediction_pipeline import create_prediction_pipeline bio_model = load_description(model_package) assert isinstance(bio_model, (ModelDescr, ModelDescr04)) diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py index bda4af08..16354d18 100644 --- a/tests/test_prediction_pipeline_device_management.py +++ b/tests/test_prediction_pipeline_device_management.py @@ -21,7 +21,7 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): if torch.cuda.device_count() == 0: raise TooFewDevicesException("Need at least one cuda device for this test") - from bioimageio.core.prediction_pipeline import create_prediction_pipeline + from bioimageio.core._prediction_pipeline import create_prediction_pipeline bio_model = load_description(model_package) assert isinstance(bio_model, (ModelDescr, ModelDescr04)) diff --git a/tests/test_resource_tests.py b/tests/test_resource_tests.py index 970bf2e2..9f69721a 100644 --- a/tests/test_resource_tests.py +++ b/tests/test_resource_tests.py @@ -4,7 +4,7 @@ def test_error_for_wrong_shape(stardist_wrong_shape: Path): - from bioimageio.core.resource_tests import test_model + from bioimageio.core._resource_tests import test_model summary = test_model(stardist_wrong_shape) expected_error_message = ( @@ -16,7 +16,7 @@ def test_error_for_wrong_shape(stardist_wrong_shape: Path): def test_error_for_wrong_shape2(stardist_wrong_shape2: Path): - from bioimageio.core.resource_tests import test_model + from bioimageio.core._resource_tests import test_model summary = test_model(stardist_wrong_shape2) expected_error_message = ( @@ -27,14 +27,14 @@ def test_error_for_wrong_shape2(stardist_wrong_shape2: Path): def test_test_model(any_model: Path): - from bioimageio.core.resource_tests import test_model + from bioimageio.core._resource_tests import test_model summary = test_model(any_model) assert summary.status == "passed" def test_test_resource(any_model: Path): - from bioimageio.core.resource_tests import test_description + from bioimageio.core._resource_tests import test_description summary = test_description(any_model) assert summary.status == "passed" @@ -42,7 +42,7 @@ def test_test_resource(any_model: Path): def test_validation_section_warning(unet2d_nuclei_broad_model: str, tmp_path: Path): from bioimageio.core import load_description - from bioimageio.core.resource_tests import test_description + from bioimageio.core._resource_tests import test_description model = load_description(unet2d_nuclei_broad_model) assert not isinstance(model, InvalidDescr) @@ -65,7 +65,7 @@ def test_issue289(unet2d_nuclei_broad_model: str): # remote model is a pytorch model, needing unet2d_nuclei_broad_model skips the test when needed _ = unet2d_nuclei_broad_model - from bioimageio.core.resource_tests import test_model + from bioimageio.core._resource_tests import test_model doi = "10.5281/zenodo.6287342" summary = test_model(doi) diff --git a/tests/test_image_helper.py b/tests/utils/test_image_helper.py similarity index 84% rename from tests/test_image_helper.py rename to tests/utils/test_image_helper.py index d9721fc2..8e86a919 100644 --- a/tests/test_image_helper.py +++ b/tests/utils/test_image_helper.py @@ -2,7 +2,7 @@ def test_transform_input_image(): - from bioimageio.core.image_helper import transpose_image + from bioimageio.core.utils.image_helper import transpose_image ax_list = ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] im = np.random.rand(256, 256) @@ -18,7 +18,7 @@ def test_transform_input_image(): def test_transform_output_tensor(): - from bioimageio.core.image_helper import transform_output_tensor + from bioimageio.core.utils.image_helper import transform_output_tensor tensor = np.random.rand(1, 3, 64, 64, 64) tensor_axes = "bczyx" diff --git a/tests/weight_converter/keras/test_tensorflow.py b/tests/weight_converter/keras/test_tensorflow.py index 5cc7f297..6cc42c57 100644 --- a/tests/weight_converter/keras/test_tensorflow.py +++ b/tests/weight_converter/keras/test_tensorflow.py @@ -1,10 +1,13 @@ import zipfile from pathlib import Path +import pytest + from bioimageio.spec import load_description from bioimageio.spec.model.v0_5 import ModelDescr +@pytest.mark.skip("tensorflow converter not updated yet") # TODO: test tensorflow converter def test_tensorflow_converter(any_keras_model: Path, tmp_path: Path): from bioimageio.core.weight_converter.keras import convert_weights_to_tensorflow_saved_model_bundle @@ -18,6 +21,7 @@ def test_tensorflow_converter(any_keras_model: Path, tmp_path: Path): assert ret_val == 0 # check for correctness is done in converter and returns 0 if it passes +@pytest.mark.skip("tensorflow converter not updated yet") # TODO: test tensorflow converter def test_tensorflow_converter_zipped(any_keras_model: Path, tmp_path: Path): from bioimageio.core.weight_converter.keras import convert_weights_to_tensorflow_saved_model_bundle diff --git a/tests/weight_converter/test_add_weights.py b/tests/weight_converter/test_add_weights.py index e3df4b80..836353c7 100644 --- a/tests/weight_converter/test_add_weights.py +++ b/tests/weight_converter/test_add_weights.py @@ -1,47 +1,48 @@ -import os +# TODO: update add weights tests +# import os -def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs): - from bioimageio.core.build_spec import add_weights +# def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs): +# from bioimageio.core.build_spec import add_weights - rdf = load_raw_resource_description(model) - assert base_weights in rdf.weights - assert added_weights in rdf.weights +# rdf = load_raw_resource_description(model) +# assert base_weights in rdf.weights +# assert added_weights in rdf.weights - weight_path = load_description(model).weights[added_weights].source - assert weight_path.exists() +# weight_path = load_description(model).weights[added_weights].source +# assert weight_path.exists() - drop_weights = set(rdf.weights.keys()) - {base_weights} - for drop in drop_weights: - rdf.weights.pop(drop) - assert tuple(rdf.weights.keys()) == (base_weights,) +# drop_weights = set(rdf.weights.keys()) - {base_weights} +# for drop in drop_weights: +# rdf.weights.pop(drop) +# assert tuple(rdf.weights.keys()) == (base_weights,) - in_path = tmp_path / "model1.zip" - export_resource_package(rdf, output_path=in_path) +# in_path = tmp_path / "model1.zip" +# export_resource_package(rdf, output_path=in_path) - out_path = tmp_path / "model2.zip" - add_weights(in_path, weight_path, weight_type=added_weights, output_path=out_path, **kwargs) +# out_path = tmp_path / "model2.zip" +# add_weights(in_path, weight_path, weight_type=added_weights, output_path=out_path, **kwargs) - assert out_path.exists() - new_rdf = load_description(out_path) - assert set(new_rdf.weights.keys()) == {base_weights, added_weights} - for weight in new_rdf.weights.values(): - assert weight.source.exists() +# assert out_path.exists() +# new_rdf = load_description(out_path) +# assert set(new_rdf.weights.keys()) == {base_weights, added_weights} +# for weight in new_rdf.weights.values(): +# assert weight.source.exists() - test_res = _test_model(out_path, added_weights) - failed = [s for s in test_res if s["status"] != "passed"] - assert not failed, failed - test_res = _test_model(out_path) - failed = [s for s in test_res if s["status"] != "passed"] - assert not failed, failed +# test_res = _test_model(out_path, added_weights) +# failed = [s for s in test_res if s["status"] != "passed"] +# assert not failed, failed +# test_res = _test_model(out_path) +# failed = [s for s in test_res if s["status"] != "passed"] +# assert not failed, failed - # make sure the weights were cleaned from the cwd - assert not os.path.exists(os.path.split(weight_path)[1]) +# # make sure the weights were cleaned from the cwd +# assert not os.path.exists(os.path.split(weight_path)[1]) -def test_add_torchscript(unet2d_nuclei_broad_model, tmp_path): - _test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "torchscript") +# def test_add_torchscript(unet2d_nuclei_broad_model, tmp_path): +# _test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "torchscript") -def test_add_onnx(unet2d_nuclei_broad_model, tmp_path): - _test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "onnx", opset_version=12) +# def test_add_onnx(unet2d_nuclei_broad_model, tmp_path): +# _test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "onnx", opset_version=12) diff --git a/tests/weight_converter/torch/test_onnx.py b/tests/weight_converter/torch/test_onnx.py index bc757806..c2efbcd8 100644 --- a/tests/weight_converter/torch/test_onnx.py +++ b/tests/weight_converter/torch/test_onnx.py @@ -4,9 +4,9 @@ import pytest -# todo: test with 'any_torch_model' -def test_onnx_converter(convert_to_onnx: Path, tmp_path, Path): - from bioimageio.core.weight_converter.torch.onnx import convert_weights_to_onnx +@pytest.mark.skip("onnx converter not updated yet") # TODO: test onnx converter +def test_onnx_converter(convert_to_onnx: Path, tmp_path: Path): + from bioimageio.core.weight_converter.torch._onnx import convert_weights_to_onnx out_path = tmp_path / "weights.onnx" ret_val = convert_weights_to_onnx(convert_to_onnx, out_path, test_decimal=3) diff --git a/tests/weight_converter/torch/test_torchscript.py b/tests/weight_converter/torch/test_torchscript.py index 2c1e47d2..e3f6e42c 100644 --- a/tests/weight_converter/torch/test_torchscript.py +++ b/tests/weight_converter/torch/test_torchscript.py @@ -1,7 +1,11 @@ from pathlib import Path + +import pytest + from bioimageio.spec.model import v0_4, v0_5 +@pytest.mark.skip("torchscript converter not updated yet") # TODO: test torchscript converter def test_torchscript_converter(any_torch_model: "v0_4.ModelDescr | v0_5.ModelDescr", tmp_path: Path): from bioimageio.core.weight_converter.torch import convert_weights_to_torchscript From 22853431afbdaf34689c4f9dc5ebda1e706f705b Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 28 Feb 2024 17:44:32 +0100 Subject: [PATCH 102/244] improve docs for pdoc --- .gitignore | 1 + bioimageio/core/__init__.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index 4edd992c..a603dade 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ build/ cache dist/ +docs/ typings/pooch/ diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 926c61d7..2dbf8081 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -1,3 +1,7 @@ +""" +.. include:: ../../README.md +""" + import json from bioimageio.core.utils import files From 40aae1005e6caea4a3ad6104f54deba8582d0f87 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 28 Feb 2024 17:44:45 +0100 Subject: [PATCH 103/244] move _get_complement_axis to spec --- bioimageio/core/proc_ops.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index a0a2a1f1..8a7b15f6 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -557,26 +557,3 @@ def get_proc_class(proc_spec: ProcDescr): else: assert_never(proc_spec) - -def _get_complement_axis(tensor: xr.DataArray, axes: Optional[Sequence[Hashable]]) -> Optional[Hashable]: - if axes is None: - return None - - v04_AXIS_TYPE_MAP = { - "b": "batch", - "t": "time", - "i": "index", - "c": "channel", - "x": "space", - "y": "space", - "z": "space", - } - converted_axes = [v04_AXIS_TYPE_MAP.get(a, a) for a in map(str, axes)] + ["batch"] - complement_axes = [a for a in tensor.dims if str(a) not in converted_axes] - if len(complement_axes) != 1: - raise ValueError( - f"Expected a single complement axis, but axes '{converted_axes}' (orignally '{axes}') " - f"for tensor dims '{tensor.dims}' leave '{complement_axes}'." - ) - - return complement_axes[0] From 569666b426cb089503f2ee3bb5651e124d8740e8 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 28 Feb 2024 21:53:04 +0100 Subject: [PATCH 104/244] parallel pytests with pytest-xdist --- pyproject.toml | 2 +- setup.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index be6c4a92..75ff4e09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ typeCheckingMode = "strict" useLibraryCodeForTypes = true [tool.pytest.ini_options] -addopts = "--capture=no --doctest-modules --failed-first" +addopts = " -n auto --capture=no --doctest-modules --failed-first" [tool.ruff] line-length = 120 diff --git a/setup.py b/setup.py index cd612047..adbf4305 100644 --- a/setup.py +++ b/setup.py @@ -39,11 +39,18 @@ ], include_package_data=True, extras_require={ - "test": ["pytest", "black[jupyter]", "onnxruntime", "torch>=1.6", "torchvision", "crick"], - "dev": ["pre-commit"], "pytorch": ["torch>=1.6", "torchvision"], "tensorflow": ["tensorflow"], "onnx": ["onnxruntime"], + "test": [ + "bioimageio.core[onnx]", + "bioimageio.core[pytorch]", + "black[jupyter]", + "crick", + "pytest-xdist[psutil]", # parallel pytest with 'pytest -n auto' + "pytest", + ], + "dev": ["pre-commit", "bioimageio.core[test]"], }, project_urls={ "Bug Reports": "https://github.com/bioimage-io/core-bioimage-io-python/issues", From 6f1c58cd5cf767104f3c4ea660694805d2d94720 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 6 Mar 2024 15:48:57 +0100 Subject: [PATCH 105/244] WIP _test_model_inference_with_parametrized_inputs --- bioimageio/core/__init__.py | 1 + bioimageio/core/_resource_tests.py | 68 +++++- bioimageio/core/utils/_digest_spec.py | 2 +- bioimageio/core/utils/image_helper.py | 339 ++++++++++++++++---------- 4 files changed, 280 insertions(+), 130 deletions(-) diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 2dbf8081..4a0846d4 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -19,6 +19,7 @@ from bioimageio.spec import save_bioimageio_yaml_only as save_bioimageio_yaml_only from bioimageio.spec import validate_format as validate_format +from ._prediction_pipeline import PredictionPipeline as PredictionPipeline from ._prediction_pipeline import create_prediction_pipeline as create_prediction_pipeline from ._resource_tests import load_description_and_test as load_description_and_test from ._resource_tests import test_description as test_description diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index 0f30b550..d4138469 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -6,9 +6,9 @@ import xarray as xr from bioimageio.core import __version__ as bioimageio_core_version -from bioimageio.core._prediction_pipeline import create_prediction_pipeline +from bioimageio.core import create_prediction_pipeline, PredictionPipeline from bioimageio.spec import InvalidDescr, ResourceDescr, build_description, dump_description, load_description -from bioimageio.spec._internal.base_nodes import ResourceDescrBase +from bioimageio.spec._internal.common_nodes import ResourceDescrBase from bioimageio.spec._internal.io_utils import load_array from bioimageio.spec.common import BioimageioYamlContent, FileSource from bioimageio.spec.model import v0_4, v0_5 @@ -81,6 +81,8 @@ def load_description_and_test( if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): _test_model_inference(rd, weight_format, devices, decimal) + if not isinstance(rd, v0_4.ModelDescr): + _test_model_inference_with_parametrized_inputs(rd, weight_format, devices) return rd @@ -114,7 +116,7 @@ def _test_model_inference( if len(results) != len(expected): error = (error or "") + ( - f"Number of outputs and number of expected outputs disagree: {len(results)} != {len(expected)}" + f"Expected {len(expected)} outputs, but got {len(results)}" ) else: for res, exp in zip(results, expected): @@ -145,6 +147,66 @@ def _test_model_inference( ) ) +def _test_model_inference_with_parametrized_inputs( + model: v0_5.ModelDescr, + weight_format: Optional[WeightsFormat], + devices: Optional[List[str]], +) -> None: + if not any(isinstance(a.size, v0_5.ParameterizedSize) for ipt in model.inputs for a in ipt.axes): + return + + error: Optional[str] = None + tb: List[str] = [] + try: + test_inputs = [ + xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(a.id for a in d.axes)) + for d in model.inputs + ] + def generate_test_cases(): + for n in [0, 1, 2, 3]: + + + + with create_prediction_pipeline( + bioimageio_model=model, devices=devices, weight_format=weight_format + ) as prediction_pipeline: + for n, inputs, exptected_output_shape in generate_test_cases(): + results = prediction_pipeline.forward(*inputs) + + if len(results) != len(exptected_output_shape): + error = (error or "") + ( + f"Expected {len(exptected_output_shape)} outputs, but got {len(results)}" + ) + else: + for res, exp in zip(results, exptected_output_shape): + if res.shape != exp: + error = (error or "") + f"(n={n}) Expected output shape {exptected_output_shape}, but got {res.shape}\n" + + if error: + break + except Exception as e: + error = str(e) + tb = traceback.format_tb(e.__traceback__) + + model.validation_summary.add_detail( + ValidationDetail( + name="Reproduce test outputs from test inputs", + status="passed" if error is None else "failed", + errors=( + [] + if error is None + else [ + ErrorEntry( + loc=("weights",) if weight_format is None else ("weights", weight_format), + msg=error, + type="bioimageio.core", + traceback=tb, + ) + ] + ), + ) + ) + def _test_expected_resource_type(rd: Union[InvalidDescr, ResourceDescr], expected_type: str): has_expected_type = rd.type == expected_type diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py index ad41789f..42ba8974 100644 --- a/bioimageio/core/utils/_digest_spec.py +++ b/bioimageio/core/utils/_digest_spec.py @@ -2,7 +2,7 @@ import xarray as xr -from bioimageio.spec.model import AnyModelDescr, v0_4 +from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.utils import load_array diff --git a/bioimageio/core/utils/image_helper.py b/bioimageio/core/utils/image_helper.py index a3045588..a5d0a5f1 100644 --- a/bioimageio/core/utils/image_helper.py +++ b/bioimageio/core/utils/image_helper.py @@ -1,126 +1,213 @@ -# # TODO: update - -# from __future__ import annotations - -# import os -# from copy import deepcopy -# from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union - -# import imageio -# import numpy as np -# from numpy.typing import ArrayLike, NDArray -# from xarray import DataArray - -# from bioimageio.spec._internal.io_utils import load_array -# from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensor04 -# from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensor04 -# from bioimageio.spec.model.v0_5 import InputTensorDescr as InputTensor05 -# from bioimageio.spec.model.v0_5 import OutputTensorDescr as OutputTensor05 - -# InputTensor = Union[InputTensor04, InputTensor05] -# OutputTensor = Union[OutputTensor04, OutputTensor05] - - -# # -# # helper functions to transform input images / output tensors to the required axes -# # - - -# def transpose_image(image: NDArray[Any], desired_axes: str, current_axes: Optional[str] = None) -> NDArray[Any]: -# """Transform an image to match desired axes. - -# Args: -# image: the input image -# desired_axes: the desired image axes -# current_axes: the axes of the input image -# """ -# # if the image axes are not given deduce them from the required axes and image shape -# if current_axes is None: -# has_z_axis = "z" in desired_axes -# ndim = image.ndim -# if ndim == 2: -# current_axes = "yx" -# elif ndim == 3: -# current_axes = "zyx" if has_z_axis else "cyx" -# elif ndim == 4: -# current_axes = "czyx" -# elif ndim == 5: -# current_axes = "bczyx" -# else: -# raise ValueError(f"Invalid number of image dimensions: {ndim}") - -# tensor = DataArray(image, dims=tuple(current_axes)) -# # expand the missing image axes -# missing_axes = tuple(set(desired_axes) - set(current_axes)) -# tensor = tensor.expand_dims(dim=missing_axes) -# # transpose to the correct axis order -# tensor = tensor.transpose(*tuple(desired_axes)) -# # return numpy array -# ret: NDArray[Any] = tensor.values -# return ret - - -# # -# # helper functions for loading and saving images -# # - - -# def load_image(in_path, axes: Sequence[str]) -> DataArray: -# ext = os.path.splitext(in_path)[1] -# if ext == ".npy": -# im = load_array(in_path) -# else: -# is_volume = "z" in axes -# im = imageio.volread(in_path) if is_volume else imageio.imread(in_path) -# im = transpose_image(im, axes) -# return DataArray(im, dims=axes) - - -# def load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]]) -> List[DataArray]: -# return [load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)] - - -# # -# # helper function for padding -# # - - -# def pad(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, Dict[str, slice]]: -# assert image.ndim == len(axes), f"{image.ndim}, {len(axes)}" - -# padding_ = deepcopy(padding) -# mode = padding_.pop("mode", "dynamic") -# assert mode in ("dynamic", "fixed") - -# is_volume = "z" in axes -# if is_volume: -# assert len(padding_) == 3 -# else: -# assert len(padding_) == 2 - -# if isinstance(pad_right, bool): -# pad_right = len(axes) * [pad_right] - -# pad_width = [] -# crop = {} -# for ax, dlen, pr in zip(axes, image.shape, pad_right): -# if ax in "zyx": -# pad_to = padding_[ax] - -# if mode == "dynamic": -# r = dlen % pad_to -# pwidth = 0 if r == 0 else (pad_to - r) -# else: -# if pad_to < dlen: -# msg = f"Padding for axis {ax} failed; pad shape {pad_to} is smaller than the image shape {dlen}." -# raise RuntimeError(msg) -# pwidth = pad_to - dlen - -# pad_width.append([0, pwidth] if pr else [pwidth, 0]) -# crop[ax] = slice(0, dlen) if pr else slice(pwidth, None) -# else: -# pad_width.append([0, 0]) -# crop[ax] = slice(None) - -# image = np.pad(image, pad_width, mode="symmetric") -# return image, crop +# TODO: update + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union + +import imageio +import numpy as np +import xarray as xr +from numpy.typing import NDArray +from typing_extensions import assert_never + +from bioimageio.spec.model import v0_4 +from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr04 +from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensorDescr04 +from bioimageio.spec.model.v0_5 import ( + AnyAxis, + AxisId, + BatchAxis, + ChannelAxis, + Identifier, + InputAxis, + InputTensorDescr, + OutputTensorDescr, + SpaceInputAxis, + convert_axes, +) +from bioimageio.spec.utils import load_array, save_array + +InputTensor = Union[InputTensorDescr04, InputTensorDescr] +OutputTensor = Union[OutputTensorDescr04, OutputTensorDescr] + + +def transpose_image( + image: NDArray[Any], + desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]], + current_axes: Optional[Union[v0_4.AxesStr, Sequence[AnyAxis]]] = None, +) -> xr.DataArray: + """Transpose an image to match desired axes. + + Args: + image: the input image + desired_axes: the desired image axes + current_axes: the axes of the input image + """ + # if the image axes are not given deduce them from the required axes and image shape + if current_axes is None: + if isinstance(desired_axes, str): + desired_space_axes = [a for a in desired_axes if a in "zyx"] + else: + desired_space_axes = [a for a in desired_axes if a.type == "space"] + + ndim = image.ndim + if ndim == 2 and len(desired_space_axes) >= 2: + current_axes = ( + SpaceInputAxis(id=AxisId("y"), size=image.shape[0]), + SpaceInputAxis(id=AxisId("x"), size=image.shape[1]), + ) + elif ndim == 3 and len(desired_space_axes) == 2: + current_axes = ( + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(image.shape[0])]), + SpaceInputAxis(id=AxisId("y"), size=image.shape[1]), + SpaceInputAxis(id=AxisId("x"), size=image.shape[2]), + ) + elif ndim == 3 and len(desired_space_axes) == 3: + current_axes = ( + SpaceInputAxis(id=AxisId("z"), size=image.shape[0]), + SpaceInputAxis(id=AxisId("y"), size=image.shape[1]), + SpaceInputAxis(id=AxisId("x"), size=image.shape[2]), + ) + elif ndim == 4: + current_axes = ( + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(image.shape[0])]), + SpaceInputAxis(id=AxisId("z"), size=image.shape[1]), + SpaceInputAxis(id=AxisId("y"), size=image.shape[2]), + SpaceInputAxis(id=AxisId("x"), size=image.shape[3]), + ) + elif ndim == 5: + current_axes = ( + BatchAxis(), + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(image.shape[1])]), + SpaceInputAxis(id=AxisId("z"), size=image.shape[2]), + SpaceInputAxis(id=AxisId("y"), size=image.shape[3]), + SpaceInputAxis(id=AxisId("x"), size=image.shape[4]), + ) + else: + raise ValueError(f"Could not guess a mapping of {image.shape} to {desired_axes}") + + current_axes_ids = tuple(current_axes) if isinstance(current_axes, str) else tuple(a.id for a in current_axes) + desired_axes_ids = tuple(desired_axes) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes) + tensor = xr.DataArray(image, dims=current_axes_ids) + # expand the missing image axes + missing_axes = tuple(set(desired_axes_ids) - set(current_axes_ids)) + tensor = tensor.expand_dims(dim=missing_axes) + # transpose to the correct axis order + return tensor.transpose(*tuple(desired_axes_ids)) + + +def convert_axes_for_known_shape(axes: v0_4.AxesStr, shape: Sequence[int]): + return convert_axes(axes, shape=shape, tensor_type="input", halo=None, size_refs={}) + + +def load_tensor( + path: Path, + desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]], + current_axes: Optional[Union[v0_4.AxesStr, Sequence[AnyAxis]]] = None, +) -> xr.DataArray: + + ext = path.suffix + if ext == ".npy": + im = load_array(path) + else: + guess_axes = current_axes or desired_axes + if isinstance(guess_axes, str): + is_volume = "z" in guess_axes or "t" in guess_axes + else: + is_volume = len([a for a in guess_axes if a.type in ("time", "space")]) > 2 + + im = imageio.volread(path) if is_volume else imageio.imread(path) + im = transpose_image(im, desired_axes=desired_axes, current_axes=current_axes) + + return xr.DataArray( + im, dims=tuple(desired_axes) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes) + ) + + +def pad( + tensor: xr.DataArray, + pad_with: Mapping[AxisId, Union[int, Tuple[int, int]]], + mode: Literal["edge", "reflect", "symmetric"] = "symmetric", +): + return tensor.pad(pad_with=pad_with, mode=mode) + + +def pad_to( + tensor: xr.DataArray, + sizes: Mapping[AxisId, int], + pad_where: Union[ + Literal["before", "center", "after"], Mapping[AxisId, Literal["before", "center", "after"]] + ] = "center", + mode: Literal["edge", "reflect", "symmetric"] = "symmetric", +): + """pad `tensor` to match `shape`""" + if isinstance(pad_where, str): + pad_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { + AxisId(str(a)): pad_where for a in tensor.dims + } + else: + pad_axis_where = pad_where + + pad_with: Dict[AxisId, Union[int, Tuple[int, int]]] = {} + for a, s_is in tensor.sizes.items(): + a = AxisId(str(a)) + if a not in sizes or sizes[a] == s_is: + pad_with[a] = 0 + elif s_is < sizes[a]: + raise ValueError(f"Cannot pad axis {a} of size {s_is} to smaller size {sizes[a]}") + elif a not in pad_axis_where: + raise ValueError(f"Don't know where to pad axis {a}, `pad_where`={pad_where}") + else: + pad_this_axis_where = pad_axis_where[a] + p = sizes[a] - s_is + if pad_this_axis_where == "before": + pad_with[a] = (p, 0) + elif pad_this_axis_where == "after": + pad_with[a] = (0, p) + elif pad_this_axis_where == "center": + pad_with[a] = (left := p // 2, p - left) + else: + assert_never(pad_this_axis_where) + + return pad(tensor, pad_with, mode) + + +def pad_old(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, Dict[str, slice]]: + assert image.ndim == len(axes), f"{image.ndim}, {len(axes)}" + + padding_ = deepcopy(padding) + mode = padding_.pop("mode", "dynamic") + assert mode in ("dynamic", "fixed") + + is_volume = "z" in axes + if is_volume: + assert len(padding_) == 3 + else: + assert len(padding_) == 2 + + if isinstance(pad_right, bool): + pad_right = len(axes) * [pad_right] + + pad_width: Sequence[Tuple[int, int]] = [] + crop = {} + for ax, dlen, pr in zip(axes, image.shape, pad_right): + if ax in "zyx": + pad_to = padding_[ax] + + if mode == "dynamic": + r = dlen % pad_to + pwidth = 0 if r == 0 else (pad_to - r) + else: + if pad_to < dlen: + msg = f"Padding for axis {ax} failed; pad shape {pad_to} is smaller than the image shape {dlen}." + raise RuntimeError(msg) + pwidth = pad_to - dlen + + pad_width.append([0, pwidth] if pr else [pwidth, 0]) + crop[ax] = slice(0, dlen) if pr else slice(pwidth, None) + else: + pad_width.append([0, 0]) + crop[ax] = slice(None) + + image = np.pad(image, pad_width, mode="symmetric") + return image, crop From 92d4373e38df0bca93c1463075f15ba0663590e6 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 7 Mar 2024 17:02:35 +0100 Subject: [PATCH 106/244] add model inference tests for various test input sizes --- bioimageio/core/__init__.py | 7 +- bioimageio/core/__main__.py | 4 +- bioimageio/core/_resource_tests.py | 118 +++++++++++------- bioimageio/core/common.py | 2 +- .../model_adapters/_onnx_model_adapter.py | 3 - bioimageio/core/utils/__init__.py | 6 + bioimageio/core/utils/image_helper.py | 16 +-- pyproject.toml | 9 ++ tests/conftest.py | 4 +- tests/test_cli.py | 12 +- tests/test_resource_tests.py | 32 ----- tests/test_stat_measures.py | 17 ++- 12 files changed, 120 insertions(+), 110 deletions(-) diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 4a0846d4..5cb579b5 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -4,12 +4,6 @@ import json -from bioimageio.core.utils import files - -with files("bioimageio.core").joinpath("VERSION").open("r", encoding="utf-8") as f: - __version__: str = json.load(f)["version"] - assert isinstance(__version__, str) - from bioimageio.spec import build_description as build_description from bioimageio.spec import dump_description as dump_description from bioimageio.spec import load_description as load_description @@ -24,5 +18,6 @@ from ._resource_tests import load_description_and_test as load_description_and_test from ._resource_tests import test_description as test_description from ._resource_tests import test_model as test_model +from .utils import VERSION as __version__ test_resource = test_description diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index 9e767ef1..237bd782 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -63,8 +63,8 @@ def package( "--weights-priority-order", "-wpo", help="For model packages only. " - "If given, only the first matching weights entry is included. " - "Defaults to including all weights present in source.", + + "If given, only the first matching weights entry is included. " + + "Defaults to including all weights present in source.", show_default=False, ), ] = None, diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index d4138469..e1b8c586 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -1,35 +1,37 @@ import traceback import warnings -from typing import List, Literal, Optional, Union +from typing import List, Literal, Optional, Sequence, Set, Tuple, Union import numpy as np import xarray as xr -from bioimageio.core import __version__ as bioimageio_core_version -from bioimageio.core import create_prediction_pipeline, PredictionPipeline +from bioimageio.core._prediction_pipeline import create_prediction_pipeline +from bioimageio.core.common import AxisId, BatchSize +from bioimageio.core.utils import VERSION +from bioimageio.core.utils.image_helper import pad_to from bioimageio.spec import InvalidDescr, ResourceDescr, build_description, dump_description, load_description from bioimageio.spec._internal.common_nodes import ResourceDescrBase from bioimageio.spec._internal.io_utils import load_array -from bioimageio.spec.common import BioimageioYamlContent, FileSource +from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import WeightsFormat from bioimageio.spec.summary import ErrorEntry, InstalledPackage, ValidationDetail, ValidationSummary def test_model( - source: FileSource, + source: PermissiveFileSource, weight_format: Optional[WeightsFormat] = None, devices: Optional[List[str]] = None, decimal: int = 4, ) -> ValidationSummary: - """Test whether the test output(s) of a model can be reproduced.""" + """Test model inference""" return test_description( source, weight_format=weight_format, devices=devices, decimal=decimal, expected_type="model" ) def test_description( - source: Union[ResourceDescr, FileSource, BioimageioYamlContent], + source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], *, format_version: Union[Literal["discover", "latest"], str] = "discover", weight_format: Optional[WeightsFormat] = None, @@ -37,7 +39,7 @@ def test_description( decimal: int = 4, expected_type: Optional[str] = None, ) -> ValidationSummary: - """Test RDF dynamically, e.g. model inference of test inputs""" + """Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models""" rd = load_description_and_test( source, format_version=format_version, @@ -50,7 +52,7 @@ def test_description( def load_description_and_test( - source: Union[ResourceDescr, FileSource, BioimageioYamlContent], + source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], *, format_version: Union[Literal["discover", "latest"], str] = "discover", weight_format: Optional[WeightsFormat] = None, @@ -74,20 +76,24 @@ def load_description_and_test( else: rd = load_description(source, format_version=format_version) - rd.validation_summary.env.append(InstalledPackage(name="bioimageio.core", version=bioimageio_core_version)) + rd.validation_summary.env.append(InstalledPackage(name="bioimageio.core", version=VERSION)) if expected_type is not None: _test_expected_resource_type(rd, expected_type) if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): - _test_model_inference(rd, weight_format, devices, decimal) - if not isinstance(rd, v0_4.ModelDescr): - _test_model_inference_with_parametrized_inputs(rd, weight_format, devices) + if isinstance(rd, v0_4.ModelDescr): + _test_model_inference_v0_4(rd, weight_format, devices, decimal) + else: + _test_model_inference_impl(rd, weight_format, devices) + + # TODO: add execution of jupyter notebooks + # TODO: add more tests return rd -def _test_model_inference( +def _test_model_inference_v0_4( model: Union[v0_4.ModelDescr, v0_5.ModelDescr], weight_format: Optional[WeightsFormat], devices: Optional[List[str]], @@ -115,9 +121,7 @@ def _test_model_inference( results = prediction_pipeline.forward(*inputs) if len(results) != len(expected): - error = (error or "") + ( - f"Expected {len(expected)} outputs, but got {len(results)}" - ) + error = (error or "") + (f"Expected {len(expected)} outputs, but got {len(results)}") else: for res, exp in zip(results, expected): try: @@ -147,65 +151,91 @@ def _test_model_inference( ) ) -def _test_model_inference_with_parametrized_inputs( + +def _test_model_inference_impl( model: v0_5.ModelDescr, weight_format: Optional[WeightsFormat], devices: Optional[List[str]], + test_cases: Sequence[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = ((0, 1), (1, 3), (2, 1), (3, 2)), ) -> None: if not any(isinstance(a.size, v0_5.ParameterizedSize) for ipt in model.inputs for a in ipt.axes): return - error: Optional[str] = None - tb: List[str] = [] try: test_inputs = [ xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(a.id for a in d.axes)) for d in model.inputs ] - def generate_test_cases(): - for n in [0, 1, 2, 3]: + def generate_test_cases(): + tested: Set[str] = set() + for n, batch_size in test_cases: + target_sizes = model.get_tensor_sizes(n, batch_size=batch_size) + hashable_target_size = str(target_sizes) + if hashable_target_size in tested: + continue + else: + tested.add(hashable_target_size) + resized_test_inputs = [ + pad_to(t, target_sizes[t_descr.id]) for t, t_descr in zip(test_inputs, model.inputs) + ] + expected_output_shapes = [target_sizes[t_descr.id] for t_descr in model.outputs] + yield n, batch_size, resized_test_inputs, expected_output_shapes with create_prediction_pipeline( bioimageio_model=model, devices=devices, weight_format=weight_format ) as prediction_pipeline: - for n, inputs, exptected_output_shape in generate_test_cases(): - results = prediction_pipeline.forward(*inputs) + for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): + error: Optional[str] = None + results = prediction_pipeline.forward(*inputs) if len(results) != len(exptected_output_shape): - error = (error or "") + ( - f"Expected {len(exptected_output_shape)} outputs, but got {len(results)}" - ) + error = (error or "") + (f"Expected {len(exptected_output_shape)} outputs, but got {len(results)}") else: for res, exp in zip(results, exptected_output_shape): - if res.shape != exp: - error = (error or "") + f"(n={n}) Expected output shape {exptected_output_shape}, but got {res.shape}\n" - - if error: - break + if diff := {a: s for a, s in res.sizes.items() if s != exp[AxisId(str(a))]}: + error = ( + (error or "") + + f"(n={n}) Expected output shape {exp}," + + f" but got {exptected_output_shape} ({diff})\n" + ) + + model.validation_summary.add_detail( + ValidationDetail( + name="Reproduce test outputs from test inputs with batch_size:" + + f" {batch_size} and size parameter n: {n}", + status="passed" if error is None else "failed", + errors=( + [] + if error is None + else [ + ErrorEntry( + loc=("weights",) if weight_format is None else ("weights", weight_format), + msg=error, + type="bioimageio.core", + ) + ] + ), + ) + ) except Exception as e: error = str(e) tb = traceback.format_tb(e.__traceback__) - - model.validation_summary.add_detail( - ValidationDetail( - name="Reproduce test outputs from test inputs", - status="passed" if error is None else "failed", - errors=( - [] - if error is None - else [ + model.validation_summary.add_detail( + ValidationDetail( + name="Reproduce test outputs from test inputs", + status="failed", + errors=[ ErrorEntry( loc=("weights",) if weight_format is None else ("weights", weight_format), msg=error, type="bioimageio.core", traceback=tb, ) - ] - ), + ], + ) ) - ) def _test_expected_resource_type(rd: Union[InvalidDescr, ResourceDescr], expected_type: str): diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 1981f2c5..915d1ca0 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -10,7 +10,7 @@ TensorId = v0_5.TensorId AxisId = v0_5.AxisId - +BatchSize = int Tensor = xr.DataArray Data = Dict[TensorId, Tensor] diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py index d1f51946..b3a632b3 100644 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -1,4 +1,3 @@ -import logging import warnings from typing import Any, List, Optional, Sequence, Union @@ -14,8 +13,6 @@ except Exception: rt = None -logger = logging.getLogger(__name__) - class ONNXModelAdapter(ModelAdapter): def __init__( diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py index 7126bd75..4037be8a 100644 --- a/bioimageio/core/utils/__init__.py +++ b/bioimageio/core/utils/__init__.py @@ -1,3 +1,4 @@ +import json import sys from pathlib import Path @@ -13,3 +14,8 @@ def files(package_name: str): else: from importlib.resources import files as files + + +with files("bioimageio.core").joinpath("VERSION").open("r", encoding="utf-8") as f: + VERSION = json.load(f)["version"] + assert isinstance(VERSION, str) diff --git a/bioimageio/core/utils/image_helper.py b/bioimageio/core/utils/image_helper.py index a5d0a5f1..3dac1772 100644 --- a/bioimageio/core/utils/image_helper.py +++ b/bioimageio/core/utils/image_helper.py @@ -126,10 +126,10 @@ def load_tensor( def pad( tensor: xr.DataArray, - pad_with: Mapping[AxisId, Union[int, Tuple[int, int]]], + pad_width: Mapping[AxisId, Union[int, Tuple[int, int]]], mode: Literal["edge", "reflect", "symmetric"] = "symmetric", ): - return tensor.pad(pad_with=pad_with, mode=mode) + return tensor.pad(pad_width=pad_width, mode=mode) def pad_to( @@ -148,11 +148,11 @@ def pad_to( else: pad_axis_where = pad_where - pad_with: Dict[AxisId, Union[int, Tuple[int, int]]] = {} + pad_width: Dict[AxisId, Union[int, Tuple[int, int]]] = {} for a, s_is in tensor.sizes.items(): a = AxisId(str(a)) if a not in sizes or sizes[a] == s_is: - pad_with[a] = 0 + pad_width[a] = 0 elif s_is < sizes[a]: raise ValueError(f"Cannot pad axis {a} of size {s_is} to smaller size {sizes[a]}") elif a not in pad_axis_where: @@ -161,15 +161,15 @@ def pad_to( pad_this_axis_where = pad_axis_where[a] p = sizes[a] - s_is if pad_this_axis_where == "before": - pad_with[a] = (p, 0) + pad_width[a] = (p, 0) elif pad_this_axis_where == "after": - pad_with[a] = (0, p) + pad_width[a] = (0, p) elif pad_this_axis_where == "center": - pad_with[a] = (left := p // 2, p - left) + pad_width[a] = (left := p // 2, p - left) else: assert_never(pad_this_axis_where) - return pad(tensor, pad_with, mode) + return pad(tensor, pad_width, mode) def pad_old(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, Dict[str, slice]]: diff --git a/pyproject.toml b/pyproject.toml index 75ff4e09..db0f8626 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,16 +7,25 @@ exclude = ["**/node_modules", "**/__pycache__", "tests/old_*"] include = ["bioimageio", "scripts", "tests"] pythonPlatform = "All" pythonVersion = "3.8" +reportDuplicateImport = "error" +reportImplicitStringConcatenation = "warning" reportIncompatibleMethodOverride = true +reportMatchNotExhaustive = "error" reportMissingSuperCall = "error" reportMissingTypeArgument = true reportMissingTypeStubs = "warning" +reportPropertyTypeMismatch = "error" reportUninitializedInstanceVariable = "error" reportUnknownMemberType = false reportUnnecessaryIsInstance = false reportUnnecessaryTypeIgnoreComment = "error" +reportUnsupportedDunderAll = "error" reportUnusedCallResult = "error" +reportUnusedClass = "error" +reportUnusedExpression = "error" +reportUnusedFunction = "error" reportUnusedVariable = "error" +reportWildcardImportFromLibrary = "error" typeCheckingMode = "strict" useLibraryCodeForTypes = true diff --git a/tests/conftest.py b/tests/conftest.py index 2355c48c..9c31410d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,12 @@ from __future__ import annotations -import logging import os import subprocess import warnings from types import MappingProxyType from typing import List, Set +from loguru import logger from pydantic import FilePath from pytest import FixtureRequest, fixture @@ -14,7 +14,6 @@ from bioimageio.spec import __version__ as bioimageio_spec_version from bioimageio.spec._package import save_bioimageio_package -logger = logging.getLogger(__name__) warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}") # test models for various frameworks @@ -131,6 +130,7 @@ load_model_packages |= set(KERAS_TF2_MODELS) load_model_packages |= set(TENSORFLOW2_MODELS) + @fixture(scope="session") def model_packages() -> MappingProxyType[str, FilePath]: return MappingProxyType({name: save_bioimageio_package(MODEL_SOURCES[name]) for name in load_model_packages}) diff --git a/tests/test_cli.py b/tests/test_cli.py index d70f21d8..967d5d80 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -5,6 +5,7 @@ import numpy as np import pytest +from pydantic import FilePath from bioimageio.core import load_description @@ -13,9 +14,6 @@ def run_subprocess(commands: Sequence[str], **kwargs: Any) -> "subprocess.Comple return subprocess.run(commands, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8", **kwargs) -FIXTURES = {"unet2d_nuclei_broad_model"} - - @pytest.mark.parametrize( "args", [ @@ -25,15 +23,15 @@ def run_subprocess(commands: Sequence[str], **kwargs: Any) -> "subprocess.Comple ["test-model", "unet2d_nuclei_broad_model"], ], ) -def test_cli(args: List[str], request: pytest.FixtureRequest): - resolved_args = [str(request.getfixturevalue(arg)) if arg in FIXTURES else arg for arg in args] +def test_cli(args: List[str], unet2d_nuclei_broad_model: FilePath): + resolved_args = [str(unet2d_nuclei_broad_model) if arg == "unet2d_nuclei_broad_model" else arg for arg in args] ret = run_subprocess(["bioimageio", *resolved_args]) assert ret.returncode == 0, ret.stdout @pytest.mark.parametrize("args", [["test-model", "stardist_wrong_shape"]]) -def test_cli_fails(args: List[str], request: pytest.FixtureRequest): - resolved_args = [str(request.getfixturevalue(arg)) if arg in FIXTURES else arg for arg in args] +def test_cli_fails(args: List[str], stardist_wrong_shape: FilePath): + resolved_args = [str(stardist_wrong_shape) if arg == "stardist_wrong_shape" else arg for arg in args] ret = run_subprocess(["bioimageio", *resolved_args]) assert ret.returncode == 1, ret.stdout diff --git a/tests/test_resource_tests.py b/tests/test_resource_tests.py index 9f69721a..36bdcc5c 100644 --- a/tests/test_resource_tests.py +++ b/tests/test_resource_tests.py @@ -38,35 +38,3 @@ def test_test_resource(any_model: Path): summary = test_description(any_model) assert summary.status == "passed" - - -def test_validation_section_warning(unet2d_nuclei_broad_model: str, tmp_path: Path): - from bioimageio.core import load_description - from bioimageio.core._resource_tests import test_description - - model = load_description(unet2d_nuclei_broad_model) - assert not isinstance(model, InvalidDescr) - summary = test_description(model) - assert summary.name == "Test documentation completeness." - assert summary.warnings == {"documentation": "No '# Validation' (sub)section found."} - assert summary.status == "passed" - - doc_with_validation = tmp_path / "doc.md" - _ = doc_with_validation.write_text("# Validation\nThis is a section about how to validate the model on new data") - model.documentation = doc_with_validation - summary = test_description(model) - assert summary.name == "Test documentation completeness." - assert summary.warnings == {} - assert summary.status == "passed" - - -def test_issue289(unet2d_nuclei_broad_model: str): - """test for failure case from https://github.com/bioimage-io/core-bioimage-io-python/issues/289""" - # remote model is a pytorch model, needing unet2d_nuclei_broad_model skips the test when needed - _ = unet2d_nuclei_broad_model - - from bioimageio.core._resource_tests import test_model - - doi = "10.5281/zenodo.6287342" - summary = test_model(doi) - assert summary.status == "passed" diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index 7845da89..7e8581a9 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -1,5 +1,5 @@ from itertools import product -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple import numpy as np import pytest @@ -11,13 +11,20 @@ from bioimageio.core.stat_measures import SamplePercentile -@pytest.mark.parametrize("name, axes", product(["mean", "var", "std"], [None, (AxisId("x"), AxisId("y"))])) -def test_individual_normal_measure(name: str, axes: Optional[Tuple[AxisId, AxisId]]): - measure = getattr(stat_measures, name.title() + "Measure")(axes=axes) +@pytest.mark.parametrize( + "name,sample_or_dataset,axes", + product(["mean", "var", "std"], ["Sample", "Dataset"], [None, (AxisId("x"), AxisId("y"))]), +) +def test_individual_normal_measure( + name: str, sample_or_dataset: Literal["Sample", "Dataset"], axes: Optional[Tuple[AxisId, AxisId]] +): + data_id = TensorId("test_data") + measure = getattr(stat_measures, sample_or_dataset + name.title())(axes=axes, tensor_id=data_id) data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) expected = getattr(data, name)(dim=axes) - actual = measure.compute(data) + sample = Sample(data={data_id: data}) + actual = measure.compute(sample) xr.testing.assert_allclose(expected, actual) From 930f9544b9d9369c1d6c6b755b808cbc07930f5f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 11 Mar 2024 10:27:57 +0100 Subject: [PATCH 107/244] WIP align with current spec --- bioimageio/core/_prediction_pipeline.py | 12 +- bioimageio/core/proc_ops.py | 4 +- bioimageio/core/proc_setup.py | 7 +- bioimageio/core/utils/_digest_spec.py | 2 +- bioimageio/core/utils/image_helper.py | 111 ++++++++++-------- pyproject.toml | 2 +- scripts/show_diff.py | 2 +- tests/conftest.py | 26 ++-- tests/test_cli.py | 7 +- tests/test_prediction.py | 49 ++++---- ...t_prediction_pipeline_device_management.py | 41 +++---- tests/utils/test_image_helper.py | 6 +- .../weight_converter/keras/test_tensorflow.py | 1 + tests/weight_converter/torch/test_onnx.py | 1 + .../torch/test_torchscript.py | 1 + 15 files changed, 129 insertions(+), 143 deletions(-) diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 4f7db9e2..912aa9dd 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -37,8 +37,8 @@ def __init__( self._preprocessing = preprocessing self._postprocessing = postprocessing if isinstance(bioimageio_model, v0_4.ModelDescr): - self._input_ids = [TensorId(d.name) for d in bioimageio_model.inputs] - self._output_ids = [TensorId(d.name) for d in bioimageio_model.outputs] + self._input_ids = [TensorId(str(d.name)) for d in bioimageio_model.inputs] + self._output_ids = [TensorId(str(d.name)) for d in bioimageio_model.outputs] else: self._input_ids = [d.id for d in bioimageio_model.inputs] self._output_ids = [d.id for d in bioimageio_model.outputs] @@ -58,7 +58,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore def predict(self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray) -> List[xr.DataArray]: """Predict input_tensor with the model without applying pre/postprocessing.""" - named_tensors = [named_input_tensors[k] for k in self._input_ids[len(input_tensors) :]] + named_tensors = [named_input_tensors[str(k)] for k in self._input_ids[len(input_tensors) :]] return self._adapter.forward(*input_tensors, *named_tensors) def apply_preprocessing(self, sample: Sample) -> None: @@ -71,11 +71,11 @@ def apply_postprocessing(self, sample: Sample) -> None: for op in self._postprocessing: op(sample) - def forward_sample(self, input_sample: Sample): + def forward_sample(self, input_sample: Sample) -> Sample: """Apply preprocessing, run prediction and apply postprocessing.""" self.apply_preprocessing(input_sample) - prediction_tensors = self.predict(**input_sample.data) + prediction_tensors = self.predict(**{str(k): v for k, v in input_sample.data.items()}) prediction = Sample(data=dict(zip(self._output_ids, prediction_tensors)), stat=input_sample.stat) self.apply_postprocessing(prediction) return prediction @@ -142,7 +142,7 @@ def create_prediction_pipeline( ) if isinstance(bioimageio_model, v0_4.ModelDescr): - input_ids = [TensorId(ipt.name) for ipt in bioimageio_model.inputs] + input_ids = [TensorId(str(ipt.name)) for ipt in bioimageio_model.inputs] else: input_ids = [ipt.id for ipt in bioimageio_model.inputs] diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 8a7b15f6..7c179e28 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -3,7 +3,6 @@ from dataclasses import InitVar, dataclass, field from typing import ( Collection, - Hashable, Literal, Mapping, Optional, @@ -302,7 +301,7 @@ def from_proc_descr( return cls( input=tensor_id, output=tensor_id, - reference_tensor=cast(TensorId, kwargs.reference_tensor), + reference_tensor=TensorId(str(kwargs.reference_tensor)), axes=axes, eps=kwargs.eps, ) @@ -556,4 +555,3 @@ def get_proc_class(proc_spec: ProcDescr): return ZeroMeanUnitVariance else: assert_never(proc_spec) - diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index a71ba023..a375a2b7 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -7,7 +7,6 @@ Sequence, Set, Union, - cast, ) from typing_extensions import assert_never @@ -77,8 +76,8 @@ def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcess post_measures: Set[Measure] = set() if isinstance(model, v0_4.ModelDescr): - input_ids = {TensorId(d.name) for d in model.inputs} - output_ids = {TensorId(d.name) for d in model.outputs} + input_ids = {TensorId(str(d.name)) for d in model.inputs} + output_ids = {TensorId(str(d.name)) for d in model.outputs} else: input_ids = {d.id for d in model.inputs} output_ids = {d.id for d in model.outputs} @@ -98,7 +97,7 @@ def prepare_procs(tensor_descrs: Sequence[TensorDescr]): for proc_d in proc_descrs: proc_class = get_proc_class(proc_d) - tensor_id = cast(TensorId, t_descr.name) if isinstance(t_descr, v0_4.TensorDescrBase) else t_descr.id + tensor_id = TensorId(str(t_descr.name)) if isinstance(t_descr, v0_4.TensorDescrBase) else t_descr.id req = proc_class.from_proc_descr(proc_d, tensor_id) # pyright: ignore[reportArgumentType] for m in req.required_measures: if m.tensor_id in input_ids: diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py index 42ba8974..ad41789f 100644 --- a/bioimageio/core/utils/_digest_spec.py +++ b/bioimageio/core/utils/_digest_spec.py @@ -2,7 +2,7 @@ import xarray as xr -from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 +from bioimageio.spec.model import AnyModelDescr, v0_4 from bioimageio.spec.utils import load_array diff --git a/bioimageio/core/utils/image_helper.py b/bioimageio/core/utils/image_helper.py index 3dac1772..80303260 100644 --- a/bioimageio/core/utils/image_helper.py +++ b/bioimageio/core/utils/image_helper.py @@ -31,69 +31,76 @@ OutputTensor = Union[OutputTensorDescr04, OutputTensorDescr] -def transpose_image( - image: NDArray[Any], +def interprete_array( + nd_array: NDArray[Any], + desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]], +) -> xr.DataArray: + if isinstance(desired_axes, str): + desired_space_axes = [a for a in desired_axes if a in "zyx"] + else: + desired_space_axes = [a for a in desired_axes if a.type == "space"] + + ndim = nd_array.ndim + if ndim == 2 and len(desired_space_axes) >= 2: + current_axes = ( + SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[0]), + SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[1]), + ) + elif ndim == 3 and len(desired_space_axes) == 2: + current_axes = ( + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[0])]), + SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[1]), + SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[2]), + ) + elif ndim == 3 and len(desired_space_axes) == 3: + current_axes = ( + SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[0]), + SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[1]), + SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[2]), + ) + elif ndim == 4: + current_axes = ( + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[0])]), + SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[1]), + SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[2]), + SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[3]), + ) + elif ndim == 5: + current_axes = ( + BatchAxis(), + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[1])]), + SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[2]), + SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[3]), + SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[4]), + ) + else: + raise ValueError(f"Could not guess a mapping of {nd_array.shape} to {desired_axes}") + + current_axes_ids = tuple(current_axes) if isinstance(current_axes, str) else tuple(a.id for a in current_axes) + return xr.DataArray(nd_array, dims=current_axes_ids) + + +def transpose_array( + arary: xr.DataArray, desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]], current_axes: Optional[Union[v0_4.AxesStr, Sequence[AnyAxis]]] = None, ) -> xr.DataArray: """Transpose an image to match desired axes. Args: - image: the input image + array: the input array desired_axes: the desired image axes current_axes: the axes of the input image """ - # if the image axes are not given deduce them from the required axes and image shape - if current_axes is None: - if isinstance(desired_axes, str): - desired_space_axes = [a for a in desired_axes if a in "zyx"] - else: - desired_space_axes = [a for a in desired_axes if a.type == "space"] - - ndim = image.ndim - if ndim == 2 and len(desired_space_axes) >= 2: - current_axes = ( - SpaceInputAxis(id=AxisId("y"), size=image.shape[0]), - SpaceInputAxis(id=AxisId("x"), size=image.shape[1]), - ) - elif ndim == 3 and len(desired_space_axes) == 2: - current_axes = ( - ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(image.shape[0])]), - SpaceInputAxis(id=AxisId("y"), size=image.shape[1]), - SpaceInputAxis(id=AxisId("x"), size=image.shape[2]), - ) - elif ndim == 3 and len(desired_space_axes) == 3: - current_axes = ( - SpaceInputAxis(id=AxisId("z"), size=image.shape[0]), - SpaceInputAxis(id=AxisId("y"), size=image.shape[1]), - SpaceInputAxis(id=AxisId("x"), size=image.shape[2]), - ) - elif ndim == 4: - current_axes = ( - ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(image.shape[0])]), - SpaceInputAxis(id=AxisId("z"), size=image.shape[1]), - SpaceInputAxis(id=AxisId("y"), size=image.shape[2]), - SpaceInputAxis(id=AxisId("x"), size=image.shape[3]), - ) - elif ndim == 5: - current_axes = ( - BatchAxis(), - ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(image.shape[1])]), - SpaceInputAxis(id=AxisId("z"), size=image.shape[2]), - SpaceInputAxis(id=AxisId("y"), size=image.shape[3]), - SpaceInputAxis(id=AxisId("x"), size=image.shape[4]), - ) - else: - raise ValueError(f"Could not guess a mapping of {image.shape} to {desired_axes}") - current_axes_ids = tuple(current_axes) if isinstance(current_axes, str) else tuple(a.id for a in current_axes) - desired_axes_ids = tuple(desired_axes) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes) - tensor = xr.DataArray(image, dims=current_axes_ids) + desired_axes_ids = ( + tuple(map(AxisId, desired_axes)) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes) + ) # expand the missing image axes - missing_axes = tuple(set(desired_axes_ids) - set(current_axes_ids)) - tensor = tensor.expand_dims(dim=missing_axes) + missing_axes = tuple(set(desired_axes_ids) - set(map(AxisId, array.dims))) + array = array.expand_dims(dim=missing_axes) # transpose to the correct axis order - return tensor.transpose(*tuple(desired_axes_ids)) + return arraytensor.transpose(*tuple(desired_axes_ids)) def convert_axes_for_known_shape(axes: v0_4.AxesStr, shape: Sequence[int]): @@ -117,7 +124,7 @@ def load_tensor( is_volume = len([a for a in guess_axes if a.type in ("time", "space")]) > 2 im = imageio.volread(path) if is_volume else imageio.imread(path) - im = transpose_image(im, desired_axes=desired_axes, current_axes=current_axes) + im = transpose_array(im, desired_axes=desired_axes, current_axes=current_axes) return xr.DataArray( im, dims=tuple(desired_axes) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes) diff --git a/pyproject.toml b/pyproject.toml index db0f8626..7d715f57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ include = ["bioimageio", "scripts", "tests"] pythonPlatform = "All" pythonVersion = "3.8" reportDuplicateImport = "error" -reportImplicitStringConcatenation = "warning" +reportImplicitStringConcatenation = "error" reportIncompatibleMethodOverride = true reportMatchNotExhaustive = "error" reportMissingSuperCall = "error" diff --git a/scripts/show_diff.py b/scripts/show_diff.py index 77623343..4a5d2223 100644 --- a/scripts/show_diff.py +++ b/scripts/show_diff.py @@ -7,7 +7,7 @@ from bioimageio.core import load_description, save_bioimageio_yaml_only if __name__ == "__main__": - rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/pydantic_axes/example_specs/models/unet2d_nuclei_broad/rdf_v0_4_9.yaml" + rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/pydantic_axes/example_descriptions/models/unet2d_nuclei_broad/rdf_v0_4_9.yaml" local_source = Path(pooch.retrieve(rdf_source, None)) # type: ignore model_as_is = load_description(rdf_source, format_version="discover") diff --git a/tests/conftest.py b/tests/conftest.py index 9c31410d..324586d5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import subprocess import warnings from types import MappingProxyType @@ -10,7 +9,6 @@ from pydantic import FilePath from pytest import FixtureRequest, fixture -os.environ["BIOIMAGEIO_COUNT_RDF_DOWNLOADS"] = "false" # disable tracking before bioimageio imports from bioimageio.spec import __version__ as bioimageio_spec_version from bioimageio.spec._package import save_bioimageio_package @@ -35,50 +33,50 @@ MODEL_SOURCES = { "unet2d_keras": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_keras_tf/rdf.yaml" ), "unet2d_keras_tf2": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_keras_tf2/rdf.yaml" ), "unet2d_nuclei_broad_model": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_nuclei_broad/rdf.yaml" ), "unet2d_expand_output_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_nuclei_broad/rdf_expand_output_shape.yaml" ), "unet2d_fixed_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_fixed_shape/rdf.yaml" ), "unet2d_multi_tensor": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_multi_tensor/rdf.yaml" ), "unet2d_diff_output_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_diff_output_shape/rdf.yaml" ), "hpa_densenet": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/hpa-densenet/rdf.yaml" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/hpa-densenet/rdf.yaml" ), "stardist": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models" "/stardist_example_model/rdf.yaml" ), "stardist_wrong_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "stardist_example_model/rdf_wrong_shape.yaml" ), "stardist_wrong_shape2": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "stardist_example_model/rdf_wrong_shape2.yaml" ), "shape_change": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "upsample_test_model/rdf.yaml" ), } diff --git a/tests/test_cli.py b/tests/test_cli.py index 967d5d80..601882b3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,14 +1,9 @@ -import os import subprocess -from pathlib import Path -from typing import Any, List, Optional, Sequence, Set +from typing import Any, List, Sequence -import numpy as np import pytest from pydantic import FilePath -from bioimageio.core import load_description - def run_subprocess(commands: Sequence[str], **kwargs: Any) -> "subprocess.CompletedProcess[str]": return subprocess.run(commands, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8", **kwargs) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index a0e34b08..a95eb3f4 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -4,8 +4,11 @@ import numpy as np from numpy.testing import assert_array_almost_equal +from bioimageio.core.utils import get_test_inputs from bioimageio.spec import load_description -from bioimageio.spec.model.v0_5 import ModelDescr +from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr_v0_4 +from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr_v0_4 +from bioimageio.spec.model.v0_5 import InputTensorDescr, ModelDescr def test_predict_image(any_model: Path, tmpdir: Path): @@ -26,7 +29,7 @@ def test_predict_image(any_model: Path, tmpdir: Path): assert_array_almost_equal(res, exp, decimal=4) -def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not, tmpdir): +def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not: Path, tmpdir: Path): from bioimageio.core.prediction import predict_image spec = load_description(unet2d_fixed_shape_or_not) @@ -44,24 +47,18 @@ def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not, tmpdir): assert_array_almost_equal(res, exp, decimal=4) -def _test_predict_with_padding(model, tmp_path): +def _test_predict_with_padding(any_model: Path, tmp_path: Path): from bioimageio.core.prediction import predict_image - spec = load_description(model) - assert isinstance(spec, Model) + model = load_description(any_model) + assert isinstance(model, (ModelDescr_v0_4, ModelDescr)) - input_spec, output_spec = spec.inputs[0], spec.outputs[0] - channel_axis = input_spec.axes.index("c") + input_spec, output_spec = model.inputs[0], model.outputs[0] + channel_axis = "c" if isinstance(input_spec, InputTensorDescr_v0_4) else [a.id for a in input_spec.axes][0] channel_first = channel_axis == 1 - image = np.load(str(spec.test_inputs[0])) - assert image.shape[channel_axis] == 1 - if channel_first: - image = image[0, 0] - else: - image = image[0, ..., 0] - original_shape = image.shape - assert image.ndim == 2 + # TODO: check more tensors + image = get_test_inputs(model)[0] if isinstance(output_spec.shape, list): n_channels = output_spec.shape[channel_axis] @@ -106,15 +103,17 @@ def check_result(): assert res.shape == exp_shape # test with dynamic padding - predict_image(model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"}) + predict_image(any_model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"}) check_result() # test with fixed padding - predict_image(model, in_path, out_path, padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"}) + predict_image( + any_model, in_path, out_path, padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"} + ) check_result() # test with automated padding - predict_image(model, in_path, out_path, padding=True) + predict_image(any_model, in_path, out_path, padding=True) check_result() @@ -133,7 +132,7 @@ def test_predict_image_with_padding_channel_last(stardist, tmp_path): _test_predict_with_padding(stardist, tmp_path) -def _test_predict_image_with_tiling(model, tmp_path: Path, exp_mean_deviation): +def _test_predict_image_with_tiling(model: Path, tmp_path: Path, exp_mean_deviation): from bioimageio.core.prediction import predict_image spec = load_description(model) @@ -166,27 +165,27 @@ def check_result(): # prediction with tiling with the parameters above may not be suited for any model # so we only run it for the pytorch unet2d here -def test_predict_image_with_tiling_1(unet2d_nuclei_broad_model, tmp_path: Path): +def test_predict_image_with_tiling_1(unet2d_nuclei_broad_model: Path, tmp_path: Path): _test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path, 0.012) -def test_predict_image_with_tiling_2(unet2d_diff_output_shape, tmp_path: Path): +def test_predict_image_with_tiling_2(unet2d_diff_output_shape: Path, tmp_path: Path): _test_predict_image_with_tiling(unet2d_diff_output_shape, tmp_path, 0.06) -def test_predict_image_with_tiling_3(shape_change_model, tmp_path: Path): +def test_predict_image_with_tiling_3(shape_change_model: Path, tmp_path: Path): _test_predict_image_with_tiling(shape_change_model, tmp_path, 0.012) -def test_predict_image_with_tiling_channel_last(stardist, tmp_path: Path): +def test_predict_image_with_tiling_channel_last(stardist: Path, tmp_path: Path): _test_predict_image_with_tiling(stardist, tmp_path, 0.13) -def test_predict_image_with_tiling_fixed_output_shape(unet2d_fixed_shape, tmp_path: Path): +def test_predict_image_with_tiling_fixed_output_shape(unet2d_fixed_shape: Path, tmp_path: Path): _test_predict_image_with_tiling(unet2d_fixed_shape, tmp_path, 0.025) -def test_predict_images(unet2d_nuclei_broad_model, tmp_path: Path): +def test_predict_images(unet2d_nuclei_broad_model: Path, tmp_path: Path): from bioimageio.core.prediction import predict_images n_images = 5 diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py index 16354d18..1236383a 100644 --- a/tests/test_prediction_pipeline_device_management.py +++ b/tests/test_prediction_pipeline_device_management.py @@ -1,14 +1,12 @@ from pathlib import Path -import numpy as np -import xarray as xr from numpy.testing import assert_array_almost_equal +from bioimageio.core import load_description +from bioimageio.core.utils import get_test_inputs, get_test_outputs from bioimageio.core.utils.testing import skip_on -from bioimageio.spec import load_description from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04 from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat -from bioimageio.spec.utils import load_array class TooFewDevicesException(Exception): @@ -27,24 +25,13 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): assert isinstance(bio_model, (ModelDescr, ModelDescr04)) pred_pipe = create_prediction_pipeline(bioimageio_model=bio_model, weight_format=weight_format, devices=["cuda:0"]) - if isinstance(bio_model, ModelDescr04): - inputs = [ - xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) - for test_tensor, spec in zip(bio_model.test_inputs, bio_model.inputs) - ] - else: - inputs = [ - xr.DataArray(load_array(ipt.test_tensor), dims=tuple(a.id for a in ipt.axes)) for ipt in bio_model.inputs - ] + inputs = get_test_inputs(bio_model) with pred_pipe as pp: outputs = pp.forward(*inputs) assert isinstance(outputs, list) - expected_outputs = [ - xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) - for test_tensor, spec in zip(bio_model.test_outputs, bio_model.outputs) - ] + expected_outputs = get_test_outputs(bio_model) assert len(outputs) == len(expected_outputs) for out, exp in zip(outputs, expected_outputs): @@ -59,26 +46,26 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): assert_array_almost_equal(out, exp, decimal=4) -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_torch(any_torch_model): +@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +def test_device_management_torch(any_torch_model: Path): _test_device_management(any_torch_model, "pytorch_state_dict") -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_torchscript(any_torchscript_model): +@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +def test_device_management_torchscript(any_torchscript_model: Path): _test_device_management(any_torchscript_model, "torchscript") -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_onnx(any_onnx_model): +@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +def test_device_management_onnx(any_onnx_model: Path): _test_device_management(any_onnx_model, "onnx") -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_tensorflow(any_tensorflow_model): +@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +def test_device_management_tensorflow(any_tensorflow_model: Path): _test_device_management(any_tensorflow_model, "tensorflow_saved_model_bundle") -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_keras(any_keras_model): +@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +def test_device_management_keras(any_keras_model: Path): _test_device_management(any_keras_model, "keras_hdf5") diff --git a/tests/utils/test_image_helper.py b/tests/utils/test_image_helper.py index 8e86a919..6e0e9c08 100644 --- a/tests/utils/test_image_helper.py +++ b/tests/utils/test_image_helper.py @@ -2,18 +2,18 @@ def test_transform_input_image(): - from bioimageio.core.utils.image_helper import transpose_image + from bioimageio.core.utils.image_helper import transpose_array ax_list = ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] im = np.random.rand(256, 256) for axes in ax_list: - inp = transpose_image(im, axes) + inp = transpose_array(im, axes) assert inp.ndim == len(axes) ax_list = ["zyx", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] vol = np.random.rand(64, 64, 64) for axes in ax_list: - inp = transpose_image(vol, axes) + inp = transpose_array(vol, axes) assert inp.ndim == len(axes) diff --git a/tests/weight_converter/keras/test_tensorflow.py b/tests/weight_converter/keras/test_tensorflow.py index 6cc42c57..069b6f23 100644 --- a/tests/weight_converter/keras/test_tensorflow.py +++ b/tests/weight_converter/keras/test_tensorflow.py @@ -1,3 +1,4 @@ +# type: ignore # TODO enable type checking import zipfile from pathlib import Path diff --git a/tests/weight_converter/torch/test_onnx.py b/tests/weight_converter/torch/test_onnx.py index c2efbcd8..a0315650 100644 --- a/tests/weight_converter/torch/test_onnx.py +++ b/tests/weight_converter/torch/test_onnx.py @@ -1,3 +1,4 @@ +# type: ignore # TODO enable type checking import os from pathlib import Path diff --git a/tests/weight_converter/torch/test_torchscript.py b/tests/weight_converter/torch/test_torchscript.py index e3f6e42c..945e778b 100644 --- a/tests/weight_converter/torch/test_torchscript.py +++ b/tests/weight_converter/torch/test_torchscript.py @@ -1,3 +1,4 @@ +# type: ignore # TODO enable type checking from pathlib import Path import pytest From baa563fe62904e8af53781645b7ee78bc2b53426 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 11 Mar 2024 10:29:15 +0100 Subject: [PATCH 108/244] update release.yml --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 60223279..e4333b5c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -35,7 +35,7 @@ jobs: - name: Check if there is a parent commit id: check-parent-commit run: | - echo "::set-output name=sha::$(git rev-parse --verify --quiet HEAD^)" + echo "name=sha::$(git rev-parse --verify --quiet HEAD^)" >> $GITHUB_OUTPUT - name: Detect new version id: check-version From 676a3e7a613dcb4410c290d452234ce0b17b9df5 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 11 Mar 2024 11:46:04 +0100 Subject: [PATCH 109/244] update example rdf urls --- tests/conftest.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 324586d5..87ae4fe1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,11 +34,11 @@ MODEL_SOURCES = { "unet2d_keras": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_keras_tf/rdf.yaml" + "unet2d_keras_tf/rdf_v0_4.yaml" ), "unet2d_keras_tf2": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_keras_tf2/rdf.yaml" + "unet2d_keras_tf2/rdf_v0_4.yaml" ), "unet2d_nuclei_broad_model": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" @@ -46,26 +46,26 @@ ), "unet2d_expand_output_shape": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_nuclei_broad/rdf_expand_output_shape.yaml" + "unet2d_nuclei_broad/rdf_expand_output_shape_v0_4.bioimageio.yaml" ), "unet2d_fixed_shape": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_fixed_shape/rdf.yaml" + "unet2d_fixed_shape/rdf_v0_4.yaml" ), "unet2d_multi_tensor": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_multi_tensor/rdf.yaml" + "unet2d_multi_tensor/rdf_v0_4.yaml" ), "unet2d_diff_output_shape": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_diff_output_shape/rdf.yaml" + "unet2d_diff_output_shape/rdf_v0_4.yaml" ), "hpa_densenet": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/hpa-densenet/rdf.yaml" ), "stardist": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models" - "/stardist_example_model/rdf.yaml" + "/stardist_example_model/rdf_v0_4.yaml" ), "stardist_wrong_shape": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" @@ -73,11 +73,11 @@ ), "stardist_wrong_shape2": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "stardist_example_model/rdf_wrong_shape2.yaml" + "stardist_example_model/rdf_wrong_shape2_v0_4.yaml" ), "shape_change": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "upsample_test_model/rdf.yaml" + "upsample_test_model/v0_4_bioimageio.yaml" ), } From 613e1543133e91dd41643398815127d39f145b4a Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 11 Mar 2024 11:52:24 +0100 Subject: [PATCH 110/244] add pdocs docs --- .github/workflows/build.yml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 065dfa53..3495e794 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -137,3 +137,21 @@ jobs: - name: linux conda build run: | conda mambabuild -c conda-forge conda-recipe + + docs: + if: github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + steps: + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e .[dev] + - name: Generate developer docs + run: pdoc -o ./dist bioimageio.spec + - run: cp README.md ./dist/README.md + - name: Deploy to gh-pages 🚀 + uses: JamesIves/github-pages-deploy-action@v4 + with: + branch: gh-pages + folder: dist + From 76f55531f262edf41a1e49ff5ecb60571c111038 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 01:53:01 +0100 Subject: [PATCH 111/244] update release.yml --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e4333b5c..73d9b263 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -35,7 +35,7 @@ jobs: - name: Check if there is a parent commit id: check-parent-commit run: | - echo "name=sha::$(git rev-parse --verify --quiet HEAD^)" >> $GITHUB_OUTPUT + echo "sha=$(git rev-parse --verify --quiet HEAD^)" >> $GITHUB_OUTPUT - name: Detect new version id: check-version From 1a5c50bcae36f689688b6f25e4cc585083946d88 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 10:29:17 +0100 Subject: [PATCH 112/244] mark that model_creatioin.ipynb needs an update --- .../{model_creation.ipynb => model_creation.ipynb.needs_update} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename example/{model_creation.ipynb => model_creation.ipynb.needs_update} (100%) diff --git a/example/model_creation.ipynb b/example/model_creation.ipynb.needs_update similarity index 100% rename from example/model_creation.ipynb rename to example/model_creation.ipynb.needs_update From 0555bf5362804d49638b23fec8c7eb47603642fd Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 12:12:28 +0100 Subject: [PATCH 113/244] update image helpers --- bioimageio/core/_resource_tests.py | 33 ++-- bioimageio/core/common.py | 10 +- bioimageio/core/utils/_digest_spec.py | 4 +- bioimageio/core/utils/image_helper.py | 219 +++++++++++++++----------- scripts/show_diff.py | 2 +- tests/conftest.py | 26 ++- tests/test_cli.py | 1 + tests/test_resource_tests.py | 6 +- tests/utils/test_image_helper.py | 57 ++++--- 9 files changed, 208 insertions(+), 150 deletions(-) diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index e1b8c586..eadfd822 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -7,8 +7,8 @@ from bioimageio.core._prediction_pipeline import create_prediction_pipeline from bioimageio.core.common import AxisId, BatchSize -from bioimageio.core.utils import VERSION -from bioimageio.core.utils.image_helper import pad_to +from bioimageio.core.utils import VERSION, get_test_inputs +from bioimageio.core.utils.image_helper import resize_to from bioimageio.spec import InvalidDescr, ResourceDescr, build_description, dump_description, load_description from bioimageio.spec._internal.common_nodes import ResourceDescrBase from bioimageio.spec._internal.io_utils import load_array @@ -19,7 +19,7 @@ def test_model( - source: PermissiveFileSource, + source: Union[v0_5.ModelDescr, PermissiveFileSource], weight_format: Optional[WeightsFormat] = None, devices: Optional[List[str]] = None, decimal: int = 4, @@ -82,10 +82,9 @@ def load_description_and_test( _test_expected_resource_type(rd, expected_type) if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): - if isinstance(rd, v0_4.ModelDescr): - _test_model_inference_v0_4(rd, weight_format, devices, decimal) - else: - _test_model_inference_impl(rd, weight_format, devices) + _test_model_inference(rd, weight_format, devices, decimal) + if not isinstance(rd, v0_4.ModelDescr): + _test_model_inference_parametrized(rd, weight_format, devices) # TODO: add execution of jupyter notebooks # TODO: add more tests @@ -93,7 +92,7 @@ def load_description_and_test( return rd -def _test_model_inference_v0_4( +def _test_model_inference( model: Union[v0_4.ModelDescr, v0_5.ModelDescr], weight_format: Optional[WeightsFormat], devices: Optional[List[str]], @@ -107,11 +106,11 @@ def _test_model_inference_v0_4( expected = [xr.DataArray(load_array(src), dims=d.axes) for src, d in zip(model.test_outputs, model.outputs)] else: inputs = [ - xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(a.id for a in d.axes)) + xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(str(a.id) for a in d.axes)) for d in model.inputs ] expected = [ - xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(a.id for a in d.axes)) + xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(str(a.id) for a in d.axes)) for d in model.outputs ] @@ -152,7 +151,7 @@ def _test_model_inference_v0_4( ) -def _test_model_inference_impl( +def _test_model_inference_parametrized( model: v0_5.ModelDescr, weight_format: Optional[WeightsFormat], devices: Optional[List[str]], @@ -162,10 +161,7 @@ def _test_model_inference_impl( return try: - test_inputs = [ - xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(a.id for a in d.axes)) - for d in model.inputs - ] + test_inputs = get_test_inputs(model) def generate_test_cases(): tested: Set[str] = set() @@ -178,7 +174,7 @@ def generate_test_cases(): tested.add(hashable_target_size) resized_test_inputs = [ - pad_to(t, target_sizes[t_descr.id]) for t, t_descr in zip(test_inputs, model.inputs) + resize_to(t, target_sizes[t_descr.id]) for t, t_descr in zip(test_inputs, model.inputs) ] expected_output_shapes = [target_sizes[t_descr.id] for t_descr in model.outputs] yield n, batch_size, resized_test_inputs, expected_output_shapes @@ -203,8 +199,7 @@ def generate_test_cases(): model.validation_summary.add_detail( ValidationDetail( - name="Reproduce test outputs from test inputs with batch_size:" - + f" {batch_size} and size parameter n: {n}", + name="Run inference for inputs with batch_size:" + f" {batch_size} and size parameter n: {n}", status="passed" if error is None else "failed", errors=( [] @@ -224,7 +219,7 @@ def generate_test_cases(): tb = traceback.format_tb(e.__traceback__) model.validation_summary.add_detail( ValidationDetail( - name="Reproduce test outputs from test inputs", + name="Run inference for parametrized inputs", status="failed", errors=[ ErrorEntry( diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 915d1ca0..f6684169 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Literal import xarray as xr @@ -10,6 +10,14 @@ TensorId = v0_5.TensorId AxisId = v0_5.AxisId + + +@dataclass +class Axis: + id: AxisId + type: Literal["batch", "channel", "index", "space", "time"] + + BatchSize = int Tensor = xr.DataArray diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py index ad41789f..a0514a02 100644 --- a/bioimageio/core/utils/_digest_spec.py +++ b/bioimageio/core/utils/_digest_spec.py @@ -10,11 +10,11 @@ def get_test_inputs(model: AnyModelDescr) -> List[xr.DataArray]: if isinstance(model, v0_4.ModelDescr): return [xr.DataArray(load_array(tt), dims=tuple(d.axes)) for d, tt in zip(model.inputs, model.test_inputs)] else: - return [xr.DataArray(load_array(d.test_tensor), dims=tuple(a.id for a in d.axes)) for d in model.inputs] + return [xr.DataArray(load_array(d.test_tensor), dims=tuple(str(a.id) for a in d.axes)) for d in model.inputs] def get_test_outputs(model: AnyModelDescr) -> List[xr.DataArray]: if isinstance(model, v0_4.ModelDescr): return [xr.DataArray(load_array(tt), dims=tuple(d.axes)) for d, tt in zip(model.outputs, model.test_outputs)] else: - return [xr.DataArray(load_array(d.test_tensor), dims=tuple(a.id for a in d.axes)) for d in model.outputs] + return [xr.DataArray(load_array(d.test_tensor), dims=tuple(str(a.id) for a in d.axes)) for d in model.outputs] diff --git a/bioimageio/core/utils/image_helper.py b/bioimageio/core/utils/image_helper.py index 80303260..a625a6cf 100644 --- a/bioimageio/core/utils/image_helper.py +++ b/bioimageio/core/utils/image_helper.py @@ -1,15 +1,13 @@ -# TODO: update - -from copy import deepcopy +import warnings from pathlib import Path -from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Union import imageio -import numpy as np import xarray as xr from numpy.typing import NDArray from typing_extensions import assert_never +from bioimageio.core.common import Axis from bioimageio.spec.model import v0_4 from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr04 from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensorDescr04 @@ -19,19 +17,18 @@ BatchAxis, ChannelAxis, Identifier, - InputAxis, InputTensorDescr, OutputTensorDescr, SpaceInputAxis, convert_axes, ) -from bioimageio.spec.utils import load_array, save_array +from bioimageio.spec.utils import load_array InputTensor = Union[InputTensorDescr04, InputTensorDescr] OutputTensor = Union[OutputTensorDescr04, OutputTensorDescr] -def interprete_array( +def interprete_array_with_desired_axes( nd_array: NDArray[Any], desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]], ) -> xr.DataArray: @@ -40,19 +37,29 @@ def interprete_array( else: desired_space_axes = [a for a in desired_axes if a.type == "space"] + return interprete_array(nd_array, len(desired_space_axes)) + + +def interprete_array( + nd_array: NDArray[Any], + n_expected_space_axes: Optional[int] = None, +) -> xr.DataArray: + ndim = nd_array.ndim - if ndim == 2 and len(desired_space_axes) >= 2: + if ndim == 2 and (n_expected_space_axes is None or n_expected_space_axes >= 2): current_axes = ( SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[0]), SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[1]), ) - elif ndim == 3 and len(desired_space_axes) == 2: + elif ndim == 3 and ( + (n_expected_space_axes is None and any(s <= 3 for s in nd_array.shape)) or n_expected_space_axes == 2 + ): current_axes = ( ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[0])]), SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[1]), SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[2]), ) - elif ndim == 3 and len(desired_space_axes) == 3: + elif ndim == 3 and (n_expected_space_axes is None or n_expected_space_axes == 3): current_axes = ( SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[0]), SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[1]), @@ -74,61 +81,60 @@ def interprete_array( SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[4]), ) else: - raise ValueError(f"Could not guess a mapping of {nd_array.shape} to {desired_axes}") + raise ValueError( + f"Could not guess an axis mapping for {nd_array.shape} with {n_expected_space_axes} expected space axes" + ) current_axes_ids = tuple(current_axes) if isinstance(current_axes, str) else tuple(a.id for a in current_axes) return xr.DataArray(nd_array, dims=current_axes_ids) -def transpose_array( - arary: xr.DataArray, - desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]], - current_axes: Optional[Union[v0_4.AxesStr, Sequence[AnyAxis]]] = None, +def axis_descr_to_ids(axes: Union[v0_4.AxesStr, Sequence[AnyAxis]]) -> Tuple[AxisId, ...]: + if isinstance(axes, str): + return tuple(map(AxisId, axes)) + else: + return tuple(a.id for a in axes) + + +def transpose_tensor( + tensor: xr.DataArray, + axes: Sequence[AxisId], ) -> xr.DataArray: - """Transpose an image to match desired axes. + """Transpose `array` to `axes` order. Args: - array: the input array - desired_axes: the desired image axes - current_axes: the axes of the input image + tensor: the input array + axes: the desired array axes """ - desired_axes_ids = ( - tuple(map(AxisId, desired_axes)) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes) - ) # expand the missing image axes - missing_axes = tuple(set(desired_axes_ids) - set(map(AxisId, array.dims))) - array = array.expand_dims(dim=missing_axes) + current_axes = tuple(AxisId(str(d)) for d in tensor.dims) + missing_axes = tuple(str(a) for a in axes if a not in current_axes) + tensor = tensor.expand_dims(missing_axes) # transpose to the correct axis order - return arraytensor.transpose(*tuple(desired_axes_ids)) + return tensor.transpose(*axes) -def convert_axes_for_known_shape(axes: v0_4.AxesStr, shape: Sequence[int]): +def convert_v0_4_axes_for_known_shape(axes: v0_4.AxesStr, shape: Sequence[int]): return convert_axes(axes, shape=shape, tensor_type="input", halo=None, size_refs={}) def load_tensor( path: Path, - desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]], - current_axes: Optional[Union[v0_4.AxesStr, Sequence[AnyAxis]]] = None, + axes: Optional[Sequence[Axis]] = None, ) -> xr.DataArray: ext = path.suffix if ext == ".npy": - im = load_array(path) + array = load_array(path) else: - guess_axes = current_axes or desired_axes - if isinstance(guess_axes, str): - is_volume = "z" in guess_axes or "t" in guess_axes - else: - is_volume = len([a for a in guess_axes if a.type in ("time", "space")]) > 2 - - im = imageio.volread(path) if is_volume else imageio.imread(path) - im = transpose_array(im, desired_axes=desired_axes, current_axes=current_axes) + is_volume = True if axes is None else sum(a.type != "channel" for a in axes) > 2 + array = imageio.volread(path) if is_volume else imageio.imread(path) - return xr.DataArray( - im, dims=tuple(desired_axes) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes) - ) + if axes is None: + return interprete_array(array) + else: + return xr.DataArray(array, dims=tuple(a.id for a in axes)) def pad( @@ -136,7 +142,83 @@ def pad( pad_width: Mapping[AxisId, Union[int, Tuple[int, int]]], mode: Literal["edge", "reflect", "symmetric"] = "symmetric", ): - return tensor.pad(pad_width=pad_width, mode=mode) + return tensor.pad(pad_width={str(k): v for k, v in pad_width.items()}, mode=mode) + + +def resize_to( + tensor: xr.DataArray, + sizes: Mapping[AxisId, int], + *, + pad_where: Union[ + Literal["before", "center", "after"], Mapping[AxisId, Literal["before", "center", "after"]] + ] = "center", + crop_where: Union[ + Literal["before", "center", "after"], Mapping[AxisId, Literal["before", "center", "after"]] + ] = "center", + pad_mode: Literal["edge", "reflect", "symmetric"] = "symmetric", +): + """crop and pad `tensor` to match `sizes`""" + crop_to_sizes: Dict[AxisId, int] = {} + pad_to_sizes: Dict[AxisId, int] = {} + new_axes = dict(sizes) + for a, s_is in tensor.sizes.items(): + a = AxisId(str(a)) + _ = new_axes.pop(a, None) + if a not in sizes or sizes[a] == s_is: + pass + elif s_is < sizes[a]: + crop_to_sizes[a] = sizes[a] + else: + pad_to_sizes[a] = sizes[a] + + if crop_to_sizes: + tensor = crop_to(tensor, crop_to_sizes, crop_where=crop_where) + + if pad_to_sizes: + tensor = pad_to(tensor, pad_to_sizes, pad_where=pad_where, mode=pad_mode) + + if new_axes: + tensor = tensor.expand_dims({str(k): v for k, v in new_axes}) + + return tensor + + +def crop_to( + tensor: xr.DataArray, + sizes: Mapping[AxisId, int], + crop_where: Union[ + Literal["before", "center", "after"], Mapping[AxisId, Literal["before", "center", "after"]] + ] = "center", +): + """crop `tensor` to match `sizes`""" + axes = [AxisId(str(a)) for a in tensor.dims] + if crop_where in ("before", "center", "after"): + crop_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = {a: crop_where for a in axes} + else: + crop_axis_where = crop_where + + slices: Dict[AxisId, slice] = {} + + for a, s_is in tensor.sizes.items(): + a = AxisId(str(a)) + if a not in sizes or sizes[a] == s_is: + pass + elif sizes[a] > s_is: + warnings.warn(f"Cannot crop axis {a} of size {s_is} to larger size {sizes[a]}") + elif a not in crop_axis_where: + raise ValueError(f"Don't know where to crop axis {a}, `crop_where`={crop_where}") + else: + crop_this_axis_where = crop_axis_where[a] + if crop_this_axis_where == "before": + slices[a] = slice(s_is - sizes[a], s_is) + elif crop_this_axis_where == "after": + slices[a] = slice(0, sizes[a]) + elif crop_this_axis_where == "center": + slices[a] = slice(start := (s_is - sizes[a]) // 2, sizes[a] + start) + else: + assert_never(crop_this_axis_where) + + return tensor.isel({str(a): s for a, s in slices.items()}) def pad_to( @@ -147,11 +229,10 @@ def pad_to( ] = "center", mode: Literal["edge", "reflect", "symmetric"] = "symmetric", ): - """pad `tensor` to match `shape`""" - if isinstance(pad_where, str): - pad_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { - AxisId(str(a)): pad_where for a in tensor.dims - } + """pad `tensor` to match `sizes`""" + axes = [AxisId(str(a)) for a in tensor.dims] + if pad_where in ("before", "center", "after"): + pad_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = {a: pad_where for a in axes} else: pad_axis_where = pad_where @@ -161,7 +242,8 @@ def pad_to( if a not in sizes or sizes[a] == s_is: pad_width[a] = 0 elif s_is < sizes[a]: - raise ValueError(f"Cannot pad axis {a} of size {s_is} to smaller size {sizes[a]}") + pad_width[a] = 0 + warnings.warn(f"Cannot pad axis {a} of size {s_is} to smaller size {sizes[a]}") elif a not in pad_axis_where: raise ValueError(f"Don't know where to pad axis {a}, `pad_where`={pad_where}") else: @@ -177,44 +259,3 @@ def pad_to( assert_never(pad_this_axis_where) return pad(tensor, pad_width, mode) - - -def pad_old(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, Dict[str, slice]]: - assert image.ndim == len(axes), f"{image.ndim}, {len(axes)}" - - padding_ = deepcopy(padding) - mode = padding_.pop("mode", "dynamic") - assert mode in ("dynamic", "fixed") - - is_volume = "z" in axes - if is_volume: - assert len(padding_) == 3 - else: - assert len(padding_) == 2 - - if isinstance(pad_right, bool): - pad_right = len(axes) * [pad_right] - - pad_width: Sequence[Tuple[int, int]] = [] - crop = {} - for ax, dlen, pr in zip(axes, image.shape, pad_right): - if ax in "zyx": - pad_to = padding_[ax] - - if mode == "dynamic": - r = dlen % pad_to - pwidth = 0 if r == 0 else (pad_to - r) - else: - if pad_to < dlen: - msg = f"Padding for axis {ax} failed; pad shape {pad_to} is smaller than the image shape {dlen}." - raise RuntimeError(msg) - pwidth = pad_to - dlen - - pad_width.append([0, pwidth] if pr else [pwidth, 0]) - crop[ax] = slice(0, dlen) if pr else slice(pwidth, None) - else: - pad_width.append([0, 0]) - crop[ax] = slice(None) - - image = np.pad(image, pad_width, mode="symmetric") - return image, crop diff --git a/scripts/show_diff.py b/scripts/show_diff.py index 4a5d2223..affbe685 100644 --- a/scripts/show_diff.py +++ b/scripts/show_diff.py @@ -7,7 +7,7 @@ from bioimageio.core import load_description, save_bioimageio_yaml_only if __name__ == "__main__": - rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/pydantic_axes/example_descriptions/models/unet2d_nuclei_broad/rdf_v0_4_9.yaml" + rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/v0_4_9.bioimageio.yaml" local_source = Path(pooch.retrieve(rdf_source, None)) # type: ignore model_as_is = load_description(rdf_source, format_version="discover") diff --git a/tests/conftest.py b/tests/conftest.py index 87ae4fe1..85c0b722 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,15 +34,15 @@ MODEL_SOURCES = { "unet2d_keras": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_keras_tf/rdf_v0_4.yaml" + "unet2d_keras_tf/v0_4.bioimageio.yaml" ), "unet2d_keras_tf2": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_keras_tf2/rdf_v0_4.yaml" + "unet2d_keras_tf2/v0_4.bioimageio.yaml" ), "unet2d_nuclei_broad_model": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_nuclei_broad/rdf.yaml" + "unet2d_nuclei_broad/bioimageio.yaml" ), "unet2d_expand_output_shape": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" @@ -50,22 +50,22 @@ ), "unet2d_fixed_shape": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_fixed_shape/rdf_v0_4.yaml" + "unet2d_fixed_shape/v0_4.bioimageio.yaml" ), "unet2d_multi_tensor": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_multi_tensor/rdf_v0_4.yaml" + "unet2d_multi_tensor/v0_4.bioimageio.yaml" ), "unet2d_diff_output_shape": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_diff_output_shape/rdf_v0_4.yaml" + "unet2d_diff_output_shape/v0_4.bioimageio.yaml" ), "hpa_densenet": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/hpa-densenet/rdf.yaml" ), "stardist": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models" - "/stardist_example_model/rdf_v0_4.yaml" + "/stardist_example_model/v0_4.bioimageio.yaml" ), "stardist_wrong_shape": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" @@ -77,7 +77,7 @@ ), "shape_change": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "upsample_test_model/v0_4_bioimageio.yaml" + "upsample_test_model/v0_4.bioimageio.yaml" ), } @@ -122,8 +122,6 @@ if tf_major_version == 1: load_model_packages |= set(KERAS_TF1_MODELS) load_model_packages |= set(TENSORFLOW1_MODELS) - load_model_packages.add("stardist_wrong_shape") - load_model_packages.add("stardist_wrong_shape2") elif tf_major_version == 2: load_model_packages |= set(KERAS_TF2_MODELS) load_model_packages |= set(TENSORFLOW2_MODELS) @@ -246,14 +244,14 @@ def shape_change_model(request: FixtureRequest, model_packages: MappingProxyType # written as model group to automatically skip on missing tensorflow 1 @fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"]) -def stardist_wrong_shape(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): - return model_packages[request.param] +def stardist_wrong_shape(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing tensorflow 1 @fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape2"]) -def stardist_wrong_shape2(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): - return model_packages[request.param] +def stardist_wrong_shape2(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing tensorflow 1 diff --git a/tests/test_cli.py b/tests/test_cli.py index 601882b3..944d5f5d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -19,6 +19,7 @@ def run_subprocess(commands: Sequence[str], **kwargs: Any) -> "subprocess.Comple ], ) def test_cli(args: List[str], unet2d_nuclei_broad_model: FilePath): + assert unet2d_nuclei_broad_model.exists() resolved_args = [str(unet2d_nuclei_broad_model) if arg == "unet2d_nuclei_broad_model" else arg for arg in args] ret = run_subprocess(["bioimageio", *resolved_args]) assert ret.returncode == 0, ret.stdout diff --git a/tests/test_resource_tests.py b/tests/test_resource_tests.py index 36bdcc5c..810f8256 100644 --- a/tests/test_resource_tests.py +++ b/tests/test_resource_tests.py @@ -1,9 +1,7 @@ from pathlib import Path -from bioimageio.spec import InvalidDescr - -def test_error_for_wrong_shape(stardist_wrong_shape: Path): +def test_error_for_wrong_shape(stardist_wrong_shape: str): from bioimageio.core._resource_tests import test_model summary = test_model(stardist_wrong_shape) @@ -15,7 +13,7 @@ def test_error_for_wrong_shape(stardist_wrong_shape: Path): assert summary.details[0].errors[0].msg == expected_error_message -def test_error_for_wrong_shape2(stardist_wrong_shape2: Path): +def test_error_for_wrong_shape2(stardist_wrong_shape2: str): from bioimageio.core._resource_tests import test_model summary = test_model(stardist_wrong_shape2) diff --git a/tests/utils/test_image_helper.py b/tests/utils/test_image_helper.py index 6e0e9c08..a0186c78 100644 --- a/tests/utils/test_image_helper.py +++ b/tests/utils/test_image_helper.py @@ -1,29 +1,46 @@ +from typing import Sequence + import numpy as np +import pytest +import xarray as xr + +from bioimageio.core.common import AxisId +from bioimageio.core.utils.image_helper import interprete_array + + +@pytest.mark.parametrize( + "axes", [[AxisId(a) for a in axes] for axes in ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"]] +) +def test_transpose_tensor_2d(axes: Sequence[AxisId]): + from bioimageio.core.utils.image_helper import transpose_tensor + + tensor = interprete_array(np.random.rand(256, 256), len(axes)) + transposed = transpose_tensor(tensor, axes) + assert transposed.ndim == len(axes) + +@pytest.mark.parametrize( + "axes", [[AxisId(a) for a in axes] for axes in ["zyx", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"]] +) +def test_transpose_tensor_3d(axes: Sequence[AxisId]): + from bioimageio.core.utils.image_helper import transpose_tensor -def test_transform_input_image(): - from bioimageio.core.utils.image_helper import transpose_array + tensor = interprete_array(np.random.rand(64, 64, 64), len(axes)) + transposed = transpose_tensor(tensor, axes) + assert transposed.ndim == len(axes) - ax_list = ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] - im = np.random.rand(256, 256) - for axes in ax_list: - inp = transpose_array(im, axes) - assert inp.ndim == len(axes) - ax_list = ["zyx", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] - vol = np.random.rand(64, 64, 64) - for axes in ax_list: - inp = transpose_array(vol, axes) - assert inp.ndim == len(axes) +def test_crop_and_pad(): + tensor = xr.DataArray(np.random.rand(64)) -def test_transform_output_tensor(): - from bioimageio.core.utils.image_helper import transform_output_tensor +# def test_transform_output_tensor(): +# from bioimageio.core.utils.image_helper import transform_output_tensor - tensor = np.random.rand(1, 3, 64, 64, 64) - tensor_axes = "bczyx" +# tensor = np.random.rand(1, 3, 64, 64, 64) +# tensor_axes = "bczyx" - out_ax_list = ["bczyx", "cyx", "xyc", "byxc", "zyx", "xyz"] - for out_axes in out_ax_list: - out = transform_output_tensor(tensor, tensor_axes, out_axes) - assert out.ndim == len(out_axes) +# out_ax_list = ["bczyx", "cyx", "xyc", "byxc", "zyx", "xyz"] +# for out_axes in out_ax_list: +# out = transform_output_tensor(tensor, tensor_axes, out_axes) +# assert out.ndim == len(out_axes) From ed2bbcb368dcaa647494bd4ace76769061dfcae2 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 12:12:45 +0100 Subject: [PATCH 114/244] update conda recipe --- conda-recipe/meta.yaml | 10 +++------- setup.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 482aaf33..a4c7aca6 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -18,12 +18,10 @@ build: requirements: host: - - python >=3.7,<3.10 + - python >=3.8,<3.13 - pip run: - - python >=3.7,<3.10 - - tqdm - - typer + - python >=3.8,<3.13 {% for dep in setup_py_data['install_requires'] %} - {{ dep.lower() }} {% endfor %} @@ -47,11 +45,10 @@ requirements: test: imports: - bioimageio.core - - bioimageio.core.build_spec source_files: - tests requires: - {% for dep in setup_py_data['extras_require']['test'] %} + {% for dep in setup_py_data['extras_require']['dev'] %} - {{ dep.lower() }} {% endfor %} commands: @@ -64,6 +61,5 @@ about: license_family: MIT license_file: LICENSE summary: 'Tools for running BioimageIO compliant neural networks in Python.' - doc_url: https://github.com/bioimage-io/core-bioimage-io-python dev_url: https://github.com/bioimage-io/core-bioimage-io-python diff --git a/setup.py b/setup.py index adbf4305..0024c735 100644 --- a/setup.py +++ b/setup.py @@ -42,15 +42,18 @@ "pytorch": ["torch>=1.6", "torchvision"], "tensorflow": ["tensorflow"], "onnx": ["onnxruntime"], - "test": [ - "bioimageio.core[onnx]", - "bioimageio.core[pytorch]", - "black[jupyter]", + "dev": [ + "black", "crick", - "pytest-xdist[psutil]", # parallel pytest with 'pytest -n auto' + "filelock", + "onnxruntime", + "pre-commit", + "psutil", # parallel pytest with 'pytest -n auto' + "pytest-xdist", # parallel pytest "pytest", + "torch>=1.6", + "torchvision", ], - "dev": ["pre-commit", "bioimageio.core[test]"], }, project_urls={ "Bug Reports": "https://github.com/bioimage-io/core-bioimage-io-python/issues", From 0116a5db1593909d818fa74ed5052bf38c01e85f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 12:27:18 +0100 Subject: [PATCH 115/244] use common.Tensor --- bioimageio/core/utils/image_helper.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/bioimageio/core/utils/image_helper.py b/bioimageio/core/utils/image_helper.py index a625a6cf..48e989af 100644 --- a/bioimageio/core/utils/image_helper.py +++ b/bioimageio/core/utils/image_helper.py @@ -3,11 +3,10 @@ from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Union import imageio -import xarray as xr from numpy.typing import NDArray from typing_extensions import assert_never -from bioimageio.core.common import Axis +from bioimageio.core.common import Axis, Tensor from bioimageio.spec.model import v0_4 from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr04 from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensorDescr04 @@ -31,7 +30,7 @@ def interprete_array_with_desired_axes( nd_array: NDArray[Any], desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]], -) -> xr.DataArray: +) -> Tensor: if isinstance(desired_axes, str): desired_space_axes = [a for a in desired_axes if a in "zyx"] else: @@ -43,7 +42,7 @@ def interprete_array_with_desired_axes( def interprete_array( nd_array: NDArray[Any], n_expected_space_axes: Optional[int] = None, -) -> xr.DataArray: +) -> Tensor: ndim = nd_array.ndim if ndim == 2 and (n_expected_space_axes is None or n_expected_space_axes >= 2): @@ -86,7 +85,7 @@ def interprete_array( ) current_axes_ids = tuple(current_axes) if isinstance(current_axes, str) else tuple(a.id for a in current_axes) - return xr.DataArray(nd_array, dims=current_axes_ids) + return Tensor(nd_array, dims=current_axes_ids) def axis_descr_to_ids(axes: Union[v0_4.AxesStr, Sequence[AnyAxis]]) -> Tuple[AxisId, ...]: @@ -97,9 +96,9 @@ def axis_descr_to_ids(axes: Union[v0_4.AxesStr, Sequence[AnyAxis]]) -> Tuple[Axi def transpose_tensor( - tensor: xr.DataArray, + tensor: Tensor, axes: Sequence[AxisId], -) -> xr.DataArray: +) -> Tensor: """Transpose `array` to `axes` order. Args: @@ -122,7 +121,7 @@ def convert_v0_4_axes_for_known_shape(axes: v0_4.AxesStr, shape: Sequence[int]): def load_tensor( path: Path, axes: Optional[Sequence[Axis]] = None, -) -> xr.DataArray: +) -> Tensor: ext = path.suffix if ext == ".npy": @@ -134,11 +133,11 @@ def load_tensor( if axes is None: return interprete_array(array) else: - return xr.DataArray(array, dims=tuple(a.id for a in axes)) + return Tensor(array, dims=tuple(a.id for a in axes)) def pad( - tensor: xr.DataArray, + tensor: Tensor, pad_width: Mapping[AxisId, Union[int, Tuple[int, int]]], mode: Literal["edge", "reflect", "symmetric"] = "symmetric", ): @@ -146,7 +145,7 @@ def pad( def resize_to( - tensor: xr.DataArray, + tensor: Tensor, sizes: Mapping[AxisId, int], *, pad_where: Union[ @@ -184,7 +183,7 @@ def resize_to( def crop_to( - tensor: xr.DataArray, + tensor: Tensor, sizes: Mapping[AxisId, int], crop_where: Union[ Literal["before", "center", "after"], Mapping[AxisId, Literal["before", "center", "after"]] @@ -222,7 +221,7 @@ def crop_to( def pad_to( - tensor: xr.DataArray, + tensor: Tensor, sizes: Mapping[AxisId, int], pad_where: Union[ Literal["before", "center", "after"], Mapping[AxisId, Literal["before", "center", "after"]] From 70be1b64d915a0b308df0cf0a565253809f69dd3 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 12:27:42 +0100 Subject: [PATCH 116/244] update model_packages fixture for pytest-xdist --- tests/conftest.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 85c0b722..61a66b30 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,9 +5,10 @@ from types import MappingProxyType from typing import List, Set +from filelock import FileLock from loguru import logger from pydantic import FilePath -from pytest import FixtureRequest, fixture +from pytest import FixtureRequest, TempPathFactory, fixture from bioimageio.spec import __version__ as bioimageio_spec_version from bioimageio.spec._package import save_bioimageio_package @@ -128,8 +129,30 @@ @fixture(scope="session") -def model_packages() -> MappingProxyType[str, FilePath]: - return MappingProxyType({name: save_bioimageio_package(MODEL_SOURCES[name]) for name in load_model_packages}) +def model_packages(tmp_path_factory: TempPathFactory, worker_id: str) -> MappingProxyType[str, FilePath]: + """prepare model packages (only run with one worker) + see https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once + """ + root_tmp_dir = tmp_path_factory.getbasetemp().parent + + packages = MappingProxyType({name: (root_tmp_dir / name).with_suffix(".zip") for name in load_model_packages}) + + def generate_packages(): + for name in load_model_packages: + actual_out = save_bioimageio_package(MODEL_SOURCES[name], output_path=packages[name]) + assert actual_out == packages[name] + + info_path = root_tmp_dir / "packages_created" + if worker_id == "master": + # no workers + generate_packages() + else: + with FileLock(info_path.with_suffix(".lock")): + if not info_path.is_file(): + generate_packages() + _ = info_path.write_text("") + + return packages @fixture(scope="session") From 3b76d89bb4ffe2f67eaec21d577520493fbe315a Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 13:46:53 +0100 Subject: [PATCH 117/244] black line-length 88 --- bioimageio/core/__init__.py | 12 +- bioimageio/core/__main__.py | 51 ++++-- bioimageio/core/_prediction_pipeline.py | 40 +++-- bioimageio/core/_resource_tests.py | 120 ++++++++++--- .../model_adapters/_keras_model_adapter.py | 32 +++- .../core/model_adapters/_model_adapter.py | 25 ++- .../model_adapters/_onnx_model_adapter.py | 30 +++- .../model_adapters/_pytorch_model_adapter.py | 42 ++++- .../_tensorflow_model_adapter.py | 62 +++++-- .../_torchscript_model_adapter.py | 28 ++- bioimageio/core/proc_ops.py | 103 ++++++++--- bioimageio/core/proc_setup.py | 28 ++- bioimageio/core/stat_calculators.py | 161 +++++++++++++----- bioimageio/core/utils/_digest_spec.py | 24 ++- bioimageio/core/utils/_import_callable.py | 13 +- bioimageio/core/utils/image_helper.py | 67 ++++++-- .../weight_converter/keras/_tensorflow.py | 31 +++- .../core/weight_converter/torch/_onnx.py | 21 ++- .../weight_converter/torch/_torchscript.py | 49 ++++-- .../core/weight_converter/torch/_utils.py | 4 +- pyproject.toml | 6 +- scripts/setup_dev_env.py | 8 +- scripts/show_diff.py | 4 +- tests/conftest.py | 115 ++++++++++--- tests/test_bioimageio_spec_version.py | 4 +- tests/test_cli.py | 36 +++- tests/test_prediction.py | 42 +++-- tests/test_prediction_pipeline.py | 4 +- ...t_prediction_pipeline_device_management.py | 24 ++- tests/test_proc_ops.py | 44 ++++- tests/test_stat_measures.py | 19 ++- tests/utils/test_image_helper.py | 12 +- .../weight_converter/keras/test_tensorflow.py | 24 ++- tests/weight_converter/torch/test_onnx.py | 4 +- .../torch/test_torchscript.py | 12 +- 35 files changed, 1010 insertions(+), 291 deletions(-) diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 5cb579b5..9bd8324d 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -7,14 +7,20 @@ from bioimageio.spec import build_description as build_description from bioimageio.spec import dump_description as dump_description from bioimageio.spec import load_description as load_description -from bioimageio.spec import load_description_and_validate_format_only as load_description_and_validate_format_only +from bioimageio.spec import ( + load_description_and_validate_format_only as load_description_and_validate_format_only, +) from bioimageio.spec import save_bioimageio_package as save_bioimageio_package -from bioimageio.spec import save_bioimageio_package_as_folder as save_bioimageio_package_as_folder +from bioimageio.spec import ( + save_bioimageio_package_as_folder as save_bioimageio_package_as_folder, +) from bioimageio.spec import save_bioimageio_yaml_only as save_bioimageio_yaml_only from bioimageio.spec import validate_format as validate_format from ._prediction_pipeline import PredictionPipeline as PredictionPipeline -from ._prediction_pipeline import create_prediction_pipeline as create_prediction_pipeline +from ._prediction_pipeline import ( + create_prediction_pipeline as create_prediction_pipeline, +) from ._resource_tests import load_description_and_test as load_description_and_test from ._resource_tests import test_description as test_description from ._resource_tests import test_model as test_model diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index 237bd782..4d769cd4 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -27,7 +27,9 @@ # prevent rewrapping with \b\n: https://click.palletsprojects.com/en/7.x/documentation/#preventing-rewrapping app = typer.Typer( help="\b\n" + help_version, - context_settings={"help_option_names": ["-h", "--help", "--version"]}, # make --version display help with version + context_settings={ + "help_option_names": ["-h", "--help", "--version"] + }, # make --version display help with version ) # https://typer.tiangolo.com/ @@ -56,7 +58,9 @@ class WeightsFormatEnum(enum.Enum): @app.command() def package( source: Annotated[str, typer.Argument(help="path or url to a bioimageio RDF")], - path: Annotated[Path, typer.Argument(help="Save package as")] = Path("bioimageio-package.zip"), + path: Annotated[Path, typer.Argument(help="Save package as")] = Path( + "bioimageio-package.zip" + ), weights_priority_order: Annotated[ Optional[List[WeightsFormatEnum]], typer.Option( @@ -70,22 +74,35 @@ def package( ] = None, ): # typer bug: typer returns empty tuple instead of None if weights_order_priority is not given - weights_priority_order = weights_priority_order or None # TODO: check if this is still the case + weights_priority_order = ( + weights_priority_order or None + ) # TODO: check if this is still the case _ = save_bioimageio_package( source, output_path=path, - weights_priority_order=None if weights_priority_order is None else [wpo.name for wpo in weights_priority_order], + weights_priority_order=( + None + if weights_priority_order is None + else [wpo.name for wpo in weights_priority_order] + ), ) @app.command() def test_model( model_rdf: Annotated[ - str, typer.Argument(help="Path or URL to the model resource description file (rdf.yaml) or zipped model.") + str, + typer.Argument( + help="Path or URL to the model resource description file (rdf.yaml) or zipped model." + ), ], - weight_format: Annotated[Optional[WeightsFormatEnum], typer.Option(help="The weight format to use.")] = None, - devices: Annotated[Optional[List[str]], typer.Option(help="Devices for running the model.")] = None, + weight_format: Annotated[ + Optional[WeightsFormatEnum], typer.Option(help="The weight format to use.") + ] = None, + devices: Annotated[ + Optional[List[str]], typer.Option(help="Devices for running the model.") + ] = None, decimal: Annotated[int, typer.Option(help="The test precision.")] = 4, ): # this is a weird typer bug: default devices are empty tuple although they should be None @@ -108,22 +125,32 @@ def test_model( @app.command() def test_resource( rdf: Annotated[ - str, typer.Argument(help="Path or URL to the resource description file (rdf.yaml) or zipped resource package.") + str, + typer.Argument( + help="Path or URL to the resource description file (rdf.yaml) or zipped resource package." + ), ], weight_format: Annotated[ - Optional[WeightsFormatEnum], typer.Option(help="(for model only) The weight format to use.") + Optional[WeightsFormatEnum], + typer.Option(help="(for model only) The weight format to use."), ] = None, devices: Annotated[ - Optional[List[str]], typer.Option(help="(for model only) Devices for running the model.") + Optional[List[str]], + typer.Option(help="(for model only) Devices for running the model."), ] = None, - decimal: Annotated[int, typer.Option(help="(for model only) The test precision.")] = 4, + decimal: Annotated[ + int, typer.Option(help="(for model only) The test precision.") + ] = 4, ): # this is a weird typer bug: default devices are empty tuple although they should be None if devices is None or len(devices) == 0: devices = None summary = _test_description( - rdf, weight_format=None if weight_format is None else weight_format.value, devices=devices, decimal=decimal + rdf, + weight_format=None if weight_format is None else weight_format.value, + devices=devices, + decimal=decimal, ) print(summary.format()) sys.exit(0 if summary.status == "passed" else 1) diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 912aa9dd..b2cf998f 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -31,7 +31,9 @@ def __init__( ) -> None: super().__init__() if bioimageio_model.run_mode: - warnings.warn(f"Not yet implemented inference for run mode '{bioimageio_model.run_mode.name}'") + warnings.warn( + f"Not yet implemented inference for run mode '{bioimageio_model.run_mode.name}'" + ) self.name = name self._preprocessing = preprocessing @@ -45,7 +47,9 @@ def __init__( self._adapter: ModelAdapter = model - def __call__(self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray) -> List[xr.DataArray]: + def __call__( + self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray + ) -> List[xr.DataArray]: return self.forward(*input_tensors, **named_input_tensors) def __enter__(self): @@ -56,9 +60,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore self.unload() return False - def predict(self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray) -> List[xr.DataArray]: + def predict( + self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray + ) -> List[xr.DataArray]: """Predict input_tensor with the model without applying pre/postprocessing.""" - named_tensors = [named_input_tensors[str(k)] for k in self._input_ids[len(input_tensors) :]] + named_tensors = [ + named_input_tensors[str(k)] for k in self._input_ids[len(input_tensors) :] + ] return self._adapter.forward(*input_tensors, *named_tensors) def apply_preprocessing(self, sample: Sample) -> None: @@ -75,8 +83,12 @@ def forward_sample(self, input_sample: Sample) -> Sample: """Apply preprocessing, run prediction and apply postprocessing.""" self.apply_preprocessing(input_sample) - prediction_tensors = self.predict(**{str(k): v for k, v in input_sample.data.items()}) - prediction = Sample(data=dict(zip(self._output_ids, prediction_tensors)), stat=input_sample.stat) + prediction_tensors = self.predict( + **{str(k): v for k, v in input_sample.data.items()} + ) + prediction = Sample( + data=dict(zip(self._output_ids, prediction_tensors)), stat=input_sample.stat + ) self.apply_postprocessing(prediction) return prediction @@ -92,7 +104,9 @@ def forward_tensors( ) return self.forward_sample(input_sample).data - def forward(self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray) -> List[xr.DataArray]: + def forward( + self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray + ) -> List[xr.DataArray]: """Apply preprocessing, run prediction and apply postprocessing.""" named_outputs = self.forward_tensors(*input_tensors, **named_input_tensors) return [named_outputs[x] for x in self._output_ids] @@ -116,9 +130,13 @@ def create_prediction_pipeline( devices: Optional[Sequence[str]] = None, weight_format: Optional[WeightsFormat] = None, weights_format: Optional[WeightsFormat] = None, - dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[xr.DataArray]]] = tuple(), + dataset_for_initial_statistics: Iterable[ + Union[Sample, Sequence[xr.DataArray]] + ] = tuple(), keep_updating_initial_dataset_statistics: bool = False, - fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType({}), + fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType( + {} + ), model_adapter: Optional[ModelAdapter] = None, **deprecated_kwargs: Any, ) -> PredictionPipeline: @@ -133,7 +151,9 @@ def create_prediction_pipeline( weights_format = weight_format or weights_format del weight_format if deprecated_kwargs: - warnings.warn(f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}") + warnings.warn( + f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}" + ) model_adapter = model_adapter or create_model_adapter( model_description=bioimageio_model, diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index eadfd822..80f57372 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -9,13 +9,24 @@ from bioimageio.core.common import AxisId, BatchSize from bioimageio.core.utils import VERSION, get_test_inputs from bioimageio.core.utils.image_helper import resize_to -from bioimageio.spec import InvalidDescr, ResourceDescr, build_description, dump_description, load_description +from bioimageio.spec import ( + InvalidDescr, + ResourceDescr, + build_description, + dump_description, + load_description, +) from bioimageio.spec._internal.common_nodes import ResourceDescrBase from bioimageio.spec._internal.io_utils import load_array from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import WeightsFormat -from bioimageio.spec.summary import ErrorEntry, InstalledPackage, ValidationDetail, ValidationSummary +from bioimageio.spec.summary import ( + ErrorEntry, + InstalledPackage, + ValidationDetail, + ValidationSummary, +) def test_model( @@ -26,7 +37,11 @@ def test_model( ) -> ValidationSummary: """Test model inference""" return test_description( - source, weight_format=weight_format, devices=devices, decimal=decimal, expected_type="model" + source, + weight_format=weight_format, + devices=devices, + decimal=decimal, + expected_type="model", ) @@ -66,7 +81,9 @@ def load_description_and_test( and format_version != "discover" and source.format_version != format_version ): - warnings.warn(f"deserializing source to ensure we validate and test using format {format_version}") + warnings.warn( + f"deserializing source to ensure we validate and test using format {format_version}" + ) source = dump_description(source) if isinstance(source, ResourceDescrBase): @@ -76,7 +93,9 @@ def load_description_and_test( else: rd = load_description(source, format_version=format_version) - rd.validation_summary.env.append(InstalledPackage(name="bioimageio.core", version=VERSION)) + rd.validation_summary.env.append( + InstalledPackage(name="bioimageio.core", version=VERSION) + ) if expected_type is not None: _test_expected_resource_type(rd, expected_type) @@ -102,15 +121,27 @@ def _test_model_inference( tb: List[str] = [] try: if isinstance(model, v0_4.ModelDescr): - inputs = [xr.DataArray(load_array(src), dims=d.axes) for src, d in zip(model.test_inputs, model.inputs)] - expected = [xr.DataArray(load_array(src), dims=d.axes) for src, d in zip(model.test_outputs, model.outputs)] + inputs = [ + xr.DataArray(load_array(src), dims=d.axes) + for src, d in zip(model.test_inputs, model.inputs) + ] + expected = [ + xr.DataArray(load_array(src), dims=d.axes) + for src, d in zip(model.test_outputs, model.outputs) + ] else: inputs = [ - xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(str(a.id) for a in d.axes)) + xr.DataArray( + load_array(d.test_tensor.download().path), + dims=tuple(str(a.id) for a in d.axes), + ) for d in model.inputs ] expected = [ - xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(str(a.id) for a in d.axes)) + xr.DataArray( + load_array(d.test_tensor.download().path), + dims=tuple(str(a.id) for a in d.axes), + ) for d in model.outputs ] @@ -120,13 +151,17 @@ def _test_model_inference( results = prediction_pipeline.forward(*inputs) if len(results) != len(expected): - error = (error or "") + (f"Expected {len(expected)} outputs, but got {len(results)}") + error = (error or "") + ( + f"Expected {len(expected)} outputs, but got {len(results)}" + ) else: for res, exp in zip(results, expected): try: np.testing.assert_array_almost_equal(res, exp, decimal=decimal) except AssertionError as e: - error = (error or "") + f"Output and expected output disagree:\n {e}" + error = ( + error or "" + ) + f"Output and expected output disagree:\n {e}" except Exception as e: error = str(e) tb = traceback.format_tb(e.__traceback__) @@ -140,7 +175,11 @@ def _test_model_inference( if error is None else [ ErrorEntry( - loc=("weights",) if weight_format is None else ("weights", weight_format), + loc=( + ("weights",) + if weight_format is None + else ("weights", weight_format) + ), msg=error, type="bioimageio.core", traceback=tb, @@ -155,9 +194,18 @@ def _test_model_inference_parametrized( model: v0_5.ModelDescr, weight_format: Optional[WeightsFormat], devices: Optional[List[str]], - test_cases: Sequence[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = ((0, 1), (1, 3), (2, 1), (3, 2)), + test_cases: Sequence[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = ( + (0, 1), + (1, 3), + (2, 1), + (3, 2), + ), ) -> None: - if not any(isinstance(a.size, v0_5.ParameterizedSize) for ipt in model.inputs for a in ipt.axes): + if not any( + isinstance(a.size, v0_5.ParameterizedSize) + for ipt in model.inputs + for a in ipt.axes + ): return try: @@ -174,9 +222,12 @@ def generate_test_cases(): tested.add(hashable_target_size) resized_test_inputs = [ - resize_to(t, target_sizes[t_descr.id]) for t, t_descr in zip(test_inputs, model.inputs) + resize_to(t, target_sizes[t_descr.id]) + for t, t_descr in zip(test_inputs, model.inputs) + ] + expected_output_shapes = [ + target_sizes[t_descr.id] for t_descr in model.outputs ] - expected_output_shapes = [target_sizes[t_descr.id] for t_descr in model.outputs] yield n, batch_size, resized_test_inputs, expected_output_shapes with create_prediction_pipeline( @@ -187,10 +238,16 @@ def generate_test_cases(): error: Optional[str] = None results = prediction_pipeline.forward(*inputs) if len(results) != len(exptected_output_shape): - error = (error or "") + (f"Expected {len(exptected_output_shape)} outputs, but got {len(results)}") + error = (error or "") + ( + f"Expected {len(exptected_output_shape)} outputs, but got {len(results)}" + ) else: for res, exp in zip(results, exptected_output_shape): - if diff := {a: s for a, s in res.sizes.items() if s != exp[AxisId(str(a))]}: + if diff := { + a: s + for a, s in res.sizes.items() + if s != exp[AxisId(str(a))] + }: error = ( (error or "") + f"(n={n}) Expected output shape {exp}," @@ -199,14 +256,19 @@ def generate_test_cases(): model.validation_summary.add_detail( ValidationDetail( - name="Run inference for inputs with batch_size:" + f" {batch_size} and size parameter n: {n}", + name="Run inference for inputs with batch_size:" + + f" {batch_size} and size parameter n: {n}", status="passed" if error is None else "failed", errors=( [] if error is None else [ ErrorEntry( - loc=("weights",) if weight_format is None else ("weights", weight_format), + loc=( + ("weights",) + if weight_format is None + else ("weights", weight_format) + ), msg=error, type="bioimageio.core", ) @@ -223,7 +285,11 @@ def generate_test_cases(): status="failed", errors=[ ErrorEntry( - loc=("weights",) if weight_format is None else ("weights", weight_format), + loc=( + ("weights",) + if weight_format is None + else ("weights", weight_format) + ), msg=error, type="bioimageio.core", traceback=tb, @@ -233,7 +299,9 @@ def generate_test_cases(): ) -def _test_expected_resource_type(rd: Union[InvalidDescr, ResourceDescr], expected_type: str): +def _test_expected_resource_type( + rd: Union[InvalidDescr, ResourceDescr], expected_type: str +): has_expected_type = rd.type == expected_type rd.validation_summary.details.append( ValidationDetail( @@ -242,7 +310,13 @@ def _test_expected_resource_type(rd: Union[InvalidDescr, ResourceDescr], expecte errors=( [] if has_expected_type - else [ErrorEntry(loc=("type",), type="type", msg=f"expected type {expected_type}, found {rd.type}")] + else [ + ErrorEntry( + loc=("type",), + type="type", + msg=f"expected type {expected_type}, found {rd.type}", + ) + ] ), ) ) diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py index e353df17..af429644 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -27,7 +27,10 @@ class KerasModelAdapter(ModelAdapter): def __init__( - self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, ) -> None: assert keras is not None super().__init__() @@ -41,12 +44,19 @@ def __init__( warnings.warn( f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." ) - elif (model_tf_version.major, model_tf_version.minor) != (tf_version.major, tf_version.minor): - warnings.warn(f"Model tensorflow version {model_tf_version} does not match {tf_version}.") + elif (model_tf_version.major, model_tf_version.minor) != ( + tf_version.major, + tf_version.minor, + ): + warnings.warn( + f"Model tensorflow version {model_tf_version} does not match {tf_version}." + ) # TODO keras device management if devices is not None: - warnings.warn(f"Device management is not implemented for keras yet, ignoring the devices {devices}") + warnings.warn( + f"Device management is not implemented for keras yet, ignoring the devices {devices}" + ) weight_path = download(model_description.weights.keras_hdf5.source).path @@ -54,8 +64,10 @@ def __init__( self._output_axes = [tuple(out.axes) for out in model_description.outputs] def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - _result: Union[Sequence[NDArray[Any]], NDArray[Any]] = ( # pyright: ignore[reportUnknownVariableType] - self._network.predict(*input_tensors) + _result: Union[Sequence[NDArray[Any]], NDArray[Any]] = ( + self._network.predict( # pyright: ignore[reportUnknownVariableType] + *input_tensors + ) ) if isinstance(_result, (tuple, list)): result: Sequence[NDArray[Any]] = _result @@ -63,7 +75,11 @@ def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: result = [_result] # type: ignore assert len(result) == len(self._output_axes) - return [xr.DataArray(r, dims=axes) for r, axes, in zip(result, self._output_axes)] + return [ + xr.DataArray(r, dims=axes) for r, axes, in zip(result, self._output_axes) + ] def unload(self) -> None: - warnings.warn("Device management is not implemented for keras yet, cannot unload model") + warnings.warn( + "Device management is not implemented for keras yet, cannot unload model" + ) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index dabaff5f..acedc122 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -59,29 +59,40 @@ def create( from ._pytorch_model_adapter import PytorchModelAdapter return PytorchModelAdapter( - outputs=model_description.outputs, weights=weights.pytorch_state_dict, devices=devices + outputs=model_description.outputs, + weights=weights.pytorch_state_dict, + devices=devices, ) except Exception as e: errors.append(e) - elif wf == "tensorflow_saved_model_bundle" and weights.tensorflow_saved_model_bundle is not None: + elif ( + wf == "tensorflow_saved_model_bundle" + and weights.tensorflow_saved_model_bundle is not None + ): try: from ._tensorflow_model_adapter import TensorflowModelAdapter - return TensorflowModelAdapter(model_description=model_description, devices=devices) + return TensorflowModelAdapter( + model_description=model_description, devices=devices + ) except Exception as e: errors.append(e) elif wf == "onnx" and weights.onnx is not None: try: from ._onnx_model_adapter import ONNXModelAdapter - return ONNXModelAdapter(model_description=model_description, devices=devices) + return ONNXModelAdapter( + model_description=model_description, devices=devices + ) except Exception as e: errors.append(e) elif wf == "torchscript" and weights.torchscript is not None: try: from ._torchscript_model_adapter import TorchscriptModelAdapter - return TorchscriptModelAdapter(model_description=model_description, devices=devices) + return TorchscriptModelAdapter( + model_description=model_description, devices=devices + ) except Exception as e: errors.append(e) elif wf == "keras_hdf5" and weights.keras_hdf5 is not None: @@ -94,7 +105,9 @@ def create( if keras is None: from ._tensorflow_model_adapter import KerasModelAdapter - return KerasModelAdapter(model_description=model_description, devices=devices) + return KerasModelAdapter( + model_description=model_description, devices=devices + ) except Exception as e: errors.append(e) diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py index b3a632b3..26400eda 100644 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -16,12 +16,19 @@ class ONNXModelAdapter(ModelAdapter): def __init__( - self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, ): assert rt is not None super().__init__() self._internal_output_axes = [ - tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes) + ( + tuple(out.axes) + if isinstance(out.axes, str) + else tuple(a.id for a in out.axes) + ) for out in model_description.outputs ] if model_description.weights.onnx is None: @@ -32,18 +39,27 @@ def __init__( self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] # type: ignore if devices is not None: - warnings.warn(f"Device management is not implemented for onnx yet, ignoring the devices {devices}") + warnings.warn( + f"Device management is not implemented for onnx yet, ignoring the devices {devices}" + ) def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: assert len(input_tensors) == len(self._input_names) input_arrays = [ipt.data for ipt in input_tensors] - result: Union[Sequence[NDArray[Any]], NDArray[Any]] = ( # pyright: ignore[reportUnknownVariableType] - self._session.run(None, dict(zip(self._input_names, input_arrays))) + result: Union[Sequence[NDArray[Any]], NDArray[Any]] = ( + self._session.run( # pyright: ignore[reportUnknownVariableType] + None, dict(zip(self._input_names, input_arrays)) + ) ) if not isinstance(result, (list, tuple)): result = [] - return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] + return [ + xr.DataArray(r, dims=axes) + for r, axes in zip(result, self._internal_output_axes) + ] def unload(self) -> None: - warnings.warn("Device management is not implemented for onnx yet, cannot unload model") + warnings.warn( + "Device management is not implemented for onnx yet, cannot unload model" + ) diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 95f3de50..b54b82d8 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -20,18 +20,27 @@ class PytorchModelAdapter(ModelAdapter): def __init__( self, *, - outputs: Union[Sequence[v0_4.OutputTensorDescr], Sequence[v0_5.OutputTensorDescr]], - weights: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr], + outputs: Union[ + Sequence[v0_4.OutputTensorDescr], Sequence[v0_5.OutputTensorDescr] + ], + weights: Union[ + v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr + ], devices: Optional[Sequence[str]] = None, ): assert torch is not None super().__init__() - self.output_dims = [tuple(a if isinstance(a, str) else a.id for a in out.axes) for out in outputs] + self.output_dims = [ + tuple(a if isinstance(a, str) else a.id for a in out.axes) + for out in outputs + ] self._network = self.get_network(weights) self._devices = self.get_devices(devices) self._network = self._network.to(self._devices[0]) - state: Any = torch.load(download(weights.source).path, map_location=self._devices[0]) + state: Any = torch.load( + download(weights.source).path, map_location=self._devices[0] + ) _ = self._network.load_state_dict(state) self._network = self._network.eval() @@ -44,9 +53,14 @@ def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: if not isinstance(result, (tuple, list)): result = [result] - result = [r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r for r in result] + result = [ + r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r + for r in result + ] if len(result) > len(self.output_dims): - raise ValueError(f"Expected at most {len(self.output_dims)} outputs, but got {len(result)}") + raise ValueError( + f"Expected at most {len(self.output_dims)} outputs, but got {len(result)}" + ) return [xr.DataArray(r, dims=out) for r, out in zip(result, self.output_dims)] @@ -57,7 +71,9 @@ def unload(self) -> None: @staticmethod def get_network( - weight_spec: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr] + weight_spec: Union[ + v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr + ] ) -> "torch.nn.Module": arch = import_callable( weight_spec.architecture, @@ -74,14 +90,22 @@ def get_network( ) network = arch(**model_kwargs) if not isinstance(network, torch.nn.Module): - raise ValueError(f"calling {weight_spec.architecture.callable} did not return a torch.nn.Module") + raise ValueError( + f"calling {weight_spec.architecture.callable} did not return a torch.nn.Module" + ) return network @staticmethod def get_devices(devices: Optional[Sequence[str]] = None) -> List["torch.device"]: if not devices: - torch_devices = [torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")] + torch_devices = [ + ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + ] else: torch_devices = [torch.device(d) for d in devices] diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index a845f380..cba0ad04 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -46,22 +46,34 @@ def __init__( warnings.warn( f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." ) - elif (model_tf_version.major, model_tf_version.minor) != (tf_version.major, tf_version.minor): + elif (model_tf_version.major, model_tf_version.minor) != ( + tf_version.major, + tf_version.minor, + ): warnings.warn( "The tensorflow version specified by the model does not match the installed: " f"{model_tf_version} != {tf_version}." ) - self.use_keras_api = tf_version.major > 1 or self.weight_format == KerasModelAdapter.weight_format + self.use_keras_api = ( + tf_version.major > 1 + or self.weight_format == KerasModelAdapter.weight_format + ) # TODO tf device management if devices is not None: - warnings.warn(f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}") + warnings.warn( + f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" + ) weight_file = self.require_unzipped(weights.source) self._network = self._get_network(weight_file) self._internal_output_axes = [ - tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes) + ( + tuple(out.axes) + if isinstance(out.axes, str) + else tuple(a.id for a in out.axes) + ) for out in model_description.outputs ] @@ -88,15 +100,19 @@ def _get_network(self, weight_file: FileSource): # alive in between of forward passes (but then the sessions need to be properly opened / closed) def _forward_tf(self, *input_tensors): input_keys = [ - ipt.name if isinstance(ipt, v0_4.InputTensorDescr) else ipt.id for ipt in self.model_description.inputs + ipt.name if isinstance(ipt, v0_4.InputTensorDescr) else ipt.id + for ipt in self.model_description.inputs ] output_keys = [ - out.name if isinstance(out, v0_4.OutputTensorDescr) else out.id for out in self.model_description.outputs + out.name if isinstance(out, v0_4.OutputTensorDescr) else out.id + for out in self.model_description.outputs ] # TODO read from spec tag = tf.saved_model.tag_constants.SERVING - signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + signature_key = ( + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + ) graph = tf.Graph() with graph.as_default(): @@ -106,13 +122,20 @@ def _forward_tf(self, *input_tensors): signature = graph_def.signature_def # get the tensors into the graph - in_names = [signature[signature_key].inputs[key].name for key in input_keys] - out_names = [signature[signature_key].outputs[key].name for key in output_keys] + in_names = [ + signature[signature_key].inputs[key].name for key in input_keys + ] + out_names = [ + signature[signature_key].outputs[key].name for key in output_keys + ] in_tensors = [graph.get_tensor_by_name(name) for name in in_names] out_tensors = [graph.get_tensor_by_name(name) for name in out_names] # run prediction - res = sess.run(dict(zip(out_names, out_tensors)), dict(zip(in_tensors, input_tensors))) + res = sess.run( + dict(zip(out_names, out_tensors)), + dict(zip(in_tensors, input_tensors)), + ) # from dict to list of tensors res = [res[out] for out in out_names] @@ -140,17 +163,25 @@ def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: else: result = self._forward_tf(*data) - return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] + return [ + xr.DataArray(r, dims=axes) + for r, axes in zip(result, self._internal_output_axes) + ] def unload(self) -> None: - warnings.warn("Device management is not implemented for keras yet, cannot unload model") + warnings.warn( + "Device management is not implemented for keras yet, cannot unload model" + ) class TensorflowModelAdapter(TensorflowModelAdapterBase): weight_format = "tensorflow_saved_model_bundle" def __init__( - self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, ): if model_description.weights.tensorflow_saved_model_bundle is None: raise ValueError("missing tensorflow_saved_model_bundle weights") @@ -166,7 +197,10 @@ class KerasModelAdapter(TensorflowModelAdapterBase): weight_format = "keras_hdf5" def __init__( - self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, ): if model_description.weights.keras_hdf5 is None: raise ValueError("missing keras_hdf5 weights") diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index 876136b8..7637bd8a 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -19,12 +19,17 @@ class TorchscriptModelAdapter(ModelAdapter): def __init__( - self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, ): assert torch is not None super().__init__() if model_description.weights.torchscript is None: - raise ValueError(f"No torchscript weights found for model {model_description.name}") + raise ValueError( + f"No torchscript weights found for model {model_description.name}" + ) weight_path = download(model_description.weights.torchscript.source).path if devices is None: @@ -33,12 +38,18 @@ def __init__( self.devices = [torch.device(d) for d in devices] if len(self.devices) > 1: - warnings.warn("Multiple devices for single torchscript model not yet implemented") + warnings.warn( + "Multiple devices for single torchscript model not yet implemented" + ) self._model = torch.jit.load(weight_path) self._model.to(self.devices[0]) self._internal_output_axes = [ - tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes) + ( + tuple(out.axes) + if isinstance(out.axes, str) + else tuple(a.id for a in out.axes) + ) for out in model_description.outputs ] @@ -53,10 +64,15 @@ def forward(self, *batch: xr.DataArray) -> List[xr.DataArray]: else: result = [_result] - result = [r.cpu().numpy() if not isinstance(r, np.ndarray) else r for r in result] + result = [ + r.cpu().numpy() if not isinstance(r, np.ndarray) else r for r in result + ] assert len(result) == len(self._internal_output_axes) - return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] + return [ + xr.DataArray(r, dims=axes) + for r, axes in zip(result, self._internal_output_axes) + ] def unload(self) -> None: self._devices = None diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 7c179e28..080b988f 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -44,7 +44,8 @@ def convert_axis_ids( - axes: Union[Sequence[AxisId], v0_4.AxesInCZYX], mode: Literal["per_sample", "per_dataset"] + axes: Union[Sequence[AxisId], v0_4.AxesInCZYX], + mode: Literal["per_sample", "per_dataset"], ) -> Tuple[AxisId, ...]: if not isinstance(axes, str): return tuple(axes) @@ -152,7 +153,8 @@ def required_measures(self) -> Set[Measure]: def __post_init__(self): self._keep_updating_initial_dataset_stats = ( - self.keep_updating_initial_dataset_stats or not self.stats_calculator.has_dataset_measures + self.keep_updating_initial_dataset_stats + or not self.stats_calculator.has_dataset_measures ) def __call__(self, sample: Sample) -> None: @@ -178,7 +180,9 @@ def _apply(self, input: Tensor, stat: Stat) -> xr.DataArray: # def get_descr(self): # return v0_5.BinarizeDescr(kwargs=v0_5.BinarizeKwargs(threshold=self.threshold)) @classmethod - def from_proc_descr(cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], tensor_id: TensorId) -> Self: + def from_proc_descr( + cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], tensor_id: TensorId + ) -> Self: return cls(input=tensor_id, output=tensor_id, threshold=descr.kwargs.threshold) @@ -199,8 +203,15 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input.clip(self.min, self.max) @classmethod - def from_proc_descr(cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], tensor_id: TensorId) -> Self: - return cls(input=tensor_id, output=tensor_id, min=descr.kwargs.min, max=descr.kwargs.max) + def from_proc_descr( + cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], tensor_id: TensorId + ) -> Self: + return cls( + input=tensor_id, + output=tensor_id, + min=descr.kwargs.min, + max=descr.kwargs.max, + ) @dataclass @@ -212,7 +223,9 @@ def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, tensor_id: TensorId): return cls(input=tensor_id, output=tensor_id, dtype=descr.kwargs.dtype) def get_descr(self): - return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=str(self.dtype))) + return v0_5.EnsureDtypeDescr( + kwargs=v0_5.EnsureDtypeKwargs(dtype=str(self.dtype)) + ) def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input.astype(self.dtype) @@ -234,12 +247,18 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: # ... @classmethod - def from_proc_descr(cls, descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], tensor_id: TensorId) -> Self: + def from_proc_descr( + cls, + descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], + tensor_id: TensorId, + ) -> Self: kwargs = descr.kwargs if isinstance(kwargs, v0_5.ScaleLinearKwargs): axis = kwargs.axis elif kwargs.axes is not None: - raise NotImplementedError("ScaleLinear operator from v0_4.ScaleLinearDescr with axes") + raise NotImplementedError( + "ScaleLinear operator from v0_4.ScaleLinearDescr with axes" + ) else: axis = None @@ -248,9 +267,15 @@ def from_proc_descr(cls, descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDes offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) else: assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 - gain = kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] + gain = ( + kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] + ) assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 - offset = kwargs.offset if isinstance(kwargs.offset, (float, int)) else kwargs.offset[0] + offset = ( + kwargs.offset + if isinstance(kwargs.offset, (float, int)) + else kwargs.offset[0] + ) return cls(input=tensor_id, output=tensor_id, gain=gain, offset=offset) @@ -293,7 +318,9 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: @classmethod def from_proc_descr( - cls, descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], tensor_id: TensorId + cls, + descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], + tensor_id: TensorId, ) -> Self: kwargs = descr.kwargs axes = _get_axes(descr.kwargs) @@ -331,8 +358,12 @@ def _get_axes( @dataclass class ScaleRange(_SimpleOperator): - lower_percentile: InitVar[Optional[Union[SamplePercentile, DatasetPercentile]]] = None - upper_percentile: InitVar[Optional[Union[SamplePercentile, DatasetPercentile]]] = None + lower_percentile: InitVar[Optional[Union[SamplePercentile, DatasetPercentile]]] = ( + None + ) + upper_percentile: InitVar[Optional[Union[SamplePercentile, DatasetPercentile]]] = ( + None + ) lower: Union[SamplePercentile, DatasetPercentile] = field(init=False) upper: Union[SamplePercentile, DatasetPercentile] = field(init=False) @@ -363,7 +394,11 @@ def required_measures(self): return {self.lower, self.upper} @classmethod - def from_proc_descr(cls, descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], tensor_id: TensorId): + def from_proc_descr( + cls, + descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], + tensor_id: TensorId, + ): kwargs = descr.kwargs ref_tensor = cast(TensorId, kwargs.reference_tensor) or tensor_id axes = _get_axes(descr.kwargs) @@ -375,8 +410,12 @@ def from_proc_descr(cls, descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr return cls( input=tensor_id, output=tensor_id, - lower_percentile=Percentile(n=kwargs.min_percentile, axes=axes, tensor_id=ref_tensor), - upper_percentile=Percentile(n=kwargs.max_percentile, axes=axes, tensor_id=ref_tensor), + lower_percentile=Percentile( + n=kwargs.min_percentile, axes=axes, tensor_id=ref_tensor + ), + upper_percentile=Percentile( + n=kwargs.max_percentile, axes=axes, tensor_id=ref_tensor + ), ) def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: @@ -411,7 +450,9 @@ def required_measures(self) -> Collection[Measure]: return {} @classmethod - def from_proc_descr(cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], tensor_id: TensorId) -> Self: + def from_proc_descr( + cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], tensor_id: TensorId + ) -> Self: assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) return cls(input=tensor_id, output=tensor_id) @@ -437,7 +478,9 @@ def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: @classmethod def from_proc_descr( - cls, descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], tensor_id: TensorId + cls, + descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], + tensor_id: TensorId, ): axes = _get_axes(descr.kwargs) @@ -461,7 +504,9 @@ def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: return (input - mean) / (std + self.eps) def get_descr(self): - return v0_5.ZeroMeanUnitVarianceDescr(kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps)) + return v0_5.ZeroMeanUnitVarianceDescr( + kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) + ) @dataclass @@ -475,7 +520,9 @@ class FixedZeroMeanUnitVariance(_SimpleOperator): def __post_init__(self): assert ( - isinstance(self.mean, (int, float)) or isinstance(self.std, (int, float)) or self.mean.dims == self.std.dims + isinstance(self.mean, (int, float)) + or isinstance(self.std, (int, float)) + or self.mean.dims == self.std.dims ) @classmethod @@ -512,7 +559,12 @@ def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: return (input - self.mean) / (self.std + self.eps) -ProcDescr = Union[v0_4.PreprocessingDescr, v0_4.PostprocessingDescr, v0_5.PreprocessingDescr, v0_5.PostprocessingDescr] +ProcDescr = Union[ + v0_4.PreprocessingDescr, + v0_4.PostprocessingDescr, + v0_5.PreprocessingDescr, + v0_5.PostprocessingDescr, +] Processing = Union[ AddKnownDatasetStats, @@ -540,13 +592,18 @@ def get_proc_class(proc_spec: ProcDescr): return FixedZeroMeanUnitVariance elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): return ScaleLinear - elif isinstance(proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr)): + elif isinstance( + proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) + ): return ScaleMeanVariance elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): return ScaleRange elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): return Sigmoid - elif isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) and proc_spec.kwargs.mode == "fixed": + elif ( + isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) + and proc_spec.kwargs.mode == "fixed" + ): return FixedZeroMeanUnitVariance elif isinstance( proc_spec, diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index a375a2b7..fbfb37ff 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -12,13 +12,23 @@ from typing_extensions import assert_never from bioimageio.core.common import Sample -from bioimageio.core.proc_ops import AddKnownDatasetStats, Processing, UpdateStats, get_proc_class +from bioimageio.core.proc_ops import ( + AddKnownDatasetStats, + Processing, + UpdateStats, + get_proc_class, +) from bioimageio.core.stat_calculators import StatsCalculator from bioimageio.core.stat_measures import DatasetMeasure, Measure, MeasureValue from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.model.v0_5 import TensorId -TensorDescr = Union[v0_4.InputTensorDescr, v0_4.OutputTensorDescr, v0_5.InputTensorDescr, v0_5.OutputTensorDescr] +TensorDescr = Union[ + v0_4.InputTensorDescr, + v0_4.OutputTensorDescr, + v0_5.InputTensorDescr, + v0_5.OutputTensorDescr, +] class PreAndPostprocessing(NamedTuple): @@ -44,7 +54,9 @@ def setup_pre_and_postprocessing( userd in `bioimageio.core.create_prediction_pipeline""" prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model) - missing_dataset_stats = {m for m in prep_meas | post_meas if m not in fixed_dataset_stats} + missing_dataset_stats = { + m for m in prep_meas | post_meas if m not in fixed_dataset_stats + } initial_stats_calc = StatsCalculator(missing_dataset_stats) for sample in dataset_for_initial_statistics: initial_stats_calc.update(sample) @@ -97,8 +109,14 @@ def prepare_procs(tensor_descrs: Sequence[TensorDescr]): for proc_d in proc_descrs: proc_class = get_proc_class(proc_d) - tensor_id = TensorId(str(t_descr.name)) if isinstance(t_descr, v0_4.TensorDescrBase) else t_descr.id - req = proc_class.from_proc_descr(proc_d, tensor_id) # pyright: ignore[reportArgumentType] + tensor_id = ( + TensorId(str(t_descr.name)) + if isinstance(t_descr, v0_4.TensorDescrBase) + else t_descr.id + ) + req = proc_class.from_proc_descr( + proc_d, tensor_id + ) # pyright: ignore[reportArgumentType] for m in req.required_measures: if m.tensor_id in input_ids: pre_measures.add(m) diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 3cc4d67e..818e6303 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -127,7 +127,9 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): self._mean: Optional[xr.DataArray] = None self._m2: Optional[xr.DataArray] = None - def compute(self, sample: Sample) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: + def compute( + self, sample: Sample + ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: tensor = sample.data[self._tensor_id] mean = tensor.mean(dim=self._axes) c = tensor - mean @@ -170,7 +172,9 @@ def update(self, sample: Sample): self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n assert self._m2.dtype == np.float64 - def finalize(self) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: + def finalize( + self, + ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: if self._mean is None: return {} else: @@ -187,7 +191,12 @@ def finalize(self) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureVa class SamplePercentilesCalculator: """to calculate sample percentiles""" - def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Collection[float]): + def __init__( + self, + tensor_id: TensorId, + axes: Optional[Sequence[AxisId]], + ns: Collection[float], + ): super().__init__() assert all(0 <= n <= 100 for n in ns) self.ns = ns @@ -198,14 +207,23 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Co def compute(self, sample: Sample) -> Dict[SamplePercentile, MeasureValue]: tensor = sample.data[self._tensor_id] ps = tensor.quantile(self._qs, dim=self._axes) - return {SamplePercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): p for n, p in zip(self.ns, ps)} + return { + SamplePercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): p + for n, p in zip(self.ns, ps) + } class MeanPercentilesCalculator: """to calculate dataset percentiles heuristically by averaging across samples - **note**: the returned dataset percentiles are an estiamte and **not mathematically correct**""" + **note**: the returned dataset percentiles are an estiamte and **not mathematically correct** + """ - def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Collection[float]): + def __init__( + self, + tensor_id: TensorId, + axes: Optional[Sequence[AxisId]], + ns: Collection[float], + ): super().__init__() assert all(0 <= n <= 100 for n in ns) self._ns = ns @@ -217,7 +235,9 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Co def update(self, sample: Sample): tensor = sample.data[self._tensor_id] - sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype(np.float64, copy=False) + sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( + np.float64, copy=False + ) # reduced voxel count n = int(np.prod(tensor.shape) / np.prod(sample_estimates.shape[1:])) @@ -226,7 +246,9 @@ def update(self, sample: Sample): assert self._n == 0 self._estimates = sample_estimates else: - self._estimates = (self._n * self._estimates + n * sample_estimates) / (self._n + n) + self._estimates = (self._n * self._estimates + n * sample_estimates) / ( + self._n + n + ) assert self._estimates.dtype == np.float64 self._n += n @@ -235,7 +257,9 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: if self._estimates is None: return {} else: - warnings.warn("Computed dataset percentiles naively by averaging percentiles of samples.") + warnings.warn( + "Computed dataset percentiles naively by averaging percentiles of samples." + ) return { DatasetPercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): e for n, e in zip(self._ns, self._estimates) @@ -245,8 +269,15 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: class CrickPercentilesCalculator: """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)""" - def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Collection[float]): - warnings.warn("Computing dataset percentiles with experimental 'crick' library.") + def __init__( + self, + tensor_id: TensorId, + axes: Optional[Sequence[AxisId]], + ns: Collection[float], + ): + warnings.warn( + "Computing dataset percentiles with experimental 'crick' library." + ) super().__init__() assert all(0 <= n <= 100 for n in ns) assert axes is None or "_percentiles" not in axes @@ -261,7 +292,9 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Co def _initialize(self, tensor_sizes: Mapping[Hashable, int]): assert crick is not None - out_sizes: OrderedDict[Hashable, int] = collections.OrderedDict(_percentiles=len(self._ns)) + out_sizes: OrderedDict[Hashable, int] = collections.OrderedDict( + _percentiles=len(self._ns) + ) if self._axes is not None: for d, s in tensor_sizes.items(): if d not in self._axes: @@ -291,19 +324,21 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: assert self._dims is not None assert self._shape is not None - vs: NDArray[Any] = np.asarray([[d.quantile(q) for d in self._digest] for q in self._qs]).reshape( - self._shape - ) + vs: NDArray[Any] = np.asarray( + [[d.quantile(q) for d in self._digest] for q in self._qs] + ).reshape(self._shape) return { - DatasetPercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): xr.DataArray(v, dims=self._dims[1:]) + DatasetPercentile( + n=n, axes=self._axes, tensor_id=self._tensor_id + ): xr.DataArray(v, dims=self._dims[1:]) for n, v in zip(self._ns, vs) } if crick is None: - DatasetPercentilesCalculator: Type[Union[MeanPercentilesCalculator, CrickPercentilesCalculator]] = ( - MeanPercentilesCalculator - ) + DatasetPercentilesCalculator: Type[ + Union[MeanPercentilesCalculator, CrickPercentilesCalculator] + ] = MeanPercentilesCalculator else: DatasetPercentilesCalculator = CrickPercentilesCalculator @@ -321,9 +356,14 @@ def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: SampleMeasureCalculator = Union[ - MeanCalculator, MeanVarStdCalculator, SamplePercentilesCalculator, NaiveSampleMeasureCalculator + MeanCalculator, + MeanVarStdCalculator, + SamplePercentilesCalculator, + NaiveSampleMeasureCalculator, +] +DatasetMeasureCalculator = Union[ + MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator ] -DatasetMeasureCalculator = Union[MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator] class StatsCalculator: @@ -332,19 +372,30 @@ class StatsCalculator: def __init__( self, measures: Collection[Measure], - initial_dataset_measures: Optional[Mapping[DatasetMeasure, MeasureValue]] = None, + initial_dataset_measures: Optional[ + Mapping[DatasetMeasure, MeasureValue] + ] = None, ): super().__init__() self.sample_count = 0 - self.sample_calculators, self.dataset_calculators = get_measure_calculators(measures) + self.sample_calculators, self.dataset_calculators = get_measure_calculators( + measures + ) if initial_dataset_measures is None: - self._current_dataset_measures: Optional[Dict[DatasetMeasure, MeasureValue]] = None + self._current_dataset_measures: Optional[ + Dict[DatasetMeasure, MeasureValue] + ] = None else: missing_dataset_meas = { - m for m in measures if isinstance(m, DatasetMeasureBase) and m not in initial_dataset_measures + m + for m in measures + if isinstance(m, DatasetMeasureBase) + and m not in initial_dataset_measures } if missing_dataset_meas: - warnings.warn(f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}") + warnings.warn( + f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" + ) self._current_dataset_measures = None else: self._current_dataset_measures = dict(initial_dataset_measures) @@ -366,7 +417,9 @@ def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: return self._current_dataset_measures - def update_and_get_all(self, sample: Union[Sample, Iterable[Sample]]) -> Dict[Measure, MeasureValue]: + def update_and_get_all( + self, sample: Union[Sample, Iterable[Sample]] + ) -> Dict[Measure, MeasureValue]: """Returns sample as well as updated dataset statistics""" last_sample = self._update(sample) if last_sample is None: @@ -411,9 +464,15 @@ def get_measure_calculators( required_sample_means: Set[SampleMean] = set() required_dataset_means: Set[DatasetMean] = set() required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() - required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = set() - required_sample_percentiles: Dict[Tuple[TensorId, Optional[Tuple[AxisId, ...]]], Set[float]] = {} - required_dataset_percentiles: Dict[Tuple[TensorId, Optional[Tuple[AxisId, ...]]], Set[float]] = {} + required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = ( + set() + ) + required_sample_percentiles: Dict[ + Tuple[TensorId, Optional[Tuple[AxisId, ...]]], Set[float] + ] = {} + required_dataset_percentiles: Dict[ + Tuple[TensorId, Optional[Tuple[AxisId, ...]]], Set[float] + ] = {} for rm in required_measures: if isinstance(rm, SampleMean): @@ -422,18 +481,28 @@ def get_measure_calculators( required_dataset_means.add(rm) elif isinstance(rm, (SampleVar, SampleStd)): required_sample_mean_var_std.update( - {msv(axes=rm.axes, tensor_id=rm.tensor_id) for msv in (SampleMean, SampleStd, SampleVar)} + { + msv(axes=rm.axes, tensor_id=rm.tensor_id) + for msv in (SampleMean, SampleStd, SampleVar) + } ) assert rm in required_sample_mean_var_std elif isinstance(rm, (DatasetVar, DatasetStd)): required_dataset_mean_var_std.update( - {msv(axes=rm.axes, tensor_id=rm.tensor_id) for msv in (DatasetMean, DatasetStd, DatasetVar)} + { + msv(axes=rm.axes, tensor_id=rm.tensor_id) + for msv in (DatasetMean, DatasetStd, DatasetVar) + } ) assert rm in required_dataset_mean_var_std elif isinstance(rm, SamplePercentile): - required_sample_percentiles.setdefault((rm.tensor_id, rm.axes), set()).add(rm.n) + required_sample_percentiles.setdefault((rm.tensor_id, rm.axes), set()).add( + rm.n + ) elif isinstance(rm, DatasetPercentile): - required_dataset_percentiles.setdefault((rm.tensor_id, rm.axes), set()).add(rm.n) + required_dataset_percentiles.setdefault((rm.tensor_id, rm.axes), set()).add( + rm.n + ) else: assert_never(rm) @@ -445,7 +514,9 @@ def get_measure_calculators( sample_calculators.append(MeanCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) for rm in required_sample_mean_var_std: - sample_calculators.append(MeanVarStdCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) + sample_calculators.append( + MeanVarStdCalculator(tensor_id=rm.tensor_id, axes=rm.axes) + ) for rm in required_dataset_means: if rm in required_dataset_mean_var_std: @@ -455,13 +526,19 @@ def get_measure_calculators( dataset_calculators.append(MeanCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) for rm in required_dataset_mean_var_std: - dataset_calculators.append(MeanVarStdCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) + dataset_calculators.append( + MeanVarStdCalculator(tensor_id=rm.tensor_id, axes=rm.axes) + ) for (tid, axes), ns in required_sample_percentiles.items(): - sample_calculators.append(SamplePercentilesCalculator(tensor_id=tid, axes=axes, ns=ns)) + sample_calculators.append( + SamplePercentilesCalculator(tensor_id=tid, axes=axes, ns=ns) + ) for (tid, axes), ns in required_dataset_percentiles.items(): - dataset_calculators.append(DatasetPercentilesCalculator(tensor_id=tid, axes=axes, ns=ns)) + dataset_calculators.append( + DatasetPercentilesCalculator(tensor_id=tid, axes=axes, ns=ns) + ) return sample_calculators, dataset_calculators @@ -485,7 +562,9 @@ def compute_dataset_measures( return ret -def compute_sample_measures(measures: Iterable[SampleMeasure], sample: Sample) -> Dict[SampleMeasure, MeasureValue]: +def compute_sample_measures( + measures: Iterable[SampleMeasure], sample: Sample +) -> Dict[SampleMeasure, MeasureValue]: """compute all sample `measures` for the given `sample`""" calculators, dataset_calculators = get_measure_calculators(measures) assert not dataset_calculators @@ -497,7 +576,9 @@ def compute_sample_measures(measures: Iterable[SampleMeasure], sample: Sample) - return ret -def compute_measures(measures: Iterable[Measure], dataset: Iterable[Sample]) -> Dict[Measure, MeasureValue]: +def compute_measures( + measures: Iterable[Measure], dataset: Iterable[Sample] +) -> Dict[Measure, MeasureValue]: """compute all `measures` for the given `dataset` sample measures are computed for the last sample in `dataset`""" sample_calculators, dataset_calculators = get_measure_calculators(measures) diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py index a0514a02..f01773ac 100644 --- a/bioimageio/core/utils/_digest_spec.py +++ b/bioimageio/core/utils/_digest_spec.py @@ -8,13 +8,29 @@ def get_test_inputs(model: AnyModelDescr) -> List[xr.DataArray]: if isinstance(model, v0_4.ModelDescr): - return [xr.DataArray(load_array(tt), dims=tuple(d.axes)) for d, tt in zip(model.inputs, model.test_inputs)] + return [ + xr.DataArray(load_array(tt), dims=tuple(d.axes)) + for d, tt in zip(model.inputs, model.test_inputs) + ] else: - return [xr.DataArray(load_array(d.test_tensor), dims=tuple(str(a.id) for a in d.axes)) for d in model.inputs] + return [ + xr.DataArray( + load_array(d.test_tensor), dims=tuple(str(a.id) for a in d.axes) + ) + for d in model.inputs + ] def get_test_outputs(model: AnyModelDescr) -> List[xr.DataArray]: if isinstance(model, v0_4.ModelDescr): - return [xr.DataArray(load_array(tt), dims=tuple(d.axes)) for d, tt in zip(model.outputs, model.test_outputs)] + return [ + xr.DataArray(load_array(tt), dims=tuple(d.axes)) + for d, tt in zip(model.outputs, model.test_outputs) + ] else: - return [xr.DataArray(load_array(d.test_tensor), dims=tuple(str(a.id) for a in d.axes)) for d in model.outputs] + return [ + xr.DataArray( + load_array(d.test_tensor), dims=tuple(str(a.id) for a in d.axes) + ) + for d in model.outputs + ] diff --git a/bioimageio/core/utils/_import_callable.py b/bioimageio/core/utils/_import_callable.py index 40ff1c45..a60259d9 100644 --- a/bioimageio/core/utils/_import_callable.py +++ b/bioimageio/core/utils/_import_callable.py @@ -9,7 +9,10 @@ from bioimageio.spec._internal.io_utils import HashKwargs, download from bioimageio.spec.common import FileSource from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile -from bioimageio.spec.model.v0_5 import ArchitectureFromFileDescr, ArchitectureFromLibraryDescr +from bioimageio.spec.model.v0_5 import ( + ArchitectureFromFileDescr, + ArchitectureFromLibraryDescr, +) @singledispatch @@ -47,10 +50,14 @@ def import_from_file05(node: ArchitectureFromFileDescr, **kwargs: Unpack[HashKwa return _import_from_file_impl(node.source, node.callable, sha256=node.sha256) -def _import_from_file_impl(source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]): +def _import_from_file_impl( + source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] +): local_file = download(source, **kwargs) module_name = local_file.path.stem - importlib_spec = importlib.util.spec_from_file_location(module_name, local_file.path) + importlib_spec = importlib.util.spec_from_file_location( + module_name, local_file.path + ) if importlib_spec is None: raise ImportError(f"Failed to import {module_name} from {source}.") diff --git a/bioimageio/core/utils/image_helper.py b/bioimageio/core/utils/image_helper.py index 48e989af..444e0fc9 100644 --- a/bioimageio/core/utils/image_helper.py +++ b/bioimageio/core/utils/image_helper.py @@ -51,10 +51,15 @@ def interprete_array( SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[1]), ) elif ndim == 3 and ( - (n_expected_space_axes is None and any(s <= 3 for s in nd_array.shape)) or n_expected_space_axes == 2 + (n_expected_space_axes is None and any(s <= 3 for s in nd_array.shape)) + or n_expected_space_axes == 2 ): current_axes = ( - ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[0])]), + ChannelAxis( + channel_names=[ + Identifier(f"channel{i}") for i in range(nd_array.shape[0]) + ] + ), SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[1]), SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[2]), ) @@ -66,7 +71,11 @@ def interprete_array( ) elif ndim == 4: current_axes = ( - ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[0])]), + ChannelAxis( + channel_names=[ + Identifier(f"channel{i}") for i in range(nd_array.shape[0]) + ] + ), SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[1]), SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[2]), SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[3]), @@ -74,7 +83,11 @@ def interprete_array( elif ndim == 5: current_axes = ( BatchAxis(), - ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[1])]), + ChannelAxis( + channel_names=[ + Identifier(f"channel{i}") for i in range(nd_array.shape[1]) + ] + ), SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[2]), SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[3]), SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[4]), @@ -84,11 +97,17 @@ def interprete_array( f"Could not guess an axis mapping for {nd_array.shape} with {n_expected_space_axes} expected space axes" ) - current_axes_ids = tuple(current_axes) if isinstance(current_axes, str) else tuple(a.id for a in current_axes) + current_axes_ids = ( + tuple(current_axes) + if isinstance(current_axes, str) + else tuple(a.id for a in current_axes) + ) return Tensor(nd_array, dims=current_axes_ids) -def axis_descr_to_ids(axes: Union[v0_4.AxesStr, Sequence[AnyAxis]]) -> Tuple[AxisId, ...]: +def axis_descr_to_ids( + axes: Union[v0_4.AxesStr, Sequence[AnyAxis]] +) -> Tuple[AxisId, ...]: if isinstance(axes, str): return tuple(map(AxisId, axes)) else: @@ -149,10 +168,12 @@ def resize_to( sizes: Mapping[AxisId, int], *, pad_where: Union[ - Literal["before", "center", "after"], Mapping[AxisId, Literal["before", "center", "after"]] + Literal["before", "center", "after"], + Mapping[AxisId, Literal["before", "center", "after"]], ] = "center", crop_where: Union[ - Literal["before", "center", "after"], Mapping[AxisId, Literal["before", "center", "after"]] + Literal["before", "center", "after"], + Mapping[AxisId, Literal["before", "center", "after"]], ] = "center", pad_mode: Literal["edge", "reflect", "symmetric"] = "symmetric", ): @@ -186,13 +207,16 @@ def crop_to( tensor: Tensor, sizes: Mapping[AxisId, int], crop_where: Union[ - Literal["before", "center", "after"], Mapping[AxisId, Literal["before", "center", "after"]] + Literal["before", "center", "after"], + Mapping[AxisId, Literal["before", "center", "after"]], ] = "center", ): """crop `tensor` to match `sizes`""" axes = [AxisId(str(a)) for a in tensor.dims] if crop_where in ("before", "center", "after"): - crop_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = {a: crop_where for a in axes} + crop_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { + a: crop_where for a in axes + } else: crop_axis_where = crop_where @@ -203,9 +227,13 @@ def crop_to( if a not in sizes or sizes[a] == s_is: pass elif sizes[a] > s_is: - warnings.warn(f"Cannot crop axis {a} of size {s_is} to larger size {sizes[a]}") + warnings.warn( + f"Cannot crop axis {a} of size {s_is} to larger size {sizes[a]}" + ) elif a not in crop_axis_where: - raise ValueError(f"Don't know where to crop axis {a}, `crop_where`={crop_where}") + raise ValueError( + f"Don't know where to crop axis {a}, `crop_where`={crop_where}" + ) else: crop_this_axis_where = crop_axis_where[a] if crop_this_axis_where == "before": @@ -224,14 +252,17 @@ def pad_to( tensor: Tensor, sizes: Mapping[AxisId, int], pad_where: Union[ - Literal["before", "center", "after"], Mapping[AxisId, Literal["before", "center", "after"]] + Literal["before", "center", "after"], + Mapping[AxisId, Literal["before", "center", "after"]], ] = "center", mode: Literal["edge", "reflect", "symmetric"] = "symmetric", ): """pad `tensor` to match `sizes`""" axes = [AxisId(str(a)) for a in tensor.dims] if pad_where in ("before", "center", "after"): - pad_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = {a: pad_where for a in axes} + pad_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { + a: pad_where for a in axes + } else: pad_axis_where = pad_where @@ -242,9 +273,13 @@ def pad_to( pad_width[a] = 0 elif s_is < sizes[a]: pad_width[a] = 0 - warnings.warn(f"Cannot pad axis {a} of size {s_is} to smaller size {sizes[a]}") + warnings.warn( + f"Cannot pad axis {a} of size {s_is} to smaller size {sizes[a]}" + ) elif a not in pad_axis_where: - raise ValueError(f"Don't know where to pad axis {a}, `pad_where`={pad_where}") + raise ValueError( + f"Don't know where to pad axis {a}, `pad_where`={pad_where}" + ) else: pad_this_axis_where = pad_axis_where[a] p = sizes[a] - s_is diff --git a/bioimageio/core/weight_converter/keras/_tensorflow.py b/bioimageio/core/weight_converter/keras/_tensorflow.py index 5fa6be54..adad502b 100644 --- a/bioimageio/core/weight_converter/keras/_tensorflow.py +++ b/bioimageio/core/weight_converter/keras/_tensorflow.py @@ -32,7 +32,13 @@ def _zip_model_bundle(model_bundle_folder: Path): # adapted from # https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236 -def _convert_tf1(keras_weight_path: Path, output_path: Path, input_name: str, output_name: str, zip_weights: bool): +def _convert_tf1( + keras_weight_path: Path, + output_path: Path, + input_name: str, + output_name: str, + zip_weights: bool, +): try: # try to build the tf model with the keras import from tensorflow from bioimageio.core.weight_converter.keras._tensorflow import keras # type: ignore @@ -47,10 +53,13 @@ def build_tf_model(): assert _tensorflow is not None builder = _tensorflow.saved_model.builder.SavedModelBuilder(output_path) signature = _tensorflow.saved_model.signature_def_utils.predict_signature_def( - inputs={input_name: keras_model.input}, outputs={output_name: keras_model.output} + inputs={input_name: keras_model.input}, + outputs={output_name: keras_model.output}, ) - signature_def_map = {_tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature} + signature_def_map = { + _tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature + } builder.add_meta_graph_and_variables( keras.backend.get_session(), @@ -86,7 +95,9 @@ def _convert_tf2(keras_weight_path: Path, output_path: Path, zip_weights: bool): return 0 -def convert_weights_to_tensorflow_saved_model_bundle(model: ModelDescr, output_path: Path): +def convert_weights_to_tensorflow_saved_model_bundle( + model: ModelDescr, output_path: Path +): """Convert model weights from format 'keras_hdf5' to 'tensorflow_saved_model_bundle'. Adapted from @@ -117,13 +128,21 @@ def convert_weights_to_tensorflow_saved_model_bundle(model: ModelDescr, output_p if weight_spec.tensorflow_version: model_tf_major_ver = int(weight_spec.tensorflow_version.major) if model_tf_major_ver != tf_major_ver: - raise RuntimeError(f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}") + raise RuntimeError( + f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}" + ) if tf_major_ver == 1: if len(model.inputs) != 1 or len(model.outputs) != 1: raise NotImplementedError( "Weight conversion for models with multiple inputs or outputs is not yet implemented." ) - return _convert_tf1(weight_path, output_path, model.inputs[0].id, model.outputs[0].id, zip_weights) + return _convert_tf1( + weight_path, + output_path, + model.inputs[0].id, + model.outputs[0].id, + zip_weights, + ) else: return _convert_tf2(weight_path, output_path, zip_weights) diff --git a/bioimageio/core/weight_converter/torch/_onnx.py b/bioimageio/core/weight_converter/torch/_onnx.py index f9b66b9f..2b4d1caf 100644 --- a/bioimageio/core/weight_converter/torch/_onnx.py +++ b/bioimageio/core/weight_converter/torch/_onnx.py @@ -42,7 +42,9 @@ def add_onnx_weights( state_dict_weights_descr = model_spec.weights.pytorch_state_dict if state_dict_weights_descr is None: - raise ValueError("The provided model does not have weights in the pytorch state dict format") + raise ValueError( + "The provided model does not have weights in the pytorch state dict format" + ) with torch.no_grad(): @@ -53,7 +55,9 @@ def add_onnx_weights( expected_tensors = model(*input_tensors) if isinstance(expected_tensors, torch.Tensor): expected_tensors = [expected_tensors] - expected_outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in expected_tensors] + expected_outputs: List[np.ndarray[Any, Any]] = [ + out.numpy() for out in expected_tensors + ] if use_tracing: torch.onnx.export( @@ -75,9 +79,16 @@ def add_onnx_weights( # check the onnx model sess = rt.InferenceSession(str(output_path)) - onnx_input_node_args = cast(List[Any], sess.get_inputs()) # fixme: remove cast, try using rt.NodeArg instead of Any - onnx_inputs = {input_name.name: inp for input_name, inp in zip(onnx_input_node_args, input_data)} - outputs = cast(Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs)) # FIXME: remove cast + onnx_input_node_args = cast( + List[Any], sess.get_inputs() + ) # fixme: remove cast, try using rt.NodeArg instead of Any + onnx_inputs = { + input_name.name: inp + for input_name, inp in zip(onnx_input_node_args, input_data) + } + outputs = cast( + Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs) + ) # FIXME: remove cast try: for exp, out in zip(expected_outputs, outputs): diff --git a/bioimageio/core/weight_converter/torch/_torchscript.py b/bioimageio/core/weight_converter/torch/_torchscript.py index e724dac2..ee11610c 100644 --- a/bioimageio/core/weight_converter/torch/_torchscript.py +++ b/bioimageio/core/weight_converter/torch/_torchscript.py @@ -14,13 +14,18 @@ # FIXME: remove Any def _check_predictions( - model: Any, scripted_model: Any, model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", input_data: Sequence[torch.Tensor] + model: Any, + scripted_model: Any, + model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", + input_data: Sequence[torch.Tensor], ): def _check(input_: Sequence[torch.Tensor]) -> None: expected_tensors = model(*input_) if isinstance(expected_tensors, torch.Tensor): expected_tensors = [expected_tensors] - expected_outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in expected_tensors] + expected_outputs: List[np.ndarray[Any, Any]] = [ + out.numpy() for out in expected_tensors + ] output_tensors = scripted_model(*input_) if isinstance(output_tensors, torch.Tensor): @@ -31,7 +36,9 @@ def _check(input_: Sequence[torch.Tensor]) -> None: for exp, out in zip(expected_outputs, outputs): assert_array_almost_equal(exp, out, decimal=4) except AssertionError as e: - raise ValueError(f"Results before and after weights conversion do not agree:\n {str(e)}") + raise ValueError( + f"Results before and after weights conversion do not agree:\n {str(e)}" + ) _check(input_data) @@ -55,7 +62,9 @@ def _check(input_: Sequence[torch.Tensor]) -> None: min_shape.append(axis.size) step.append(0) elif axis.size is None: - raise NotImplementedError(f"Can't verify inputs that don't specify their shape fully: {axis}") + raise NotImplementedError( + f"Can't verify inputs that don't specify their shape fully: {axis}" + ) elif isinstance(axis.size, v0_5.SizeReference): raise NotImplementedError(f"Can't handle axes like '{axis}' yet") else: @@ -66,16 +75,23 @@ def _check(input_: Sequence[torch.Tensor]) -> None: # check that input and output agree for decreasing input sizes for step_factor in range(1, max_steps + 1): - slice_ = tuple(slice(None) if st == 0 else slice(step_factor * st, -step_factor * st) for st in half_step) + slice_ = tuple( + slice(None) if st == 0 else slice(step_factor * st, -step_factor * st) + for st in half_step + ) this_input = [inp[slice_] for inp in input_data] this_shape = this_input[0].shape if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)): - raise ValueError(f"Mismatched shapes: {this_shape}. Expected at least {min_shape}") + raise ValueError( + f"Mismatched shapes: {this_shape}. Expected at least {min_shape}" + ) _check(this_input) def convert_weights_to_torchscript( - model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr], output_path: Path, use_tracing: bool = True + model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr], + output_path: Path, + use_tracing: bool = True, ) -> v0_5.TorchscriptWeightsDescr: """Convert model weights from format 'pytorch_state_dict' to 'torchscript'. @@ -87,7 +103,9 @@ def convert_weights_to_torchscript( state_dict_weights_descr = model_descr.weights.pytorch_state_dict if state_dict_weights_descr is None: - raise ValueError("The provided model does not have weights in the pytorch state dict format") + raise ValueError( + "The provided model does not have weights in the pytorch state dict format" + ) input_data = model_descr.get_input_test_arrays() @@ -102,11 +120,20 @@ def convert_weights_to_torchscript( else: scripted_model: Any = torch.jit.script(model) - _check_predictions(model=model, scripted_model=scripted_model, model_spec=model_descr, input_data=input_data) + _check_predictions( + model=model, + scripted_model=scripted_model, + model_spec=model_descr, + input_data=input_data, + ) # save the torchscript model - scripted_model.save(str(output_path)) # does not support Path, so need to cast to str + scripted_model.save( + str(output_path) + ) # does not support Path, so need to cast to str return v0_5.TorchscriptWeightsDescr( - source=output_path, pytorch_version=Version(torch.__version__), parent="pytorch_state_dict" + source=output_path, + pytorch_version=Version(torch.__version__), + parent="pytorch_state_dict", ) diff --git a/bioimageio/core/weight_converter/torch/_utils.py b/bioimageio/core/weight_converter/torch/_utils.py index 4b5debad..2acf17be 100644 --- a/bioimageio/core/weight_converter/torch/_utils.py +++ b/bioimageio/core/weight_converter/torch/_utils.py @@ -7,7 +7,9 @@ # additional convenience for pytorch state dict, eventually we want this in python-bioimageio too # and for each weight format -def load_torch_model(node: "v0_4.PytorchStateDictWeightsDescr | v0_5.PytorchStateDictWeightsDescr"): +def load_torch_model( + node: "v0_4.PytorchStateDictWeightsDescr | v0_5.PytorchStateDictWeightsDescr", +): model = PytorchModelAdapter.get_network(node) state = torch.load(download(node.source).path, map_location="cpu") _ = model.load_state_dict(state) # FIXME: check incompatible keys? diff --git a/pyproject.toml b/pyproject.toml index 7d715f57..083aaf84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.black] -line-length = 120 -target-version = ["py38", "py39", "py310", "py311"] +line-length = 88 +target-version = ["py38", "py39", "py310", "py311", "py312"] [tool.pyright] exclude = ["**/node_modules", "**/__pycache__", "tests/old_*"] @@ -33,6 +33,6 @@ useLibraryCodeForTypes = true addopts = " -n auto --capture=no --doctest-modules --failed-first" [tool.ruff] -line-length = 120 +line-length = 88 include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"] target-version = "py38" diff --git a/scripts/setup_dev_env.py b/scripts/setup_dev_env.py index 315306a2..b1df230f 100644 --- a/scripts/setup_dev_env.py +++ b/scripts/setup_dev_env.py @@ -14,8 +14,12 @@ def run(prompt: str): chdir(str(repo_dir)) try: run("mamba env create --file core-bioimage-io/dev/env.yaml") - run("pip install --no-deps --config-settings editable_mode=compat -e spec-bioimage-io") - run("pip install --no-deps --config-settings editable_mode=compat -e core-bioimage-io") + run( + "pip install --no-deps --config-settings editable_mode=compat -e spec-bioimage-io" + ) + run( + "pip install --no-deps --config-settings editable_mode=compat -e core-bioimage-io" + ) except Exception: chdir(cur_dir) raise diff --git a/scripts/show_diff.py b/scripts/show_diff.py index affbe685..1b0163bb 100644 --- a/scripts/show_diff.py +++ b/scripts/show_diff.py @@ -17,7 +17,9 @@ with TemporaryDirectory() as tmp: as_is = Path(tmp) / "as_is.bioimageio.yaml" - save_bioimageio_yaml_only(model_as_is, file=as_is) # write out as is to avoid sorting diff + save_bioimageio_yaml_only( + model_as_is, file=as_is + ) # write out as is to avoid sorting diff latest = Path(tmp) / "latest.bioimageio.yaml" save_bioimageio_yaml_only(model_latest, file=latest) diff --git a/tests/conftest.py b/tests/conftest.py index 61a66b30..03fe126e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -129,17 +129,26 @@ @fixture(scope="session") -def model_packages(tmp_path_factory: TempPathFactory, worker_id: str) -> MappingProxyType[str, FilePath]: +def model_packages( + tmp_path_factory: TempPathFactory, worker_id: str +) -> MappingProxyType[str, FilePath]: """prepare model packages (only run with one worker) see https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once """ root_tmp_dir = tmp_path_factory.getbasetemp().parent - packages = MappingProxyType({name: (root_tmp_dir / name).with_suffix(".zip") for name in load_model_packages}) + packages = MappingProxyType( + { + name: (root_tmp_dir / name).with_suffix(".zip") + for name in load_model_packages + } + ) def generate_packages(): for name in load_model_packages: - actual_out = save_bioimageio_package(MODEL_SOURCES[name], output_path=packages[name]) + actual_out = save_bioimageio_package( + MODEL_SOURCES[name], output_path=packages[name] + ) assert actual_out == packages[name] info_path = root_tmp_dir / "packages_created" @@ -176,32 +185,56 @@ def mamba_cmd(): @fixture(params=[] if skip_torch else TORCH_MODELS) -def any_torch_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +def any_torch_model( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] @fixture(params=[] if skip_torch else TORCHSCRIPT_MODELS) -def any_torchscript_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +def any_torchscript_model( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] @fixture(params=[] if skip_onnx else ONNX_MODELS) -def any_onnx_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +def any_onnx_model( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] -@fixture(params=[] if skip_tensorflow else TENSORFLOW1_MODELS if tf_major_version == 1 else TENSORFLOW2_MODELS) -def any_tensorflow_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +@fixture( + params=( + [] + if skip_tensorflow + else TENSORFLOW1_MODELS if tf_major_version == 1 else TENSORFLOW2_MODELS + ) +) +def any_tensorflow_model( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] -@fixture(params=[] if skip_tensorflow else KERAS_TF1_MODELS if tf_major_version == 1 else KERAS_TF2_MODELS) -def any_keras_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +@fixture( + params=( + [] + if skip_tensorflow + else KERAS_TF1_MODELS if tf_major_version == 1 else KERAS_TF2_MODELS + ) +) +def any_keras_model( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] @fixture(params=[] if skip_tensorflow_js else TENSORFLOW_JS_MODELS) -def any_tensorflow_js_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +def any_tensorflow_js_model( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] @@ -220,59 +253,93 @@ def any_model(request: FixtureRequest, model_packages: MappingProxyType[str, Fil # -@fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"]) -def unet2d_fixed_shape_or_not(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +@fixture( + params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"] +) +def unet2d_fixed_shape_or_not( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] -@fixture(params=[] if skip_onnx or skip_torch else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]) -def convert_to_onnx(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +@fixture( + params=( + [] + if skip_onnx or skip_torch + else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"] + ) +) +def convert_to_onnx( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] -@fixture(params=[] if skip_tensorflow else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"]) -def unet2d_keras(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +@fixture( + params=( + [] + if skip_tensorflow + else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"] + ) +) +def unet2d_keras( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] # written as model group to automatically skip on missing torch @fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"]) -def unet2d_nuclei_broad_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +def unet2d_nuclei_broad_model( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] # written as model group to automatically skip on missing torch @fixture(params=[] if skip_torch else ["unet2d_diff_output_shape"]) -def unet2d_diff_output_shape(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +def unet2d_diff_output_shape( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] # written as model group to automatically skip on missing torch @fixture(params=[] if skip_torch else ["unet2d_expand_output_shape"]) -def unet2d_expand_output_shape(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +def unet2d_expand_output_shape( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] # written as model group to automatically skip on missing torch @fixture(params=[] if skip_torch else ["unet2d_fixed_shape"]) -def unet2d_fixed_shape(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +def unet2d_fixed_shape( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] # written as model group to automatically skip on missing torch @fixture(params=[] if skip_torch else ["shape_change"]) -def shape_change_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): +def shape_change_model( + request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] +): return model_packages[request.param] # written as model group to automatically skip on missing tensorflow 1 -@fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"]) +@fixture( + params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"] +) def stardist_wrong_shape(request: FixtureRequest): return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing tensorflow 1 -@fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape2"]) +@fixture( + params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape2"] +) def stardist_wrong_shape2(request: FixtureRequest): return MODEL_SOURCES[request.param] diff --git a/tests/test_bioimageio_spec_version.py b/tests/test_bioimageio_spec_version.py index af47dfea..2a0ae2c2 100644 --- a/tests/test_bioimageio_spec_version.py +++ b/tests/test_bioimageio_spec_version.py @@ -14,7 +14,9 @@ def test_bioimageio_spec_version(mamba_cmd: Optional[str]): # get latest released bioimageio.spec version mamba_repoquery = subprocess.run( - f"{pytest.mamba_cmd} repoquery search -c conda-forge --json bioimageio.spec".split(" "), + f"{pytest.mamba_cmd} repoquery search -c conda-forge --json bioimageio.spec".split( + " " + ), encoding="utf-8", capture_output=True, check=True, diff --git a/tests/test_cli.py b/tests/test_cli.py index 944d5f5d..ee09f66f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -5,29 +5,53 @@ from pydantic import FilePath -def run_subprocess(commands: Sequence[str], **kwargs: Any) -> "subprocess.CompletedProcess[str]": - return subprocess.run(commands, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8", **kwargs) +def run_subprocess( + commands: Sequence[str], **kwargs: Any +) -> "subprocess.CompletedProcess[str]": + return subprocess.run( + commands, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + encoding="utf-8", + **kwargs, + ) @pytest.mark.parametrize( "args", [ - ["package", "unet2d_nuclei_broad_model", "--weight-format", "pytorch_state_dict"], + [ + "package", + "unet2d_nuclei_broad_model", + "--weight-format", + "pytorch_state_dict", + ], ["package", "unet2d_nuclei_broad_model"], - ["test-model", "unet2d_nuclei_broad_model", "--weight-format", "pytorch_state_dict"], + [ + "test-model", + "unet2d_nuclei_broad_model", + "--weight-format", + "pytorch_state_dict", + ], ["test-model", "unet2d_nuclei_broad_model"], ], ) def test_cli(args: List[str], unet2d_nuclei_broad_model: FilePath): assert unet2d_nuclei_broad_model.exists() - resolved_args = [str(unet2d_nuclei_broad_model) if arg == "unet2d_nuclei_broad_model" else arg for arg in args] + resolved_args = [ + str(unet2d_nuclei_broad_model) if arg == "unet2d_nuclei_broad_model" else arg + for arg in args + ] ret = run_subprocess(["bioimageio", *resolved_args]) assert ret.returncode == 0, ret.stdout @pytest.mark.parametrize("args", [["test-model", "stardist_wrong_shape"]]) def test_cli_fails(args: List[str], stardist_wrong_shape: FilePath): - resolved_args = [str(stardist_wrong_shape) if arg == "stardist_wrong_shape" else arg for arg in args] + resolved_args = [ + str(stardist_wrong_shape) if arg == "stardist_wrong_shape" else arg + for arg in args + ] ret = run_subprocess(["bioimageio", *resolved_args]) assert ret.returncode == 1, ret.stdout diff --git a/tests/test_prediction.py b/tests/test_prediction.py index a95eb3f4..2a6c4487 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -29,7 +29,9 @@ def test_predict_image(any_model: Path, tmpdir: Path): assert_array_almost_equal(res, exp, decimal=4) -def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not: Path, tmpdir: Path): +def test_predict_image_with_weight_format( + unet2d_fixed_shape_or_not: Path, tmpdir: Path +): from bioimageio.core.prediction import predict_image spec = load_description(unet2d_fixed_shape_or_not) @@ -37,7 +39,9 @@ def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not: Path, tmpdi inputs = spec.test_inputs outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))] - predict_image(unet2d_fixed_shape_or_not, inputs, outputs, weight_format="pytorch_state_dict") + predict_image( + unet2d_fixed_shape_or_not, inputs, outputs, weight_format="pytorch_state_dict" + ) for out_path in outputs: assert out_path.exists() @@ -54,7 +58,11 @@ def _test_predict_with_padding(any_model: Path, tmp_path: Path): assert isinstance(model, (ModelDescr_v0_4, ModelDescr)) input_spec, output_spec = model.inputs[0], model.outputs[0] - channel_axis = "c" if isinstance(input_spec, InputTensorDescr_v0_4) else [a.id for a in input_spec.axes][0] + channel_axis = ( + "c" + if isinstance(input_spec, InputTensorDescr_v0_4) + else [a.id for a in input_spec.axes][0] + ) channel_first = channel_axis == 1 # TODO: check more tensors @@ -78,14 +86,17 @@ def _test_predict_with_padding(any_model: Path, tmp_path: Path): scale = dict(zip(output_spec.axes, output_spec.shape.scale)) offset = dict(zip(output_spec.axes, output_spec.shape.offset)) spatial_axes = [ax for ax in output_spec.axes if ax in "xyz"] - network_resizes = any(sc != 1 for ax, sc in scale.items() if ax in spatial_axes) or any( - off != 0 for ax, off in offset.items() if ax in spatial_axes - ) + network_resizes = any( + sc != 1 for ax, sc in scale.items() if ax in spatial_axes + ) or any(off != 0 for ax, off in offset.items() if ax in spatial_axes) else: network_resizes = False if network_resizes: - exp_shape = tuple(int(sh * scale[ax] + 2 * offset[ax]) for sh, ax in zip(image.shape, spatial_axes)) + exp_shape = tuple( + int(sh * scale[ax] + 2 * offset[ax]) + for sh, ax in zip(image.shape, spatial_axes) + ) else: exp_shape = image.shape @@ -103,12 +114,17 @@ def check_result(): assert res.shape == exp_shape # test with dynamic padding - predict_image(any_model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"}) + predict_image( + any_model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"} + ) check_result() # test with fixed padding predict_image( - any_model, in_path, out_path, padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"} + any_model, + in_path, + out_path, + padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"}, ) check_result() @@ -124,7 +140,9 @@ def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path): # and with different output shape -def test_predict_image_with_padding_diff_output_shape(unet2d_diff_output_shape, tmp_path): +def test_predict_image_with_padding_diff_output_shape( + unet2d_diff_output_shape, tmp_path +): _test_predict_with_padding(unet2d_diff_output_shape, tmp_path) @@ -181,7 +199,9 @@ def test_predict_image_with_tiling_channel_last(stardist: Path, tmp_path: Path): _test_predict_image_with_tiling(stardist, tmp_path, 0.13) -def test_predict_image_with_tiling_fixed_output_shape(unet2d_fixed_shape: Path, tmp_path: Path): +def test_predict_image_with_tiling_fixed_output_shape( + unet2d_fixed_shape: Path, tmp_path: Path +): _test_predict_image_with_tiling(unet2d_fixed_shape, tmp_path, 0.025) diff --git a/tests/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py index 5347380a..ddc0b6d1 100644 --- a/tests/test_prediction_pipeline.py +++ b/tests/test_prediction_pipeline.py @@ -13,7 +13,9 @@ def _test_prediction_pipeline(model_package: Path, weights_format: WeightsFormat bio_model = load_description(model_package) assert isinstance(bio_model, (ModelDescr, ModelDescr04)) - pp = create_prediction_pipeline(bioimageio_model=bio_model, weight_format=weights_format) + pp = create_prediction_pipeline( + bioimageio_model=bio_model, weight_format=weights_format + ) inputs = get_test_inputs(bio_model) outputs = pp.forward(*inputs) diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py index 1236383a..4c7ac1a0 100644 --- a/tests/test_prediction_pipeline_device_management.py +++ b/tests/test_prediction_pipeline_device_management.py @@ -23,7 +23,9 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): bio_model = load_description(model_package) assert isinstance(bio_model, (ModelDescr, ModelDescr04)) - pred_pipe = create_prediction_pipeline(bioimageio_model=bio_model, weight_format=weight_format, devices=["cuda:0"]) + pred_pipe = create_prediction_pipeline( + bioimageio_model=bio_model, weight_format=weight_format, devices=["cuda:0"] + ) inputs = get_test_inputs(bio_model) with pred_pipe as pp: @@ -46,26 +48,36 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): assert_array_almost_equal(out, exp, decimal=4) -@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +@skip_on( + TooFewDevicesException, reason="Too few devices" +) # pyright: ignore[reportArgumentType] def test_device_management_torch(any_torch_model: Path): _test_device_management(any_torch_model, "pytorch_state_dict") -@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +@skip_on( + TooFewDevicesException, reason="Too few devices" +) # pyright: ignore[reportArgumentType] def test_device_management_torchscript(any_torchscript_model: Path): _test_device_management(any_torchscript_model, "torchscript") -@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +@skip_on( + TooFewDevicesException, reason="Too few devices" +) # pyright: ignore[reportArgumentType] def test_device_management_onnx(any_onnx_model: Path): _test_device_management(any_onnx_model, "onnx") -@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +@skip_on( + TooFewDevicesException, reason="Too few devices" +) # pyright: ignore[reportArgumentType] def test_device_management_tensorflow(any_tensorflow_model: Path): _test_device_management(any_tensorflow_model, "tensorflow_saved_model_bundle") -@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +@skip_on( + TooFewDevicesException, reason="Too few devices" +) # pyright: ignore[reportArgumentType] def test_device_management_keras(any_keras_model: Path): _test_device_management(any_keras_model, "keras_hdf5") diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py index f029517e..c3e5d51b 100644 --- a/tests/test_proc_ops.py +++ b/tests/test_proc_ops.py @@ -79,11 +79,20 @@ def test_zero_mean_unit_variance_fixed(tid: TensorId): from bioimageio.core.proc_ops import FixedZeroMeanUnitVariance op = FixedZeroMeanUnitVariance( - tid, tid, mean=xr.DataArray([1, 4, 7], dims=("y")), std=xr.DataArray([0.81650, 0.81650, 0.81650], dims=("y")) + tid, + tid, + mean=xr.DataArray([1, 4, 7], dims=("y")), + std=xr.DataArray([0.81650, 0.81650, 0.81650], dims=("y")), ) data = xr.DataArray(np.arange(9).reshape((1, 1, 3, 3)), dims=("b", "c", "x", "y")) expected = xr.DataArray( - np.array([[-1.224743, 0.0, 1.224743], [-1.224743, 0.0, 1.224743], [-1.224743, 0.0, 1.224743]])[None, None], + np.array( + [ + [-1.224743, 0.0, 1.224743], + [-1.224743, 0.0, 1.224743], + [-1.224743, 0.0, 1.224743], + ] + )[None, None], dims=("b", "c", "x", "y"), ) sample = Sample(data={tid: data}) @@ -96,7 +105,9 @@ def test_zero_mean_unit_across_axes(tid: TensorId): data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) - op = ZeroMeanUnitVariance(tid, tid, SampleMean(tid, (AxisId("c"),)), SampleStd(tid, (AxisId("c"),))) + op = ZeroMeanUnitVariance( + tid, tid, SampleMean(tid, (AxisId("c"),)), SampleStd(tid, (AxisId("c"),)) + ) sample = Sample(data={tid: data}) sample.stat = compute_measures(op.required_measures, [sample]) @@ -166,7 +177,9 @@ def test_clip(tid: TensorId): data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) sample = Sample(data={tid: data}) - expected = xr.DataArray(np.array([[3, 3, 3], [3, 4, 5], [5, 5, 5]]), dims=("x", "y")) + expected = xr.DataArray( + np.array([[3, 3, 3], [3, 4, 5], [5, 5, 5]]), dims=("x", "y") + ) op(sample) xr.testing.assert_equal(expected, sample.data[tid]) @@ -176,7 +189,9 @@ def test_combination_of_op_steps_with_dims_specified(tid: TensorId): data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) sample = Sample(data={tid: data}) - op = ZeroMeanUnitVariance(tid, tid, SampleMean(tid, (AxisId("c"),)), SampleStd(tid, (AxisId("c"),))) + op = ZeroMeanUnitVariance( + tid, tid, SampleMean(tid, (AxisId("c"),)), SampleStd(tid, (AxisId("c"),)) + ) sample.stat = compute_measures(op.required_measures, [sample]) expected = xr.DataArray( @@ -194,7 +209,15 @@ def test_combination_of_op_steps_with_dims_specified(tid: TensorId): xr.testing.assert_allclose(expected, sample.data[tid]) -@pytest.mark.parametrize("axes", [None, tuple(map(AxisId, "cy")), tuple(map(AxisId, "cyx")), tuple(map(AxisId, "x"))]) +@pytest.mark.parametrize( + "axes", + [ + None, + tuple(map(AxisId, "cy")), + tuple(map(AxisId, "cyx")), + tuple(map(AxisId, "x")), + ], +) def test_scale_mean_variance(tid: TensorId, axes: Optional[Tuple[AxisId, ...]]): from bioimageio.core.proc_ops import ScaleMeanVariance @@ -211,8 +234,13 @@ def test_scale_mean_variance(tid: TensorId, axes: Optional[Tuple[AxisId, ...]]): xr.testing.assert_allclose(ref_data, sample.data[tid]) -@pytest.mark.parametrize("axes", [None, tuple(map(AxisId, "cy")), tuple(map(AxisId, "y")), tuple(map(AxisId, "yx"))]) -def test_scale_mean_variance_per_channel(tid: TensorId, axes: Optional[Tuple[AxisId, ...]]): +@pytest.mark.parametrize( + "axes", + [None, tuple(map(AxisId, "cy")), tuple(map(AxisId, "y")), tuple(map(AxisId, "yx"))], +) +def test_scale_mean_variance_per_channel( + tid: TensorId, axes: Optional[Tuple[AxisId, ...]] +): from bioimageio.core.proc_ops import ScaleMeanVariance shape = (3, 32, 46) diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index 7e8581a9..4287b0db 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -7,19 +7,30 @@ from bioimageio.core import stat_measures from bioimageio.core.common import AxisId, Sample, TensorId -from bioimageio.core.stat_calculators import SamplePercentilesCalculator, get_measure_calculators +from bioimageio.core.stat_calculators import ( + SamplePercentilesCalculator, + get_measure_calculators, +) from bioimageio.core.stat_measures import SamplePercentile @pytest.mark.parametrize( "name,sample_or_dataset,axes", - product(["mean", "var", "std"], ["Sample", "Dataset"], [None, (AxisId("x"), AxisId("y"))]), + product( + ["mean", "var", "std"], + ["Sample", "Dataset"], + [None, (AxisId("x"), AxisId("y"))], + ), ) def test_individual_normal_measure( - name: str, sample_or_dataset: Literal["Sample", "Dataset"], axes: Optional[Tuple[AxisId, AxisId]] + name: str, + sample_or_dataset: Literal["Sample", "Dataset"], + axes: Optional[Tuple[AxisId, AxisId]], ): data_id = TensorId("test_data") - measure = getattr(stat_measures, sample_or_dataset + name.title())(axes=axes, tensor_id=data_id) + measure = getattr(stat_measures, sample_or_dataset + name.title())( + axes=axes, tensor_id=data_id + ) data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) expected = getattr(data, name)(dim=axes) diff --git a/tests/utils/test_image_helper.py b/tests/utils/test_image_helper.py index a0186c78..9e17be3f 100644 --- a/tests/utils/test_image_helper.py +++ b/tests/utils/test_image_helper.py @@ -9,7 +9,11 @@ @pytest.mark.parametrize( - "axes", [[AxisId(a) for a in axes] for axes in ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"]] + "axes", + [ + [AxisId(a) for a in axes] + for axes in ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] + ], ) def test_transpose_tensor_2d(axes: Sequence[AxisId]): from bioimageio.core.utils.image_helper import transpose_tensor @@ -20,7 +24,11 @@ def test_transpose_tensor_2d(axes: Sequence[AxisId]): @pytest.mark.parametrize( - "axes", [[AxisId(a) for a in axes] for axes in ["zyx", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"]] + "axes", + [ + [AxisId(a) for a in axes] + for axes in ["zyx", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] + ], ) def test_transpose_tensor_3d(axes: Sequence[AxisId]): from bioimageio.core.utils.image_helper import transpose_tensor diff --git a/tests/weight_converter/keras/test_tensorflow.py b/tests/weight_converter/keras/test_tensorflow.py index 069b6f23..65c93f60 100644 --- a/tests/weight_converter/keras/test_tensorflow.py +++ b/tests/weight_converter/keras/test_tensorflow.py @@ -8,9 +8,13 @@ from bioimageio.spec.model.v0_5 import ModelDescr -@pytest.mark.skip("tensorflow converter not updated yet") # TODO: test tensorflow converter +@pytest.mark.skip( + "tensorflow converter not updated yet" +) # TODO: test tensorflow converter def test_tensorflow_converter(any_keras_model: Path, tmp_path: Path): - from bioimageio.core.weight_converter.keras import convert_weights_to_tensorflow_saved_model_bundle + from bioimageio.core.weight_converter.keras import ( + convert_weights_to_tensorflow_saved_model_bundle, + ) out_path = tmp_path / "weights" model = load_description(any_keras_model) @@ -19,19 +23,27 @@ def test_tensorflow_converter(any_keras_model: Path, tmp_path: Path): assert out_path.exists() assert (out_path / "variables").exists() assert (out_path / "saved_model.pb").exists() - assert ret_val == 0 # check for correctness is done in converter and returns 0 if it passes + assert ( + ret_val == 0 + ) # check for correctness is done in converter and returns 0 if it passes -@pytest.mark.skip("tensorflow converter not updated yet") # TODO: test tensorflow converter +@pytest.mark.skip( + "tensorflow converter not updated yet" +) # TODO: test tensorflow converter def test_tensorflow_converter_zipped(any_keras_model: Path, tmp_path: Path): - from bioimageio.core.weight_converter.keras import convert_weights_to_tensorflow_saved_model_bundle + from bioimageio.core.weight_converter.keras import ( + convert_weights_to_tensorflow_saved_model_bundle, + ) out_path = tmp_path / "weights.zip" model = load_description(any_keras_model) assert isinstance(model, ModelDescr), model.validation_summary.format() ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path) assert out_path.exists() - assert ret_val == 0 # check for correctness is done in converter and returns 0 if it passes + assert ( + ret_val == 0 + ) # check for correctness is done in converter and returns 0 if it passes # make sure that the zip package was created correctly expected_names = {"saved_model.pb", "variables/variables.index"} diff --git a/tests/weight_converter/torch/test_onnx.py b/tests/weight_converter/torch/test_onnx.py index a0315650..54f2cdf4 100644 --- a/tests/weight_converter/torch/test_onnx.py +++ b/tests/weight_converter/torch/test_onnx.py @@ -13,4 +13,6 @@ def test_onnx_converter(convert_to_onnx: Path, tmp_path: Path): ret_val = convert_weights_to_onnx(convert_to_onnx, out_path, test_decimal=3) assert os.path.exists(out_path) if not pytest.skip_onnx: - assert ret_val == 0 # check for correctness is done in converter and returns 0 if it passes + assert ( + ret_val == 0 + ) # check for correctness is done in converter and returns 0 if it passes diff --git a/tests/weight_converter/torch/test_torchscript.py b/tests/weight_converter/torch/test_torchscript.py index 945e778b..e0cee3d8 100644 --- a/tests/weight_converter/torch/test_torchscript.py +++ b/tests/weight_converter/torch/test_torchscript.py @@ -6,11 +6,17 @@ from bioimageio.spec.model import v0_4, v0_5 -@pytest.mark.skip("torchscript converter not updated yet") # TODO: test torchscript converter -def test_torchscript_converter(any_torch_model: "v0_4.ModelDescr | v0_5.ModelDescr", tmp_path: Path): +@pytest.mark.skip( + "torchscript converter not updated yet" +) # TODO: test torchscript converter +def test_torchscript_converter( + any_torch_model: "v0_4.ModelDescr | v0_5.ModelDescr", tmp_path: Path +): from bioimageio.core.weight_converter.torch import convert_weights_to_torchscript out_path = tmp_path / "weights.pt" ret_val = convert_weights_to_torchscript(any_torch_model, out_path) assert out_path.exists() - assert ret_val == 0 # check for correctness is done in converter and returns 0 if it passes + assert ( + ret_val == 0 + ) # check for correctness is done in converter and returns 0 if it passes From 8ef28ffe2081fa2e895a71a5cfb0f83fb96378b4 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 13:48:19 +0100 Subject: [PATCH 118/244] update pre-commit config --- .pre-commit-config.yaml | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index faeee4a4..ef0eba58 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,19 @@ repos: - repo: https://github.com/ambv/black - rev: 23.7.0 + rev: 24.2.0 hooks: - id: black-jupyter - - repo: https://github.com/pycqa/isort - rev: 5.12.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.2 hooks: - - id: isort - name: isort + - id: ruff + args: [--fix] + - repo: local + hooks: + - id: pyright + name: pyright + entry: pyright + language: system + always_run: true + pass_filenames: true + files: ^.*\.py$ From b29789236dc9aab7c8a874398adc7fb73b4c9963 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 14:34:41 +0100 Subject: [PATCH 119/244] houskeeping --- .github/workflows/{build.yml => build.yaml} | 4 ++-- setup.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) rename .github/workflows/{build.yml => build.yaml} (98%) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yaml similarity index 98% rename from .github/workflows/build.yml rename to .github/workflows/build.yaml index 3495e794..e6e1fc88 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yaml @@ -26,7 +26,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: [3.8, 3.9, 3.10, 3.11, 3.12] steps: - uses: actions/checkout@v3 - name: Install Conda environment with Micromamba @@ -46,7 +46,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: [3.8, 3.12] steps: - uses: actions/checkout@v3 - name: Install Conda environment with Micromamba diff --git a/setup.py b/setup.py index 0024c735..9665a2a1 100644 --- a/setup.py +++ b/setup.py @@ -25,11 +25,13 @@ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], packages=find_namespace_packages(exclude=["tests"]), install_requires=[ "bioimageio.spec==0.5.0.*", "imageio>=2.5", + "loguru", "numpy", "ruyaml", "tifffile", From 534534618b466695cf03f2103bf299d1cd502b5b Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 15:48:05 +0100 Subject: [PATCH 120/244] update _import_callable.py --- bioimageio/core/utils/_import_callable.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/bioimageio/core/utils/_import_callable.py b/bioimageio/core/utils/_import_callable.py index a60259d9..166d4014 100644 --- a/bioimageio/core/utils/_import_callable.py +++ b/bioimageio/core/utils/_import_callable.py @@ -23,7 +23,7 @@ def import_callable(node: type, /) -> Callable[..., Any]: @import_callable.register def import_from_dependency04(node: CallableFromDepencency) -> Callable[..., Any]: module = importlib.import_module(node.module_name) - c = getattr(module, node.callable_name) + c = getattr(module, str(node.callable_name)) if not callable(c): raise ValueError(f"{node} (imported: {c}) is not callable") @@ -33,7 +33,7 @@ def import_from_dependency04(node: CallableFromDepencency) -> Callable[..., Any] @import_callable.register def import_from_dependency05(node: ArchitectureFromLibraryDescr) -> Callable[..., Any]: module = importlib.import_module(node.import_from) - c = getattr(module, node.callable) + c = getattr(module, str(node.callable)) if not callable(c): raise ValueError(f"{node} (imported: {c}) is not callable") @@ -42,22 +42,18 @@ def import_from_dependency05(node: ArchitectureFromLibraryDescr) -> Callable[... @import_callable.register def import_from_file04(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): - return _import_from_file_impl(node.file, node.callable_name, **kwargs) + return _import_from_file_impl(node.file, str(node.callable_name), **kwargs) @import_callable.register def import_from_file05(node: ArchitectureFromFileDescr, **kwargs: Unpack[HashKwargs]): - return _import_from_file_impl(node.source, node.callable, sha256=node.sha256) + return _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) -def _import_from_file_impl( - source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] -): +def _import_from_file_impl(source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]): local_file = download(source, **kwargs) module_name = local_file.path.stem - importlib_spec = importlib.util.spec_from_file_location( - module_name, local_file.path - ) + importlib_spec = importlib.util.spec_from_file_location(module_name, local_file.path) if importlib_spec is None: raise ImportError(f"Failed to import {module_name} from {source}.") From 63eaae048b2f153926f9250dac85b8ad646edb23 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 16:10:16 +0100 Subject: [PATCH 121/244] avoid implicit string concatenation --- bioimageio/core/model_adapters/_model_adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index acedc122..b9e913be 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -117,8 +117,8 @@ def create( error_msg = "" raise ValueError( - f"None of the weight formats {weight_format_priority_order} is supported for {model_description.name} " - f"in this environment.{error_msg}" + f"None of the weight formats {weight_format_priority_order} is " + + f"supported for {model_description.name} in this environment.{error_msg}" ) @final From 185b7e8199852391cd80384829f4f515df3ec13e Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 16:23:08 +0100 Subject: [PATCH 122/244] properly export __version__ --- bioimageio/core/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 9bd8324d..780e5cc3 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -2,8 +2,6 @@ .. include:: ../../README.md """ -import json - from bioimageio.spec import build_description as build_description from bioimageio.spec import dump_description as dump_description from bioimageio.spec import load_description as load_description @@ -24,6 +22,8 @@ from ._resource_tests import load_description_and_test as load_description_and_test from ._resource_tests import test_description as test_description from ._resource_tests import test_model as test_model -from .utils import VERSION as __version__ +from .utils import VERSION + +__version__ = VERSION test_resource = test_description From 92178e4250eb128bfd9c36a7006222022e0a13df Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 16:32:59 +0100 Subject: [PATCH 123/244] update dev envs --- dev/env.yaml | 11 +++++++---- dev/environment-base.yaml | 20 +++++++++++++------- dev/environment-tf-legacy.yaml | 12 +++++++++--- dev/environment-tf.yaml | 12 +++++++++--- dev/environment-torch.yaml | 17 +++++++++++------ 5 files changed, 49 insertions(+), 23 deletions(-) diff --git a/dev/env.yaml b/dev/env.yaml index 7ee582f5..78ed1f8d 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -4,24 +4,27 @@ channels: - defaults dependencies: - annotated-types + - bioimageio.spec==0.5.* - black - deepdiff - email-validator - - imageio[version='>=2.5'] + - filelock + - imageio>=2.5 + - loguru - lxml - numpy - onnxruntime - - packaging[version='>=17.0'] + - packaging>=17.0 - pooch - pre-commit - - pydantic[version='>=2.3.0'] + - pydantic>=2.6.4 - pyright - pytest - python-dateutil - python=3.8 - pytorch - - ruyaml - ruff + - ruyaml - torchvision - tqdm - typer diff --git a/dev/environment-base.yaml b/dev/environment-base.yaml index 88a336ee..96a96d91 100644 --- a/dev/environment-base.yaml +++ b/dev/environment-base.yaml @@ -3,20 +3,26 @@ channels: - conda-forge - defaults dependencies: + - bioimageio.spec==0.5.* - black - - bioimageio.spec - conda-build + - filelock - h5py >=2.10,<2.11 + - loguru - mypy + - onnx + - onnxruntime - pip - pre-commit + - psutil - pytest - - python >=3.7,<3.8 # this environment is only available for python 3.7 - - xarray + - pytest-xdist + - python >=3.7,<3.8 # this environment is only available for python 3.7 - pytorch - - onnx - - onnxruntime + - ruyaml - tensorflow >=1.12,<2.0 - - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 + - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 + - typer + - xarray - pip: - - keras==1.2.2 + - keras==1.2.2 diff --git a/dev/environment-tf-legacy.yaml b/dev/environment-tf-legacy.yaml index 976ea3d6..d1bbbc28 100644 --- a/dev/environment-tf-legacy.yaml +++ b/dev/environment-tf-legacy.yaml @@ -4,14 +4,20 @@ channels: - defaults dependencies: - black - - bioimageio.spec + - bioimageio.spec==0.5.* - conda-build - h5py >=2.10,<2.11 - mypy - pip - pytest - - python >=3.7,<3.8 # this environment is only available for python 3.7 + - pytest-xdist + - filelock + - psutil + - python >=3.7,<3.8 # this environment is only available for python 3.7 - xarray - tensorflow >1.14,<2.0 - - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 + - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 - keras + - ruyaml + - typer + - loguru diff --git a/dev/environment-tf.yaml b/dev/environment-tf.yaml index 4ecd57d8..03c6b08b 100644 --- a/dev/environment-tf.yaml +++ b/dev/environment-tf.yaml @@ -3,13 +3,19 @@ channels: - conda-forge - defaults dependencies: + - bioimageio.spec==0.5.* - black - - bioimageio.spec - conda-build + - filelock + - loguru - mypy - pip + - psutil - pytest + - pytest-xdist - python - - xarray + - ruyaml - tensorflow >=2.9,<3.0 - - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 + - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 + - typer + - xarray diff --git a/dev/environment-torch.yaml b/dev/environment-torch.yaml index 98a944cd..c97a8225 100644 --- a/dev/environment-torch.yaml +++ b/dev/environment-torch.yaml @@ -3,17 +3,22 @@ channels: - conda-forge - defaults dependencies: + - bioimageio.spec==0.5.* - black - - bioimageio.spec >=0.4.4 - conda-build + - filelock - h5py + - loguru - mypy + - onnx + - onnxruntime - pip + - psutil - pytest + - pytest-xdist - python >=3.7 - - xarray - pytorch - - onnx - - onnxruntime - - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 - + - ruyaml + - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 + - typer + - xarray From ff7cca3c955a2ec63eb6d867cef09c7ebd99b2a5 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 16:48:12 +0100 Subject: [PATCH 124/244] remove unsupported tf-legacy env --- .github/workflows/build.yaml | 25 ----------------------- dev/environment-tf-legacy.yaml | 23 --------------------- dev/environment-torch.yaml | 2 +- example/model_creation.ipynb.needs_update | 8 ++++---- 4 files changed, 5 insertions(+), 53 deletions(-) delete mode 100644 dev/environment-tf-legacy.yaml diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index e6e1fc88..945912ba 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -90,31 +90,6 @@ jobs: - name: pytest-spec-tf run: pytest --disable-pytest-warnings - test-spec-tf-legacy: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.7] - steps: - - uses: actions/checkout@v3 - - name: Install Conda environment with Micromamba - uses: mamba-org/setup-micromamba@v1 - with: - cache-downloads: true - cache-environment: true - environment-file: dev/environment-tf-legacy.yaml - condarc: | - channel_priority: flexible - create-args: | - python=${{ matrix.python-version }} - - name: additional setup - run: | - conda remove --yes --force bioimageio.spec || true # allow failure for cached env - pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io - pip install --no-deps -e . - - name: pytest-spec-tf-legacy - run: pytest --disable-pytest-warnings - conda-build: runs-on: ubuntu-latest needs: test-spec-conda diff --git a/dev/environment-tf-legacy.yaml b/dev/environment-tf-legacy.yaml deleted file mode 100644 index d1bbbc28..00000000 --- a/dev/environment-tf-legacy.yaml +++ /dev/null @@ -1,23 +0,0 @@ -name: bio-core-tf-legacy -channels: - - conda-forge - - defaults -dependencies: - - black - - bioimageio.spec==0.5.* - - conda-build - - h5py >=2.10,<2.11 - - mypy - - pip - - pytest - - pytest-xdist - - filelock - - psutil - - python >=3.7,<3.8 # this environment is only available for python 3.7 - - xarray - - tensorflow >1.14,<2.0 - - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 - - keras - - ruyaml - - typer - - loguru diff --git a/dev/environment-torch.yaml b/dev/environment-torch.yaml index c97a8225..d5809082 100644 --- a/dev/environment-torch.yaml +++ b/dev/environment-torch.yaml @@ -16,7 +16,7 @@ dependencies: - psutil - pytest - pytest-xdist - - python >=3.7 + - python >=3.8 - pytorch - ruyaml - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 diff --git a/example/model_creation.ipynb.needs_update b/example/model_creation.ipynb.needs_update index dc81c52e..e45714e7 100644 --- a/example/model_creation.ipynb.needs_update +++ b/example/model_creation.ipynb.needs_update @@ -162,7 +162,7 @@ "# it will output a list of dictionaries. each dict gives the status of a different test that is being run\n", "# if all of them contain \"status\": \"passed\" then all tests were successful\n", "from bioimageio.core.resource_tests import test_model\n", - "my_model = bioimageio.core.load_resource_description(\"my-model/model.zip\") \n", + "my_model = bioimageio.core.load_resource_description(\"my-model/model.zip\")\n", "test_model(my_model)" ] }, @@ -272,7 +272,7 @@ "zip_path = os.path.join(model_root, f\"{name}.zip\")\n", "\n", "# `build_model` needs some additional information about the model, like citation information\n", - "# all this additional information is passed as plain python types and will be converted into the bioimageio representation internally \n", + "# all this additional information is passed as plain python types and will be converted into the bioimageio representation internally\n", "# for more informantion, check out the function signature\n", "# https://github.com/bioimage-io/core-bioimage-io-python/blob/main/bioimageio/core/build_spec/build_model.py#L252\n", "cite = [{\"text\": cite_entry.text, \"url\": cite_entry.url} for cite_entry in model_resource.cite]\n", @@ -388,7 +388,7 @@ "# the path to save the new model with torchscript weights\n", "temp_zip_path = f\"{model_root}/new_model3.zip\"\n", "\n", - "_ = build_model( \n", + "_ = build_model(\n", " weight_uri=weight_file,\n", " weight_type=\"pytorch_state_dict\",\n", " architecture=model_source,\n", @@ -494,7 +494,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.8.17" } }, "nbformat": 4, From e18b1a6145e534a5bf44a0cb7b729607172ecf9e Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 12 Mar 2024 17:18:44 +0100 Subject: [PATCH 125/244] fix tests --- bioimageio/core/proc_ops.py | 2 +- bioimageio/core/utils/_import_callable.py | 10 +++++++--- bioimageio/core/utils/image_helper.py | 9 +++------ tests/utils/test_image_helper.py | 17 ++++++++++++----- 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 080b988f..e26a9c98 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -152,7 +152,7 @@ def required_measures(self) -> Set[Measure]: return set() def __post_init__(self): - self._keep_updating_initial_dataset_stats = ( + self._keep_updating_dataset_stats = ( self.keep_updating_initial_dataset_stats or not self.stats_calculator.has_dataset_measures ) diff --git a/bioimageio/core/utils/_import_callable.py b/bioimageio/core/utils/_import_callable.py index 166d4014..3e1569b7 100644 --- a/bioimageio/core/utils/_import_callable.py +++ b/bioimageio/core/utils/_import_callable.py @@ -42,7 +42,7 @@ def import_from_dependency05(node: ArchitectureFromLibraryDescr) -> Callable[... @import_callable.register def import_from_file04(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): - return _import_from_file_impl(node.file, str(node.callable_name), **kwargs) + return _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) @import_callable.register @@ -50,10 +50,14 @@ def import_from_file05(node: ArchitectureFromFileDescr, **kwargs: Unpack[HashKwa return _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) -def _import_from_file_impl(source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]): +def _import_from_file_impl( + source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] +): local_file = download(source, **kwargs) module_name = local_file.path.stem - importlib_spec = importlib.util.spec_from_file_location(module_name, local_file.path) + importlib_spec = importlib.util.spec_from_file_location( + module_name, local_file.path + ) if importlib_spec is None: raise ImportError(f"Failed to import {module_name} from {source}.") diff --git a/bioimageio/core/utils/image_helper.py b/bioimageio/core/utils/image_helper.py index 444e0fc9..8b56000f 100644 --- a/bioimageio/core/utils/image_helper.py +++ b/bioimageio/core/utils/image_helper.py @@ -97,11 +97,8 @@ def interprete_array( f"Could not guess an axis mapping for {nd_array.shape} with {n_expected_space_axes} expected space axes" ) - current_axes_ids = ( - tuple(current_axes) - if isinstance(current_axes, str) - else tuple(a.id for a in current_axes) - ) + current_axes_ids = tuple(str(a.id) for a in current_axes) + return Tensor(nd_array, dims=current_axes_ids) @@ -130,7 +127,7 @@ def transpose_tensor( missing_axes = tuple(str(a) for a in axes if a not in current_axes) tensor = tensor.expand_dims(missing_axes) # transpose to the correct axis order - return tensor.transpose(*axes) + return tensor.transpose(*map(str, axes)) def convert_v0_4_axes_for_known_shape(axes: v0_4.AxesStr, shape: Sequence[int]): diff --git a/tests/utils/test_image_helper.py b/tests/utils/test_image_helper.py index 9e17be3f..07f8938a 100644 --- a/tests/utils/test_image_helper.py +++ b/tests/utils/test_image_helper.py @@ -3,9 +3,15 @@ import numpy as np import pytest import xarray as xr +from xarray.testing import assert_equal # pyright: ignore[reportUnknownVariableType] from bioimageio.core.common import AxisId -from bioimageio.core.utils.image_helper import interprete_array +from bioimageio.core.utils.image_helper import ( + crop_to, + interprete_array, + pad, + transpose_tensor, +) @pytest.mark.parametrize( @@ -16,7 +22,6 @@ ], ) def test_transpose_tensor_2d(axes: Sequence[AxisId]): - from bioimageio.core.utils.image_helper import transpose_tensor tensor = interprete_array(np.random.rand(256, 256), len(axes)) transposed = transpose_tensor(tensor, axes) @@ -31,15 +36,17 @@ def test_transpose_tensor_2d(axes: Sequence[AxisId]): ], ) def test_transpose_tensor_3d(axes: Sequence[AxisId]): - from bioimageio.core.utils.image_helper import transpose_tensor - tensor = interprete_array(np.random.rand(64, 64, 64), len(axes)) transposed = transpose_tensor(tensor, axes) assert transposed.ndim == len(axes) def test_crop_and_pad(): - tensor = xr.DataArray(np.random.rand(64)) + tensor = xr.DataArray(np.random.rand(10, 20), dims=("x", "y")) + sizes = {AxisId(str(k)): v for k, v in tensor.sizes.items()} + padded = pad(tensor, {AxisId("x"): 7, AxisId("y"): (3, 3)}) + cropped = crop_to(padded, sizes) + assert_equal(tensor, cropped) # def test_transform_output_tensor(): From d0a6cdcee262ed62c048c5f1aac905b97ce0a0b4 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 13 Mar 2024 16:25:41 +0100 Subject: [PATCH 126/244] expose input_ids and output_ids --- bioimageio/core/_prediction_pipeline.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index b2cf998f..a596c790 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -39,11 +39,11 @@ def __init__( self._preprocessing = preprocessing self._postprocessing = postprocessing if isinstance(bioimageio_model, v0_4.ModelDescr): - self._input_ids = [TensorId(str(d.name)) for d in bioimageio_model.inputs] - self._output_ids = [TensorId(str(d.name)) for d in bioimageio_model.outputs] + self.input_ids = [TensorId(str(d.name)) for d in bioimageio_model.inputs] + self.output_ids = [TensorId(str(d.name)) for d in bioimageio_model.outputs] else: - self._input_ids = [d.id for d in bioimageio_model.inputs] - self._output_ids = [d.id for d in bioimageio_model.outputs] + self.input_ids = [d.id for d in bioimageio_model.inputs] + self.output_ids = [d.id for d in bioimageio_model.outputs] self._adapter: ModelAdapter = model @@ -65,7 +65,7 @@ def predict( ) -> List[xr.DataArray]: """Predict input_tensor with the model without applying pre/postprocessing.""" named_tensors = [ - named_input_tensors[str(k)] for k in self._input_ids[len(input_tensors) :] + named_input_tensors[str(k)] for k in self.input_ids[len(input_tensors) :] ] return self._adapter.forward(*input_tensors, *named_tensors) @@ -87,7 +87,7 @@ def forward_sample(self, input_sample: Sample) -> Sample: **{str(k): v for k, v in input_sample.data.items()} ) prediction = Sample( - data=dict(zip(self._output_ids, prediction_tensors)), stat=input_sample.stat + data=dict(zip(self.output_ids, prediction_tensors)), stat=input_sample.stat ) self.apply_postprocessing(prediction) return prediction @@ -98,7 +98,7 @@ def forward_tensors( """Apply preprocessing, run prediction and apply postprocessing.""" input_sample = Sample( data={ - **dict(zip(self._input_ids, input_tensors)), + **dict(zip(self.input_ids, input_tensors)), **{TensorId(k): v for k, v in named_input_tensors.items()}, } ) @@ -109,7 +109,7 @@ def forward( ) -> List[xr.DataArray]: """Apply preprocessing, run prediction and apply postprocessing.""" named_outputs = self.forward_tensors(*input_tensors, **named_input_tensors) - return [named_outputs[x] for x in self._output_ids] + return [named_outputs[x] for x in self.output_ids] def load(self): """ From d36fff04b1eec2e75865d24133a833ed5cfd9105 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 13 Mar 2024 23:45:28 +0100 Subject: [PATCH 127/244] updates around AxisId --- bioimageio/core/common.py | 6 +++- bioimageio/core/proc_ops.py | 6 +++- bioimageio/core/stat_measures.py | 12 ++++--- bioimageio/core/utils/image_helper.py | 1 - tests/test_proc_ops.py | 47 +++++++++++++-------------- tests/test_stat_measures.py | 8 ++--- tests/utils/test_image_helper.py | 20 ++++-------- 7 files changed, 51 insertions(+), 49 deletions(-) diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index f6684169..635af235 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, Literal +from typing import TYPE_CHECKING, Dict, Iterable, Literal import xarray as xr @@ -21,6 +21,7 @@ class Axis: BatchSize = int Tensor = xr.DataArray + Data = Dict[TensorId, Tensor] Stat = Dict["Measure", "MeasureValue"] @@ -34,3 +35,6 @@ class Sample: stat: Stat = field(default_factory=dict) """sample and dataset statistics""" + + +Dataset = Iterable[Sample] diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index e26a9c98..6a262639 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -400,7 +400,11 @@ def from_proc_descr( tensor_id: TensorId, ): kwargs = descr.kwargs - ref_tensor = cast(TensorId, kwargs.reference_tensor) or tensor_id + ref_tensor = ( + tensor_id + if kwargs.reference_tensor is None + else TensorId(str(kwargs.reference_tensor)) + ) axes = _get_axes(descr.kwargs) if axes is None or AxisId("batch") in axes: Percentile = DatasetPercentile diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index 99329498..5c599af2 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -40,7 +40,8 @@ class SampleMean(_Mean, SampleMeasureBase): """The mean value of a single tensor""" def compute(self, sample: Sample) -> MeasureValue: - return sample.data[self.tensor_id].mean(dim=self.axes) + tensor = sample.data[self.tensor_id] + return tensor.mean(dim=self.axes) def __post_init__(self): assert self.axes is None or AxisId("batch") not in self.axes @@ -65,7 +66,8 @@ class SampleStd(_Std, SampleMeasureBase): """The standard deviation of a single tensor""" def compute(self, sample: Sample) -> MeasureValue: - return sample.data[self.tensor_id].std(dim=self.axes) + tensor = sample.data[self.tensor_id] + return tensor.std(dim=self.axes) def __post_init__(self): assert self.axes is None or AxisId("batch") not in self.axes @@ -90,7 +92,8 @@ class SampleVar(_Var, SampleMeasureBase): """The variance of a single tensor""" def compute(self, sample: Sample) -> MeasureValue: - return sample.data[self.tensor_id].var(dim=self.axes) + tensor = sample.data[self.tensor_id] + return tensor.var(dim=self.axes) def __post_init__(self): assert self.axes is None or AxisId("batch") not in self.axes @@ -120,7 +123,8 @@ class SamplePercentile(_Percentile, SampleMeasureBase): """The `n`th percentile of a single tensor""" def compute(self, sample: Sample) -> MeasureValue: - return sample.data[self.tensor_id].quantile(self.n / 100.0, dim=self.axes) + tensor = sample.data[self.tensor_id] + return tensor.quantile(self.n / 100.0, dim=self.axes) def __post_init__(self): super().__post_init__() diff --git a/bioimageio/core/utils/image_helper.py b/bioimageio/core/utils/image_helper.py index 8b56000f..2ad325e0 100644 --- a/bioimageio/core/utils/image_helper.py +++ b/bioimageio/core/utils/image_helper.py @@ -121,7 +121,6 @@ def transpose_tensor( tensor: the input array axes: the desired array axes """ - # expand the missing image axes current_axes = tuple(AxisId(str(d)) for d in tensor.dims) missing_axes = tuple(str(a) for a in axes if a not in current_axes) diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py index c3e5d51b..24dd3379 100644 --- a/tests/test_proc_ops.py +++ b/tests/test_proc_ops.py @@ -81,19 +81,21 @@ def test_zero_mean_unit_variance_fixed(tid: TensorId): op = FixedZeroMeanUnitVariance( tid, tid, - mean=xr.DataArray([1, 4, 7], dims=("y")), - std=xr.DataArray([0.81650, 0.81650, 0.81650], dims=("y")), + mean=xr.DataArray([3, 4, 5], dims=("c")), + std=xr.DataArray([2.44948974, 2.44948974, 2.44948974], dims=("c")), ) - data = xr.DataArray(np.arange(9).reshape((1, 1, 3, 3)), dims=("b", "c", "x", "y")) + data = xr.DataArray(np.arange(9).reshape((1, 3, 3)), dims=("b", "c", "x")) expected = xr.DataArray( np.array( [ - [-1.224743, 0.0, 1.224743], - [-1.224743, 0.0, 1.224743], - [-1.224743, 0.0, 1.224743], + [ + [-1.22474487, -0.81649658, -0.40824829], + [-0.40824829, 0.0, 0.40824829], + [0.40824829, 0.81649658, 1.22474487], + ] ] - )[None, None], - dims=("b", "c", "x", "y"), + ), + dims=("b", "c", "x"), ) sample = Sample(data={tid: data}) op(sample) @@ -106,20 +108,17 @@ def test_zero_mean_unit_across_axes(tid: TensorId): data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) op = ZeroMeanUnitVariance( - tid, tid, SampleMean(tid, (AxisId("c"),)), SampleStd(tid, (AxisId("c"),)) + tid, + tid, + SampleMean(tid, (AxisId("x"), AxisId("y"))), + SampleStd(tid, (AxisId("x"), AxisId("y"))), ) sample = Sample(data={tid: data}) sample.stat = compute_measures(op.required_measures, [sample]) expected = xr.DataArray( - np.array( - [ - [-1.54919274, -1.16189455, -0.77459637], - [-0.38729818, 0.0, 0.38729818], - [0.77459637, 1.16189455, 1.54919274], - ] - ), - dims=("x", "y"), + np.array([]), + dims=("c", "x", "y"), ) op(sample) xr.testing.assert_allclose(expected, sample.data[tid]) @@ -235,14 +234,14 @@ def test_scale_mean_variance(tid: TensorId, axes: Optional[Tuple[AxisId, ...]]): @pytest.mark.parametrize( - "axes", - [None, tuple(map(AxisId, "cy")), tuple(map(AxisId, "y")), tuple(map(AxisId, "yx"))], + "axes_str", + [None, "cy", "y", "yx"], ) -def test_scale_mean_variance_per_channel( - tid: TensorId, axes: Optional[Tuple[AxisId, ...]] -): +def test_scale_mean_variance_per_channel(tid: TensorId, axes_str: Optional[str]): from bioimageio.core.proc_ops import ScaleMeanVariance + axes = None if axes_str is None else tuple(map(AxisId, axes_str)) + shape = (3, 32, 46) ipt_axes = ("c", "y", "x") np_data = np.random.rand(*shape) @@ -288,8 +287,8 @@ def test_scale_range(tid: TensorId): def test_scale_range_axes(tid: TensorId): from bioimageio.core.proc_ops import ScaleRange - lower_percentile = SamplePercentile(tid, 1, axes=(AxisId("c"),)) - upper_percentile = SamplePercentile(tid, 100, axes=(AxisId("c"),)) + lower_percentile = SamplePercentile(tid, 1, axes=(AxisId("x"), AxisId("y"))) + upper_percentile = SamplePercentile(tid, 100, axes=(AxisId("x"), AxisId("y"))) op = ScaleRange(tid, tid, lower_percentile, upper_percentile) np_data = np.arange(18).reshape((2, 3, 3)).astype("float32") diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index 4287b0db..ea8774b2 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -15,20 +15,18 @@ @pytest.mark.parametrize( - "name,sample_or_dataset,axes", + "name,axes", product( ["mean", "var", "std"], - ["Sample", "Dataset"], - [None, (AxisId("x"), AxisId("y"))], + [None, (AxisId("c"),), (AxisId("x"), AxisId("y"))], ), ) def test_individual_normal_measure( name: str, - sample_or_dataset: Literal["Sample", "Dataset"], axes: Optional[Tuple[AxisId, AxisId]], ): data_id = TensorId("test_data") - measure = getattr(stat_measures, sample_or_dataset + name.title())( + measure = getattr(stat_measures, "Sample" + name.title())( axes=axes, tensor_id=data_id ) data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) diff --git a/tests/utils/test_image_helper.py b/tests/utils/test_image_helper.py index 07f8938a..d51f186a 100644 --- a/tests/utils/test_image_helper.py +++ b/tests/utils/test_image_helper.py @@ -16,28 +16,22 @@ @pytest.mark.parametrize( "axes", - [ - [AxisId(a) for a in axes] - for axes in ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] - ], + ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"], ) -def test_transpose_tensor_2d(axes: Sequence[AxisId]): +def test_transpose_tensor_2d(axes: str): tensor = interprete_array(np.random.rand(256, 256), len(axes)) - transposed = transpose_tensor(tensor, axes) + transposed = transpose_tensor(tensor, [AxisId(a) for a in axes]) assert transposed.ndim == len(axes) @pytest.mark.parametrize( "axes", - [ - [AxisId(a) for a in axes] - for axes in ["zyx", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] - ], + ["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"], ) -def test_transpose_tensor_3d(axes: Sequence[AxisId]): - tensor = interprete_array(np.random.rand(64, 64, 64), len(axes)) - transposed = transpose_tensor(tensor, axes) +def test_transpose_tensor_3d(axes: str): + tensor = interprete_array(np.random.rand(64, 64, 64), 3) + transposed = transpose_tensor(tensor, [AxisId(a) for a in axes]) assert transposed.ndim == len(axes) From 808f23fd4e0b7b5cd3d4759834663cf1751cbc5f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 13 Mar 2024 23:47:07 +0100 Subject: [PATCH 128/244] add test_stat_calculators.py --- tests/test_stat_calculators.py | 52 ++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 tests/test_stat_calculators.py diff --git a/tests/test_stat_calculators.py b/tests/test_stat_calculators.py new file mode 100644 index 00000000..4d1117a8 --- /dev/null +++ b/tests/test_stat_calculators.py @@ -0,0 +1,52 @@ +from typing import Tuple, Union + +import numpy as np +import pytest +from xarray.testing import assert_allclose # pyright: ignore[reportUnknownVariableType] + +from bioimageio.core.common import AxisId, Sample, Tensor, TensorId +from bioimageio.core.stat_calculators import MeanVarStdCalculator +from bioimageio.core.stat_measures import ( + DatasetMean, + DatasetStd, + DatasetVar, +) + + +def create_random_dataset(tid: TensorId, axes: Tuple[str, ...], n: int = 3): + assert axes[0] == "batch" + sizes = list(range(1, len(axes) + 1)) + b = sizes[0] + ds_array = Tensor(np.random.rand(n * b, *sizes[1:]), dims=axes) + ds = [Sample(data={tid: ds_array[i * b : (i + 1) * b]}) for i in range(n)] + return ds_array, ds + + +@pytest.mark.parametrize( + "axes", + [ + None, + ("x", "y"), + ("channel", "y"), + ], +) +def test_mean_var_std_calculator(axes: Union[None, str, Tuple[str, ...]]): + tid = TensorId("tensor") + axes = tuple(map(AxisId, ("batch", "channel", "x", "y"))) + data, ds = create_random_dataset(tid, axes) + expected_mean = data.mean(axes) + expected_var = data.var(axes) + expected_std = data.std(axes) + + calc = MeanVarStdCalculator(tid, axes=axes) + for s in ds: + calc.update(s) + + actual = calc.finalize() + actual_mean = actual[DatasetMean(tid, axes=axes)] + actual_var = actual[DatasetVar(tid, axes=axes)] + actual_std = actual[DatasetStd(tid, axes=axes)] + + assert_allclose(actual_mean, expected_mean) + assert_allclose(actual_var, expected_var) + assert_allclose(actual_std, expected_std) From ccc84eb9569a4629cb8d657d89460565919716a3 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 14 Mar 2024 00:35:45 +0100 Subject: [PATCH 129/244] fix some tests --- tests/test_proc_ops.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py index 24dd3379..c8b0b6d5 100644 --- a/tests/test_proc_ops.py +++ b/tests/test_proc_ops.py @@ -116,9 +116,8 @@ def test_zero_mean_unit_across_axes(tid: TensorId): sample = Sample(data={tid: data}) sample.stat = compute_measures(op.required_measures, [sample]) - expected = xr.DataArray( - np.array([]), - dims=("c", "x", "y"), + expected = xr.concat( + [(data[i : i + 1] - data[i].mean()) / data[i].std() for i in range(2)], dim="c" ) op(sample) xr.testing.assert_allclose(expected, sample.data[tid]) @@ -189,19 +188,35 @@ def test_combination_of_op_steps_with_dims_specified(tid: TensorId): data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) sample = Sample(data={tid: data}) op = ZeroMeanUnitVariance( - tid, tid, SampleMean(tid, (AxisId("c"),)), SampleStd(tid, (AxisId("c"),)) + tid, + tid, + SampleMean( + tid, + (AxisId("x"), AxisId("y")), + ), + SampleStd( + tid, + (AxisId("x"), AxisId("y")), + ), ) sample.stat = compute_measures(op.required_measures, [sample]) expected = xr.DataArray( np.array( [ - [-1.54919274, -1.16189455, -0.77459637], - [-0.38729818, 0.0, 0.38729818], - [0.77459637, 1.16189455, 1.54919274], + [ + [-1.54919274, -1.16189455, -0.77459637], + [-0.38729818, 0.0, 0.38729818], + [0.77459637, 1.16189455, 1.54919274], + ], + [ + [-1.54919274, -1.16189455, -0.77459637], + [-0.38729818, 0.0, 0.38729818], + [0.77459637, 1.16189455, 1.54919274], + ], ] ), - dims=("x", "y"), + dims=("c", "x", "y"), ) op(sample) From f798344213c179a6c836938aff9e4f6c46f8d23c Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 15 Mar 2024 09:56:06 +0100 Subject: [PATCH 130/244] improve tests --- .../core/model_adapters/_model_adapter.py | 8 +- tests/conftest.py | 157 ++++-------------- tests/test_any_model_fixture.py | 6 + tests/test_prediction_pipeline.py | 4 +- tests/test_resource_tests.py | 4 +- 5 files changed, 51 insertions(+), 128 deletions(-) create mode 100644 tests/test_any_model_fixture.py diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index b9e913be..607317a4 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -23,17 +23,17 @@ class ModelAdapter(ABC): """ Represents model *without* any preprocessing or postprocessing. - >>> from bioimageio.core import read_description - >>> model = read_description() + >>> from bioimageio.core import load_description + >>> model = load_description() >>> print("option 1:") option 1: >>> adapter = ModelAdapter.create(model) - >>> adapter.forward() + >>> adapter.forward # (...) >>> adapter.unload() >>> print("option 2:") option 2: >>> with ModelAdapter.create(model) as adapter: - >>> adapter.forward() + >>> adapter.forward # (...) """ diff --git a/tests/conftest.py b/tests/conftest.py index 03fe126e..ee302035 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,16 +2,12 @@ import subprocess import warnings -from types import MappingProxyType -from typing import List, Set +from typing import List -from filelock import FileLock from loguru import logger -from pydantic import FilePath -from pytest import FixtureRequest, TempPathFactory, fixture +from pytest import FixtureRequest, fixture from bioimageio.spec import __version__ as bioimageio_spec_version -from bioimageio.spec._package import save_bioimageio_package warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}") @@ -47,7 +43,7 @@ ), "unet2d_expand_output_shape": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_nuclei_broad/rdf_expand_output_shape_v0_4.bioimageio.yaml" + "unet2d_nuclei_broad/expand_output_shape_v0_4.bioimageio.yaml" ), "unet2d_fixed_shape": ( "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" @@ -110,59 +106,6 @@ skip_tensorflow = tensorflow is None skip_tensorflow_js = True # TODO: add a tensorflow_js example model -# load all model packages we need for testing -load_model_packages: Set[str] = set() -if not skip_torch: - load_model_packages |= set(TORCH_MODELS + TORCHSCRIPT_MODELS) - -if not skip_onnx: - load_model_packages |= set(ONNX_MODELS) - -if not skip_tensorflow: - load_model_packages |= set(TENSORFLOW_JS_MODELS) - if tf_major_version == 1: - load_model_packages |= set(KERAS_TF1_MODELS) - load_model_packages |= set(TENSORFLOW1_MODELS) - elif tf_major_version == 2: - load_model_packages |= set(KERAS_TF2_MODELS) - load_model_packages |= set(TENSORFLOW2_MODELS) - - -@fixture(scope="session") -def model_packages( - tmp_path_factory: TempPathFactory, worker_id: str -) -> MappingProxyType[str, FilePath]: - """prepare model packages (only run with one worker) - see https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once - """ - root_tmp_dir = tmp_path_factory.getbasetemp().parent - - packages = MappingProxyType( - { - name: (root_tmp_dir / name).with_suffix(".zip") - for name in load_model_packages - } - ) - - def generate_packages(): - for name in load_model_packages: - actual_out = save_bioimageio_package( - MODEL_SOURCES[name], output_path=packages[name] - ) - assert actual_out == packages[name] - - info_path = root_tmp_dir / "packages_created" - if worker_id == "master": - # no workers - generate_packages() - else: - with FileLock(info_path.with_suffix(".lock")): - if not info_path.is_file(): - generate_packages() - _ = info_path.write_text("") - - return packages - @fixture(scope="session") def mamba_cmd(): @@ -185,24 +128,18 @@ def mamba_cmd(): @fixture(params=[] if skip_torch else TORCH_MODELS) -def any_torch_model( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def any_torch_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] @fixture(params=[] if skip_torch else TORCHSCRIPT_MODELS) -def any_torchscript_model( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def any_torchscript_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] @fixture(params=[] if skip_onnx else ONNX_MODELS) -def any_onnx_model( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def any_onnx_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] @fixture( @@ -212,10 +149,8 @@ def any_onnx_model( else TENSORFLOW1_MODELS if tf_major_version == 1 else TENSORFLOW2_MODELS ) ) -def any_tensorflow_model( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def any_tensorflow_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] @fixture( @@ -225,25 +160,21 @@ def any_tensorflow_model( else KERAS_TF1_MODELS if tf_major_version == 1 else KERAS_TF2_MODELS ) ) -def any_keras_model( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def any_keras_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] @fixture(params=[] if skip_tensorflow_js else TENSORFLOW_JS_MODELS) -def any_tensorflow_js_model( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def any_tensorflow_js_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] # fixture to test with all models that should run in the current environment # we exclude stardist_wrong_shape here because it is not a valid model # and included only to test that validation for this model fails -@fixture(params=load_model_packages - {"stardist_wrong_shape", "stardist_wrong_shape2"}) -def any_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): - return model_packages[request.param] +@fixture(params=set(MODEL_SOURCES) - {"stardist_wrong_shape", "stardist_wrong_shape2"}) +def any_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] # TODO it would be nice to just generate fixtures for all the individual models dynamically @@ -256,10 +187,8 @@ def any_model(request: FixtureRequest, model_packages: MappingProxyType[str, Fil @fixture( params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"] ) -def unet2d_fixed_shape_or_not( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def unet2d_fixed_shape_or_not(request: FixtureRequest): + return MODEL_SOURCES[request.param] @fixture( @@ -269,10 +198,8 @@ def unet2d_fixed_shape_or_not( else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"] ) ) -def convert_to_onnx( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def convert_to_onnx(request: FixtureRequest): + return MODEL_SOURCES[request.param] @fixture( @@ -282,50 +209,38 @@ def convert_to_onnx( else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"] ) ) -def unet2d_keras( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def unet2d_keras(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch @fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"]) -def unet2d_nuclei_broad_model( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def unet2d_nuclei_broad_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch @fixture(params=[] if skip_torch else ["unet2d_diff_output_shape"]) -def unet2d_diff_output_shape( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def unet2d_diff_output_shape(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch @fixture(params=[] if skip_torch else ["unet2d_expand_output_shape"]) -def unet2d_expand_output_shape( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def unet2d_expand_output_shape(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch @fixture(params=[] if skip_torch else ["unet2d_fixed_shape"]) -def unet2d_fixed_shape( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def unet2d_fixed_shape(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch @fixture(params=[] if skip_torch else ["shape_change"]) -def shape_change_model( - request: FixtureRequest, model_packages: MappingProxyType[str, FilePath] -): - return model_packages[request.param] +def shape_change_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing tensorflow 1 @@ -346,5 +261,5 @@ def stardist_wrong_shape2(request: FixtureRequest): # written as model group to automatically skip on missing tensorflow 1 @fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"]) -def stardist(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]): - return model_packages[request.param] +def stardist(request: FixtureRequest): + return MODEL_SOURCES[request.param] diff --git a/tests/test_any_model_fixture.py b/tests/test_any_model_fixture.py new file mode 100644 index 00000000..a4cc1bce --- /dev/null +++ b/tests/test_any_model_fixture.py @@ -0,0 +1,6 @@ +from bioimageio.spec import load_description_and_validate_format_only + + +def test_model(any_model: str): + summary = load_description_and_validate_format_only(any_model) + assert summary.status == "passed", summary.format() diff --git a/tests/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py index ddc0b6d1..61fde356 100644 --- a/tests/test_prediction_pipeline.py +++ b/tests/test_prediction_pipeline.py @@ -12,7 +12,9 @@ def _test_prediction_pipeline(model_package: Path, weights_format: WeightsFormat from bioimageio.core._prediction_pipeline import create_prediction_pipeline bio_model = load_description(model_package) - assert isinstance(bio_model, (ModelDescr, ModelDescr04)) + assert isinstance( + bio_model, (ModelDescr, ModelDescr04) + ), bio_model.validation_summary.format() pp = create_prediction_pipeline( bioimageio_model=bio_model, weight_format=weights_format ) diff --git a/tests/test_resource_tests.py b/tests/test_resource_tests.py index 810f8256..08b4343d 100644 --- a/tests/test_resource_tests.py +++ b/tests/test_resource_tests.py @@ -28,11 +28,11 @@ def test_test_model(any_model: Path): from bioimageio.core._resource_tests import test_model summary = test_model(any_model) - assert summary.status == "passed" + assert summary.status == "passed", summary.format() def test_test_resource(any_model: Path): from bioimageio.core._resource_tests import test_description summary = test_description(any_model) - assert summary.status == "passed" + assert summary.status == "passed", summary.format() From 7d022421718044c67268281292141172e9199b56 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 18 Mar 2024 12:53:38 +0100 Subject: [PATCH 131/244] avoid pytest checking example use in ModelAdapter docstring --- .../core/model_adapters/_model_adapter.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index 607317a4..cd2769b9 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -23,18 +23,20 @@ class ModelAdapter(ABC): """ Represents model *without* any preprocessing or postprocessing. - >>> from bioimageio.core import load_description - >>> model = load_description() - >>> print("option 1:") - option 1: - >>> adapter = ModelAdapter.create(model) - >>> adapter.forward # (...) - >>> adapter.unload() - >>> print("option 2:") - option 2: - >>> with ModelAdapter.create(model) as adapter: - >>> adapter.forward # (...) + ``` + from bioimageio.core import load_description + model = load_description(...) + + # option 1: + adapter = ModelAdapter.create(model) + adapter.forward(...) + adapter.unload() + + # option 2: + with ModelAdapter.create(model) as adapter: + adapter.forward(...) + ``` """ @final From 183cc0e2765593b2acb65628d457a12cac7eec73 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 18 Mar 2024 14:44:14 +0100 Subject: [PATCH 132/244] some refactoring and improve interprete_array --- bioimageio/core/_resource_tests.py | 30 +-- bioimageio/core/common.py | 7 +- bioimageio/core/utils/_digest_spec.py | 39 ++-- bioimageio/core/utils/image_helper.py | 305 +++++++++----------------- bioimageio/core/utils/tiling.py | 147 +++++++++++++ tests/test_cli.py | 5 +- tests/test_resource_tests.py | 7 +- tests/utils/test_image_helper.py | 9 +- 8 files changed, 278 insertions(+), 271 deletions(-) create mode 100644 bioimageio/core/utils/tiling.py diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index 80f57372..1e72ff1d 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -3,11 +3,10 @@ from typing import List, Literal, Optional, Sequence, Set, Tuple, Union import numpy as np -import xarray as xr from bioimageio.core._prediction_pipeline import create_prediction_pipeline from bioimageio.core.common import AxisId, BatchSize -from bioimageio.core.utils import VERSION, get_test_inputs +from bioimageio.core.utils import VERSION, get_test_inputs, get_test_outputs from bioimageio.core.utils.image_helper import resize_to from bioimageio.spec import ( InvalidDescr, @@ -17,7 +16,6 @@ load_description, ) from bioimageio.spec._internal.common_nodes import ResourceDescrBase -from bioimageio.spec._internal.io_utils import load_array from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import WeightsFormat @@ -120,30 +118,8 @@ def _test_model_inference( error: Optional[str] = None tb: List[str] = [] try: - if isinstance(model, v0_4.ModelDescr): - inputs = [ - xr.DataArray(load_array(src), dims=d.axes) - for src, d in zip(model.test_inputs, model.inputs) - ] - expected = [ - xr.DataArray(load_array(src), dims=d.axes) - for src, d in zip(model.test_outputs, model.outputs) - ] - else: - inputs = [ - xr.DataArray( - load_array(d.test_tensor.download().path), - dims=tuple(str(a.id) for a in d.axes), - ) - for d in model.inputs - ] - expected = [ - xr.DataArray( - load_array(d.test_tensor.download().path), - dims=tuple(str(a.id) for a in d.axes), - ) - for d in model.outputs - ] + inputs = get_test_inputs(model) + expected = get_test_outputs(model) with create_prediction_pipeline( bioimageio_model=model, devices=devices, weight_format=weight_format diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 635af235..10878dd4 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, Iterable, Literal +from typing import TYPE_CHECKING, Dict, Iterable, Literal, Protocol import xarray as xr @@ -18,6 +18,11 @@ class Axis: type: Literal["batch", "channel", "index", "space", "time"] +class AxisLike(Protocol): + id: str + type: Literal["batch", "channel", "index", "space", "time"] + + BatchSize = int Tensor = xr.DataArray diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py index f01773ac..7f0b892c 100644 --- a/bioimageio/core/utils/_digest_spec.py +++ b/bioimageio/core/utils/_digest_spec.py @@ -1,36 +1,27 @@ from typing import List -import xarray as xr - +from bioimageio.core.common import Tensor from bioimageio.spec.model import AnyModelDescr, v0_4 from bioimageio.spec.utils import load_array +from .image_helper import interprete_array + -def get_test_inputs(model: AnyModelDescr) -> List[xr.DataArray]: +def get_test_inputs(model: AnyModelDescr) -> List[Tensor]: + axes = [d.axes for d in model.inputs] if isinstance(model, v0_4.ModelDescr): - return [ - xr.DataArray(load_array(tt), dims=tuple(d.axes)) - for d, tt in zip(model.inputs, model.test_inputs) - ] + arrays = [load_array(tt) for tt in model.test_inputs] else: - return [ - xr.DataArray( - load_array(d.test_tensor), dims=tuple(str(a.id) for a in d.axes) - ) - for d in model.inputs - ] + arrays = [load_array(d.test_tensor) for d in model.inputs] + return [interprete_array(arr, ax) for arr, ax in zip(arrays, axes)] -def get_test_outputs(model: AnyModelDescr) -> List[xr.DataArray]: + +def get_test_outputs(model: AnyModelDescr) -> List[Tensor]: + axes = [d.axes for d in model.outputs] if isinstance(model, v0_4.ModelDescr): - return [ - xr.DataArray(load_array(tt), dims=tuple(d.axes)) - for d, tt in zip(model.outputs, model.test_outputs) - ] + arrays = [load_array(tt) for tt in model.test_outputs] else: - return [ - xr.DataArray( - load_array(d.test_tensor), dims=tuple(str(a.id) for a in d.axes) - ) - for d in model.outputs - ] + arrays = [load_array(d.test_tensor) for d in model.outputs] + + return [interprete_array(arr, ax) for arr, ax in zip(arrays, axes)] diff --git a/bioimageio/core/utils/image_helper.py b/bioimageio/core/utils/image_helper.py index 2ad325e0..e3b3e5d4 100644 --- a/bioimageio/core/utils/image_helper.py +++ b/bioimageio/core/utils/image_helper.py @@ -1,12 +1,11 @@ -import warnings from pathlib import Path -from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Literal, Optional, Sequence, Tuple, Union import imageio +import numpy as np from numpy.typing import NDArray -from typing_extensions import assert_never -from bioimageio.core.common import Axis, Tensor +from bioimageio.core.common import Axis, AxisLike, Tensor from bioimageio.spec.model import v0_4 from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr04 from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensorDescr04 @@ -18,6 +17,7 @@ Identifier, InputTensorDescr, OutputTensorDescr, + SizeReference, SpaceInputAxis, convert_axes, ) @@ -27,88 +27,124 @@ OutputTensor = Union[OutputTensorDescr04, OutputTensorDescr] -def interprete_array_with_desired_axes( - nd_array: NDArray[Any], - desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]], -) -> Tensor: - if isinstance(desired_axes, str): - desired_space_axes = [a for a in desired_axes if a in "zyx"] +def normalize_axes( + axes: Union[v0_4.AxesStr, Sequence[Union[AnyAxis, AxisLike]]] +) -> Tuple[Axis, ...]: + AXIS_TYPE_MAP: Dict[str, Literal["batch", "time", "index", "channel", "space"]] = { + "b": "batch", + "t": "time", + "i": "index", + "c": "channel", + "x": "space", + "y": "space", + "z": "space", + } + if isinstance(axes, str): + return tuple(Axis(id=AxisId(a), type=AXIS_TYPE_MAP[a]) for a in axes) else: - desired_space_axes = [a for a in desired_axes if a.type == "space"] - - return interprete_array(nd_array, len(desired_space_axes)) - + return tuple( + Axis(id=a.id if isinstance(a.id, AxisId) else AxisId(a.id), type=a.type) + for a in axes + ) -def interprete_array( - nd_array: NDArray[Any], - n_expected_space_axes: Optional[int] = None, -) -> Tensor: - ndim = nd_array.ndim - if ndim == 2 and (n_expected_space_axes is None or n_expected_space_axes >= 2): +def _interprete_array_wo_known_axes(array: NDArray[Any]): + ndim = array.ndim + if ndim == 2: current_axes = ( - SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[0]), - SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[1]), + SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), + SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), ) - elif ndim == 3 and ( - (n_expected_space_axes is None and any(s <= 3 for s in nd_array.shape)) - or n_expected_space_axes == 2 - ): + elif ndim == 3 and any(s <= 3 for s in array.shape): current_axes = ( ChannelAxis( - channel_names=[ - Identifier(f"channel{i}") for i in range(nd_array.shape[0]) - ] + channel_names=[Identifier(f"channel{i}") for i in range(array.shape[0])] ), - SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[1]), - SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[2]), + SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), + SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), ) - elif ndim == 3 and (n_expected_space_axes is None or n_expected_space_axes == 3): + elif ndim == 3: current_axes = ( - SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[0]), - SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[1]), - SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[2]), + SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), + SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), + SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), ) elif ndim == 4: current_axes = ( ChannelAxis( - channel_names=[ - Identifier(f"channel{i}") for i in range(nd_array.shape[0]) - ] + channel_names=[Identifier(f"channel{i}") for i in range(array.shape[0])] ), - SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[1]), - SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[2]), - SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[3]), + SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), + SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), + SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), ) elif ndim == 5: current_axes = ( BatchAxis(), ChannelAxis( - channel_names=[ - Identifier(f"channel{i}") for i in range(nd_array.shape[1]) - ] + channel_names=[Identifier(f"channel{i}") for i in range(array.shape[1])] ), - SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[2]), - SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[3]), - SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[4]), + SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), + SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), + SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), ) else: - raise ValueError( - f"Could not guess an axis mapping for {nd_array.shape} with {n_expected_space_axes} expected space axes" - ) + raise ValueError(f"Could not guess an axis mapping for {array.shape}") - current_axes_ids = tuple(str(a.id) for a in current_axes) + return Tensor(array, dims=tuple(a.id for a in current_axes)) - return Tensor(nd_array, dims=current_axes_ids) +def interprete_array( + array: NDArray[Any], + axes: Optional[Union[v0_4.AxesStr, Sequence[AnyAxis]]], +) -> Tensor: + if axes is None: + return _interprete_array_wo_known_axes(array) -def axis_descr_to_ids( - axes: Union[v0_4.AxesStr, Sequence[AnyAxis]] -) -> Tuple[AxisId, ...]: - if isinstance(axes, str): - return tuple(map(AxisId, axes)) - else: - return tuple(a.id for a in axes) + original_shape = tuple(array.shape) + if len(array.shape) > len(axes): + # remove singletons + for i, s in enumerate(array.shape): + if s == 1: + array = np.take(array, 0, axis=i) + if len(array.shape) == len(axes): + break + + if len(array.shape) < len(axes): + # add singletons + for a in axes: + if len(array.shape) == len(axes): + break + + if isinstance(a, str) or a.size is None: + array = array[None] + continue + + if isinstance(a.size, int): + if a.size == 1: + array = array[None] + + continue + + if isinstance(a.size, SizeReference): + continue # TODO: check if singleton is ok for a `SizeReference` + + try: + maybe_size_one = a.size.validate_size( + 1 + ) # TODO: refactor validate_size() to have boolean func here + except ValueError: + continue + + if maybe_size_one == 1: + array = array[None] + + if len(array.shape) != len(axes): + raise ValueError(f"Array shape {original_shape} does not map to axes {axes}") + + normalized_axes = normalize_axes(axes) + assert len(normalized_axes) == len(axes) + return Tensor(array, dims=tuple(a.id for a in normalized_axes)) def transpose_tensor( @@ -122,8 +158,8 @@ def transpose_tensor( axes: the desired array axes """ # expand the missing image axes - current_axes = tuple(AxisId(str(d)) for d in tensor.dims) - missing_axes = tuple(str(a) for a in axes if a not in current_axes) + current_axes = tuple(d if isinstance(d, AxisId) else AxisId(d) for d in tensor.dims) + missing_axes = tuple(a for a in axes if a not in current_axes) tensor = tensor.expand_dims(missing_axes) # transpose to the correct axis order return tensor.transpose(*map(str, axes)) @@ -135,7 +171,7 @@ def convert_v0_4_axes_for_known_shape(axes: v0_4.AxesStr, shape: Sequence[int]): def load_tensor( path: Path, - axes: Optional[Sequence[Axis]] = None, + axes: Optional[Sequence[AnyAxis]] = None, ) -> Tensor: ext = path.suffix @@ -145,147 +181,4 @@ def load_tensor( is_volume = True if axes is None else sum(a.type != "channel" for a in axes) > 2 array = imageio.volread(path) if is_volume else imageio.imread(path) - if axes is None: - return interprete_array(array) - else: - return Tensor(array, dims=tuple(a.id for a in axes)) - - -def pad( - tensor: Tensor, - pad_width: Mapping[AxisId, Union[int, Tuple[int, int]]], - mode: Literal["edge", "reflect", "symmetric"] = "symmetric", -): - return tensor.pad(pad_width={str(k): v for k, v in pad_width.items()}, mode=mode) - - -def resize_to( - tensor: Tensor, - sizes: Mapping[AxisId, int], - *, - pad_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", - crop_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", - pad_mode: Literal["edge", "reflect", "symmetric"] = "symmetric", -): - """crop and pad `tensor` to match `sizes`""" - crop_to_sizes: Dict[AxisId, int] = {} - pad_to_sizes: Dict[AxisId, int] = {} - new_axes = dict(sizes) - for a, s_is in tensor.sizes.items(): - a = AxisId(str(a)) - _ = new_axes.pop(a, None) - if a not in sizes or sizes[a] == s_is: - pass - elif s_is < sizes[a]: - crop_to_sizes[a] = sizes[a] - else: - pad_to_sizes[a] = sizes[a] - - if crop_to_sizes: - tensor = crop_to(tensor, crop_to_sizes, crop_where=crop_where) - - if pad_to_sizes: - tensor = pad_to(tensor, pad_to_sizes, pad_where=pad_where, mode=pad_mode) - - if new_axes: - tensor = tensor.expand_dims({str(k): v for k, v in new_axes}) - - return tensor - - -def crop_to( - tensor: Tensor, - sizes: Mapping[AxisId, int], - crop_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", -): - """crop `tensor` to match `sizes`""" - axes = [AxisId(str(a)) for a in tensor.dims] - if crop_where in ("before", "center", "after"): - crop_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { - a: crop_where for a in axes - } - else: - crop_axis_where = crop_where - - slices: Dict[AxisId, slice] = {} - - for a, s_is in tensor.sizes.items(): - a = AxisId(str(a)) - if a not in sizes or sizes[a] == s_is: - pass - elif sizes[a] > s_is: - warnings.warn( - f"Cannot crop axis {a} of size {s_is} to larger size {sizes[a]}" - ) - elif a not in crop_axis_where: - raise ValueError( - f"Don't know where to crop axis {a}, `crop_where`={crop_where}" - ) - else: - crop_this_axis_where = crop_axis_where[a] - if crop_this_axis_where == "before": - slices[a] = slice(s_is - sizes[a], s_is) - elif crop_this_axis_where == "after": - slices[a] = slice(0, sizes[a]) - elif crop_this_axis_where == "center": - slices[a] = slice(start := (s_is - sizes[a]) // 2, sizes[a] + start) - else: - assert_never(crop_this_axis_where) - - return tensor.isel({str(a): s for a, s in slices.items()}) - - -def pad_to( - tensor: Tensor, - sizes: Mapping[AxisId, int], - pad_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", - mode: Literal["edge", "reflect", "symmetric"] = "symmetric", -): - """pad `tensor` to match `sizes`""" - axes = [AxisId(str(a)) for a in tensor.dims] - if pad_where in ("before", "center", "after"): - pad_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { - a: pad_where for a in axes - } - else: - pad_axis_where = pad_where - - pad_width: Dict[AxisId, Union[int, Tuple[int, int]]] = {} - for a, s_is in tensor.sizes.items(): - a = AxisId(str(a)) - if a not in sizes or sizes[a] == s_is: - pad_width[a] = 0 - elif s_is < sizes[a]: - pad_width[a] = 0 - warnings.warn( - f"Cannot pad axis {a} of size {s_is} to smaller size {sizes[a]}" - ) - elif a not in pad_axis_where: - raise ValueError( - f"Don't know where to pad axis {a}, `pad_where`={pad_where}" - ) - else: - pad_this_axis_where = pad_axis_where[a] - p = sizes[a] - s_is - if pad_this_axis_where == "before": - pad_width[a] = (p, 0) - elif pad_this_axis_where == "after": - pad_width[a] = (0, p) - elif pad_this_axis_where == "center": - pad_width[a] = (left := p // 2, p - left) - else: - assert_never(pad_this_axis_where) - - return pad(tensor, pad_width, mode) + return interprete_array(array, axes) diff --git a/bioimageio/core/utils/tiling.py b/bioimageio/core/utils/tiling.py new file mode 100644 index 00000000..2b65b361 --- /dev/null +++ b/bioimageio/core/utils/tiling.py @@ -0,0 +1,147 @@ +import warnings +from typing import Dict, Literal, Mapping, Tuple, Union + +from typing_extensions import assert_never + +from bioimageio.core.common import Tensor +from bioimageio.spec.model.v0_5 import AxisId + + +def pad( + tensor: Tensor, + pad_width: Mapping[AxisId, Union[int, Tuple[int, int]]], + mode: Literal["edge", "reflect", "symmetric"] = "symmetric", +): + return tensor.pad(pad_width={str(k): v for k, v in pad_width.items()}, mode=mode) + + +def pad_to( + tensor: Tensor, + sizes: Mapping[AxisId, int], + pad_where: Union[ + Literal["before", "center", "after"], + Mapping[AxisId, Literal["before", "center", "after"]], + ] = "center", + mode: Literal["edge", "reflect", "symmetric"] = "symmetric", +): + """pad `tensor` to match `sizes`""" + axes = [AxisId(str(a)) for a in tensor.dims] + if pad_where in ("before", "center", "after"): + pad_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { + a: pad_where for a in axes + } + else: + pad_axis_where = pad_where + + pad_width: Dict[AxisId, Union[int, Tuple[int, int]]] = {} + for a, s_is in tensor.sizes.items(): + a = AxisId(str(a)) + if a not in sizes or sizes[a] == s_is: + pad_width[a] = 0 + elif s_is < sizes[a]: + pad_width[a] = 0 + warnings.warn( + f"Cannot pad axis {a} of size {s_is} to smaller size {sizes[a]}" + ) + elif a not in pad_axis_where: + raise ValueError( + f"Don't know where to pad axis {a}, `pad_where`={pad_where}" + ) + else: + pad_this_axis_where = pad_axis_where[a] + p = sizes[a] - s_is + if pad_this_axis_where == "before": + pad_width[a] = (p, 0) + elif pad_this_axis_where == "after": + pad_width[a] = (0, p) + elif pad_this_axis_where == "center": + pad_width[a] = (left := p // 2, p - left) + else: + assert_never(pad_this_axis_where) + + return pad(tensor, pad_width, mode) + + +def crop_to( + tensor: Tensor, + sizes: Mapping[AxisId, int], + crop_where: Union[ + Literal["before", "center", "after"], + Mapping[AxisId, Literal["before", "center", "after"]], + ] = "center", +): + """crop `tensor` to match `sizes`""" + axes = [AxisId(str(a)) for a in tensor.dims] + if crop_where in ("before", "center", "after"): + crop_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { + a: crop_where for a in axes + } + else: + crop_axis_where = crop_where + + slices: Dict[AxisId, slice] = {} + + for a, s_is in tensor.sizes.items(): + a = AxisId(str(a)) + if a not in sizes or sizes[a] == s_is: + pass + elif sizes[a] > s_is: + warnings.warn( + f"Cannot crop axis {a} of size {s_is} to larger size {sizes[a]}" + ) + elif a not in crop_axis_where: + raise ValueError( + f"Don't know where to crop axis {a}, `crop_where`={crop_where}" + ) + else: + crop_this_axis_where = crop_axis_where[a] + if crop_this_axis_where == "before": + slices[a] = slice(s_is - sizes[a], s_is) + elif crop_this_axis_where == "after": + slices[a] = slice(0, sizes[a]) + elif crop_this_axis_where == "center": + slices[a] = slice(start := (s_is - sizes[a]) // 2, sizes[a] + start) + else: + assert_never(crop_this_axis_where) + + return tensor.isel({str(a): s for a, s in slices.items()}) + + +def resize_to( + tensor: Tensor, + sizes: Mapping[AxisId, int], + *, + pad_where: Union[ + Literal["before", "center", "after"], + Mapping[AxisId, Literal["before", "center", "after"]], + ] = "center", + crop_where: Union[ + Literal["before", "center", "after"], + Mapping[AxisId, Literal["before", "center", "after"]], + ] = "center", + pad_mode: Literal["edge", "reflect", "symmetric"] = "symmetric", +): + """crop and pad `tensor` to match `sizes`""" + crop_to_sizes: Dict[AxisId, int] = {} + pad_to_sizes: Dict[AxisId, int] = {} + new_axes = dict(sizes) + for a, s_is in tensor.sizes.items(): + a = AxisId(str(a)) + _ = new_axes.pop(a, None) + if a not in sizes or sizes[a] == s_is: + pass + elif s_is < sizes[a]: + crop_to_sizes[a] = sizes[a] + else: + pad_to_sizes[a] = sizes[a] + + if crop_to_sizes: + tensor = crop_to(tensor, crop_to_sizes, crop_where=crop_where) + + if pad_to_sizes: + tensor = pad_to(tensor, pad_to_sizes, pad_where=pad_where, mode=pad_mode) + + if new_axes: + tensor = tensor.expand_dims({str(k): v for k, v in new_axes}) + + return tensor diff --git a/tests/test_cli.py b/tests/test_cli.py index ee09f66f..5844b115 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,5 @@ import subprocess +from pathlib import Path from typing import Any, List, Sequence import pytest @@ -36,8 +37,8 @@ def run_subprocess( ["test-model", "unet2d_nuclei_broad_model"], ], ) -def test_cli(args: List[str], unet2d_nuclei_broad_model: FilePath): - assert unet2d_nuclei_broad_model.exists() +def test_cli(args: List[str], unet2d_nuclei_broad_model: str): + assert Path(unet2d_nuclei_broad_model).exists() resolved_args = [ str(unet2d_nuclei_broad_model) if arg == "unet2d_nuclei_broad_model" else arg for arg in args diff --git a/tests/test_resource_tests.py b/tests/test_resource_tests.py index 08b4343d..48312322 100644 --- a/tests/test_resource_tests.py +++ b/tests/test_resource_tests.py @@ -1,6 +1,3 @@ -from pathlib import Path - - def test_error_for_wrong_shape(stardist_wrong_shape: str): from bioimageio.core._resource_tests import test_model @@ -24,14 +21,14 @@ def test_error_for_wrong_shape2(stardist_wrong_shape2: str): assert summary.details[0].errors[0].msg == expected_error_message -def test_test_model(any_model: Path): +def test_test_model(any_model: str): from bioimageio.core._resource_tests import test_model summary = test_model(any_model) assert summary.status == "passed", summary.format() -def test_test_resource(any_model: Path): +def test_test_resource(any_model: str): from bioimageio.core._resource_tests import test_description summary = test_description(any_model) diff --git a/tests/utils/test_image_helper.py b/tests/utils/test_image_helper.py index d51f186a..ea3b4f24 100644 --- a/tests/utils/test_image_helper.py +++ b/tests/utils/test_image_helper.py @@ -1,5 +1,3 @@ -from typing import Sequence - import numpy as np import pytest import xarray as xr @@ -7,11 +5,10 @@ from bioimageio.core.common import AxisId from bioimageio.core.utils.image_helper import ( - crop_to, interprete_array, - pad, transpose_tensor, ) +from bioimageio.core.utils.tiling import crop_to, pad @pytest.mark.parametrize( @@ -20,7 +17,7 @@ ) def test_transpose_tensor_2d(axes: str): - tensor = interprete_array(np.random.rand(256, 256), len(axes)) + tensor = interprete_array(np.random.rand(256, 256), None) transposed = transpose_tensor(tensor, [AxisId(a) for a in axes]) assert transposed.ndim == len(axes) @@ -30,7 +27,7 @@ def test_transpose_tensor_2d(axes: str): ["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"], ) def test_transpose_tensor_3d(axes: str): - tensor = interprete_array(np.random.rand(64, 64, 64), 3) + tensor = interprete_array(np.random.rand(64, 64, 64), None) transposed = transpose_tensor(tensor, [AxisId(a) for a in axes]) assert transposed.ndim == len(axes) From 4866fb4aa1c03bd075dc8b4544c861e4b5057142 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 20 Mar 2024 10:27:34 +0100 Subject: [PATCH 133/244] update model inference testing --- bioimageio/core/_resource_tests.py | 59 ++++++++++++++++++++++-------- setup.py | 2 +- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index 1e72ff1d..2a482da7 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -1,13 +1,13 @@ import traceback import warnings -from typing import List, Literal, Optional, Sequence, Set, Tuple, Union +from typing import Dict, Hashable, List, Literal, Optional, Sequence, Set, Tuple, Union import numpy as np from bioimageio.core._prediction_pipeline import create_prediction_pipeline from bioimageio.core.common import AxisId, BatchSize from bioimageio.core.utils import VERSION, get_test_inputs, get_test_outputs -from bioimageio.core.utils.image_helper import resize_to +from bioimageio.core.utils.tiling import resize_to from bioimageio.spec import ( InvalidDescr, ResourceDescr, @@ -171,7 +171,7 @@ def _test_model_inference_parametrized( weight_format: Optional[WeightsFormat], devices: Optional[List[str]], test_cases: Sequence[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = ( - (0, 1), + (0, 2), (1, 3), (2, 1), (3, 2), @@ -182,27 +182,51 @@ def _test_model_inference_parametrized( for ipt in model.inputs for a in ipt.axes ): - return + # only test different batch sizes for n=0 + test_cases = [tc for tc in test_cases if tc[0] == 0] + if not test_cases: + return try: test_inputs = get_test_inputs(model) def generate_test_cases(): - tested: Set[str] = set() + tested: Set[Hashable] = set() + + def get_ns(n: int): + return {(t.id, a.id): n for t in model.inputs for a in t.axes} + for n, batch_size in test_cases: - target_sizes = model.get_tensor_sizes(n, batch_size=batch_size) - hashable_target_size = str(target_sizes) + input_target_sizes, expected_output_sizes = model.get_axis_sizes( + get_ns(n), batch_size=batch_size + ) + hashable_target_size = tuple( + (input_target_sizes, input_target_sizes[ts]) + for ts in sorted(input_target_sizes) + ) if hashable_target_size in tested: continue else: tested.add(hashable_target_size) resized_test_inputs = [ - resize_to(t, target_sizes[t_descr.id]) + resize_to( + t, + { + aid: s + for (tid, aid), s in input_target_sizes.items() + if tid == t_descr.id + }, + ) for t, t_descr in zip(test_inputs, model.inputs) ] expected_output_shapes = [ - target_sizes[t_descr.id] for t_descr in model.outputs + { + aid: s + for (tid, aid), s in expected_output_sizes.items() + if tid == t_descr.id + } + for t_descr in model.outputs ] yield n, batch_size, resized_test_inputs, expected_output_shapes @@ -219,15 +243,20 @@ def generate_test_cases(): ) else: for res, exp in zip(results, exptected_output_shape): - if diff := { - a: s - for a, s in res.sizes.items() - if s != exp[AxisId(str(a))] - }: + diff: Dict[AxisId, int] = {} + for a, s in res.sizes.items(): + if isinstance((e_aid := exp[AxisId(a)]), int): + if s != e_aid: + diff[AxisId(a)] = s + elif ( + s < e_aid.min or e_aid.max is not None and s > e_aid.max + ): + diff[AxisId(a)] = s + if diff: error = ( (error or "") + f"(n={n}) Expected output shape {exp}," - + f" but got {exptected_output_shape} ({diff})\n" + + f" but got {res.sizes} ({diff})\n" ) model.validation_summary.add_detail( diff --git a/setup.py b/setup.py index 9665a2a1..0af57391 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ ], packages=find_namespace_packages(exclude=["tests"]), install_requires=[ - "bioimageio.spec==0.5.0.*", + "bioimageio.spec==0.5.1.*", "imageio>=2.5", "loguru", "numpy", From 4a5c0468a97f4f5dc033c03245c1ba8eae9f0f7f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 20 Mar 2024 10:28:02 +0100 Subject: [PATCH 134/244] remove outdated environment variable docs --- README.md | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/README.md b/README.md index b99db4c2..54cfb069 100644 --- a/README.md +++ b/README.md @@ -68,16 +68,6 @@ pip install -e . --no-deps There are different environment files that only install tensorflow or pytorch as dependencies available. -## 🏞 Environment variables - -| Name | Default | Description | -|---|---|---| -| BIOIMAGEIO_USE_CACHE | "true" | Enables simple URL to file cache. possible, case-insensitive, positive values are: -"true", "yes", "1". Any other value is interpreted as "false" | -| BIOIMAGEIO_CACHE_PATH | generated tmp folder | File path for simple URL to file cache; -changes of URL source are not detected. | -| BIOIMAGEIO_CACHE_WARNINGS_LIMIT | "3" | Maximum number of warnings generated for simple cache hits. | - ## 💻 Command Line `bioimageio.core` installs a command line interface (CLI) for testing models and other functionality. From c0e25b870e82f5208c058419d0a64e9484859f75 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 20 Mar 2024 23:53:58 +0100 Subject: [PATCH 135/244] add test_loading_description_multiple_times --- tests/test_resource_tests.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_resource_tests.py b/tests/test_resource_tests.py index 48312322..b9d3cf66 100644 --- a/tests/test_resource_tests.py +++ b/tests/test_resource_tests.py @@ -1,3 +1,6 @@ +from bioimageio.spec import InvalidDescr + + def test_error_for_wrong_shape(stardist_wrong_shape: str): from bioimageio.core._resource_tests import test_model @@ -33,3 +36,14 @@ def test_test_resource(any_model: str): summary = test_description(any_model) assert summary.status == "passed", summary.format() + + +def test_loading_description_multiple_times(unet2d_nuclei_broad_model: str): + from bioimageio.core import load_description + + model_descr = load_description(unet2d_nuclei_broad_model) + assert not isinstance(model_descr, InvalidDescr) + + # load again, which some users might end up doing + model_descr = load_description(model_descr) # pyright: ignore[reportArgumentType] + assert not isinstance(model_descr, InvalidDescr) From de6435f773abde3526b80e10793536fd36acab76 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 21 Mar 2024 00:46:15 +0100 Subject: [PATCH 136/244] allow for optional tensors --- bioimageio/core/_prediction_pipeline.py | 40 ++++++++++--------- .../model_adapters/_keras_model_adapter.py | 11 ++--- .../core/model_adapters/_model_adapter.py | 5 +-- .../model_adapters/_onnx_model_adapter.py | 23 ++++++----- .../model_adapters/_pytorch_model_adapter.py | 21 +++++++--- .../_tensorflow_model_adapter.py | 23 +++++++---- .../_torchscript_model_adapter.py | 24 +++++++---- 7 files changed, 90 insertions(+), 57 deletions(-) diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index a596c790..2d3c42b0 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -2,9 +2,7 @@ from types import MappingProxyType from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union -import xarray as xr - -from bioimageio.core.common import Sample, TensorId +from bioimageio.core.common import Sample, Tensor, TensorId from bioimageio.core.model_adapters import ModelAdapter, create_model_adapter from bioimageio.core.model_adapters import get_weight_formats as get_weight_formats from bioimageio.core.proc_ops import Processing @@ -48,8 +46,8 @@ def __init__( self._adapter: ModelAdapter = model def __call__( - self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray - ) -> List[xr.DataArray]: + self, *input_tensors: Tensor, **named_input_tensors: Tensor + ) -> List[Tensor]: return self.forward(*input_tensors, **named_input_tensors) def __enter__(self): @@ -61,11 +59,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore return False def predict( - self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray - ) -> List[xr.DataArray]: + self, *input_tensors: Optional[Tensor], **named_input_tensors: Optional[Tensor] + ) -> List[Tensor]: """Predict input_tensor with the model without applying pre/postprocessing.""" named_tensors = [ - named_input_tensors[str(k)] for k in self.input_ids[len(input_tensors) :] + named_input_tensors.get(str(k)) + for k in self.input_ids[len(input_tensors) :] ] return self._adapter.forward(*input_tensors, *named_tensors) @@ -93,23 +92,30 @@ def forward_sample(self, input_sample: Sample) -> Sample: return prediction def forward_tensors( - self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray - ) -> Dict[TensorId, xr.DataArray]: + self, *input_tensors: Optional[Tensor], **named_input_tensors: Optional[Tensor] + ) -> Dict[TensorId, Tensor]: """Apply preprocessing, run prediction and apply postprocessing.""" + assert all(TensorId(k) in self.input_ids for k in named_input_tensors) input_sample = Sample( data={ - **dict(zip(self.input_ids, input_tensors)), - **{TensorId(k): v for k, v in named_input_tensors.items()}, + **{ + k: v for k, v in zip(self.input_ids, input_tensors) if v is not None + }, + **{ + TensorId(k): v + for k, v in named_input_tensors.items() + if v is not None + }, } ) return self.forward_sample(input_sample).data def forward( - self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray - ) -> List[xr.DataArray]: + self, *input_tensors: Optional[Tensor], **named_input_tensors: Optional[Tensor] + ) -> List[Optional[Tensor]]: """Apply preprocessing, run prediction and apply postprocessing.""" named_outputs = self.forward_tensors(*input_tensors, **named_input_tensors) - return [named_outputs[x] for x in self.output_ids] + return [named_outputs.get(x) for x in self.output_ids] def load(self): """ @@ -130,9 +136,7 @@ def create_prediction_pipeline( devices: Optional[Sequence[str]] = None, weight_format: Optional[WeightsFormat] = None, weights_format: Optional[WeightsFormat] = None, - dataset_for_initial_statistics: Iterable[ - Union[Sample, Sequence[xr.DataArray]] - ] = tuple(), + dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(), keep_updating_initial_dataset_statistics: bool = False, fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType( {} diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py index af429644..445f4069 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -4,6 +4,8 @@ from numpy.typing import NDArray from packaging.version import Version +from bioimageio.core.common import Tensor + # by default, we use the keras integrated with tensorflow try: import tensorflow as tf @@ -63,11 +65,10 @@ def __init__( self._network = keras.models.load_model(weight_path) self._output_axes = [tuple(out.axes) for out in model_description.outputs] - def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - _result: Union[Sequence[NDArray[Any]], NDArray[Any]] = ( - self._network.predict( # pyright: ignore[reportUnknownVariableType] - *input_tensors - ) + def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: + _result: Union[Sequence[NDArray[Any]], NDArray[Any]] + _result = self._network.predict( # pyright: ignore[reportUnknownVariableType] + *input_tensors ) if isinstance(_result, (tuple, list)): result: Sequence[NDArray[Any]] = _result diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index cd2769b9..ec028a2d 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -2,8 +2,7 @@ from abc import ABC, abstractmethod from typing import List, Optional, Sequence, Tuple, Union, final -import xarray as xr - +from bioimageio.core.common import Tensor from bioimageio.spec.model import v0_4, v0_5 WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat] @@ -128,7 +127,7 @@ def load(self, *, devices: Optional[Sequence[str]] = None) -> None: warnings.warn("Deprecated. ModelAdapter is always loaded") @abstractmethod - def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: + def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: """ Run forward pass of model to get model predictions """ diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py index 26400eda..19fdf0cc 100644 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -1,9 +1,9 @@ import warnings from typing import Any, List, Optional, Sequence, Union -import xarray as xr from numpy.typing import NDArray +from bioimageio.core.common import Tensor from bioimageio.spec.model import v0_4, v0_5 from ._model_adapter import ModelAdapter @@ -43,20 +43,21 @@ def __init__( f"Device management is not implemented for onnx yet, ignoring the devices {devices}" ) - def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: + def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: assert len(input_tensors) == len(self._input_names) - input_arrays = [ipt.data for ipt in input_tensors] - result: Union[Sequence[NDArray[Any]], NDArray[Any]] = ( - self._session.run( # pyright: ignore[reportUnknownVariableType] - None, dict(zip(self._input_names, input_arrays)) - ) + input_arrays = [None if ipt is None else ipt.data for ipt in input_tensors] + result: Union[Sequence[Optional[NDArray[Any]]], Optional[NDArray[Any]]] + result = self._session.run( # pyright: ignore[reportUnknownVariableType] + None, dict(zip(self._input_names, input_arrays)) ) - if not isinstance(result, (list, tuple)): - result = [] + if isinstance(result, (list, tuple)): + result_seq: Sequence[Optional[NDArray[Any]]] = result + else: + result_seq = [result] # type: ignore return [ - xr.DataArray(r, dims=axes) - for r, axes in zip(result, self._internal_output_axes) + None if r is None else Tensor(r, dims=axes) + for r, axes in zip(result_seq, self._internal_output_axes) ] def unload(self) -> None: diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index b54b82d8..98af48b4 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -2,8 +2,7 @@ import warnings from typing import Any, List, Optional, Sequence, Tuple, Union -import xarray as xr - +from bioimageio.core.common import Tensor from bioimageio.core.utils import import_callable from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download @@ -45,16 +44,23 @@ def __init__( self._network = self._network.eval() - def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: + def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: with torch.no_grad(): - tensors = [torch.from_numpy(ipt.data) for ipt in input_tensors] + tensors = [ + None if ipt is None else torch.from_numpy(ipt.data) + for ipt in input_tensors + ] tensors = [t.to(self._devices[0]) for t in tensors] result: Union[Tuple[Any, ...], List[Any], Any] = self._network(*tensors) if not isinstance(result, (tuple, list)): result = [result] result = [ - r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r + ( + None + if r is None + else r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r + ) for r in result ] if len(result) > len(self.output_dims): @@ -62,7 +68,10 @@ def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: f"Expected at most {len(self.output_dims)} outputs, but got {len(result)}" ) - return [xr.DataArray(r, dims=out) for r, out in zip(result, self.output_dims)] + return [ + None if r is None else Tensor(r, dims=out) + for r, out in zip(result, self.output_dims) + ] def unload(self) -> None: del self._network diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index cba0ad04..905a4c73 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -3,8 +3,8 @@ from typing import List, Literal, Optional, Sequence, Union import numpy as np -import xarray as xr +from bioimageio.core.common import Tensor from bioimageio.spec.common import FileSource from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download @@ -141,10 +141,12 @@ def _forward_tf(self, *input_tensors): return res - def _forward_keras(self, *input_tensors: xr.DataArray): + def _forward_keras(self, *input_tensors: Optional[Tensor]): assert self.use_keras_api assert not isinstance(self._network, str) - tf_tensor = [tf.convert_to_tensor(ipt) for ipt in input_tensors] + tf_tensor = [ + None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_tensors + ] try: result = self._network.forward(*tf_tensor) @@ -154,17 +156,24 @@ def _forward_keras(self, *input_tensors: xr.DataArray): if not isinstance(result, (tuple, list)): result = [result] - return [r if isinstance(r, np.ndarray) else tf.make_ndarray(r) for r in result] + return [ + ( + None + if r is None + else r if isinstance(r, np.ndarray) else tf.make_ndarray(r) + ) + for r in result + ] - def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - data = [ipt.data for ipt in input_tensors] + def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: + data = [None if ipt is None else ipt.data for ipt in input_tensors] if self.use_keras_api: result = self._forward_keras(*data) else: result = self._forward_tf(*data) return [ - xr.DataArray(r, dims=axes) + None if r is None else Tensor(r, dims=axes) for r, axes in zip(result, self._internal_output_axes) ] diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index 7637bd8a..8ef52616 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -3,9 +3,9 @@ from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np -import xarray as xr from numpy.typing import NDArray +from bioimageio.core.common import Tensor from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download @@ -53,24 +53,34 @@ def __init__( for out in model_description.outputs ] - def forward(self, *batch: xr.DataArray) -> List[xr.DataArray]: + def forward(self, *batch: Optional[Tensor]) -> List[Optional[Tensor]]: with torch.no_grad(): - torch_tensor = [torch.from_numpy(b.data).to(self.devices[0]) for b in batch] + torch_tensor = [ + None if b is None else torch.from_numpy(b.data).to(self.devices[0]) + for b in batch + ] _result: Union[ # pyright: ignore[reportUnknownVariableType] - Tuple[NDArray[Any], ...], List[NDArray[Any]], NDArray[Any] + Tuple[Optional[NDArray[Any]], ...], + List[Optional[NDArray[Any]]], + Optional[NDArray[Any]], ] = self._model.forward(*torch_tensor) if isinstance(_result, (tuple, list)): - result: Sequence[NDArray[Any]] = _result + result: Sequence[Optional[NDArray[Any]]] = _result else: result = [_result] result = [ - r.cpu().numpy() if not isinstance(r, np.ndarray) else r for r in result + ( + None + if r is None + else r.cpu().numpy() if not isinstance(r, np.ndarray) else r + ) + for r in result ] assert len(result) == len(self._internal_output_axes) return [ - xr.DataArray(r, dims=axes) + None if r is None else Tensor(r, dims=axes) for r, axes in zip(result, self._internal_output_axes) ] From 789ba05c0991a79987eb711e06d8e3ad1216ed45 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 21 Mar 2024 09:10:25 +0100 Subject: [PATCH 137/244] Fix optional inputs for pytorch --- bioimageio/core/model_adapters/_pytorch_model_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 98af48b4..103af9bb 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -50,7 +50,7 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: None if ipt is None else torch.from_numpy(ipt.data) for ipt in input_tensors ] - tensors = [t.to(self._devices[0]) for t in tensors] + tensors = [None if t is None else t.to(self._devices[0]) for t in tensors] result: Union[Tuple[Any, ...], List[Any], Any] = self._network(*tensors) if not isinstance(result, (tuple, list)): result = [result] From 9ba69aefc4ca3e467588687437cf4c515b9c7f8b Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 21 Mar 2024 16:40:41 +0100 Subject: [PATCH 138/244] fix pyright issues in model_adapters --- .../model_adapters/_keras_model_adapter.py | 49 ++++---- .../core/model_adapters/_model_adapter.py | 5 +- .../model_adapters/_pytorch_model_adapter.py | 36 ++++-- .../_tensorflow_model_adapter.py | 116 +++++++++++++----- .../_torchscript_model_adapter.py | 6 +- 5 files changed, 146 insertions(+), 66 deletions(-) diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py index 445f4069..af69315f 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -1,30 +1,30 @@ -import warnings from typing import Any, List, Optional, Sequence, Union +from loguru import logger from numpy.typing import NDArray -from packaging.version import Version from bioimageio.core.common import Tensor +from bioimageio.spec._internal.io_utils import download +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import Version + +from ._model_adapter import ModelAdapter # by default, we use the keras integrated with tensorflow try: - import tensorflow as tf - from tensorflow import keras + import tensorflow as tf # pyright: ignore[reportMissingImports] + from tensorflow import ( # pyright: ignore[reportMissingImports] + keras, # pyright: ignore[reportUnknownVariableType] + ) - tf_version = Version(tf.__version__) + tf_version = Version(tf.__version__) # pyright: ignore[reportUnknownArgumentType] except Exception: try: - import keras + import keras # pyright: ignore[reportMissingImports] except Exception: keras = None tf_version = None -import xarray as xr - -from bioimageio.spec._internal.io_utils import download -from bioimageio.spec.model import v0_4, v0_5 - -from ._model_adapter import ModelAdapter class KerasModelAdapter(ModelAdapter): @@ -41,23 +41,28 @@ def __init__( model_tf_version = model_description.weights.keras_hdf5.tensorflow_version if tf_version is None or model_tf_version is None: - warnings.warn("Could not check tensorflow versions.") + logger.warning("Could not check tensorflow versions.") elif model_tf_version > tf_version: - warnings.warn( - f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." + logger.warning( + "The model specifies a newer tensorflow version than installed: {} > {}.", + model_tf_version, + tf_version, ) elif (model_tf_version.major, model_tf_version.minor) != ( tf_version.major, tf_version.minor, ): - warnings.warn( - f"Model tensorflow version {model_tf_version} does not match {tf_version}." + logger.warning( + "Model tensorflow version {} does not match {}.", + model_tf_version, + tf_version, ) # TODO keras device management if devices is not None: - warnings.warn( - f"Device management is not implemented for keras yet, ignoring the devices {devices}" + logger.warning( + "Device management is not implemented for keras yet, ignoring the devices {}", + devices, ) weight_path = download(model_description.weights.keras_hdf5.source).path @@ -76,11 +81,9 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: result = [_result] # type: ignore assert len(result) == len(self._output_axes) - return [ - xr.DataArray(r, dims=axes) for r, axes, in zip(result, self._output_axes) - ] + return [Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)] def unload(self) -> None: - warnings.warn( + logger.warning( "Device management is not implemented for keras yet, cannot unload model" ) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index ec028a2d..3e4da1df 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -101,7 +101,10 @@ def create( # we try to first import the keras model adapter using the separate package and, # if it is not available, try to load the one using tf try: - from ._keras_model_adapter import KerasModelAdapter, keras + from ._keras_model_adapter import ( + KerasModelAdapter, + keras, # type: ignore + ) if keras is None: from ._tensorflow_model_adapter import KerasModelAdapter diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 103af9bb..9a6fd4bf 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -37,21 +37,36 @@ def __init__( self._devices = self.get_devices(devices) self._network = self._network.to(self._devices[0]) + self._primary_device = self._devices[0] state: Any = torch.load( - download(weights.source).path, map_location=self._devices[0] + download(weights.source).path, + map_location=self._primary_device, # pyright: ignore[reportUnknownArgumentType] ) - _ = self._network.load_state_dict(state) + self._network.load_state_dict(state) self._network = self._network.eval() def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: + assert torch is not None with torch.no_grad(): tensors = [ None if ipt is None else torch.from_numpy(ipt.data) for ipt in input_tensors ] - tensors = [None if t is None else t.to(self._devices[0]) for t in tensors] - result: Union[Tuple[Any, ...], List[Any], Any] = self._network(*tensors) + tensors = [ + ( + None + if t is None + else t.to( + self._primary_device # pyright: ignore[reportUnknownArgumentType] + ) + ) + for t in tensors + ] + result: Union[Tuple[Any, ...], List[Any], Any] + result = self._network( # pyright: ignore[reportUnknownVariableType] + *tensors + ) if not isinstance(result, (tuple, list)): result = [result] @@ -61,7 +76,7 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: if r is None else r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r ) - for r in result + for r in result # pyright: ignore[reportUnknownVariableType] ] if len(result) > len(self.output_dims): raise ValueError( @@ -76,14 +91,16 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: def unload(self) -> None: del self._network _ = gc.collect() # deallocate memory + assert torch is not None torch.cuda.empty_cache() # release reserved memory @staticmethod - def get_network( + def get_network( # pyright: ignore[reportUnknownParameterType] weight_spec: Union[ v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr ] - ) -> "torch.nn.Module": + ) -> "torch.nn.Module": # pyright: ignore[reportInvalidTypeForm] + assert torch is not None arch = import_callable( weight_spec.architecture, sha256=( @@ -106,7 +123,10 @@ def get_network( return network @staticmethod - def get_devices(devices: Optional[Sequence[str]] = None) -> List["torch.device"]: + def get_devices( # pyright: ignore[reportUnknownParameterType] + devices: Optional[Sequence[str]] = None, + ) -> List["torch.device"]: # pyright: ignore[reportInvalidTypeForm] + assert torch is not None if not devices: torch_devices = [ ( diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index 905a4c73..0f238925 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -12,7 +12,7 @@ from ._model_adapter import ModelAdapter try: - import tensorflow as tf + import tensorflow as tf # pyright: ignore[reportMissingImports] except Exception: tf = None @@ -35,12 +35,14 @@ def __init__( assert tf is not None super().__init__() self.model_description = model_description - tf_version = v0_5.Version(tf.__version__) + tf_version = v0_5.Version( + tf.__version__ # pyright: ignore[reportUnknownArgumentType] + ) model_tf_version = weights.tensorflow_version if model_tf_version is None: warnings.warn( "The model does not specify the tensorflow version." - f"Cannot check if it is compatible with intalled tensorflow {tf_version}." + + f"Cannot check if it is compatible with intalled tensorflow {tf_version}." ) elif model_tf_version > tf_version: warnings.warn( @@ -52,7 +54,7 @@ def __init__( ): warnings.warn( "The tensorflow version specified by the model does not match the installed: " - f"{model_tf_version} != {tf_version}." + + f"{model_tf_version} != {tf_version}." ) self.use_keras_api = ( @@ -88,17 +90,25 @@ def require_unzipped(self, weight_file: FileSource): else: return loacl_weights_file - def _get_network(self, weight_file: FileSource): + def _get_network( # pyright: ignore[reportUnknownParameterType] + self, weight_file: FileSource + ): weight_file = self.require_unzipped(weight_file) + assert tf is not None if self.use_keras_api: - return tf.keras.models.load_model(weight_file, compile=False) + return tf.keras.models.load_model( + weight_file, compile=False + ) # pyright: ignore[reportUnknownVariableType] else: # NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model return str(weight_file) # TODO currently we relaod the model every time. it would be better to keep the graph and session # alive in between of forward passes (but then the sessions need to be properly opened / closed) - def _forward_tf(self, *input_tensors): + def _forward_tf( # pyright: ignore[reportUnknownParameterType] + self, *input_tensors: Optional[Tensor] + ): + assert tf is not None input_keys = [ ipt.name if isinstance(ipt, v0_4.InputTensorDescr) else ipt.id for ipt in self.model_description.inputs @@ -107,74 +117,114 @@ def _forward_tf(self, *input_tensors): out.name if isinstance(out, v0_4.OutputTensorDescr) else out.id for out in self.model_description.outputs ] - # TODO read from spec - tag = tf.saved_model.tag_constants.SERVING - signature_key = ( + tag = ( # pyright: ignore[reportUnknownVariableType] + tf.saved_model.tag_constants.SERVING + ) + signature_key = ( # pyright: ignore[reportUnknownVariableType] tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY ) - graph = tf.Graph() + graph = tf.Graph() # pyright: ignore[reportUnknownVariableType] with graph.as_default(): - with tf.Session(graph=graph) as sess: + with tf.Session( + graph=graph + ) as sess: # pyright: ignore[reportUnknownVariableType] # load the model and the signature - graph_def = tf.saved_model.loader.load(sess, [tag], self._network) - signature = graph_def.signature_def + graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType] + sess, [tag], self._network + ) + signature = ( # pyright: ignore[reportUnknownVariableType] + graph_def.signature_def + ) # get the tensors into the graph - in_names = [ + in_names = [ # pyright: ignore[reportUnknownVariableType] signature[signature_key].inputs[key].name for key in input_keys ] - out_names = [ + out_names = [ # pyright: ignore[reportUnknownVariableType] signature[signature_key].outputs[key].name for key in output_keys ] - in_tensors = [graph.get_tensor_by_name(name) for name in in_names] - out_tensors = [graph.get_tensor_by_name(name) for name in out_names] + in_tensors = [ # pyright: ignore[reportUnknownVariableType] + graph.get_tensor_by_name(name) + for name in in_names # pyright: ignore[reportUnknownVariableType] + ] + out_tensors = [ # pyright: ignore[reportUnknownVariableType] + graph.get_tensor_by_name(name) + for name in out_names # pyright: ignore[reportUnknownVariableType] + ] # run prediction - res = sess.run( - dict(zip(out_names, out_tensors)), - dict(zip(in_tensors, input_tensors)), + res = sess.run( # pyright: ignore[reportUnknownVariableType] + dict( + zip( + out_names, # pyright: ignore[reportUnknownArgumentType] + out_tensors, # pyright: ignore[reportUnknownArgumentType] + ) + ), + dict( + zip( + in_tensors, # pyright: ignore[reportUnknownArgumentType] + input_tensors, + ) + ), ) # from dict to list of tensors - res = [res[out] for out in out_names] + res = [ # pyright: ignore[reportUnknownVariableType] + res[out] + for out in out_names # pyright: ignore[reportUnknownVariableType] + ] - return res + return res # pyright: ignore[reportUnknownVariableType] - def _forward_keras(self, *input_tensors: Optional[Tensor]): + def _forward_keras( # pyright: ignore[reportUnknownParameterType] + self, *input_tensors: Optional[Tensor] + ): assert self.use_keras_api assert not isinstance(self._network, str) - tf_tensor = [ + assert tf is not None + tf_tensor = [ # pyright: ignore[reportUnknownVariableType] None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_tensors ] try: - result = self._network.forward(*tf_tensor) + result = ( # pyright: ignore[reportUnknownVariableType] + self._network.forward(*tf_tensor) + ) except AttributeError: - result = self._network.predict(*tf_tensor) + result = ( # pyright: ignore[reportUnknownVariableType] + self._network.predict(*tf_tensor) + ) if not isinstance(result, (tuple, list)): - result = [result] + result = [result] # pyright: ignore[reportUnknownVariableType] - return [ + return [ # pyright: ignore[reportUnknownVariableType] ( None if r is None else r if isinstance(r, np.ndarray) else tf.make_ndarray(r) ) - for r in result + for r in result # pyright: ignore[reportUnknownVariableType] ] def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: data = [None if ipt is None else ipt.data for ipt in input_tensors] if self.use_keras_api: - result = self._forward_keras(*data) + result = self._forward_keras( # pyright: ignore[reportUnknownVariableType] + *data + ) else: - result = self._forward_tf(*data) + result = self._forward_tf( # pyright: ignore[reportUnknownVariableType] + *data + ) return [ None if r is None else Tensor(r, dims=axes) - for r, axes in zip(result, self._internal_output_axes) + for r, axes in zip( # pyright: ignore[reportUnknownVariableType] + result, # pyright: ignore[reportUnknownArgumentType] + self._internal_output_axes, + ) ] def unload(self) -> None: diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index 8ef52616..c50d131a 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -42,7 +42,9 @@ def __init__( "Multiple devices for single torchscript model not yet implemented" ) - self._model = torch.jit.load(weight_path) + self._model = torch.jit.load( # pyright: ignore[reportPrivateImportUsage] + weight_path + ) self._model.to(self.devices[0]) self._internal_output_axes = [ ( @@ -54,6 +56,7 @@ def __init__( ] def forward(self, *batch: Optional[Tensor]) -> List[Optional[Tensor]]: + assert torch is not None with torch.no_grad(): torch_tensor = [ None if b is None else torch.from_numpy(b.data).to(self.devices[0]) @@ -85,6 +88,7 @@ def forward(self, *batch: Optional[Tensor]) -> List[Optional[Tensor]]: ] def unload(self) -> None: + assert torch is not None self._devices = None del self._model _ = gc.collect() # deallocate memory From 5299d232f26457ff71f40376620dc24681b34533 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 21 Mar 2024 17:05:52 +0100 Subject: [PATCH 139/244] small fixes and improve typing --- bioimageio/core/_prediction_pipeline.py | 13 ++-- bioimageio/core/_resource_tests.py | 27 +++++--- bioimageio/core/proc_ops.py | 69 ++++++++++--------- bioimageio/core/proc_setup.py | 4 +- .../weight_converter/keras/_tensorflow.py | 21 +++--- .../core/weight_converter/torch/_onnx.py | 1 + .../weight_converter/torch/_torchscript.py | 1 + .../core/weight_converter/torch/_utils.py | 18 +++-- tests/test_stat_measures.py | 8 +-- 9 files changed, 96 insertions(+), 66 deletions(-) diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 2d3c42b0..8f49c654 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -46,8 +46,8 @@ def __init__( self._adapter: ModelAdapter = model def __call__( - self, *input_tensors: Tensor, **named_input_tensors: Tensor - ) -> List[Tensor]: + self, *input_tensors: Optional[Tensor], **named_input_tensors: Optional[Tensor] + ) -> List[Optional[Tensor]]: return self.forward(*input_tensors, **named_input_tensors) def __enter__(self): @@ -60,7 +60,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore def predict( self, *input_tensors: Optional[Tensor], **named_input_tensors: Optional[Tensor] - ) -> List[Tensor]: + ) -> List[Optional[Tensor]]: """Predict input_tensor with the model without applying pre/postprocessing.""" named_tensors = [ named_input_tensors.get(str(k)) @@ -86,7 +86,12 @@ def forward_sample(self, input_sample: Sample) -> Sample: **{str(k): v for k, v in input_sample.data.items()} ) prediction = Sample( - data=dict(zip(self.output_ids, prediction_tensors)), stat=input_sample.stat + data={ + tid: t + for tid, t in zip(self.output_ids, prediction_tensors) + if t is not None + }, + stat=input_sample.stat, ) self.apply_postprocessing(prediction) return prediction diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index 2a482da7..b04e42f8 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -127,17 +127,18 @@ def _test_model_inference( results = prediction_pipeline.forward(*inputs) if len(results) != len(expected): - error = (error or "") + ( - f"Expected {len(expected)} outputs, but got {len(results)}" - ) + error = f"Expected {len(expected)} outputs, but got {len(results)}" + else: for res, exp in zip(results, expected): + if res is None: + error = "Output tensors for test case may not be None" + break try: np.testing.assert_array_almost_equal(res, exp, decimal=decimal) except AssertionError as e: - error = ( - error or "" - ) + f"Output and expected output disagree:\n {e}" + error = f"Output and expected output disagree:\n {e}" + break except Exception as e: error = str(e) tb = traceback.format_tb(e.__traceback__) @@ -238,11 +239,17 @@ def get_ns(n: int): error: Optional[str] = None results = prediction_pipeline.forward(*inputs) if len(results) != len(exptected_output_shape): - error = (error or "") + ( - f"Expected {len(exptected_output_shape)} outputs, but got {len(results)}" + error = ( + f"Expected {len(exptected_output_shape)} outputs," + + f" but got {len(results)}" ) + else: for res, exp in zip(results, exptected_output_shape): + if res is None: + error = "Output tensors may not be None for test case" + break + diff: Dict[AxisId, int] = {} for a, s in res.sizes.items(): if isinstance((e_aid := exp[AxisId(a)]), int): @@ -254,10 +261,10 @@ def get_ns(n: int): diff[AxisId(a)] = s if diff: error = ( - (error or "") - + f"(n={n}) Expected output shape {exp}," + f"(n={n}) Expected output shape {exp}," + f" but got {res.sizes} ({diff})\n" ) + break model.validation_summary.add_detail( ValidationDetail( diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 6a262639..5299df4a 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -10,7 +10,6 @@ Set, Tuple, Union, - cast, ) import numpy as np @@ -168,22 +167,29 @@ def __call__(self, sample: Sample) -> None: class Binarize(_SimpleOperator): """'output = tensor > threshold'.""" - threshold: float + threshold: Union[float, Sequence[float]] + axis: Optional[AxisId] = None def _apply(self, input: Tensor, stat: Stat) -> xr.DataArray: return input > self.threshold - # @classmethod - # def from_descr(cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr]): - # return cls(threshold=descr.kwargs.threshold) - - # def get_descr(self): - # return v0_5.BinarizeDescr(kwargs=v0_5.BinarizeKwargs(threshold=self.threshold)) @classmethod def from_proc_descr( cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], tensor_id: TensorId ) -> Self: - return cls(input=tensor_id, output=tensor_id, threshold=descr.kwargs.threshold) + if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): + return cls( + input=tensor_id, output=tensor_id, threshold=descr.kwargs.threshold + ) + elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): + return cls( + input=tensor_id, + output=tensor_id, + threshold=descr.kwargs.threshold, + axis=descr.kwargs.axis, + ) + else: + assert_never(descr.kwargs) @dataclass @@ -224,7 +230,9 @@ def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, tensor_id: TensorId): def get_descr(self): return v0_5.EnsureDtypeDescr( - kwargs=v0_5.EnsureDtypeKwargs(dtype=str(self.dtype)) + kwargs=v0_5.EnsureDtypeKwargs( + dtype=str(self.dtype) # pyright: ignore[reportArgumentType] + ) ) def _apply(self, input: Tensor, stat: Stat) -> Tensor: @@ -242,10 +250,6 @@ class ScaleLinear(_SimpleOperator): def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input * self.gain + self.offset - # @classmethod - # def from_descr(cls, descr: ScaleLinearDescr) -> Self: - # ... - @classmethod def from_proc_descr( cls, @@ -253,14 +257,12 @@ def from_proc_descr( tensor_id: TensorId, ) -> Self: kwargs = descr.kwargs - if isinstance(kwargs, v0_5.ScaleLinearKwargs): + if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): axis = kwargs.axis - elif kwargs.axes is not None: - raise NotImplementedError( - "ScaleLinear operator from v0_4.ScaleLinearDescr with axes" - ) - else: + elif isinstance(kwargs, (v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs)): axis = None + else: + assert_never(kwargs) if axis: gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) @@ -535,29 +537,34 @@ def from_proc_descr( descr: v0_5.FixedZeroMeanUnitVarianceDescr, tensor_id: TensorId, ) -> Self: + if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): + dims = None + elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): + dims = (descr.kwargs.axis,) + else: + assert_never(descr.kwargs) + return cls( input=tensor_id, output=tensor_id, - mean=xr.DataArray(descr.kwargs.mean, dims=(descr.kwargs.axis,)), - std=xr.DataArray(descr.kwargs.std, dims=(descr.kwargs.axis,)), + mean=xr.DataArray(descr.kwargs.mean, dims=dims), + std=xr.DataArray(descr.kwargs.std, dims=dims), ) def get_descr(self): if isinstance(self.mean, (int, float)): assert isinstance(self.std, (int, float)) - axis = None - mean = self.mean - std = self.std + kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) else: assert isinstance(self.std, xr.DataArray) assert len(self.mean.dims) == 1 - axis = AxisId(str(self.mean.dims[0])) - mean = tuple(self.mean) - std = tuple(self.std) + kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( + axis=AxisId(str(self.mean.dims[0])), + mean=list(self.mean), + std=list(self.std), + ) - return v0_5.FixedZeroMeanUnitVarianceDescr( - kwargs=v0_5.FixedZeroMeanUnitVarianceKwargs(axis=axis, mean=mean, std=std) - ) + return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: return (input - self.mean) / (self.std + self.eps) diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index fbfb37ff..b7bd54bb 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -115,8 +115,8 @@ def prepare_procs(tensor_descrs: Sequence[TensorDescr]): else t_descr.id ) req = proc_class.from_proc_descr( - proc_d, tensor_id - ) # pyright: ignore[reportArgumentType] + proc_d, tensor_id # pyright: ignore[reportArgumentType] + ) for m in req.required_measures: if m.tensor_id in input_ids: pre_measures.add(m) diff --git a/bioimageio/core/weight_converter/keras/_tensorflow.py b/bioimageio/core/weight_converter/keras/_tensorflow.py index adad502b..c901f458 100644 --- a/bioimageio/core/weight_converter/keras/_tensorflow.py +++ b/bioimageio/core/weight_converter/keras/_tensorflow.py @@ -1,3 +1,4 @@ +# type: ignore # TODO: type import os import shutil from pathlib import Path @@ -7,7 +8,7 @@ try: import tensorflow.saved_model except Exception: - _tensorflow = None + tensorflow = None from bioimageio.spec._internal.io_utils import download from bioimageio.spec.model.v0_5 import ModelDescr @@ -41,7 +42,9 @@ def _convert_tf1( ): try: # try to build the tf model with the keras import from tensorflow - from bioimageio.core.weight_converter.keras._tensorflow import keras # type: ignore + from bioimageio.core.weight_converter.keras._tensorflow import ( + keras, # type: ignore + ) except Exception: # if the above fails try to export with the standalone keras @@ -50,20 +53,20 @@ def _convert_tf1( @no_type_check def build_tf_model(): keras_model = keras.models.load_model(keras_weight_path) - assert _tensorflow is not None - builder = _tensorflow.saved_model.builder.SavedModelBuilder(output_path) - signature = _tensorflow.saved_model.signature_def_utils.predict_signature_def( + assert tensorflow is not None + builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path) + signature = tensorflow.saved_model.signature_def_utils.predict_signature_def( inputs={input_name: keras_model.input}, outputs={output_name: keras_model.output}, ) signature_def_map = { - _tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature + tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature } builder.add_meta_graph_and_variables( keras.backend.get_session(), - [_tensorflow.saved_model.tag_constants.SERVING], + [tensorflow.saved_model.tag_constants.SERVING], signature_def_map=signature_def_map, ) builder.save() @@ -107,8 +110,8 @@ def convert_weights_to_tensorflow_saved_model_bundle( model: The bioimageio model description output_path: where to save the tensorflow weights. This path must not exist yet. """ - assert _tensorflow is not None - tf_major_ver = int(_tensorflow.__version__.split(".")[0]) + assert tensorflow is not None + tf_major_ver = int(tensorflow.__version__.split(".")[0]) if output_path.suffix == ".zip": output_path = output_path.with_suffix("") diff --git a/bioimageio/core/weight_converter/torch/_onnx.py b/bioimageio/core/weight_converter/torch/_onnx.py index 2b4d1caf..50c56fbd 100644 --- a/bioimageio/core/weight_converter/torch/_onnx.py +++ b/bioimageio/core/weight_converter/torch/_onnx.py @@ -1,3 +1,4 @@ +# type: ignore # TODO: type import warnings from pathlib import Path from typing import Any, List, Sequence, cast diff --git a/bioimageio/core/weight_converter/torch/_torchscript.py b/bioimageio/core/weight_converter/torch/_torchscript.py index ee11610c..0d226563 100644 --- a/bioimageio/core/weight_converter/torch/_torchscript.py +++ b/bioimageio/core/weight_converter/torch/_torchscript.py @@ -1,3 +1,4 @@ +# type: ignore # TODO: type from pathlib import Path from typing import List, Sequence, Union diff --git a/bioimageio/core/weight_converter/torch/_utils.py b/bioimageio/core/weight_converter/torch/_utils.py index 2acf17be..d3908f61 100644 --- a/bioimageio/core/weight_converter/torch/_utils.py +++ b/bioimageio/core/weight_converter/torch/_utils.py @@ -1,3 +1,5 @@ +from typing import Union + import torch from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter @@ -7,10 +9,14 @@ # additional convenience for pytorch state dict, eventually we want this in python-bioimageio too # and for each weight format -def load_torch_model( - node: "v0_4.PytorchStateDictWeightsDescr | v0_5.PytorchStateDictWeightsDescr", +def load_torch_model( # pyright: ignore[reportUnknownParameterType] + node: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr], ): - model = PytorchModelAdapter.get_network(node) - state = torch.load(download(node.source).path, map_location="cpu") - _ = model.load_state_dict(state) # FIXME: check incompatible keys? - return model.eval() + model = ( # pyright: ignore[reportUnknownVariableType] + PytorchModelAdapter.get_network(node) + ) + state = torch.load( # pyright: ignore[reportUnknownVariableType] + download(node.source).path, map_location="cpu" + ) + model.load_state_dict(state) # FIXME: check incompatible keys? + return model.eval() # pyright: ignore[reportUnknownVariableType] diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index ea8774b2..1bd6231f 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -1,12 +1,12 @@ from itertools import product -from typing import Literal, Optional, Tuple +from typing import Optional, Tuple import numpy as np import pytest import xarray as xr from bioimageio.core import stat_measures -from bioimageio.core.common import AxisId, Sample, TensorId +from bioimageio.core.common import AxisId, Sample, Tensor, TensorId from bioimageio.core.stat_calculators import ( SamplePercentilesCalculator, get_measure_calculators, @@ -29,7 +29,7 @@ def test_individual_normal_measure( measure = getattr(stat_measures, "Sample" + name.title())( axes=axes, tensor_id=data_id ) - data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) + data = Tensor(np.random.random((5, 6, 3)), dims=("x", "y", "c")) expected = getattr(data, name)(dim=axes) sample = Sample(data={data_id: data}) @@ -48,7 +48,7 @@ def test_individual_percentile_measure(axes: Optional[Tuple[AxisId, ...]]): calc = calcs[0] assert isinstance(calc, SamplePercentilesCalculator) - data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) + data = Tensor(np.random.random((5, 6, 3)), dims=("x", "y", "c")) actual = calc.compute(Sample(data={tid: data})) for m in measures: expected = data.quantile(q=m.n / 100, dim=m.axes) From 7a3c64f6fc7043a4559fb2538f9fd315d99ead4a Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 21 Mar 2024 17:42:25 +0100 Subject: [PATCH 140/244] try post-cleanup: 'all' --- .github/workflows/build.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 945912ba..85d1f69c 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -37,6 +37,7 @@ jobs: environment-file: dev/environment-torch.yaml create-args: >- python=${{ matrix.python-version }} + post-cleanup: 'all' - name: additional setup run: pip install --no-deps -e . - name: pytest-spec-conda @@ -57,6 +58,7 @@ jobs: environment-file: dev/environment-torch.yaml create-args: >- python=${{ matrix.python-version }} + post-cleanup: 'all' - name: additional setup run: | conda remove --yes --force bioimageio.spec || true # allow failure for cached env @@ -82,6 +84,7 @@ jobs: channel-priority: flexible create-args: >- python=${{ matrix.python-version }} + post-cleanup: 'all' - name: additional setup run: | conda remove --yes --force bioimageio.spec || true # allow failure for cached env From 9b15feaa67bad0e5f8496c64cda33dd4856f1b59 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 21 Mar 2024 17:49:31 +0100 Subject: [PATCH 141/244] disable parallel pytests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 083aaf84..e4356c86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ typeCheckingMode = "strict" useLibraryCodeForTypes = true [tool.pytest.ini_options] -addopts = " -n auto --capture=no --doctest-modules --failed-first" +addopts = " -n 0 --capture=no --doctest-modules --failed-first" [tool.ruff] line-length = 88 From 564f58153fb2cf52b224655f179c5fd8b7f22483 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 21 Mar 2024 17:50:27 +0100 Subject: [PATCH 142/244] bump black version --- .github/workflows/build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 85d1f69c..cfe2710d 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -20,7 +20,7 @@ jobs: options: "--check --verbose" src: "." jupyter: true - version: "23.7" + version: "24.3" test-spec-conda: runs-on: ubuntu-latest From e545a7d27f425d82fdb00f7afd445e199c339a58 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 21 Mar 2024 17:53:56 +0100 Subject: [PATCH 143/244] fix python version matrix --- .github/workflows/build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index cfe2710d..3fe5552f 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -26,7 +26,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.9, 3.10, 3.11, 3.12] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v3 - name: Install Conda environment with Micromamba From 966a08569b7990392ddb96c024aa09913e0f6d8d Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 22 Mar 2024 09:55:35 +0100 Subject: [PATCH 144/244] set keras backend --- bioimageio/core/_settings.py | 20 +++++++++++++++++++ .../model_adapters/_keras_model_adapter.py | 4 ++++ 2 files changed, 24 insertions(+) create mode 100644 bioimageio/core/_settings.py diff --git a/bioimageio/core/_settings.py b/bioimageio/core/_settings.py new file mode 100644 index 00000000..d09f3b8b --- /dev/null +++ b/bioimageio/core/_settings.py @@ -0,0 +1,20 @@ +from typing import Literal + +from dotenv import load_dotenv +from pydantic import Field +from typing_extensions import Annotated + +from bioimageio.spec._internal._settings import Settings as SpecSettings + +_ = load_dotenv() + + +class Settings(SpecSettings): + """environment variables""" + + keras_backend: Annotated[ + Literal["torch", "tensorflow", "jax"], Field(alias="KERAS_BACKEND") + ] = "torch" + + +settings = Settings() diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py index af69315f..785fd6a7 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -1,3 +1,4 @@ +import os from typing import Any, List, Optional, Sequence, Union from loguru import logger @@ -8,8 +9,11 @@ from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import Version +from .._settings import settings from ._model_adapter import ModelAdapter +os.environ["KERAS_BACKEND"] = settings.keras_backend + # by default, we use the keras integrated with tensorflow try: import tensorflow as tf # pyright: ignore[reportMissingImports] From 33d1d114adfa4e8ab9bd32626041081368206bad Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 22 Mar 2024 09:56:05 +0100 Subject: [PATCH 145/244] fix normalize_axes --- bioimageio/core/__init__.py | 1 + bioimageio/core/utils/image_helper.py | 56 ++++++++++++++------------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 780e5cc3..ef261839 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -22,6 +22,7 @@ from ._resource_tests import load_description_and_test as load_description_and_test from ._resource_tests import test_description as test_description from ._resource_tests import test_model as test_model +from ._settings import settings as settings from .utils import VERSION __version__ = VERSION diff --git a/bioimageio/core/utils/image_helper.py b/bioimageio/core/utils/image_helper.py index e3b3e5d4..b3e23320 100644 --- a/bioimageio/core/utils/image_helper.py +++ b/bioimageio/core/utils/image_helper.py @@ -39,13 +39,18 @@ def normalize_axes( "y": "space", "z": "space", } + AXIS_ID_MAP = { + "b": "batch", + "t": "time", + "i": "index", + "c": "channel", + } if isinstance(axes, str): - return tuple(Axis(id=AxisId(a), type=AXIS_TYPE_MAP[a]) for a in axes) - else: return tuple( - Axis(id=a.id if isinstance(a.id, AxisId) else AxisId(a.id), type=a.type) - for a in axes + Axis(id=AxisId(AXIS_ID_MAP.get(a, a)), type=AXIS_TYPE_MAP[a]) for a in axes ) + else: + return tuple(Axis(id=AxisId(a.id), type=a.type) for a in axes) def _interprete_array_wo_known_axes(array: NDArray[Any]): @@ -110,34 +115,33 @@ def interprete_array( if len(array.shape) == len(axes): break - if len(array.shape) < len(axes): - # add singletons - for a in axes: - if len(array.shape) == len(axes): - break + # add singletons if nececsary + for a in axes: + if len(array.shape) >= len(axes): + break - if isinstance(a, str) or a.size is None: - array = array[None] - continue + if isinstance(a, str) or a.size is None: + array = array[None] + continue - if isinstance(a.size, int): - if a.size == 1: - array = array[None] + if isinstance(a.size, int): + if a.size == 1: + array = array[None] - continue + continue - if isinstance(a.size, SizeReference): - continue # TODO: check if singleton is ok for a `SizeReference` + if isinstance(a.size, SizeReference): + continue # TODO: check if singleton is ok for a `SizeReference` - try: - maybe_size_one = a.size.validate_size( - 1 - ) # TODO: refactor validate_size() to have boolean func here - except ValueError: - continue + try: + maybe_size_one = a.size.validate_size( + 1 + ) # TODO: refactor validate_size() to have boolean func here + except ValueError: + continue - if maybe_size_one == 1: - array = array[None] + if maybe_size_one == 1: + array = array[None] if len(array.shape) != len(axes): raise ValueError(f"Array shape {original_shape} does not map to axes {axes}") From 7dda729cd6d7f4bfb257c1ad9f45c71e818b8b26 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 22 Mar 2024 09:58:32 +0100 Subject: [PATCH 146/244] update setup.py and dev envs --- .github/workflows/build.yaml | 33 +++++++++++++++++---------------- README.md | 23 ++++++++++++----------- dev/env-wo-python.yaml | 36 ++++++++++++++++++++++++++++++++++++ dev/env.yaml | 23 +++++++++++++---------- dev/environment-base.yaml | 28 ---------------------------- dev/environment-tf.yaml | 30 +++++++++++++++++++++++------- dev/environment-torch.yaml | 24 ------------------------ setup.py | 12 +++++++++--- 8 files changed, 110 insertions(+), 99 deletions(-) create mode 100644 dev/env-wo-python.yaml delete mode 100644 dev/environment-base.yaml delete mode 100644 dev/environment-torch.yaml diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 3fe5552f..c06980da 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -14,7 +14,7 @@ jobs: black: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: psf/black@stable with: options: "--check --verbose" @@ -28,13 +28,13 @@ jobs: matrix: python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install Conda environment with Micromamba uses: mamba-org/setup-micromamba@v1 with: cache-downloads: true cache-environment: true - environment-file: dev/environment-torch.yaml + environment-file: dev/env-wo-python.yaml create-args: >- python=${{ matrix.python-version }} post-cleanup: 'all' @@ -47,15 +47,15 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.12] + python-version: ['3.8', '3.12'] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install Conda environment with Micromamba uses: mamba-org/setup-micromamba@v1 with: cache-downloads: true cache-environment: true - environment-file: dev/environment-torch.yaml + environment-file: dev/env-wo-python.yaml create-args: >- python=${{ matrix.python-version }} post-cleanup: 'all' @@ -67,19 +67,19 @@ jobs: - name: pytest-spec-main run: pytest --disable-pytest-warnings - test-spec-tf: + test-tf: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: ['3.8', '3.12'] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install Conda environment with Micromamba uses: mamba-org/setup-micromamba@v1 with: cache-downloads: true cache-environment: true - environment-file: dev/environment-tf.yaml + environment-file: dev/env-tf.yaml condarc: | channel-priority: flexible create-args: >- @@ -98,7 +98,7 @@ jobs: needs: test-spec-conda steps: - name: checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - name: Install Conda environment with Micromamba @@ -120,10 +120,12 @@ jobs: if: github.ref == 'refs/heads/main' runs-on: ubuntu-latest steps: - - name: Install dependencies - run: | - pip install --upgrade pip - pip install -e .[dev] + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + cache: 'pip' + - run: pip install -e .[dev] - name: Generate developer docs run: pdoc -o ./dist bioimageio.spec - run: cp README.md ./dist/README.md @@ -132,4 +134,3 @@ jobs: with: branch: gh-pages folder: dist - diff --git a/README.md b/README.md index 54cfb069..51ac8c37 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ Python specific core utilities for running models in the [BioImage Model Zoo](ht ## Installation -### Via Conda +### Via Mamba/Conda The `bioimageio.core` package can be installed from conda-forge via ```console -conda install -c conda-forge bioimageio.core +mamba install -c conda-forge bioimageio.core ``` If you do not install any additional deep learning libraries, you will only be able to use general convenience @@ -21,13 +21,13 @@ To install additional deep learning libraries use: CPU installation (if you don't have an nvidia graphics card): ```console - conda install -c pytorch -c conda-forge bioimageio.core pytorch torchvision cpuonly + mamba install -c pytorch -c conda-forge bioimageio.core pytorch torchvision cpuonly ``` GPU installation (for cuda 11.6, please choose the appropriate cuda version for your system): ```console - conda install -c pytorch -c nvidia -c conda-forge bioimageio.core pytorch torchvision pytorch-cuda=11.6 + mamba install -c pytorch -c nvidia -c conda-forge bioimageio.core pytorch torchvision pytorch-cuda=11.8 ``` Note that the pytorch installation instructions may change in the future. For the latest instructions please refer to [pytorch.org](https://pytorch.org/). @@ -37,7 +37,7 @@ To install additional deep learning libraries use: Currently only CPU version supported ```console - conda install -c conda-forge bioimageio.core tensorflow + mamba install -c conda-forge bioimageio.core tensorflow ``` * ONNXRuntime @@ -45,24 +45,25 @@ To install additional deep learning libraries use: Currently only cpu version supported ```console - conda install -c conda-forge bioimageio.core onnxruntime + mamba install -c conda-forge bioimageio.core onnxruntime ``` ### Via pip -The package is also available via pip: +The package is also available via pip +(e.g. with recommended extras `onnx` and `pytorch`): ```console -pip install bioimageio.core +pip install bioimageio.core[onnx,pytorch] ``` ### Set up Development Environment -To set up a development conda environment run the following commands: +To set up a development mamba environment run the following commands: ```console -conda env create -f dev/environment-base.yaml -conda activate bio-core-dev +mamba env create -f dev/env.yaml +mamba activate core pip install -e . --no-deps ``` diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml new file mode 100644 index 00000000..4fc66314 --- /dev/null +++ b/dev/env-wo-python.yaml @@ -0,0 +1,36 @@ +# modified copy of env.yaml +name: core +channels: + - conda-forge + - defaults +dependencies: + - bioimageio.spec>=0.5.1 + - black + - crick + - filelock + - imageio>=2.5 + - keras>=3.0 + - loguru + - numpy + - onnxruntime + - packaging>=17.0 + - pip + - pre-commit + - psutil + - pydantic + - pydantic-settings + - pyright + - pytest + - pytest-xdist + - python-dotenv + # - python=3.8 # removed + - pytorch>=1.6 + - ruff + - ruyaml + - torchvision + - tqdm + - typer + - typing-extensions + - xarray + - pip: + - --no-deps -e .. diff --git a/dev/env.yaml b/dev/env.yaml index 78ed1f8d..02423c46 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -1,28 +1,29 @@ -name: bio38 +name: core channels: - conda-forge - defaults dependencies: - - annotated-types - - bioimageio.spec==0.5.* + - bioimageio.spec>=0.5.1 - black - - deepdiff - - email-validator + - crick - filelock - imageio>=2.5 + - keras>=3.0 - loguru - - lxml - numpy - onnxruntime - packaging>=17.0 - - pooch + - pip - pre-commit - - pydantic>=2.6.4 + - psutil + - pydantic + - pydantic-settings - pyright - pytest - - python-dateutil + - pytest-xdist + - python-dotenv - python=3.8 - - pytorch + - pytorch>=1.6 - ruff - ruyaml - torchvision @@ -30,3 +31,5 @@ dependencies: - typer - typing-extensions - xarray + - pip: + - --no-deps -e .. diff --git a/dev/environment-base.yaml b/dev/environment-base.yaml deleted file mode 100644 index 96a96d91..00000000 --- a/dev/environment-base.yaml +++ /dev/null @@ -1,28 +0,0 @@ -name: bio-core-dev -channels: - - conda-forge - - defaults -dependencies: - - bioimageio.spec==0.5.* - - black - - conda-build - - filelock - - h5py >=2.10,<2.11 - - loguru - - mypy - - onnx - - onnxruntime - - pip - - pre-commit - - psutil - - pytest - - pytest-xdist - - python >=3.7,<3.8 # this environment is only available for python 3.7 - - pytorch - - ruyaml - - tensorflow >=1.12,<2.0 - - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 - - typer - - xarray - - pip: - - keras==1.2.2 diff --git a/dev/environment-tf.yaml b/dev/environment-tf.yaml index 03c6b08b..1de415f8 100644 --- a/dev/environment-tf.yaml +++ b/dev/environment-tf.yaml @@ -1,21 +1,37 @@ -name: bio-core-tf +# modified copy of env.yaml +name: core-tf # changed channels: - conda-forge - defaults dependencies: - - bioimageio.spec==0.5.* + - bioimageio.spec>=0.5.1 - black - - conda-build + - crick - filelock + - imageio>=2.5 + - keras>=3.0 - loguru - - mypy + - numpy + - onnxruntime + - packaging>=17.0 - pip + - pre-commit - psutil + - pydantic + - pydantic-settings + - pyright - pytest - pytest-xdist - - python + - python-dotenv + - python=3.8 + # - pytorch>=1.6 # removed + - ruff - ruyaml - - tensorflow >=2.9,<3.0 - - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 + - tensorflow>=2.16 # added + # - torchvision # removed + - tqdm - typer + - typing-extensions - xarray + - pip: + - --no-deps -e .. diff --git a/dev/environment-torch.yaml b/dev/environment-torch.yaml deleted file mode 100644 index d5809082..00000000 --- a/dev/environment-torch.yaml +++ /dev/null @@ -1,24 +0,0 @@ -name: bio-core-torch -channels: - - conda-forge - - defaults -dependencies: - - bioimageio.spec==0.5.* - - black - - conda-build - - filelock - - h5py - - loguru - - mypy - - onnx - - onnxruntime - - pip - - psutil - - pytest - - pytest-xdist - - python >=3.8 - - pytorch - - ruyaml - - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 - - typer - - xarray diff --git a/setup.py b/setup.py index 0af57391..f8aa1ee0 100644 --- a/setup.py +++ b/setup.py @@ -30,27 +30,33 @@ packages=find_namespace_packages(exclude=["tests"]), install_requires=[ "bioimageio.spec==0.5.1.*", + "dotenv", "imageio>=2.5", "loguru", "numpy", + "pydantic-settings", + "pydantic", "ruyaml", - "tifffile", "tqdm", "typer", + "typing-extensions", "xarray", ], include_package_data=True, extras_require={ - "pytorch": ["torch>=1.6", "torchvision"], - "tensorflow": ["tensorflow"], + "pytorch": ["torch>=1.6", "torchvision", "keras>=3.0"], + "tensorflow": ["tensorflow", "keras>=3.0"], "onnx": ["onnxruntime"], "dev": [ "black", "crick", "filelock", + "keras>=3.0", "onnxruntime", + "packaging>=17.0", "pre-commit", "psutil", # parallel pytest with 'pytest -n auto' + "pyright", "pytest-xdist", # parallel pytest "pytest", "torch>=1.6", From 3931d682a6c7665554dcbc94d82c5f9d35a5d770 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 22 Mar 2024 10:58:43 +0100 Subject: [PATCH 147/244] keras needs py >=3.9 --- .github/workflows/build.yaml | 9 +++++++++ README.md | 2 +- dev/env-py38.yaml | 36 ++++++++++++++++++++++++++++++++++++ dev/env-wo-python.yaml | 4 ++-- dev/env.yaml | 4 ++-- dev/environment-tf.yaml | 4 ++-- 6 files changed, 52 insertions(+), 7 deletions(-) create mode 100644 dev/env-py38.yaml diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index c06980da..77bf04c7 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -30,6 +30,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Install Conda environment with Micromamba + if: matrix.python-version != '3.8' uses: mamba-org/setup-micromamba@v1 with: cache-downloads: true @@ -38,6 +39,14 @@ jobs: create-args: >- python=${{ matrix.python-version }} post-cleanup: 'all' + - name: Install py3.8 environment + if: matrix.python-version == '3.8' + uses: mamba-org/setup-micromamba@v1 + with: + cache-downloads: true + cache-environment: true + environment-file: dev/env-py38.yaml + post-cleanup: 'all' - name: additional setup run: pip install --no-deps -e . - name: pytest-spec-conda diff --git a/README.md b/README.md index 51ac8c37..dd76c085 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ pip install bioimageio.core[onnx,pytorch] ### Set up Development Environment -To set up a development mamba environment run the following commands: +To set up a development conda environment run the following commands: ```console mamba env create -f dev/env.yaml diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml new file mode 100644 index 00000000..4c9cefd8 --- /dev/null +++ b/dev/env-py38.yaml @@ -0,0 +1,36 @@ +# manipulated copy of env.yaml +name: core38 +channels: + - conda-forge + - defaults +dependencies: + - bioimageio.spec>=0.5.1 + - black + - crick + - filelock + - imageio>=2.5 + # - keras>=3.0 # removed + - loguru + - numpy + - onnxruntime + - packaging>=17.0 + - pip + - pre-commit + - psutil + - pydantic + - pydantic-settings + - pyright + - pytest + - pytest-xdist + - python-dotenv + - python=3.8 # changed + - pytorch>=2.1 + - ruff + - ruyaml + - torchvision + - tqdm + - typer + - typing-extensions + - xarray + - pip: + - --no-deps -e .. diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml index 4fc66314..8ddce65b 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-wo-python.yaml @@ -23,8 +23,8 @@ dependencies: - pytest - pytest-xdist - python-dotenv - # - python=3.8 # removed - - pytorch>=1.6 + # - python=3.9 # removed + - pytorch>=2.1 - ruff - ruyaml - torchvision diff --git a/dev/env.yaml b/dev/env.yaml index 02423c46..580cf7be 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -22,8 +22,8 @@ dependencies: - pytest - pytest-xdist - python-dotenv - - python=3.8 - - pytorch>=1.6 + - python=3.9 + - pytorch>=2.1 - ruff - ruyaml - torchvision diff --git a/dev/environment-tf.yaml b/dev/environment-tf.yaml index 1de415f8..566b7ca2 100644 --- a/dev/environment-tf.yaml +++ b/dev/environment-tf.yaml @@ -23,8 +23,8 @@ dependencies: - pytest - pytest-xdist - python-dotenv - - python=3.8 - # - pytorch>=1.6 # removed + - python=3.9 + # - pytorch>=2.1 # removed - ruff - ruyaml - tensorflow>=2.16 # added From 6aea2989e55e5d3cf4960323ca227d35ad4c0c88 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 22 Mar 2024 11:15:31 +0100 Subject: [PATCH 148/244] remove --no-deps from env yamls --- dev/env-py38.yaml | 2 +- dev/env-wo-python.yaml | 2 +- dev/env.yaml | 2 +- dev/environment-tf.yaml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml index 4c9cefd8..ea5389ff 100644 --- a/dev/env-py38.yaml +++ b/dev/env-py38.yaml @@ -33,4 +33,4 @@ dependencies: - typing-extensions - xarray - pip: - - --no-deps -e .. + - -e .. diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml index 8ddce65b..bba2340d 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-wo-python.yaml @@ -33,4 +33,4 @@ dependencies: - typing-extensions - xarray - pip: - - --no-deps -e .. + - -e .. diff --git a/dev/env.yaml b/dev/env.yaml index 580cf7be..f91071a5 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -32,4 +32,4 @@ dependencies: - typing-extensions - xarray - pip: - - --no-deps -e .. + - -e .. diff --git a/dev/environment-tf.yaml b/dev/environment-tf.yaml index 566b7ca2..90bc8668 100644 --- a/dev/environment-tf.yaml +++ b/dev/environment-tf.yaml @@ -34,4 +34,4 @@ dependencies: - typing-extensions - xarray - pip: - - --no-deps -e .. + - -e .. From f3fe013e607dc891bce9d2f4c360c0efeebbbeb9 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 22 Mar 2024 11:38:50 +0100 Subject: [PATCH 149/244] fix setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f8aa1ee0..3b321375 100644 --- a/setup.py +++ b/setup.py @@ -30,12 +30,12 @@ packages=find_namespace_packages(exclude=["tests"]), install_requires=[ "bioimageio.spec==0.5.1.*", - "dotenv", "imageio>=2.5", "loguru", "numpy", "pydantic-settings", "pydantic", + "python-dotenv", "ruyaml", "tqdm", "typer", From 1c11aeb27ac166d2ee6a61a9d1dfc64bb9377ef5 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 23 Mar 2024 22:15:06 +0100 Subject: [PATCH 150/244] fix: set ns only for parameterized sizes --- bioimageio/core/_resource_tests.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index b04e42f8..32ea9760 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -195,7 +195,12 @@ def generate_test_cases(): tested: Set[Hashable] = set() def get_ns(n: int): - return {(t.id, a.id): n for t in model.inputs for a in t.axes} + return { + (t.id, a.id): n + for t in model.inputs + for a in t.axes + if isinstance(a.size, v0_5.ParameterizedSize) + } for n, batch_size in test_cases: input_target_sizes, expected_output_sizes = model.get_axis_sizes( From 37f20ce6f3eb26a676aeedfbe69d1e3ba064170f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 23 Mar 2024 22:16:29 +0100 Subject: [PATCH 151/244] raise more expressive errors --- .../core/model_adapters/_keras_model_adapter.py | 4 +++- bioimageio/core/model_adapters/_model_adapter.py | 5 +++++ .../core/model_adapters/_onnx_model_adapter.py | 4 +++- .../core/model_adapters/_pytorch_model_adapter.py | 12 ++++++++---- .../core/model_adapters/_tensorflow_model_adapter.py | 4 +++- .../model_adapters/_torchscript_model_adapter.py | 4 +++- 6 files changed, 25 insertions(+), 8 deletions(-) diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py index 785fd6a7..6ab18624 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -38,7 +38,9 @@ def __init__( model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None, ) -> None: - assert keras is not None + if keras is None: + raise ImportError("keras") + super().__init__() if model_description.weights.keras_hdf5 is None: raise ValueError("model has not keras_hdf5 weights specified") diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index 3e4da1df..cb4762c4 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -52,6 +52,11 @@ def create( Note: All specific adapters should happen inside this function to prevent different framework initializations interfering with each other """ + if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): + raise TypeError( + f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" + ) + weights = model_description.weights errors: List[Exception] = [] for wf in weight_format_priority_order or DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py index 19fdf0cc..9811efa2 100644 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -21,7 +21,9 @@ def __init__( model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None, ): - assert rt is not None + if rt is None: + raise ImportError("onnxruntime") + super().__init__() self._internal_output_axes = [ ( diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 9a6fd4bf..5a5a9e83 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -27,7 +27,8 @@ def __init__( ], devices: Optional[Sequence[str]] = None, ): - assert torch is not None + if torch is None: + raise ImportError("torch") super().__init__() self.output_dims = [ tuple(a if isinstance(a, str) else a.id for a in out.axes) @@ -47,7 +48,8 @@ def __init__( self._network = self._network.eval() def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: - assert torch is not None + if torch is None: + raise ImportError("torch") with torch.no_grad(): tensors = [ None if ipt is None else torch.from_numpy(ipt.data) @@ -100,7 +102,8 @@ def get_network( # pyright: ignore[reportUnknownParameterType] v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr ] ) -> "torch.nn.Module": # pyright: ignore[reportInvalidTypeForm] - assert torch is not None + if torch is None: + raise ImportError("torch") arch = import_callable( weight_spec.architecture, sha256=( @@ -126,7 +129,8 @@ def get_network( # pyright: ignore[reportUnknownParameterType] def get_devices( # pyright: ignore[reportUnknownParameterType] devices: Optional[Sequence[str]] = None, ) -> List["torch.device"]: # pyright: ignore[reportInvalidTypeForm] - assert torch is not None + if torch is None: + raise ImportError("torch") if not devices: torch_devices = [ ( diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index 0f238925..f2942d89 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -32,7 +32,9 @@ def __init__( ], model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], ): - assert tf is not None + if tf is None: + raise ImportError("tensorflow") + super().__init__() self.model_description = model_description tf_version = v0_5.Version( diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index c50d131a..ec432d71 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -24,7 +24,9 @@ def __init__( model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None, ): - assert torch is not None + if torch is None: + raise ImportError("torch") + super().__init__() if model_description.weights.torchscript is None: raise ValueError( From 020f01869c10ebdf8379ec7ab61b46e11fe201aa Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 23 Mar 2024 22:17:16 +0100 Subject: [PATCH 152/244] fix conftest (skip model_sources dependeing on failing imports) --- tests/conftest.py | 154 +++++++++++++++++++++++++--------------------- 1 file changed, 84 insertions(+), 70 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ee302035..189df1f1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,82 +2,13 @@ import subprocess import warnings -from typing import List +from typing import Dict, List from loguru import logger from pytest import FixtureRequest, fixture from bioimageio.spec import __version__ as bioimageio_spec_version -warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}") - -# test models for various frameworks -TORCH_MODELS = [ - "unet2d_fixed_shape", - "unet2d_multi_tensor", - "unet2d_nuclei_broad_model", - "unet2d_diff_output_shape", - "shape_change", -] -TORCHSCRIPT_MODELS = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"] -ONNX_MODELS = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"] -TENSORFLOW1_MODELS = ["stardist"] -TENSORFLOW2_MODELS = ["unet2d_keras_tf2"] -KERAS_TF1_MODELS = ["unet2d_keras"] -KERAS_TF2_MODELS = ["unet2d_keras_tf2"] -TENSORFLOW_JS_MODELS: List[str] = [] - - -MODEL_SOURCES = { - "unet2d_keras": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_keras_tf/v0_4.bioimageio.yaml" - ), - "unet2d_keras_tf2": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_keras_tf2/v0_4.bioimageio.yaml" - ), - "unet2d_nuclei_broad_model": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_nuclei_broad/bioimageio.yaml" - ), - "unet2d_expand_output_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_nuclei_broad/expand_output_shape_v0_4.bioimageio.yaml" - ), - "unet2d_fixed_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_fixed_shape/v0_4.bioimageio.yaml" - ), - "unet2d_multi_tensor": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_multi_tensor/v0_4.bioimageio.yaml" - ), - "unet2d_diff_output_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_diff_output_shape/v0_4.bioimageio.yaml" - ), - "hpa_densenet": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/hpa-densenet/rdf.yaml" - ), - "stardist": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models" - "/stardist_example_model/v0_4.bioimageio.yaml" - ), - "stardist_wrong_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "stardist_example_model/rdf_wrong_shape.yaml" - ), - "stardist_wrong_shape2": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "stardist_example_model/rdf_wrong_shape2_v0_4.yaml" - ), - "shape_change": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "upsample_test_model/v0_4.bioimageio.yaml" - ), -} - try: import torch @@ -102,10 +33,93 @@ tensorflow = None tf_major_version = None +try: + import keras # type: ignore +except ImportError: + keras = None skip_tensorflow = tensorflow is None skip_tensorflow_js = True # TODO: add a tensorflow_js example model +warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}") + +# test models for various frameworks +TORCH_MODELS = [ + "unet2d_fixed_shape", + "unet2d_multi_tensor", + "unet2d_nuclei_broad_model", + "unet2d_diff_output_shape", + "shape_change", +] +TORCHSCRIPT_MODELS = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"] +ONNX_MODELS = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"] +TENSORFLOW1_MODELS = ["stardist"] +TENSORFLOW2_MODELS = ["unet2d_keras_tf2"] +KERAS_TF1_MODELS = ["unet2d_keras"] +KERAS_TF2_MODELS = ["unet2d_keras_tf2"] +TENSORFLOW_JS_MODELS: List[str] = [] + + +MODEL_SOURCES: Dict[str, str] = {} +if keras is not None: + MODEL_SOURCES.update( + { + "unet2d_keras": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_keras_tf/v0_4.bioimageio.yaml" + ), + "unet2d_keras_tf2": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_keras_tf2/v0_4.bioimageio.yaml" + ), + } + ) +if torch is not None: + MODEL_SOURCES.update( + { + "unet2d_nuclei_broad_model": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_nuclei_broad/bioimageio.yaml" + ), + "unet2d_expand_output_shape": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_nuclei_broad/expand_output_shape_v0_4.bioimageio.yaml" + ), + "unet2d_fixed_shape": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_fixed_shape/v0_4.bioimageio.yaml" + ), + "unet2d_multi_tensor": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_multi_tensor/v0_4.bioimageio.yaml" + ), + "unet2d_diff_output_shape": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_diff_output_shape/v0_4.bioimageio.yaml" + ), + } + ) +if tensorflow is not None: + MODEL_SOURCES.update( + { + "hpa_densenet": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/hpa-densenet/rdf.yaml" + ), + "stardist": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models" + "/stardist_example_model/v0_4.bioimageio.yaml" + ), + "stardist_wrong_shape": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "stardist_example_model/rdf_wrong_shape.yaml" + ), + "stardist_wrong_shape2": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "stardist_example_model/rdf_wrong_shape2_v0_4.yaml" + ), + } + ) + @fixture(scope="session") def mamba_cmd(): From 50184d16d3baeaffb8138e94b60d880f1a7f8e3e Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 23 Mar 2024 22:17:58 +0100 Subject: [PATCH 153/244] report default weight_format_priority_order in error instead of None --- bioimageio/core/model_adapters/_model_adapter.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index cb4762c4..ec83b5a2 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -59,7 +59,12 @@ def create( weights = model_description.weights errors: List[Exception] = [] - for wf in weight_format_priority_order or DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: + weight_format_priority_order = ( + DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER + if weight_format_priority_order is None + else weight_format_priority_order + ) + for wf in weight_format_priority_order: if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None: try: from ._pytorch_model_adapter import PytorchModelAdapter From 71bb7dd092e960e10ea7a73904e55417bd768408 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 23 Mar 2024 22:18:37 +0100 Subject: [PATCH 154/244] add ensure dtype ops for v0_4 procs --- bioimageio/core/proc_setup.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index b7bd54bb..2e84a010 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -98,15 +98,31 @@ def prepare_procs(tensor_descrs: Sequence[TensorDescr]): procs: List[Processing] = [] for t_descr in tensor_descrs: if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): - proc_descrs = t_descr.preprocessing + proc_descrs: List[ + Union[ + v0_4.PreprocessingDescr, + v0_5.PreprocessingDescr, + v0_4.PostprocessingDescr, + v0_5.PostprocessingDescr, + ] + ] = list(t_descr.preprocessing) elif isinstance( t_descr, (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr), ): - proc_descrs = t_descr.postprocessing + proc_descrs = list(t_descr.postprocessing) else: assert_never(t_descr) + if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): + ensure_dtype = v0_5.EnsureDtypeDescr( + kwargs=v0_5.EnsureDtypeKwargs(dtype=t_descr.data_type) + ) + if isinstance(t_descr, v0_4.InputTensorDescr) and proc_descrs: + proc_descrs.insert(0, ensure_dtype) + + proc_descrs.append(ensure_dtype) + for proc_d in proc_descrs: proc_class = get_proc_class(proc_d) tensor_id = ( From a1b20731d0b085d02d4498b0e4e324752442c8b5 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 23 Mar 2024 23:15:02 +0100 Subject: [PATCH 155/244] fix hashable_target_size --- bioimageio/core/_resource_tests.py | 3 +-- bioimageio/core/model_adapters/_model_adapter.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index 32ea9760..6ec00ec4 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -207,8 +207,7 @@ def get_ns(n: int): get_ns(n), batch_size=batch_size ) hashable_target_size = tuple( - (input_target_sizes, input_target_sizes[ts]) - for ts in sorted(input_target_sizes) + (k, input_target_sizes[k]) for k in sorted(input_target_sizes) ) if hashable_target_size in tested: continue diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index ec83b5a2..89b38614 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -137,7 +137,7 @@ def create( @final def load(self, *, devices: Optional[Sequence[str]] = None) -> None: - warnings.warn("Deprecated. ModelAdapter is always loaded") + warnings.warn("Deprecated. ModelAdapter is loaded on initialization") @abstractmethod def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: @@ -149,7 +149,7 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: def unload(self): """ Unload model from any devices, freeing their memory. - Note: Use ModelAdapter as context to not worry about calling unload()! + The moder adapter should be considered unusable afterwards. """ From 0c26f29ccd59239b380fffd909504fbe81264395 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 23 Mar 2024 23:15:42 +0100 Subject: [PATCH 156/244] do not call deprecated load --- bioimageio/core/_prediction_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 8f49c654..d8b6aad1 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -126,7 +126,7 @@ def load(self): """ optional step: load model onto devices before calling forward if not using it as context manager """ - self._adapter.load() + pass def unload(self): """ From 643300d4645ba9953bec547e0f5f9b085897d1f0 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 23 Mar 2024 23:52:21 +0100 Subject: [PATCH 157/244] improve _test_model_inference_parametrized --- bioimageio/core/_resource_tests.py | 108 +++++++++++++++-------------- 1 file changed, 56 insertions(+), 52 deletions(-) diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index 6ec00ec4..fa144775 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -171,74 +171,78 @@ def _test_model_inference_parametrized( model: v0_5.ModelDescr, weight_format: Optional[WeightsFormat], devices: Optional[List[str]], - test_cases: Sequence[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = ( + test_cases: Set[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = { (0, 2), (1, 3), (2, 1), (3, 2), - ), + }, ) -> None: + if not test_cases: + return + if not any( isinstance(a.size, v0_5.ParameterizedSize) for ipt in model.inputs for a in ipt.axes ): - # only test different batch sizes for n=0 - test_cases = [tc for tc in test_cases if tc[0] == 0] - if not test_cases: - return - - try: - test_inputs = get_test_inputs(model) - - def generate_test_cases(): - tested: Set[Hashable] = set() - - def get_ns(n: int): - return { - (t.id, a.id): n - for t in model.inputs - for a in t.axes - if isinstance(a.size, v0_5.ParameterizedSize) - } - - for n, batch_size in test_cases: - input_target_sizes, expected_output_sizes = model.get_axis_sizes( - get_ns(n), batch_size=batch_size - ) - hashable_target_size = tuple( - (k, input_target_sizes[k]) for k in sorted(input_target_sizes) - ) - if hashable_target_size in tested: - continue - else: - tested.add(hashable_target_size) - - resized_test_inputs = [ - resize_to( - t, - { - aid: s - for (tid, aid), s in input_target_sizes.items() - if tid == t_descr.id - }, - ) - for t, t_descr in zip(test_inputs, model.inputs) - ] - expected_output_shapes = [ + # no parameterized sizes => set n=0 + test_cases = {(0, b) for _n, b in test_cases} + + if not any(isinstance(a, v0_5.BatchAxis) for ipt in model.inputs for a in ipt.axes): + # no batch axis => set b=1 + test_cases = {(n, 1) for n, _b in test_cases} + + def generate_test_cases(): + tested: Set[Hashable] = set() + + def get_ns(n: int): + return { + (t.id, a.id): n + for t in model.inputs + for a in t.axes + if isinstance(a.size, v0_5.ParameterizedSize) + } + + for n, batch_size in sorted(test_cases): + input_target_sizes, expected_output_sizes = model.get_axis_sizes( + get_ns(n), batch_size=batch_size + ) + hashable_target_size = tuple( + (k, input_target_sizes[k]) for k in sorted(input_target_sizes) + ) + if hashable_target_size in tested: + continue + else: + tested.add(hashable_target_size) + + resized_test_inputs = [ + resize_to( + t, { aid: s - for (tid, aid), s in expected_output_sizes.items() + for (tid, aid), s in input_target_sizes.items() if tid == t_descr.id - } - for t_descr in model.outputs - ] - yield n, batch_size, resized_test_inputs, expected_output_shapes + }, + ) + for t, t_descr in zip(test_inputs, model.inputs) + ] + expected_output_shapes = [ + { + aid: s + for (tid, aid), s in expected_output_sizes.items() + if tid == t_descr.id + } + for t_descr in model.outputs + ] + yield n, batch_size, resized_test_inputs, expected_output_shapes + + try: + test_inputs = get_test_inputs(model) with create_prediction_pipeline( bioimageio_model=model, devices=devices, weight_format=weight_format ) as prediction_pipeline: - for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): error: Optional[str] = None results = prediction_pipeline.forward(*inputs) @@ -266,7 +270,7 @@ def get_ns(n: int): if diff: error = ( f"(n={n}) Expected output shape {exp}," - + f" but got {res.sizes} ({diff})\n" + + f" but got {res.sizes} (diff: {diff})" ) break From cd06342ac69189e16536a967273def7891329070 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 23 Mar 2024 23:52:35 +0100 Subject: [PATCH 158/244] fix pad_to --- bioimageio/core/utils/tiling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bioimageio/core/utils/tiling.py b/bioimageio/core/utils/tiling.py index 2b65b361..658edc76 100644 --- a/bioimageio/core/utils/tiling.py +++ b/bioimageio/core/utils/tiling.py @@ -38,7 +38,7 @@ def pad_to( a = AxisId(str(a)) if a not in sizes or sizes[a] == s_is: pad_width[a] = 0 - elif s_is < sizes[a]: + elif s_is > sizes[a]: pad_width[a] = 0 warnings.warn( f"Cannot pad axis {a} of size {s_is} to smaller size {sizes[a]}" @@ -130,7 +130,7 @@ def resize_to( _ = new_axes.pop(a, None) if a not in sizes or sizes[a] == s_is: pass - elif s_is < sizes[a]: + elif s_is > sizes[a]: crop_to_sizes[a] = sizes[a] else: pad_to_sizes[a] = sizes[a] From 49dd4948ecde38c33424407315c9f75920546abb Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sun, 24 Mar 2024 00:00:32 +0100 Subject: [PATCH 159/244] install crick only in env-py38 --- dev/env-py38.yaml | 2 +- dev/{environment-tf.yaml => env-tf.yaml} | 2 +- dev/env-wo-python.yaml | 2 +- dev/env.yaml | 2 +- setup.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) rename dev/{environment-tf.yaml => env-tf.yaml} (92%) diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml index ea5389ff..d6c752a9 100644 --- a/dev/env-py38.yaml +++ b/dev/env-py38.yaml @@ -6,7 +6,7 @@ channels: dependencies: - bioimageio.spec>=0.5.1 - black - - crick + - crick # uncommented - filelock - imageio>=2.5 # - keras>=3.0 # removed diff --git a/dev/environment-tf.yaml b/dev/env-tf.yaml similarity index 92% rename from dev/environment-tf.yaml rename to dev/env-tf.yaml index 90bc8668..5b1b4017 100644 --- a/dev/environment-tf.yaml +++ b/dev/env-tf.yaml @@ -6,7 +6,7 @@ channels: dependencies: - bioimageio.spec>=0.5.1 - black - - crick + # - crick # currently requires python<=3.9 - filelock - imageio>=2.5 - keras>=3.0 diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml index bba2340d..cfdd48dc 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-wo-python.yaml @@ -6,7 +6,7 @@ channels: dependencies: - bioimageio.spec>=0.5.1 - black - - crick + # - crick # currently requires python<=3.9 - filelock - imageio>=2.5 - keras>=3.0 diff --git a/dev/env.yaml b/dev/env.yaml index f91071a5..e41eb838 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -5,7 +5,7 @@ channels: dependencies: - bioimageio.spec>=0.5.1 - black - - crick + # - crick # currently requires python<=3.9 - filelock - imageio>=2.5 - keras>=3.0 diff --git a/setup.py b/setup.py index 3b321375..36ed51fd 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ "onnx": ["onnxruntime"], "dev": [ "black", - "crick", + # "crick", # currently requires python<=3.9 "filelock", "keras>=3.0", "onnxruntime", From df302829931d44d90ce51e6bbc6ff7cc227f58fb Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sun, 24 Mar 2024 00:10:22 +0100 Subject: [PATCH 160/244] add jupyter(-black) dependencies --- dev/env-py38.yaml | 2 ++ dev/env-tf.yaml | 2 ++ dev/env-wo-python.yaml | 2 ++ dev/env.yaml | 2 ++ setup.py | 2 ++ 5 files changed, 10 insertions(+) diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml index d6c752a9..726ce341 100644 --- a/dev/env-py38.yaml +++ b/dev/env-py38.yaml @@ -9,6 +9,8 @@ dependencies: - crick # uncommented - filelock - imageio>=2.5 + - jupyter + - jupyter-black # - keras>=3.0 # removed - loguru - numpy diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml index 5b1b4017..726087f2 100644 --- a/dev/env-tf.yaml +++ b/dev/env-tf.yaml @@ -9,6 +9,8 @@ dependencies: # - crick # currently requires python<=3.9 - filelock - imageio>=2.5 + - jupyter + - jupyter-black - keras>=3.0 - loguru - numpy diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml index cfdd48dc..8816ea48 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-wo-python.yaml @@ -9,6 +9,8 @@ dependencies: # - crick # currently requires python<=3.9 - filelock - imageio>=2.5 + - jupyter + - jupyter-black - keras>=3.0 - loguru - numpy diff --git a/dev/env.yaml b/dev/env.yaml index e41eb838..0aa1660e 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -8,6 +8,8 @@ dependencies: # - crick # currently requires python<=3.9 - filelock - imageio>=2.5 + - jupyter + - jupyter-black - keras>=3.0 - loguru - numpy diff --git a/setup.py b/setup.py index 36ed51fd..34944b0a 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,8 @@ "black", # "crick", # currently requires python<=3.9 "filelock", + "jupyter", + "jupyter-black", "keras>=3.0", "onnxruntime", "packaging>=17.0", From 6be34761e394b5d4afe69f70fc91ac9f5071d13f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sun, 24 Mar 2024 00:10:37 +0100 Subject: [PATCH 161/244] fix ci envs --- .github/workflows/build.yaml | 9 +++++++++ dev/env-tf.yaml | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 77bf04c7..d8c6fb95 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -60,6 +60,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Install Conda environment with Micromamba + if: matrix.python-version != '3.8' uses: mamba-org/setup-micromamba@v1 with: cache-downloads: true @@ -68,6 +69,14 @@ jobs: create-args: >- python=${{ matrix.python-version }} post-cleanup: 'all' + - name: Install py3.8 environment + if: matrix.python-version == '3.8' + uses: mamba-org/setup-micromamba@v1 + with: + cache-downloads: true + cache-environment: true + environment-file: dev/env-py38.yaml + post-cleanup: 'all' - name: additional setup run: | conda remove --yes --force bioimageio.spec || true # allow failure for cached env diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml index 726087f2..304c0193 100644 --- a/dev/env-tf.yaml +++ b/dev/env-tf.yaml @@ -25,11 +25,11 @@ dependencies: - pytest - pytest-xdist - python-dotenv - - python=3.9 + # - python=3.9 # removed # - pytorch>=2.1 # removed - ruff - ruyaml - - tensorflow>=2.16 # added + - tensorflow>=2.15 # added # - torchvision # removed - tqdm - typer From 2230bf054a1b722f0ecf03a93bdea45f7f39e8fb Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sun, 24 Mar 2024 00:15:25 +0100 Subject: [PATCH 162/244] ruff needs python<3.11 --- dev/env-py38.yaml | 2 +- dev/env-tf.yaml | 2 +- dev/env-wo-python.yaml | 2 +- dev/env.yaml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml index 726ce341..0fe96b73 100644 --- a/dev/env-py38.yaml +++ b/dev/env-py38.yaml @@ -27,7 +27,7 @@ dependencies: - python-dotenv - python=3.8 # changed - pytorch>=2.1 - - ruff + - ruff # uncommented - ruyaml - torchvision - tqdm diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml index 304c0193..53a18ae0 100644 --- a/dev/env-tf.yaml +++ b/dev/env-tf.yaml @@ -27,7 +27,7 @@ dependencies: - python-dotenv # - python=3.9 # removed # - pytorch>=2.1 # removed - - ruff + # - ruff # removed - ruyaml - tensorflow>=2.15 # added # - torchvision # removed diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml index 8816ea48..40fe27b6 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-wo-python.yaml @@ -27,7 +27,7 @@ dependencies: - python-dotenv # - python=3.9 # removed - pytorch>=2.1 - - ruff + # - ruff # requires python < 3.11 - ruyaml - torchvision - tqdm diff --git a/dev/env.yaml b/dev/env.yaml index 0aa1660e..6f6b8059 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -26,7 +26,7 @@ dependencies: - python-dotenv - python=3.9 - pytorch>=2.1 - - ruff + # - ruff # requires python < 3.11 - ruyaml - torchvision - tqdm From 0270edef4839d5e2a4d0e9a3037ac22f2e0d0fbd Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sun, 24 Mar 2024 00:16:20 +0100 Subject: [PATCH 163/244] black notebooks --- example/dataset_creation.ipynb | 41 +++++++-- example/dataset_statistics_demo.ipynb | 123 ++++++++++++++++++-------- example/model_usage.ipynb | 65 +++++++++----- 3 files changed, 162 insertions(+), 67 deletions(-) diff --git a/example/dataset_creation.ipynb b/example/dataset_creation.ipynb index e7f5d27d..13596956 100644 --- a/example/dataset_creation.ipynb +++ b/example/dataset_creation.ipynb @@ -6,7 +6,9 @@ "metadata": {}, "outputs": [], "source": [ - "from bioimageio.spec.pretty_validation_errors import enable_pretty_validation_errors_in_ipynb\n", + "from bioimageio.spec.pretty_validation_errors import (\n", + " enable_pretty_validation_errors_in_ipynb,\n", + ")\n", "\n", "enable_pretty_validation_errors_in_ipynb()" ] @@ -17,7 +19,13 @@ "metadata": {}, "outputs": [], "source": [ - "from bioimageio.spec.dataset.v0_3 import Author, CiteEntry, Dataset, HttpUrl, RelativeFilePath\n", + "from bioimageio.spec.dataset.v0_3 import (\n", + " Author,\n", + " CiteEntry,\n", + " Dataset,\n", + " HttpUrl,\n", + " RelativeFilePath,\n", + ")\n", "\n", "nuclei_broad_data = Dataset(\n", " name=\"Kaggle 2018 Data Science Bowl\",\n", @@ -27,13 +35,30 @@ " \"segmentation of the nuclei belonging to cells from a breadth of biological contexts.\",\n", " documentation=RelativeFilePath(\"README.md\"),\n", " covers=(\n", - " HttpUrl(\"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage1.png\"),\n", - " HttpUrl(\"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage2.png\"),\n", - " HttpUrl(\"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage3.png\"),\n", - " HttpUrl(\"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage4.png\"),\n", - " HttpUrl(\"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage5.png\"),\n", + " HttpUrl(\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage1.png\"\n", + " ),\n", + " HttpUrl(\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage2.png\"\n", + " ),\n", + " HttpUrl(\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage3.png\"\n", + " ),\n", + " HttpUrl(\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage4.png\"\n", + " ),\n", + " HttpUrl(\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage5.png\"\n", + " ),\n", + " ),\n", + " authors=(\n", + " Author(\n", + " name=\"Fynn Beuttenmueller\",\n", + " affiliation=\"EMBL\",\n", + " github_user=\"fynnbe\",\n", + " orcid=\"0000-0002-8567-6389\",\n", + " ),\n", " ),\n", - " authors=(Author(name=\"Fynn Beuttenmueller\", affiliation=\"EMBL\", github_user=\"fynnbe\", orcid=\"0000-0002-8567-6389\"),),\n", " source=HttpUrl(\"https://bbbc.lbroadinstitute.org/BBBC038/\"),\n", " cite=(\n", " CiteEntry(\n", diff --git a/example/dataset_statistics_demo.ipynb b/example/dataset_statistics_demo.ipynb index a8f30526..e59f4fd3 100644 --- a/example/dataset_statistics_demo.ipynb +++ b/example/dataset_statistics_demo.ipynb @@ -98,27 +98,41 @@ "source": [ "from bioimageio.core.prediction import get_tiling\n", "\n", - "tile_shape = dict(zip(\n", - " model_resource.inputs[0].axes, \n", - " np.asarray(model_resource.inputs[0].shape.min) + np.asarray(model_resource.inputs[0].shape.step)\n", - "))\n", - "\n", - "tiles = list(get_tiling(\n", - " shape=input_image.shape,\n", - " tile_shape=tile_shape,\n", - " halo=dict(zip(model_resource.inputs[0].axes, model_resource.outputs[0].halo)),\n", - " input_axes=model_resource.inputs[0].axes\n", - "))\n", + "tile_shape = dict(\n", + " zip(\n", + " model_resource.inputs[0].axes,\n", + " np.asarray(model_resource.inputs[0].shape.min)\n", + " + np.asarray(model_resource.inputs[0].shape.step),\n", + " )\n", + ")\n", + "\n", + "tiles = list(\n", + " get_tiling(\n", + " shape=input_image.shape,\n", + " tile_shape=tile_shape,\n", + " halo=dict(zip(model_resource.inputs[0].axes, model_resource.outputs[0].halo)),\n", + " input_axes=model_resource.inputs[0].axes,\n", + " )\n", + ")\n", + "\n", "\n", "def add_tile_box(ax, t):\n", " x = t[\"x\"].start\n", " w = t[\"x\"].stop - x\n", " y = t[\"y\"].start\n", " h = t[\"y\"].stop - y\n", - " \n", - " box = Rectangle((x, y), w, h, linewidth=1, edgecolor=np.random.choice(list(\"rgbcmykw\")), facecolor=\"none\")\n", + "\n", + " box = Rectangle(\n", + " (x, y),\n", + " w,\n", + " h,\n", + " linewidth=1,\n", + " edgecolor=np.random.choice(list(\"rgbcmykw\")),\n", + " facecolor=\"none\",\n", + " )\n", " ax.add_patch(box)\n", "\n", + "\n", "fig, ax = plt.subplots(1, 2)\n", "fig.suptitle(\"'samples' of test image 'dataset'\")\n", "ax[0].set_title(\"input (outer) tiles\")\n", @@ -175,19 +189,28 @@ "def process_dataset(pp, dataset):\n", " stats = pp._ipt_stats.compute_measures()[\"per_dataset\"]\n", " print(f\"initial stats:\")\n", - " pprint(None if not stats else {k: f\"{v.item():.2f}\" for k, v in stats[\"input0\"].items()})\n", + " pprint(\n", + " None\n", + " if not stats\n", + " else {k: f\"{v.item():.2f}\" for k, v in stats[\"input0\"].items()}\n", + " )\n", " stats = {}\n", " sample_dataset = [{\"input0\": s} for s in dataset]\n", " [pp.apply_preprocessing(s, stats) for s in sample_dataset]\n", " print(f\"final stats:\")\n", - " pprint(None if not stats else {k: f\"{v.item():.2f}\" for k, v in stats[\"per_dataset\"][\"input0\"].items()})\n", + " pprint(\n", + " None\n", + " if not stats\n", + " else {k: f\"{v.item():.2f}\" for k, v in stats[\"per_dataset\"][\"input0\"].items()}\n", + " )\n", " return [s[\"input0\"] for s in sample_dataset]\n", "\n", + "\n", "# accumulate dataset statistics exclusively while processing samples (no initial dataset statistics are computed)\n", "with create_prediction_pipeline(\n", - " bioimageio_model=model_resource,\n", - " # dataset_for_initial_statistics=tuple(), # an empty dataset is the default\n", - " ) as pp:\n", + " bioimageio_model=model_resource,\n", + " # dataset_for_initial_statistics=tuple(), # an empty dataset is the default\n", + ") as pp:\n", " wo_init_dataset_stats = process_dataset(pp, dataset)" ] }, @@ -220,9 +243,9 @@ "source": [ "# accumulate dataset statistics exclusively while processing samples for a limited number of samples\n", "with create_prediction_pipeline(\n", - " bioimageio_model=model_resource,\n", - " \tupdate_dataset_stats_for_n_samples=len(dataset) // 2,\n", - " ) as pp:\n", + " bioimageio_model=model_resource,\n", + " update_dataset_stats_for_n_samples=len(dataset) // 2,\n", + ") as pp:\n", " wo_init_dataset_stats_limit = process_dataset(pp, dataset)" ] }, @@ -256,14 +279,14 @@ ], "source": [ "# initialize dataset statistics with first n samples and keep update dataset statistics after the n sample\n", - "# this assumes that the n samples present in 'dataset_for_initial_statistics' are those that will be processed \n", - "# by the prediction pipeline and thus should not update the dataset statistics. \n", + "# this assumes that the n samples present in 'dataset_for_initial_statistics' are those that will be processed\n", + "# by the prediction pipeline and thus should not update the dataset statistics.\n", "# Use 'update_dataset_stats_after_n_samples=0' if that is not your use case.\n", "with create_prediction_pipeline(\n", - " bioimageio_model=model_resource,\n", - " dataset_for_initial_statistics=dataset[:len(dataset) // 2],\n", - " # update_dataset_stats_after_n_samples=None, # defaults to len(dataset_for_initial_statistics)\n", - " ) as pp:\n", + " bioimageio_model=model_resource,\n", + " dataset_for_initial_statistics=dataset[: len(dataset) // 2],\n", + " # update_dataset_stats_after_n_samples=None, # defaults to len(dataset_for_initial_statistics)\n", + ") as pp:\n", " partial_init_dataset_stats = process_dataset(pp, dataset)" ] }, @@ -299,11 +322,11 @@ "# compute dataset statistics on all samples\n", "# (in this case we should really use the non-overlapping tiles as samples in dataset_for_initial_statistics)\n", "with create_prediction_pipeline(\n", - " bioimageio_model=model_resource,\n", - " dataset_for_initial_statistics=dataset,\n", - " update_dataset_stats_for_n_samples=0, # if you call the prediciton pipeline more then len(dataset) \n", + " bioimageio_model=model_resource,\n", + " dataset_for_initial_statistics=dataset,\n", + " update_dataset_stats_for_n_samples=0, # if you call the prediciton pipeline more then len(dataset)\n", " # times you might want to set this to zero to avoid further updates to the dataset statistics\n", - " ) as pp:\n", + ") as pp:\n", " only_init_dataset_stats = process_dataset(pp, dataset)" ] }, @@ -323,15 +346,32 @@ "outputs": [], "source": [ "def untile(outputs):\n", - " untiled = xr.DataArray(np.empty((1, 2, *input_image.squeeze().shape)), dims=model_resource.outputs[0].axes)\n", + " untiled = xr.DataArray(\n", + " np.empty((1, 2, *input_image.squeeze().shape)),\n", + " dims=model_resource.outputs[0].axes,\n", + " )\n", " for out, t in zip(outputs, tiles):\n", " untiled[t.inner] = out[0]\n", "\n", " return untiled.data.squeeze()\n", "\n", + "\n", "# prepare image comparisons\n", - "titles = [\"wo_init_dataset_stats\", \"wo_init_dataset_stats_limit\", \"partial_init_dataset_stats\", \"only_init_dataset_stats\"]\n", - "images = [untile(out)[0] for out in [wo_init_dataset_stats, wo_init_dataset_stats_limit, partial_init_dataset_stats, only_init_dataset_stats]]" + "titles = [\n", + " \"wo_init_dataset_stats\",\n", + " \"wo_init_dataset_stats_limit\",\n", + " \"partial_init_dataset_stats\",\n", + " \"only_init_dataset_stats\",\n", + "]\n", + "images = [\n", + " untile(out)[0]\n", + " for out in [\n", + " wo_init_dataset_stats,\n", + " wo_init_dataset_stats_limit,\n", + " partial_init_dataset_stats,\n", + " only_init_dataset_stats,\n", + " ]\n", + "]" ] }, { @@ -364,17 +404,22 @@ "source": [ "fig, axes = plt.subplots(2, 4, figsize=(30, 15))\n", "for ax in axes[1]:\n", - " ax.set_axis_off()\n", + " ax.set_axis_off()\n", "\n", "zoom_roi = np.s_[20:60, 50:90]\n", + "\n", + "\n", "def get_box():\n", " return Rectangle(\n", " (zoom_roi[1].start, zoom_roi[0].start),\n", " zoom_roi[1].stop - zoom_roi[1].start,\n", " zoom_roi[0].stop - zoom_roi[0].start,\n", - " linewidth=1, edgecolor='r', facecolor='none'\n", + " linewidth=1,\n", + " edgecolor=\"r\",\n", + " facecolor=\"none\",\n", " )\n", "\n", + "\n", "vmin = min([img.min() for img in images])\n", "vmax = max([img.max() for img in images])\n", "zoom_vmin = min([img[zoom_roi].min() for img in images])\n", @@ -385,7 +430,9 @@ " axes[0, i].add_patch(get_box())\n", " axes[0, i].set_title(f\"{title} (min: {img.min():.2f} max: {img.max():.2f})\")\n", " axes[1, i].imshow(img[zoom_roi], vmin=zoom_vmin, vmax=zoom_vmax)\n", - " axes[1, i].set_title(f\"zooom in (min: {img[zoom_roi].min():.2f} max: {img[zoom_roi].max():.2f})\")\n", + " axes[1, i].set_title(\n", + " f\"zooom in (min: {img[zoom_roi].min():.2f} max: {img[zoom_roi].max():.2f})\"\n", + " )\n", "\n", "plt.show()" ] @@ -421,7 +468,7 @@ "fig, ax = plt.subplots(4, 4, figsize=(20, 20))\n", "for ai, (atitle, a) in enumerate(zip(titles, images)):\n", " for bi, (btitle, b) in enumerate(zip(titles, images)):\n", - " ax[ai, bi].imshow(np.abs(a-b))\n", + " ax[ai, bi].imshow(np.abs(a - b))\n", " if ai == 0:\n", " ax[ai, bi].set_title(btitle)\n", " if bi == 0:\n", diff --git a/example/model_usage.ipynb b/example/model_usage.ipynb index ab9e4c19..5f7c835d 100644 --- a/example/model_usage.ipynb +++ b/example/model_usage.ipynb @@ -37,7 +37,7 @@ "# helper function for showing multiple images in napari\n", "def show_images(*images, names=None):\n", " v = napari.Viewer()\n", - " for i, im in enumerate(images):\n", + " for i, im in enumerate(images):\n", " name = None if names is None else names[i]\n", " if isinstance(im, str):\n", " im = imageio.imread(im)\n", @@ -77,7 +77,9 @@ "# - go to https://bioimage.io/#/?id=10.5281%2Fzenodo.5764892%2F5764893\n", "# - click the download icon\n", "# - select \"ilastik\" weight format\n", - "rdf_path = \"/home/pape/Downloads/nuclei-segmentation-boundarymodel_pytorch_state_dict.zip\"" + "rdf_path = (\n", + " \"/home/pape/Downloads/nuclei-segmentation-boundarymodel_pytorch_state_dict.zip\"\n", + ")" ] }, { @@ -126,7 +128,10 @@ "# we can e.g. check what weight formats are available in the model (pytorch_state_dict for the model used here)\n", "print(\"Available weight formats for this model:\", model_resource.weights.keys())\n", "# or where the (downloaded) weight files are stored\n", - "print(\"Pytorch state dict weights are stored at:\", model_resource.weights[\"pytorch_state_dict\"].source)\n", + "print(\n", + " \"Pytorch state dict weights are stored at:\",\n", + " model_resource.weights[\"pytorch_state_dict\"].source,\n", + ")\n", "print()\n", "# or what inputs the model expects\n", "print(\"The model requires as inputs:\")\n", @@ -151,6 +156,7 @@ "# before using a model, it is recommended to check that it properly works with this function\n", "# 'test_model' returns a dict with 'status'='passed'/'failed' and more detailed information\n", "from bioimageio.core.resource_tests import test_model\n", + "\n", "test_result = test_model(model_resource)\n", "if test_result[\"status\"] == \"failed\":\n", " print(\"model test:\", test_result[\"name\"])\n", @@ -235,7 +241,9 @@ "# The prediction pipeline always returns a tuple (even if the model only has a single output tensor).\n", "# So we access the first element of the prediction to get the predicted tensor.\n", "prediction = prediction_pipeline(input_array)[0]\n", - "show_images(input_image, prediction, names=[\"image\", \"prediction\"]) # show the prediction result" + "show_images(\n", + " input_image, prediction, names=[\"image\", \"prediction\"]\n", + ") # show the prediction result" ] }, { @@ -273,7 +281,9 @@ "source": [ "# Instead, we can use the function `predict_with_padding`, which will pad the image to a shape that fits the model.\n", "prediction = bioimageio.core.predict_with_padding(prediction_pipeline, cropped_array)\n", - "show_images(cropped_image, prediction, names=[\"image\", \"prediction\"]) # show the prediction result" + "show_images(\n", + " cropped_image, prediction, names=[\"image\", \"prediction\"]\n", + ") # show the prediction result" ] }, { @@ -290,11 +300,16 @@ "# that is cropped in order to reduce boundary artifacts.\n", "# Alternatively, `tiling` can also be set to `True`, than the tile size and halo will be deduced from the model config\n", "# (this is also the default behavior when the `tiling` parameter is not passed).\n", - "tiling = {\"tile\": {\"x\": 128, \"y\": 128}, \"halo\": {\"x\": 16, \"y\": 16}} # use a tile size of 128x128 and crop a halo of 16 pixels\n", + "tiling = {\n", + " \"tile\": {\"x\": 128, \"y\": 128},\n", + " \"halo\": {\"x\": 16, \"y\": 16},\n", + "} # use a tile size of 128x128 and crop a halo of 16 pixels\n", "\n", - "# if `verbose` is set to True a progress bar will be printed \n", - "prediction = bioimageio.core.predict_with_tiling(prediction_pipeline, cropped_array, tiling=tiling, verbose=True)\n", - "show_images(cropped_image, prediction, names=[\"image\", \"prediction\"]) " + "# if `verbose` is set to True a progress bar will be printed\n", + "prediction = bioimageio.core.predict_with_tiling(\n", + " prediction_pipeline, cropped_array, tiling=tiling, verbose=True\n", + ")\n", + "show_images(cropped_image, prediction, names=[\"image\", \"prediction\"])" ] }, { @@ -321,15 +336,18 @@ "\n", "# The filepath where the output should be stored; supports most common image formats as well as npy fileformat.\n", "outputs = [\"prediction.tif\"]\n", - "predict_image(\n", - " model_resource, model_resource.test_inputs, outputs\n", - ")\n", + "predict_image(model_resource, model_resource.test_inputs, outputs)\n", "\n", "# The output tensor contains 2 channels, which is not supported by normal tif.\n", "# Thus, these 2 channels are stored as 2 separate images.\n", "fg_pred = imageio.imread(\"prediction-c0.tif\")\n", "bd_pred = imageio.imread(\"prediction-c1.tif\")\n", - "show_images(input_image, fg_pred, bd_pred, names=[\"image\", \"foreground-prediction\", \"boundary-prediction\"])" + "show_images(\n", + " input_image,\n", + " fg_pred,\n", + " bd_pred,\n", + " names=[\"image\", \"foreground-prediction\", \"boundary-prediction\"],\n", + ")" ] }, { @@ -349,6 +367,7 @@ "\n", "# Get all paths to the images in the \"example-images\" folder.\n", "from glob import glob\n", + "\n", "inputs = glob(\"./example-images/*.png\")\n", "\n", "# Create an output folder and specify the output path for each image.\n", @@ -374,12 +393,14 @@ "# `{\"x\": 512, \"y\": 512, \"mode\": \"fixed\"}` will always pad to a size of 512x512.\n", "# The padding is cropped again after the prediction to restore the input shape.\n", "padding = {\"x\": 16, \"y\": 16, \"mode\": \"dynamic\"}\n", - "predict_images(\n", - " model_resource, inputs, outputs, padding=padding, verbose=True\n", - ")\n", + "predict_images(model_resource, inputs, outputs, padding=padding, verbose=True)\n", "\n", "# check the first input/output\n", - "show_images(inputs[0], outputs[0].replace(\".png\", \"-c0.png\"), outputs[0].replace(\".png\", \"-c1.png\"))" + "show_images(\n", + " inputs[0],\n", + " outputs[0].replace(\".png\", \"-c0.png\"),\n", + " outputs[0].replace(\".png\", \"-c1.png\"),\n", + ")" ] }, { @@ -395,12 +416,14 @@ " \"tile\": {\"x\": 256, \"y\": 256},\n", " \"halo\": {\"x\": 16, \"y\": 16},\n", "}\n", - "predict_images(\n", - " model_resource, inputs, outputs, tiling=tiling, verbose=True\n", - ")\n", + "predict_images(model_resource, inputs, outputs, tiling=tiling, verbose=True)\n", "\n", "# Check the first input output pair.\n", - "show_images(inputs[0], outputs[0].replace(\".png\", \"-c0.png\"), outputs[0].replace(\".png\", \"-c1.png\"))" + "show_images(\n", + " inputs[0],\n", + " outputs[0].replace(\".png\", \"-c0.png\"),\n", + " outputs[0].replace(\".png\", \"-c1.png\"),\n", + ")" ] } ], From 2d9d1536d892867bdcd104a3881f28be4e8d7e4a Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sun, 24 Mar 2024 13:44:09 +0100 Subject: [PATCH 164/244] add missing shape_change model source --- tests/conftest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 189df1f1..49064e11 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -97,6 +97,10 @@ "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_diff_output_shape/v0_4.bioimageio.yaml" ), + "shape_change": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "upsample_test_model/v0_4.bioimageio.yaml" + ), } ) if tensorflow is not None: From b4dbb1a6a8eace3bd324035b58622378864cfb17 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sun, 24 Mar 2024 23:35:39 +0100 Subject: [PATCH 165/244] fix shape_change source --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 49064e11..5e54b17e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -98,7 +98,7 @@ "unet2d_diff_output_shape/v0_4.bioimageio.yaml" ), "shape_change": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "upsample_test_model/v0_4.bioimageio.yaml" ), } From 5de1aa80285e35a2159fe07a4faad0b476e76ab7 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sun, 24 Mar 2024 23:36:10 +0100 Subject: [PATCH 166/244] improve model adapter creation error message --- .../core/model_adapters/_model_adapter.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index 89b38614..7d206425 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -58,7 +58,7 @@ def create( ) weights = model_description.weights - errors: List[Exception] = [] + errors: List[str] = [] weight_format_priority_order = ( DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER if weight_format_priority_order is None @@ -75,7 +75,7 @@ def create( devices=devices, ) except Exception as e: - errors.append(e) + errors.append(f"{wf}: {e}") elif ( wf == "tensorflow_saved_model_bundle" and weights.tensorflow_saved_model_bundle is not None @@ -87,7 +87,7 @@ def create( model_description=model_description, devices=devices ) except Exception as e: - errors.append(e) + errors.append(f"{wf}: {e}") elif wf == "onnx" and weights.onnx is not None: try: from ._onnx_model_adapter import ONNXModelAdapter @@ -96,7 +96,7 @@ def create( model_description=model_description, devices=devices ) except Exception as e: - errors.append(e) + errors.append(f"{wf}: {e}") elif wf == "torchscript" and weights.torchscript is not None: try: from ._torchscript_model_adapter import TorchscriptModelAdapter @@ -105,7 +105,7 @@ def create( model_description=model_description, devices=devices ) except Exception as e: - errors.append(e) + errors.append(f"{wf}: {e}") elif wf == "keras_hdf5" and weights.keras_hdf5 is not None: # keras can either be installed as a separate package or used as part of tensorflow # we try to first import the keras model adapter using the separate package and, @@ -123,16 +123,14 @@ def create( model_description=model_description, devices=devices ) except Exception as e: - errors.append(e) - - if errors: - error_msg = f" Errors are: {errors}." - else: - error_msg = "" + errors.append(f"{wf}: {e}") + assert errors + error_list = "\n - ".join(errors) raise ValueError( - f"None of the weight formats {weight_format_priority_order} is " - + f"supported for {model_description.name} in this environment.{error_msg}" + "None of the weight format specific model adapters could be created for" + + f" '{model_description.id or model_description.name}'" + + f" in this environment. Errors are:\n\n{error_list}.\n\n" ) @final From fc396397197b1684647b6655d66ece285fc2efe8 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sun, 24 Mar 2024 23:33:40 +0100 Subject: [PATCH 167/244] WIP add Tile --- bioimageio/core/_op_base.py | 2 +- bioimageio/core/_prediction_pipeline.py | 3 +- bioimageio/core/common.py | 21 +---- bioimageio/core/dataset.py | 5 ++ bioimageio/core/proc_ops.py | 2 +- bioimageio/core/proc_setup.py | 2 +- bioimageio/core/sample.py | 105 ++++++++++++++++++++++++ bioimageio/core/stat_calculators.py | 2 +- bioimageio/core/stat_measures.py | 3 +- bioimageio/core/tile.py | 80 ++++++++++++++++++ bioimageio/core/utils/tiling.py | 23 +++--- tests/test_proc_ops.py | 3 +- tests/test_stat_calculators.py | 3 +- tests/test_stat_measures.py | 3 +- 14 files changed, 221 insertions(+), 36 deletions(-) create mode 100644 bioimageio/core/dataset.py create mode 100644 bioimageio/core/sample.py create mode 100644 bioimageio/core/tile.py diff --git a/bioimageio/core/_op_base.py b/bioimageio/core/_op_base.py index 8392f8e5..a0ca7ae1 100644 --- a/bioimageio/core/_op_base.py +++ b/bioimageio/core/_op_base.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Collection -from bioimageio.core.common import Sample +from bioimageio.core.sample import Sample from bioimageio.core.stat_measures import Measure diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index d8b6aad1..94f86b35 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -2,11 +2,12 @@ from types import MappingProxyType from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union -from bioimageio.core.common import Sample, Tensor, TensorId +from bioimageio.core.common import Tensor, TensorId from bioimageio.core.model_adapters import ModelAdapter, create_model_adapter from bioimageio.core.model_adapters import get_weight_formats as get_weight_formats from bioimageio.core.proc_ops import Processing from bioimageio.core.proc_setup import setup_pre_and_postprocessing +from bioimageio.core.sample import Sample from bioimageio.core.stat_measures import DatasetMeasure, MeasureValue from bioimageio.spec.model import AnyModelDescr, v0_4 from bioimageio.spec.model.v0_5 import WeightsFormat diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 10878dd4..7b9b3280 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,5 +1,7 @@ -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, Iterable, Literal, Protocol +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Literal, Mapping, Protocol, Tuple, Union import xarray as xr @@ -26,20 +28,5 @@ class AxisLike(Protocol): BatchSize = int Tensor = xr.DataArray - Data = Dict[TensorId, Tensor] Stat = Dict["Measure", "MeasureValue"] - - -@dataclass -class Sample: - """A (dataset) sample""" - - data: Data = field(default_factory=dict) - """the samples tensors""" - - stat: Stat = field(default_factory=dict) - """sample and dataset statistics""" - - -Dataset = Iterable[Sample] diff --git a/bioimageio/core/dataset.py b/bioimageio/core/dataset.py new file mode 100644 index 00000000..59361b2d --- /dev/null +++ b/bioimageio/core/dataset.py @@ -0,0 +1,5 @@ +from typing import Iterable + +from bioimageio.core.sample import Sample + +Dataset = Iterable[Sample] diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 5299df4a..18521fa9 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -20,11 +20,11 @@ from bioimageio.core._op_base import Operator from bioimageio.core.common import ( AxisId, - Sample, Stat, Tensor, TensorId, ) +from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import StatsCalculator from bioimageio.core.stat_measures import ( DatasetMean, diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index 2e84a010..8202bc97 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -11,13 +11,13 @@ from typing_extensions import assert_never -from bioimageio.core.common import Sample from bioimageio.core.proc_ops import ( AddKnownDatasetStats, Processing, UpdateStats, get_proc_class, ) +from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import StatsCalculator from bioimageio.core.stat_measures import DatasetMeasure, Measure, MeasureValue from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py new file mode 100644 index 00000000..d2b90638 --- /dev/null +++ b/bioimageio/core/sample.py @@ -0,0 +1,105 @@ +from dataclasses import dataclass, field +from typing import Dict, Iterable, Iterator, Mapping, Optional, Tuple, Union, cast +from typing_extensions import Self +from bioimageio.core.common import AxisId, Data, Stat, Tensor, TensorId + +from .tile import SampleSizes, TensorTilePos, Tile, TilePos, tile_tensor + +TiledSample = Iterable[Tile] +"""A dataset sample split into tiles""" + + +@dataclass +class Sample: + """A dataset sample""" + + data: Data + """the sample's tensors""" + + stat: Stat = field(default_factory=dict) + """sample and dataset statistics""" + + @property + def sizes(self) -> SampleSizes: + return {tid: cast(Dict[AxisId, int], dict(t.sizes)) for tid, t in self.data.items()} + + def tile( + self, + tile_shape: Mapping[TensorId, Mapping[AxisId, int]], + pad_width: Mapping[TensorId, Mapping[AxisId, Union[int, Tuple[int, int]]]], + ) -> TiledSample: + return tile_sample(self, tile_shape, pad_width) + + @classmethod + def from_tiles(cls, tiles: Iterable[Tile]) -> Self: + data: Data = {} + stat: Stat = {} + for tile in tiles: + for tid, tile_data in tile.data.items(): + + stat = tile.stat + + return cls(data=data, stat=stat) + +def tile_sample( + sample: Sample, + tile_shape: Mapping[TensorId, Mapping[AxisId, int]], + pad_width: Mapping[TensorId, Mapping[AxisId, Union[int, Tuple[int, int]]]], +): + assert all(tid in sample.data for tid in tile_shape), (tile_shape, sample.data) + assert all(tid in pad_width for tid in tile_shape), (tile_shape, pad_width) + tensor_ids = list(tile_shape) + + tile_generators: Dict[TensorId, Iterable[Tuple[int, TensorTilePos, Tensor]]] = {} + n_tiles: Dict[TensorId, int] = {} + for tid in tensor_ids: + n_tiles[tid], tile_generators[tid] = tile_tensor( + sample.data[tid], tile_shape=tile_shape[tid], pad_width=pad_width[tid] + ) + + n_tiles_common: Optional[int] = None + single_tile_tensors: Dict[TensorId, Tuple[TensorTilePos, Tensor]] = {} + tile_iterators: Dict[TensorId, Iterator[Tuple[int, TensorTilePos, Tensor]]] = {} + for tid, n in n_tiles.items(): + tile_iterator = iter(tile_generators[tid]) + if n == 1: + t0, pos, tensor_tile = next(tile_iterator) + assert t0 == 0 + single_tile_tensors[tid] = (pos, tensor_tile) + continue + + if n_tiles_common is None: + n_tiles_common = n + elif n != n_tiles_common: + raise ValueError( + f"{sample} tiled by {tile_shape} yields different numbers of tiles: {n_tiles}" + ) + + tile_iterators[tid] = tile_iterator + + if n_tiles_common is None: + assert not tile_iterators + n_tiles_common = 1 + + for t in range(n_tiles_common): + data: Dict[TensorId, Tensor] = {} + tile_pos: TilePos = {} + for tid, (tensor_pos, tensor_tile) in single_tile_tensors.items(): + data[tid] = tensor_tile + tile_pos[tid] = tensor_pos + + for tid, tile_iterator in tile_iterators.items(): + assert tid not in data + assert tid not in tile_pos + _t, tensor_pos, tensor_tile = next(tile_iterator) + assert _t == t, (_t, t) + data[tid] = tensor_tile + tile_pos[tid] = tensor_pos + + yield Tile( + data=data, + pos=tile_pos, + tile_number=t, + tiles_in_sample=n_tiles_common, + stat=sample.stat, + ) diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 818e6303..176cdd35 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -28,9 +28,9 @@ from bioimageio.core.common import ( AxisId, - Sample, TensorId, ) +from bioimageio.core.sample import Sample from bioimageio.core.stat_measures import ( DatasetMean, DatasetMeasure, diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index 5c599af2..93d4b1fd 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -6,7 +6,8 @@ import xarray as xr -from bioimageio.core.common import AxisId, Sample, TensorId +from bioimageio.core.common import AxisId, TensorId +from bioimageio.core.sample import Sample MeasureValue = Union[float, xr.DataArray] diff --git a/bioimageio/core/tile.py b/bioimageio/core/tile.py new file mode 100644 index 00000000..6e4db979 --- /dev/null +++ b/bioimageio/core/tile.py @@ -0,0 +1,80 @@ +import itertools +from dataclasses import dataclass, field +from math import prod +from typing import Dict, Iterable, List, Mapping, Tuple, Union, cast + +from .common import AxisId, Data, LeftRight, Stat, Tensor, TensorId + +TensorTilePos = Dict[AxisId, int] +TilePos = Dict[TensorId, TensorTilePos] +TensorSampleSize = Dict[AxisId, int] +SampleSizes = Dict[TensorId, TensorSampleSize] + + +@dataclass +class Tile: + """A tile of a dataset sample""" + + data: Data + """the tile's tensors""" + + pos: TilePos + """position of the inner origin (origin of tile if halo is cropped) within the sample""" + + halo: Dict[AxisId, LeftRight] + """padded or overlapping border region""" + + tile_number: int + """the n-th tile of the sample""" + + tiles_in_sample: int + """total number of tiles of the sample""" + + sample_sizes: SampleSizes + """the axis sizes of the sample""" + + stat: Stat = field(default_factory=dict) + """sample and dataset statistics""" + + +def _tile_generator(tensor: Tensor, all_1d_tiles: List[List[Tuple[int, slice]]]): + axes = cast(Tuple[AxisId, ...], tensor.dims) + for i, tile in enumerate(itertools.product(*all_1d_tiles)): + pos: TensorTilePos = {a: p for a, (p, _) in zip(axes, tile)} + tile_slice = {a: s for a, (_, s) in zip(axes, tile)} + yield i, pos, tensor[tile_slice] + + +def tile_tensor( + tensor: Tensor, + tile_shape: Mapping[AxisId, int], + pad_width: Mapping[AxisId, Union[int, Tuple[int, int]]], +) -> Tuple[int, Iterable[Tuple[int, TensorTilePos, Tensor]]]: + """tile a tensor + + Args: + tile_shape: output tile shape + pad_width: padding at edge of sample, overlap with neighboring tiles within the sample + + """ + assert all(aid in tensor.dims for aid in tile_shape), (tensor.dims, set(tile_shape)) + assert all(aid in tensor.dims for aid in pad_width), (tensor.dims, set(pad_width)) + assert all(aid in tile_shape for aid in tensor.dims), (tensor.dims, set(tile_shape)) + assert all(aid in pad_width for aid in tensor.dims), (tensor.dims, set(pad_width)) + + axes = cast(Tuple[AxisId, ...], tensor.dims) + + all_1d_tiles: List[List[Tuple[int, slice]]] = [] + shape = tensor.shape + for aid, s in zip(axes, shape): + pad = _pad if isinstance(_pad := pad_width[aid], tuple) else (_pad, _pad) + stride = tile_shape[aid] - sum(pad) + tiles_1d = [ + (p, slice(max(0, p - pad[0]), min(s, p + pad[1]))) + for p in range(0, s, stride) + ] + all_1d_tiles.append(tiles_1d) + + n_tiles = prod(map(len, all_1d_tiles)) + + return n_tiles, _tile_generator(tensor, all_1d_tiles) diff --git a/bioimageio/core/utils/tiling.py b/bioimageio/core/utils/tiling.py index 658edc76..fb89a2d2 100644 --- a/bioimageio/core/utils/tiling.py +++ b/bioimageio/core/utils/tiling.py @@ -1,16 +1,19 @@ import warnings -from typing import Dict, Literal, Mapping, Tuple, Union +from typing import Dict, Literal, Mapping, Union from typing_extensions import assert_never -from bioimageio.core.common import Tensor from bioimageio.spec.model.v0_5 import AxisId +from ..common import LeftRight, Tensor + +PadMode = Literal["edge", "reflect", "symmetric"] + def pad( tensor: Tensor, - pad_width: Mapping[AxisId, Union[int, Tuple[int, int]]], - mode: Literal["edge", "reflect", "symmetric"] = "symmetric", + pad_width: Mapping[AxisId, Union[int, LeftRight]], + mode: PadMode = "symmetric", ): return tensor.pad(pad_width={str(k): v for k, v in pad_width.items()}, mode=mode) @@ -22,7 +25,7 @@ def pad_to( Literal["before", "center", "after"], Mapping[AxisId, Literal["before", "center", "after"]], ] = "center", - mode: Literal["edge", "reflect", "symmetric"] = "symmetric", + mode: PadMode = "symmetric", ): """pad `tensor` to match `sizes`""" axes = [AxisId(str(a)) for a in tensor.dims] @@ -33,7 +36,7 @@ def pad_to( else: pad_axis_where = pad_where - pad_width: Dict[AxisId, Union[int, Tuple[int, int]]] = {} + pad_width: Dict[AxisId, Union[int, LeftRight]] = {} for a, s_is in tensor.sizes.items(): a = AxisId(str(a)) if a not in sizes or sizes[a] == s_is: @@ -51,11 +54,11 @@ def pad_to( pad_this_axis_where = pad_axis_where[a] p = sizes[a] - s_is if pad_this_axis_where == "before": - pad_width[a] = (p, 0) + pad_width[a] = LeftRight(p, 0) elif pad_this_axis_where == "after": - pad_width[a] = (0, p) + pad_width[a] = LeftRight(0, p) elif pad_this_axis_where == "center": - pad_width[a] = (left := p // 2, p - left) + pad_width[a] = LeftRight(left := p // 2, p - left) else: assert_never(pad_this_axis_where) @@ -119,7 +122,7 @@ def resize_to( Literal["before", "center", "after"], Mapping[AxisId, Literal["before", "center", "after"]], ] = "center", - pad_mode: Literal["edge", "reflect", "symmetric"] = "symmetric", + pad_mode: PadMode = "symmetric", ): """crop and pad `tensor` to match `sizes`""" crop_to_sizes: Dict[AxisId, int] = {} diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py index c8b0b6d5..6c43df8d 100644 --- a/tests/test_proc_ops.py +++ b/tests/test_proc_ops.py @@ -5,7 +5,8 @@ import xarray as xr from typing_extensions import TypeGuard -from bioimageio.core.common import AxisId, Sample, TensorId +from bioimageio.core.common import AxisId, TensorId +from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import compute_measures from bioimageio.core.stat_measures import SampleMean, SamplePercentile, SampleStd diff --git a/tests/test_stat_calculators.py b/tests/test_stat_calculators.py index 4d1117a8..68507acb 100644 --- a/tests/test_stat_calculators.py +++ b/tests/test_stat_calculators.py @@ -4,7 +4,8 @@ import pytest from xarray.testing import assert_allclose # pyright: ignore[reportUnknownVariableType] -from bioimageio.core.common import AxisId, Sample, Tensor, TensorId +from bioimageio.core.common import AxisId, Tensor, TensorId +from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import MeanVarStdCalculator from bioimageio.core.stat_measures import ( DatasetMean, diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index 1bd6231f..762c27d7 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -6,7 +6,8 @@ import xarray as xr from bioimageio.core import stat_measures -from bioimageio.core.common import AxisId, Sample, Tensor, TensorId +from bioimageio.core.common import AxisId, Tensor, TensorId +from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import ( SamplePercentilesCalculator, get_measure_calculators, From 27332c0455337fc7b36cd550c37234b4e0f20be3 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 27 Mar 2024 15:22:40 +0100 Subject: [PATCH 168/244] WIP improve Tile --- bioimageio/core/common.py | 21 +++++++++- bioimageio/core/sample.py | 20 +++++++++- bioimageio/core/tile.py | 84 ++++++++++++++++++++++++++++++++++----- 3 files changed, 112 insertions(+), 13 deletions(-) diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 7b9b3280..aa98a92c 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,7 +1,16 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Literal, Mapping, Protocol, Tuple, Union +from typing import ( + TYPE_CHECKING, + Dict, + Literal, + Mapping, + NamedTuple, + Protocol, + Tuple, + Union, +) import xarray as xr @@ -30,3 +39,13 @@ class AxisLike(Protocol): Data = Dict[TensorId, Tensor] Stat = Dict["Measure", "MeasureValue"] + + +class LeftRight(NamedTuple): + left: int + right: int + + +class SliceInfo(NamedTuple): + start: int + stop: int diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index d2b90638..04c3c34a 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -1,6 +1,10 @@ from dataclasses import dataclass, field from typing import Dict, Iterable, Iterator, Mapping, Optional, Tuple, Union, cast + +import numpy +import xarray from typing_extensions import Self + from bioimageio.core.common import AxisId, Data, Stat, Tensor, TensorId from .tile import SampleSizes, TensorTilePos, Tile, TilePos, tile_tensor @@ -21,7 +25,9 @@ class Sample: @property def sizes(self) -> SampleSizes: - return {tid: cast(Dict[AxisId, int], dict(t.sizes)) for tid, t in self.data.items()} + return { + tid: cast(Dict[AxisId, int], dict(t.sizes)) for tid, t in self.data.items() + } def tile( self, @@ -32,15 +38,25 @@ def tile( @classmethod def from_tiles(cls, tiles: Iterable[Tile]) -> Self: + # TODO: add `mode: Literal['in-memory', 'to-disk']` or similar to save out of mem samples data: Data = {} stat: Stat = {} for tile in tiles: for tid, tile_data in tile.data.items(): - + if tid not in data: + axes = cast(Tuple[AxisId], tile_data.dims) + data[tid] = Tensor( + numpy.zeros( + tuple(tile.sample_sizes[tid][a] for a in axes), + dtype=tile_data.dtype, # pyright: ignore[reportUnknownArgumentType] + ), + dims=axes, + ) stat = tile.stat return cls(data=data, stat=stat) + def tile_sample( sample: Sample, tile_shape: Mapping[TensorId, Mapping[AxisId, int]], diff --git a/bioimageio/core/tile.py b/bioimageio/core/tile.py index 6e4db979..28f942a8 100644 --- a/bioimageio/core/tile.py +++ b/bioimageio/core/tile.py @@ -3,12 +3,18 @@ from math import prod from typing import Dict, Iterable, List, Mapping, Tuple, Union, cast -from .common import AxisId, Data, LeftRight, Stat, Tensor, TensorId +from xarray.core.utils import Frozen -TensorTilePos = Dict[AxisId, int] -TilePos = Dict[TensorId, TensorTilePos] -TensorSampleSize = Dict[AxisId, int] -SampleSizes = Dict[TensorId, TensorSampleSize] +from .common import AxisId, Data, LeftRight, SliceInfo, Stat, Tensor, TensorId + +# TensorTilePos = Mapping[AxisId, int] +# TilePos = Mapping[TensorId, TensorTilePos] +TensorTileSlice = Mapping[AxisId, SliceInfo] +TileSlice = Mapping[TensorId, TensorTileSlice] +TensorSampleSize = Mapping[AxisId, int] +SampleSizes = Mapping[TensorId, TensorSampleSize] +TensorTileHalo = Mapping[AxisId, LeftRight] +TileHalo = Mapping[TensorId, TensorTileHalo] @dataclass @@ -18,11 +24,69 @@ class Tile: data: Data """the tile's tensors""" - pos: TilePos - """position of the inner origin (origin of tile if halo is cropped) within the sample""" - - halo: Dict[AxisId, LeftRight] - """padded or overlapping border region""" + inner_slice: TileSlice + """slice of the inner tile (without padding and overlap) of the sample""" + + halo: TileHalo + """pad/overlap to extend the (inner) tile (to the outer tile)""" + + outer_slice: Frozen[TensorId, Frozen[AxisId, SliceInfo]] = field(init=False) + """slice of the outer tile (including overlap, but not padding) in the sample""" + + overlap: Frozen[TensorId, Frozen[AxisId, LeftRight]] = field(init=False) + """overlap 'into a neighboring tile'""" + + padding: Frozen[TensorId, Frozen[AxisId, LeftRight]] = field(init=False) + """pad (at sample edges where we cannot overlap to realize `halo`""" + + def __post_init__(self): + self.outer_slice = Frozen( + { + t: Frozen( + { + a: SliceInfo( + max(0, self.inner_slice[t][a].start - self.halo[t][a].left), + min( + self.sample_sizes[t][a], + self.inner_slice[t][a].stop + self.halo[t][a].right, + ), + ) + for a in self.inner_slice[t] + } + ) + for t in self.inner_slice + } + ) + self.overlap = Frozen( + { + tid: Frozen( + { + a: LeftRight( + self.inner_slice[tid][a].start + - self.outer_slice[tid][a].start, + self.outer_slice[tid][a].stop + - self.inner_slice[tid][a].stop, + ) + for a in self.inner_slice[tid] + } + ) + for tid in self.inner_slice + } + ) + self.padding = Frozen( + { + tid: Frozen( + { + a: LeftRight( + self.halo[tid][a].left - self.overlap[tid][a].left, + self.halo[tid][a].right - self.overlap[tid][a].right, + ) + for a in self.inner_slice[tid] + } + ) + for tid in self.inner_slice + } + ) tile_number: int """the n-th tile of the sample""" From 075b902a2c0105886b312d6eb314959d030187cf Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 27 Mar 2024 15:23:17 +0100 Subject: [PATCH 169/244] customize pdoc dev docs --- .github/workflows/build.yaml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index d8c6fb95..d5dea6f1 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -144,8 +144,16 @@ jobs: python-version: '3.12' cache: 'pip' - run: pip install -e .[dev] + - id: get_version + run: python -c 'import bioimageio.core;print(f"version={bioimageio.core.__version__}")' >> $GITHUB_OUTPUT - name: Generate developer docs - run: pdoc -o ./dist bioimageio.spec + run: | + pdoc \ + --logo https://bioimage.io/static/img/bioimage-io-logo.svg \ + --logo-link https://bioimage.io/ \ + --favicon https://bioimage.io/static/img/bioimage-io-icon-small.svg \ + --footer-text 'bioimageio.core ${{steps.get_version.outputs.version}}' \ + -o ./dist bioimageio.core - run: cp README.md ./dist/README.md - name: Deploy to gh-pages 🚀 uses: JamesIves/github-pages-deploy-action@v4 From 636861ab08f506b5b6660ced988d28782aea90de Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 28 Mar 2024 02:10:37 +0100 Subject: [PATCH 170/244] WIP axis, tensor, tile, sample --- bioimageio/core/__init__.py | 4 + bioimageio/core/_prediction_pipeline.py | 2 +- bioimageio/core/_resource_tests.py | 2 +- bioimageio/core/axis.py | 112 +++++ bioimageio/core/common.py | 75 ++-- bioimageio/core/io.py | 30 ++ .../model_adapters/_keras_model_adapter.py | 2 +- .../core/model_adapters/_model_adapter.py | 2 +- .../model_adapters/_onnx_model_adapter.py | 2 +- .../model_adapters/_pytorch_model_adapter.py | 2 +- .../_tensorflow_model_adapter.py | 2 +- .../_torchscript_model_adapter.py | 2 +- bioimageio/core/proc_ops.py | 7 +- bioimageio/core/sample.py | 183 ++++---- bioimageio/core/stat_calculators.py | 4 +- bioimageio/core/stat_measures.py | 6 +- bioimageio/core/tensor.py | 408 ++++++++++++++++++ bioimageio/core/tile.py | 199 ++++----- bioimageio/core/utils/_digest_spec.py | 40 +- bioimageio/core/utils/image_helper.py | 188 -------- bioimageio/core/utils/testing.py | 1 + bioimageio/core/utils/tiling.py | 150 ------- tests/test_proc_ops.py | 3 +- tests/test_stat_calculators.py | 3 +- tests/test_stat_measures.py | 3 +- tests/utils/test_image_helper.py | 4 +- 26 files changed, 842 insertions(+), 594 deletions(-) create mode 100644 bioimageio/core/axis.py create mode 100644 bioimageio/core/io.py create mode 100644 bioimageio/core/tensor.py delete mode 100644 bioimageio/core/utils/image_helper.py delete mode 100644 bioimageio/core/utils/tiling.py diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index ef261839..2692cc70 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -24,6 +24,10 @@ from ._resource_tests import test_model as test_model from ._settings import settings as settings from .utils import VERSION +from .tensor import Tensor as Tensor +from .tile import Tile as Tile +from .sample import Sample as Sample + __version__ = VERSION diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 94f86b35..33c83eb3 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -2,13 +2,13 @@ from types import MappingProxyType from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union -from bioimageio.core.common import Tensor, TensorId from bioimageio.core.model_adapters import ModelAdapter, create_model_adapter from bioimageio.core.model_adapters import get_weight_formats as get_weight_formats from bioimageio.core.proc_ops import Processing from bioimageio.core.proc_setup import setup_pre_and_postprocessing from bioimageio.core.sample import Sample from bioimageio.core.stat_measures import DatasetMeasure, MeasureValue +from bioimageio.core.Tensor import Tensor, TensorId from bioimageio.spec.model import AnyModelDescr, v0_4 from bioimageio.spec.model.v0_5 import WeightsFormat diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index fa144775..82bd316b 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -5,7 +5,7 @@ import numpy as np from bioimageio.core._prediction_pipeline import create_prediction_pipeline -from bioimageio.core.common import AxisId, BatchSize +from bioimageio.core.axis import AxisId, BatchSize from bioimageio.core.utils import VERSION, get_test_inputs, get_test_outputs from bioimageio.core.utils.tiling import resize_to from bioimageio.spec import ( diff --git a/bioimageio/core/axis.py b/bioimageio/core/axis.py new file mode 100644 index 00000000..e86c7911 --- /dev/null +++ b/bioimageio/core/axis.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal, Mapping, Optional, TypeVar, Union + +from typing_extensions import assert_never + +from bioimageio.spec.model import v0_5 + + +def _get_axis_type(a: Literal["b", "t", "i", "c", "x", "y", "z"]): + if a == "b": + return "batch" + elif a == "t": + return "time" + elif a == "i": + return "index" + elif a == "c": + return "channel" + elif a in ("x", "y", "z"): + return "space" + else: + assert_never(a) + + +S = TypeVar("S", bound=str) + + +def _get_axis_id(a: Union[Literal["b", "t", "i", "c"], S]): + if a == "b": + return AxisId("batch") + elif a == "t": + return AxisId("time") + elif a == "i": + return AxisId("index") + elif a == "c": + return AxisId("channel") + else: + return AxisId(a) + + +AxisId = v0_5.AxisId + +T = TypeVar("T") +PerAxis = Mapping[AxisId, T] + +BatchSize = int + +AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"] +AxisLike = Union[AxisLetter, v0_5.AnyAxis, "Axis"] + + +@dataclass +class Axis: + id: AxisId + type: Literal["batch", "channel", "index", "space", "time"] + + @classmethod + def create(cls, axis: AxisLike) -> Axis: + if isinstance(axis, cls): + return axis + elif isinstance(axis, Axis): + return Axis(id=axis.id, type=axis.type) + elif isinstance(axis, str): + return Axis(id=_get_axis_id(axis), type=_get_axis_type(axis)) + elif isinstance(axis, v0_5.AxisBase): + return Axis(id=AxisId(axis.id), type=axis.type) + else: + assert_never(axis) + + +@dataclass +class AxisInfo(Axis): + maybe_singleton: bool + + @classmethod + def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisInfo: + if isinstance(axis, AxisInfo): + return axis + + axis_base = super().create(axis) + if maybe_singleton is None: + if isinstance(axis, Axis): + maybe_singleton = False + elif isinstance(axis, str): + maybe_singleton = axis == "b" + else: + if axis.size is None: + maybe_singleton = True + elif isinstance(axis.size, int): + maybe_singleton = axis.size == 1 + elif isinstance(axis.size, v0_5.SizeReference): + maybe_singleton = ( + False # TODO: check if singleton is ok for a `SizeReference` + ) + elif isinstance( + axis.size, (v0_5.ParameterizedSize, v0_5.DataDependentSize) + ): + try: + maybe_size_one = axis.size.validate_size( + 1 + ) # TODO: refactor validate_size() to have boolean func here + except ValueError: + maybe_singleton = False + else: + maybe_singleton = maybe_size_one == 1 + else: + assert_never(axis.size) + + return AxisInfo( + id=axis_base.id, type=axis_base.type, maybe_singleton=maybe_singleton + ) diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index aa98a92c..1c94c77f 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,51 +1,64 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Dict, - Literal, - Mapping, - NamedTuple, - Protocol, - Tuple, - Union, -) +from typing import Literal, NamedTuple, Tuple, TypeVar, Union -import xarray as xr +from typing_extensions import Self, assert_never -from bioimageio.spec.model import v0_5 +DTypeStr = Literal[ + "bool", + "float32", + "float64", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", +] -if TYPE_CHECKING: - from bioimageio.core.stat_measures import Measure, MeasureValue -TensorId = v0_5.TensorId -AxisId = v0_5.AxisId +LeftRight_T = TypeVar("LeftRight_T", bound="LeftRight") +LeftRightLike = Union[int, Tuple[int, int], LeftRight_T] -@dataclass -class Axis: - id: AxisId - type: Literal["batch", "channel", "index", "space", "time"] +class LeftRight(NamedTuple): + left: int + right: int + @classmethod + def create(cls, like: LeftRightLike[Self]) -> Self: + if isinstance(like, cls): + return like + elif isinstance(like, tuple): + return cls(*like) + elif isinstance(like, int): + return cls(like, like) + else: + assert_never(like) -class AxisLike(Protocol): - id: str - type: Literal["batch", "channel", "index", "space", "time"] +class Halo(LeftRight): + pass -BatchSize = int -Tensor = xr.DataArray -Data = Dict[TensorId, Tensor] -Stat = Dict["Measure", "MeasureValue"] +HaloLike = LeftRightLike[Halo] -class LeftRight(NamedTuple): - left: int - right: int +class PadWidth(LeftRight): + pass + + +PadWidthLike = LeftRightLike[PadWidth] +PadMode = Literal["edge", "reflect", "symmetric"] +PadWhere = Literal["before", "center", "after"] class SliceInfo(NamedTuple): start: int stop: int + + +TileNumber = int +TotalNumberOfTiles = int diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py new file mode 100644 index 00000000..557e61bb --- /dev/null +++ b/bioimageio/core/io.py @@ -0,0 +1,30 @@ +from pathlib import Path +from typing import Optional, Sequence, Union + +import imageio + +from bioimageio.core.axis import Axis, AxisLike +from bioimageio.spec.model import v0_5 +from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr04 +from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensorDescr04 +from bioimageio.spec.utils import load_array + +from .tensor import Tensor, TensorId + + +def load_tensor( + path: Path, axes: Optional[Sequence[AxisLike]] = None, id: Optional[TensorId] = None +) -> Tensor: + + ext = path.suffix + if ext == ".npy": + array = load_array(path) + else: + is_volume = ( + True + if axes is None + else sum(Axis.create(a).type != "channel" for a in axes) > 2 + ) + array = imageio.volread(path) if is_volume else imageio.imread(path) + + return Tensor.from_numpy(array, axes, id=TensorId(path.stem) if id is None else id) diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py index 6ab18624..5d956807 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -4,7 +4,7 @@ from loguru import logger from numpy.typing import NDArray -from bioimageio.core.common import Tensor +from bioimageio.core.Tensor import Tensor from bioimageio.spec._internal.io_utils import download from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import Version diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index 7d206425..3560d61d 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import List, Optional, Sequence, Tuple, Union, final -from bioimageio.core.common import Tensor +from bioimageio.core.Tensor import Tensor from bioimageio.spec.model import v0_4, v0_5 WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat] diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py index 9811efa2..fb5c6648 100644 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -3,7 +3,7 @@ from numpy.typing import NDArray -from bioimageio.core.common import Tensor +from bioimageio.core.Tensor import Tensor from bioimageio.spec.model import v0_4, v0_5 from ._model_adapter import ModelAdapter diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 5a5a9e83..839919f6 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -2,7 +2,7 @@ import warnings from typing import Any, List, Optional, Sequence, Tuple, Union -from bioimageio.core.common import Tensor +from bioimageio.core.Tensor import Tensor from bioimageio.core.utils import import_callable from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index f2942d89..eecc9b45 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -4,7 +4,7 @@ import numpy as np -from bioimageio.core.common import Tensor +from bioimageio.core.Tensor import Tensor from bioimageio.spec.common import FileSource from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index ec432d71..4f0a50ba 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -5,7 +5,7 @@ import numpy as np from numpy.typing import NDArray -from bioimageio.core.common import Tensor +from bioimageio.core.Tensor import Tensor from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 18521fa9..8523f991 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -18,11 +18,8 @@ from typing_extensions import Self, assert_never from bioimageio.core._op_base import Operator -from bioimageio.core.common import ( +from bioimageio.core.axis import ( AxisId, - Stat, - Tensor, - TensorId, ) from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import StatsCalculator @@ -37,8 +34,10 @@ SampleMean, SamplePercentile, SampleStd, + Stat, StdMeasure, ) +from bioimageio.core.Tensor import Tensor, TensorId from bioimageio.spec.model import v0_4, v0_5 diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 04c3c34a..82f2fc75 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -2,12 +2,14 @@ from typing import Dict, Iterable, Iterator, Mapping, Optional, Tuple, Union, cast import numpy -import xarray from typing_extensions import Self +from xarray.core.utils import Frozen -from bioimageio.core.common import AxisId, Data, Stat, Tensor, TensorId - -from .tile import SampleSizes, TensorTilePos, Tile, TilePos, tile_tensor +from .axis import AxisId, PerAxis +from .common import Halo, HaloLike, PadMode, PadWidth, SliceInfo, TileNumber +from .stat_measures import Stat +from .tensor import PerTensor, Tensor, TensorId +from .tile import Tile, tile_tensor TiledSample = Iterable[Tile] """A dataset sample split into tiles""" @@ -17,105 +19,118 @@ class Sample: """A dataset sample""" - data: Data + data: PerTensor[Tensor] """the sample's tensors""" stat: Stat = field(default_factory=dict) """sample and dataset statistics""" @property - def sizes(self) -> SampleSizes: - return { - tid: cast(Dict[AxisId, int], dict(t.sizes)) for tid, t in self.data.items() - } + def sizes(self) -> PerTensor[PerAxis[int]]: + return {tid: t.sizes for tid, t in self.data.items()} def tile( self, - tile_shape: Mapping[TensorId, Mapping[AxisId, int]], - pad_width: Mapping[TensorId, Mapping[AxisId, Union[int, Tuple[int, int]]]], + tile_sizes: PerTensor[PerAxis[int]], + minimum_halo: PerTensor[PerAxis[HaloLike]], ) -> TiledSample: - return tile_sample(self, tile_shape, pad_width) + assert not ( + missing := [t for t in tile_sizes if t not in self.data] + ), f"`tile_sizes` specified for missing tensors: {missing}" + assert not ( + missing := [t for t in minimum_halo if t not in tile_sizes] + ), f"`minimum_halo` specified for tensors without `tile_sizes`: {missing}" + + tensor_ids = list(tile_sizes) + + tensor_tile_generators: Dict[ + TensorId, Iterable[Tuple[TileNumber, Tensor, PerAxis[SliceInfo]]] + ] = {} + n_tiles: Dict[TensorId, int] = {} + for t in tensor_ids: + n_tiles[t], tensor_tile_generators[t] = tile_tensor( + self.data[t], + tile_sizes=tile_sizes.get(t, self.data[t].sizes), + minimum_halo=minimum_halo.get(t, {a: 0 for a in self.data[t].dims}), + pad_mode=pad_mode, + ) + + n_tiles_common: Optional[int] = None + single_tile_tensors: Dict[TensorId, Tuple[TensorTilePos, Tensor]] = {} + tile_iterators: Dict[TensorId, Iterator[Tuple[int, TensorTilePos, Tensor]]] = {} + for t, n in n_tiles.items(): + tile_iterator = iter(tensor_tile_generators[t]) + if n == 1: + t0, pos, tensor_tile = next(tile_iterator) + assert t0 == 0 + single_tile_tensors[t] = (pos, tensor_tile) + continue + + if n_tiles_common is None: + n_tiles_common = n + elif n != n_tiles_common: + raise ValueError( + f"{self} tiled by {tile_sizes} yields different numbers of tiles: {n_tiles}" + ) + + tile_iterators[t] = tile_iterator + + if n_tiles_common is None: + assert not tile_iterators + n_tiles_common = 1 + + for t in range(n_tiles_common): + data: Dict[TensorId, Tensor] = {} + tile_pos: TilePos = {} + inner_slice: TileSlice = {} + outer_slice: TileSlice = {} + for t, (tensor_tile, tensor_pos) in single_tile_tensors.items(): + data[t] = tensor_tile + tile_pos[t] = tensor_pos + inner_slice[t] = inner_tensor_slice + outer_slice[t] = outer_tensor_slice + + for t, tile_iterator in tile_iterators.items(): + assert t not in data + assert t not in tile_pos + _t, tensor_pos, tensor_tile = next(tile_iterator) + assert _t == t, (_t, t) + data[t] = tensor_tile + tile_pos[t] = tensor_pos + + yield Tile( + data=data, + pos=tile_pos, + inner_slice=inner_slice, + outer_slice=outer_slice, + tile_number=t, + tiles_in_self=n_tiles_common, + stat=self.stat, + ) @classmethod - def from_tiles(cls, tiles: Iterable[Tile]) -> Self: + def from_tiles( + cls, tiles: Iterable[Tile], *, fill_value: float = float("nan") + ) -> Self: # TODO: add `mode: Literal['in-memory', 'to-disk']` or similar to save out of mem samples - data: Data = {} + data: TileData = {} stat: Stat = {} for tile in tiles: - for tid, tile_data in tile.data.items(): - if tid not in data: + for t, tile_data in tile.inner_data.items(): + if t not in data: axes = cast(Tuple[AxisId], tile_data.dims) - data[tid] = Tensor( - numpy.zeros( - tuple(tile.sample_sizes[tid][a] for a in axes), - dtype=tile_data.dtype, # pyright: ignore[reportUnknownArgumentType] + data[t] = Tensor( + numpy.full( + tuple(tile.sample_sizes[t][a] for a in axes), + fill_value, + dtype=tile_data.dtype, ), dims=axes, + id=t, ) - stat = tile.stat - - return cls(data=data, stat=stat) + data[t][tile.inner_slice[t]] = tile_data -def tile_sample( - sample: Sample, - tile_shape: Mapping[TensorId, Mapping[AxisId, int]], - pad_width: Mapping[TensorId, Mapping[AxisId, Union[int, Tuple[int, int]]]], -): - assert all(tid in sample.data for tid in tile_shape), (tile_shape, sample.data) - assert all(tid in pad_width for tid in tile_shape), (tile_shape, pad_width) - tensor_ids = list(tile_shape) - - tile_generators: Dict[TensorId, Iterable[Tuple[int, TensorTilePos, Tensor]]] = {} - n_tiles: Dict[TensorId, int] = {} - for tid in tensor_ids: - n_tiles[tid], tile_generators[tid] = tile_tensor( - sample.data[tid], tile_shape=tile_shape[tid], pad_width=pad_width[tid] - ) - - n_tiles_common: Optional[int] = None - single_tile_tensors: Dict[TensorId, Tuple[TensorTilePos, Tensor]] = {} - tile_iterators: Dict[TensorId, Iterator[Tuple[int, TensorTilePos, Tensor]]] = {} - for tid, n in n_tiles.items(): - tile_iterator = iter(tile_generators[tid]) - if n == 1: - t0, pos, tensor_tile = next(tile_iterator) - assert t0 == 0 - single_tile_tensors[tid] = (pos, tensor_tile) - continue - - if n_tiles_common is None: - n_tiles_common = n - elif n != n_tiles_common: - raise ValueError( - f"{sample} tiled by {tile_shape} yields different numbers of tiles: {n_tiles}" - ) + stat = tile.stat - tile_iterators[tid] = tile_iterator - - if n_tiles_common is None: - assert not tile_iterators - n_tiles_common = 1 - - for t in range(n_tiles_common): - data: Dict[TensorId, Tensor] = {} - tile_pos: TilePos = {} - for tid, (tensor_pos, tensor_tile) in single_tile_tensors.items(): - data[tid] = tensor_tile - tile_pos[tid] = tensor_pos - - for tid, tile_iterator in tile_iterators.items(): - assert tid not in data - assert tid not in tile_pos - _t, tensor_pos, tensor_tile = next(tile_iterator) - assert _t == t, (_t, t) - data[tid] = tensor_tile - tile_pos[tid] = tensor_pos - - yield Tile( - data=data, - pos=tile_pos, - tile_number=t, - tiles_in_sample=n_tiles_common, - stat=sample.stat, - ) + return cls(data=data, stat=stat) diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 176cdd35..851dfba6 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -26,9 +26,8 @@ from numpy.typing import NDArray from typing_extensions import assert_never -from bioimageio.core.common import ( +from bioimageio.core.axis import ( AxisId, - TensorId, ) from bioimageio.core.sample import Sample from bioimageio.core.stat_measures import ( @@ -46,6 +45,7 @@ SampleStd, SampleVar, ) +from bioimageio.core.Tensor import TensorId try: import crick diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index 93d4b1fd..83775fc9 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -2,12 +2,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Tuple, TypeVar, Union +from typing import Dict, Optional, Tuple, TypeVar, Union import xarray as xr -from bioimageio.core.common import AxisId, TensorId +from bioimageio.core.axis import AxisId from bioimageio.core.sample import Sample +from bioimageio.core.Tensor import TensorId MeasureValue = Union[float, xr.DataArray] @@ -144,6 +145,7 @@ def __post_init__(self): SampleMeasure = Union[SampleMean, SampleStd, SampleVar, SamplePercentile] DatasetMeasure = Union[DatasetMean, DatasetStd, DatasetVar, DatasetPercentile] Measure = Union[SampleMeasure, DatasetMeasure] +Stat = Dict[Measure, MeasureValue] MeanMeasure = Union[SampleMean, DatasetMean] StdMeasure = Union[SampleStd, DatasetStd] diff --git a/bioimageio/core/tensor.py b/bioimageio/core/tensor.py new file mode 100644 index 00000000..26384f0e --- /dev/null +++ b/bioimageio/core/tensor.py @@ -0,0 +1,408 @@ +from __future__ import annotations + +import itertools +from math import prod +from typing import ( + Any, + Dict, + Generator, + List, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, + get_args, +) + +import numpy as np +import xarray as xr +from loguru import logger +from numpy.typing import NDArray +from typing_extensions import Self, assert_never + +from bioimageio.core.axis import PerAxis +from bioimageio.core.common import PadMode, PadWhere +from bioimageio.spec.model import v0_4, v0_5 + +from .axis import Axis, AxisId, AxisInfo, AxisLike +from .common import ( + DTypeStr, + Halo, + HaloLike, + PadWidth, + SliceInfo, + TileNumber, + TotalNumberOfTiles, +) + +TensorId = v0_5.TensorId + +T = TypeVar("T") +PerTensor = Mapping[TensorId, T] + + +class Tensor: + def __init__( + self, + array: NDArray[Any], + dims: Union[AxisId, Sequence[AxisId]], + id: TensorId, + ) -> None: + super().__init__() + self._data = xr.DataArray(array, dims=dims, name=id) + self._id = id + + def __getitem__(self, key: PerAxis[Union[SliceInfo, slice]]) -> Self: + key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} + return self.__class__.from_xarray(self._data[key]) + + def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None: + key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} + self._data[key] = value._data + + @classmethod + def from_xarray(cls, data_array: xr.DataArray) -> Self: + if data_array.name is None: + raise ValueError( + "Expected a named `data_array` to use `data_array.name` as tensor id" + ) + + return cls( + array=data_array.data, + dims=tuple(AxisId(d) for d in data_array.dims), + id=TensorId(data_array.name), + ) + + @classmethod + def from_numpy( + cls, array: NDArray[Any], axes: Optional[Sequence[AxisLike]], id: TensorId + ) -> Tensor: + if axes is None: + return cls._interprete_array_wo_known_axes(array, id=id) + + original_shape = tuple(array.shape) + if len(array.shape) > len(axes): + # remove singletons + for i, s in enumerate(array.shape): + if s == 1: + array = np.take(array, 0, axis=i) + if len(array.shape) == len(axes): + break + + # add singletons if nececsary + for a in axes: + a = AxisInfo.create(a) + if len(array.shape) >= len(axes): + break + + if a.maybe_singleton: + array = array[None] + + if len(array.shape) != len(axes): + raise ValueError( + f"Array shape {original_shape} does not map to axes {axes}" + ) + + normalized_axes = normalize_axes(axes) + assert len(normalized_axes) == len(axes) + return Tensor(array, dims=tuple(a.id for a in normalized_axes)) + + @property + def data(self): + return self._data + + @property + def dims(self): + return cast(Tuple[AxisId, ...], self._data.dims) + + @property + def dtype(self) -> DTypeStr: + dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] + assert dt in get_args(DTypeStr) + return dt # pyright: ignore[reportReturnType] + + @property + def id(self): + return self._id + + @property + def sizes(self): + return cast(Mapping[AxisId, int], self.data.sizes) + + def crop_to( + tensor: Tensor, + sizes: Mapping[AxisId, int], + crop_where: Union[ + Literal["before", "center", "after"], + Mapping[AxisId, Literal["before", "center", "after"]], + ] = "center", + ): + """crop `tensor` to match `sizes`""" + axes = [AxisId(str(a)) for a in tensor.dims] + if crop_where in ("before", "center", "after"): + crop_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { + a: crop_where for a in axes + } + else: + crop_axis_where = crop_where + + slices: Dict[AxisId, slice] = {} + + for a, s_is in tensor.sizes.items(): + a = AxisId(str(a)) + if a not in sizes or sizes[a] == s_is: + pass + elif sizes[a] > s_is: + warnings.warn( + f"Cannot crop axis {a} of size {s_is} to larger size {sizes[a]}" + ) + elif a not in crop_axis_where: + raise ValueError( + f"Don't know where to crop axis {a}, `crop_where`={crop_where}" + ) + else: + crop_this_axis_where = crop_axis_where[a] + if crop_this_axis_where == "before": + slices[a] = slice(s_is - sizes[a], s_is) + elif crop_this_axis_where == "after": + slices[a] = slice(0, sizes[a]) + elif crop_this_axis_where == "center": + slices[a] = slice(start := (s_is - sizes[a]) // 2, sizes[a] + start) + else: + assert_never(crop_this_axis_where) + + return tensor.isel({str(a): s for a, s in slices.items()}) + + def mean(self, dim: Union[AxisId, Sequence[AxisId]]) -> Self: + return self.__class__.from_xarray(self._data.mean(dims=dim)) + + def std(self, dim: Union[AxisId, Sequence[AxisId]]) -> Self: + return self.__class__.from_xarray(self._data.std(dims=dim)) + + def var(self, dim: Union[AxisId, Sequence[AxisId]]) -> Self: + return self.__class__.from_xarray(self._data.var(dims=dim)) + + def pad( + self, + pad_width: PerAxis[PadWidth], + mode: PadMode = "symmetric", + ) -> Self: + return self.__class__.from_xarray( + self._data.pad(pad_width=pad_width, mode=mode) + ) + + def pad_to( + self, + sizes: PerAxis[int], + pad_where: Union[PadWhere, PerAxis[PadWhere]] = "center", + mode: PadMode = "symmetric", + ) -> Self: + """pad `tensor` to match `sizes`""" + if isinstance(pad_where, str): + pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} + else: + pad_axis_where = pad_where + + pad_width: Dict[AxisId, PadWidth] = {} + for a, s_is in self.sizes.items(): + if a not in sizes or sizes[a] == s_is: + pad_width[a] = PadWidth(0, 0) + elif s_is > sizes[a]: + pad_width[a] = PadWidth(0, 0) + logger.warning( + "Cannot pad axis {} of size {} to smaller size {}", + a, + s_is, + sizes[a], + ) + elif a not in pad_axis_where: + raise ValueError( + f"Don't know where to pad axis {a}, `pad_where`={pad_where}" + ) + else: + pad_this_axis_where = pad_axis_where[a] + p = sizes[a] - s_is + if pad_this_axis_where == "before": + pad_width[a] = PadWidth(p, 0) + elif pad_this_axis_where == "after": + pad_width[a] = PadWidth(0, p) + elif pad_this_axis_where == "center": + pad_width[a] = PadWidth(left := p // 2, p - left) + else: + assert_never(pad_this_axis_where) + + return self.pad(pad_width, mode) + + def resize_to( + tensor: Tensor, + sizes: Mapping[AxisId, int], + *, + pad_where: Union[ + Literal["before", "center", "after"], + Mapping[AxisId, Literal["before", "center", "after"]], + ] = "center", + crop_where: Union[ + Literal["before", "center", "after"], + Mapping[AxisId, Literal["before", "center", "after"]], + ] = "center", + pad_mode: PadMode = "symmetric", + ): + """crop and pad `tensor` to match `sizes`""" + crop_to_sizes: Dict[AxisId, int] = {} + pad_to_sizes: Dict[AxisId, int] = {} + new_axes = dict(sizes) + for a, s_is in tensor.sizes.items(): + a = AxisId(str(a)) + _ = new_axes.pop(a, None) + if a not in sizes or sizes[a] == s_is: + pass + elif s_is > sizes[a]: + crop_to_sizes[a] = sizes[a] + else: + pad_to_sizes[a] = sizes[a] + + if crop_to_sizes: + tensor = crop_to(tensor, crop_to_sizes, crop_where=crop_where) + + if pad_to_sizes: + tensor = pad_to(tensor, pad_to_sizes, pad_where=pad_where, mode=pad_mode) + + if new_axes: + tensor = tensor.expand_dims({str(k): v for k, v in new_axes}) + + return tensor + + def tile( + self, + tile_size: PerAxis[int], + halo: PerAxis[HaloLike], + pad_mode: PadMode, + ) -> Tuple[ + TotalNumberOfTiles, + Generator[Tuple[TileNumber, Tensor, PerAxis[SliceInfo]], Any, None], + ]: + """tile this tensor into `tile_size` tiles that overlap by `halo`. + At the tensor's edge the `halo` is padded with `pad_mode`. + + Args: + tile_sizes: (Outer) output tile shape. + halo: padding At the tensor's edge, overlap with neighboring tiles within + the tensor; additional padding at the end of dimensions that do not + evenly divide by the tile shape may result in larger halos for edge + tiles. + pad_mode: How to pad at the tensor's edge. + """ + assert all(a in self.dims for a in tile_size), (self.dims, set(tile_size)) + assert all(a in self.dims for a in halo), (self.dims, set(halo)) + + inner_1d_tiles: List[List[SliceInfo]] = [] + halo = {a: Halo.create(h) for a, h in halo.items()} + for a, s in self.sizes.items(): + stride = tile_size[a] - sum(halo[a]) + tiles_1d = [SliceInfo(p, min(s, p + stride)) for p in range(0, s, stride)] + inner_1d_tiles.append(tiles_1d) + + n_tiles = prod(map(len, inner_1d_tiles)) + + return n_tiles, self._tile_generator( + inner_1d_tiles=inner_1d_tiles, halo=halo, pad_mode=pad_mode + ) + + def transpose( + self, + axes: Sequence[AxisId], + ) -> Self: + """return a transposed tensor + + Args: + axes: the desired tensor axes + """ + # expand the missing image axes + current_axes = tuple( + d if isinstance(d, AxisId) else AxisId(d) for d in tensor.dims + ) + missing_axes = tuple(a for a in axes if a not in current_axes) + tensor = tensor.expand_dims(missing_axes) + # transpose to the correct axis order + return tensor.transpose(*map(str, axes)) + + @classmethod + def _interprete_array_wo_known_axes(cls, array: NDArray[Any], id: TensorId): + ndim = array.ndim + if ndim == 2: + current_axes = ( + v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), + v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), + ) + elif ndim == 3 and any(s <= 3 for s in array.shape): + current_axes = ( + v0_5.ChannelAxis( + channel_names=[ + v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) + ] + ), + v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), + v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), + ) + elif ndim == 3: + current_axes = ( + v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), + v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), + v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), + ) + elif ndim == 4: + current_axes = ( + v0_5.ChannelAxis( + channel_names=[ + v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) + ] + ), + v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), + v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), + v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), + ) + elif ndim == 5: + current_axes = ( + v0_5.BatchAxis(), + v0_5.ChannelAxis( + channel_names=[ + v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) + ] + ), + v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), + v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), + v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), + ) + else: + raise ValueError(f"Could not guess an axis mapping for {array.shape}") + + return cls(array, dims=tuple(a.id for a in current_axes), id=id) + + def _tile_generator( + self, + *, + inner_1d_tiles: List[List[SliceInfo]], + halo: PerAxis[Halo], + pad_mode: PadMode, + ): + for i, nd_tile in enumerate(itertools.product(*inner_1d_tiles)): + inner_slice: PerAxis[SliceInfo] = dict(zip(self.dims, nd_tile)) + outer_slice = { + a: SliceInfo( + max(0, inner.start - halo[a].left), + min(self.sizes[a], inner.stop + halo[a].right), + ) + for a, inner in inner_slice.items() + } + pad_width: PerAxis[PadWidth] = { + a: PadWidth( + max(0, halo[a].left - inner.start), + max(0, inner.stop + halo[a].right - self.sizes[a]), + ) + for a, inner in inner_slice.items() + } + + yield i, self[outer_slice].pad(pad_width, pad_mode), inner_slice diff --git a/bioimageio/core/tile.py b/bioimageio/core/tile.py index 28f942a8..03703ce3 100644 --- a/bioimageio/core/tile.py +++ b/bioimageio/core/tile.py @@ -1,144 +1,113 @@ -import itertools from dataclasses import dataclass, field -from math import prod -from typing import Dict, Iterable, List, Mapping, Tuple, Union, cast -from xarray.core.utils import Frozen +from bioimageio.core.common import TileNumber, TotalNumberOfTiles -from .common import AxisId, Data, LeftRight, SliceInfo, Stat, Tensor, TensorId - -# TensorTilePos = Mapping[AxisId, int] -# TilePos = Mapping[TensorId, TensorTilePos] -TensorTileSlice = Mapping[AxisId, SliceInfo] -TileSlice = Mapping[TensorId, TensorTileSlice] -TensorSampleSize = Mapping[AxisId, int] -SampleSizes = Mapping[TensorId, TensorSampleSize] -TensorTileHalo = Mapping[AxisId, LeftRight] -TileHalo = Mapping[TensorId, TensorTileHalo] +from .axis import PerAxis +from .common import Halo, LeftRight, PadWidth, SliceInfo +from .stat_measures import Stat +from .tensor import PerTensor, Tensor @dataclass -class Tile: - """A tile of a dataset sample""" +class AbstractTile: + """A tile of a dataset sample without any data""" - data: Data - """the tile's tensors""" - - inner_slice: TileSlice + inner_slice: PerTensor[PerAxis[SliceInfo]] """slice of the inner tile (without padding and overlap) of the sample""" - halo: TileHalo + halo: PerTensor[PerAxis[Halo]] """pad/overlap to extend the (inner) tile (to the outer tile)""" - outer_slice: Frozen[TensorId, Frozen[AxisId, SliceInfo]] = field(init=False) + tile_number: TileNumber + """the n-th tile of the sample""" + + tiles_in_sample: TotalNumberOfTiles + """total number of tiles of the sample""" + + sample_sizes: PerTensor[PerAxis[int]] + """the axis sizes of the sample""" + + stat: Stat + """sample and dataset statistics""" + + outer_slice: PerTensor[PerAxis[SliceInfo]] = field(init=False) """slice of the outer tile (including overlap, but not padding) in the sample""" - overlap: Frozen[TensorId, Frozen[AxisId, LeftRight]] = field(init=False) + local_slice: PerTensor[PerAxis[SliceInfo]] = field(init=False) + """slice to extract the inner tile from the outer tile""" + + overlap: PerTensor[PerAxis[LeftRight]] = field(init=False) """overlap 'into a neighboring tile'""" - padding: Frozen[TensorId, Frozen[AxisId, LeftRight]] = field(init=False) + padding: PerTensor[PerAxis[PadWidth]] = field(init=False) """pad (at sample edges where we cannot overlap to realize `halo`""" def __post_init__(self): - self.outer_slice = Frozen( - { - t: Frozen( - { - a: SliceInfo( - max(0, self.inner_slice[t][a].start - self.halo[t][a].left), - min( - self.sample_sizes[t][a], - self.inner_slice[t][a].stop + self.halo[t][a].right, - ), - ) - for a in self.inner_slice[t] - } + self.outer_slice = { + t: { + a: SliceInfo( + max(0, self.inner_slice[t][a].start - self.halo[t][a].left), + min( + self.sample_sizes[t][a], + self.inner_slice[t][a].stop + self.halo[t][a].right, + ), ) - for t in self.inner_slice + for a in self.inner_slice[t] } - ) - self.overlap = Frozen( - { - tid: Frozen( - { - a: LeftRight( - self.inner_slice[tid][a].start - - self.outer_slice[tid][a].start, - self.outer_slice[tid][a].stop - - self.inner_slice[tid][a].stop, - ) - for a in self.inner_slice[tid] - } + for t in self.inner_slice + } + self.local_slice = { + t: { + a: SliceInfo( + self.inner_slice[t][a].start - self.outer_slice[t][a].start, + self.inner_slice[t][a].stop - self.outer_slice[t][a].start, ) - for tid in self.inner_slice + for a in self.inner_slice[t] } - ) - self.padding = Frozen( - { - tid: Frozen( - { - a: LeftRight( - self.halo[tid][a].left - self.overlap[tid][a].left, - self.halo[tid][a].right - self.overlap[tid][a].right, - ) - for a in self.inner_slice[tid] - } + for t in self.inner_slice + } + self.overlap = { + t: { + a: LeftRight( + self.inner_slice[t][a].start - self.outer_slice[t][a].start, + self.outer_slice[t][a].stop - self.inner_slice[t][a].stop, ) - for tid in self.inner_slice + for a in self.inner_slice[t] } - ) - - tile_number: int - """the n-th tile of the sample""" - - tiles_in_sample: int - """total number of tiles of the sample""" - - sample_sizes: SampleSizes - """the axis sizes of the sample""" - - stat: Stat = field(default_factory=dict) - """sample and dataset statistics""" - - -def _tile_generator(tensor: Tensor, all_1d_tiles: List[List[Tuple[int, slice]]]): - axes = cast(Tuple[AxisId, ...], tensor.dims) - for i, tile in enumerate(itertools.product(*all_1d_tiles)): - pos: TensorTilePos = {a: p for a, (p, _) in zip(axes, tile)} - tile_slice = {a: s for a, (_, s) in zip(axes, tile)} - yield i, pos, tensor[tile_slice] - - -def tile_tensor( - tensor: Tensor, - tile_shape: Mapping[AxisId, int], - pad_width: Mapping[AxisId, Union[int, Tuple[int, int]]], -) -> Tuple[int, Iterable[Tuple[int, TensorTilePos, Tensor]]]: - """tile a tensor - - Args: - tile_shape: output tile shape - pad_width: padding at edge of sample, overlap with neighboring tiles within the sample + for t in self.inner_slice + } + self.padding = { + t: { + a: PadWidth( + self.halo[t][a].left - self.overlap[t][a].left, + self.halo[t][a].right - self.overlap[t][a].right, + ) + for a in self.inner_slice[t] + } + for t in self.inner_slice + } - """ - assert all(aid in tensor.dims for aid in tile_shape), (tensor.dims, set(tile_shape)) - assert all(aid in tensor.dims for aid in pad_width), (tensor.dims, set(pad_width)) - assert all(aid in tile_shape for aid in tensor.dims), (tensor.dims, set(tile_shape)) - assert all(aid in pad_width for aid in tensor.dims), (tensor.dims, set(pad_width)) - axes = cast(Tuple[AxisId, ...], tensor.dims) +@dataclass +class Tile(AbstractTile): + """A tile of a dataset sample""" - all_1d_tiles: List[List[Tuple[int, slice]]] = [] - shape = tensor.shape - for aid, s in zip(axes, shape): - pad = _pad if isinstance(_pad := pad_width[aid], tuple) else (_pad, _pad) - stride = tile_shape[aid] - sum(pad) - tiles_1d = [ - (p, slice(max(0, p - pad[0]), min(s, p + pad[1]))) - for p in range(0, s, stride) - ] - all_1d_tiles.append(tiles_1d) + data: PerTensor[Tensor] + """the tile's tensors""" - n_tiles = prod(map(len, all_1d_tiles)) + @property + def inner_data(self): + return {t: self.data[t][self.local_slice[t]] for t in self.data} - return n_tiles, _tile_generator(tensor, all_1d_tiles) + def __post_init__(self): + super().__post_init__() + for t, d in self.data.items(): + assert t == d.id, f"tensor id mismatch: {t} != {d.id}" + for a, s in d.sizes.items(): + slice_ = self.inner_slice[t][a] + halo = self.halo[t][a] + assert s == slice_.stop - slice_.start + halo.left + halo.right, ( + s, + slice_, + halo, + ) diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py index 7f0b892c..d88ea113 100644 --- a/bioimageio/core/utils/_digest_spec.py +++ b/bioimageio/core/utils/_digest_spec.py @@ -1,27 +1,57 @@ -from typing import List +from typing import List, Sequence, get_args -from bioimageio.core.common import Tensor +from bioimageio.core.axis import AxisLetter, AxisLike from bioimageio.spec.model import AnyModelDescr, v0_4 from bioimageio.spec.utils import load_array -from .image_helper import interprete_array +from ..tensor import Tensor, TensorId def get_test_inputs(model: AnyModelDescr) -> List[Tensor]: axes = [d.axes for d in model.inputs] + if isinstance(axes, str): + core_axes: List[Sequence[AxisLike]] = [ + a if a in get_args(AxisLetter) else "i" for a in axes + ] # pyright: ignore[reportAssignmentType] + else: + core_axes = axes # pyright: ignore[reportAssignmentType] + if isinstance(model, v0_4.ModelDescr): arrays = [load_array(tt) for tt in model.test_inputs] else: arrays = [load_array(d.test_tensor) for d in model.inputs] - return [interprete_array(arr, ax) for arr, ax in zip(arrays, axes)] + if isinstance(model, v0_4.ModelDescr): + tensor_ids = [TensorId(ipt.name) for ipt in model.inputs] + else: + tensor_ids = [ipt.id for ipt in model.inputs] + + return [ + Tensor.from_numpy(arr, ax, t) + for arr, ax, t in zip(arrays, core_axes, tensor_ids) + ] def get_test_outputs(model: AnyModelDescr) -> List[Tensor]: axes = [d.axes for d in model.outputs] + if isinstance(axes, str): + core_axes: List[Sequence[AxisLike]] = [ + a if a in get_args(AxisLetter) else "i" for a in axes + ] # pyright: ignore[reportAssignmentType] + else: + core_axes = axes # pyright: ignore[reportAssignmentType] + if isinstance(model, v0_4.ModelDescr): arrays = [load_array(tt) for tt in model.test_outputs] else: arrays = [load_array(d.test_tensor) for d in model.outputs] - return [interprete_array(arr, ax) for arr, ax in zip(arrays, axes)] + if isinstance(model, v0_4.ModelDescr): + tensor_ids = [TensorId(ipt.name) for ipt in model.inputs] + else: + tensor_ids = [ipt.id for ipt in model.inputs] + + return [ + Tensor.from_numpy(arr, ax, t) + for arr, ax, t in zip(arrays, core_axes, tensor_ids) + ] diff --git a/bioimageio/core/utils/image_helper.py b/bioimageio/core/utils/image_helper.py deleted file mode 100644 index b3e23320..00000000 --- a/bioimageio/core/utils/image_helper.py +++ /dev/null @@ -1,188 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, Literal, Optional, Sequence, Tuple, Union - -import imageio -import numpy as np -from numpy.typing import NDArray - -from bioimageio.core.common import Axis, AxisLike, Tensor -from bioimageio.spec.model import v0_4 -from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr04 -from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensorDescr04 -from bioimageio.spec.model.v0_5 import ( - AnyAxis, - AxisId, - BatchAxis, - ChannelAxis, - Identifier, - InputTensorDescr, - OutputTensorDescr, - SizeReference, - SpaceInputAxis, - convert_axes, -) -from bioimageio.spec.utils import load_array - -InputTensor = Union[InputTensorDescr04, InputTensorDescr] -OutputTensor = Union[OutputTensorDescr04, OutputTensorDescr] - - -def normalize_axes( - axes: Union[v0_4.AxesStr, Sequence[Union[AnyAxis, AxisLike]]] -) -> Tuple[Axis, ...]: - AXIS_TYPE_MAP: Dict[str, Literal["batch", "time", "index", "channel", "space"]] = { - "b": "batch", - "t": "time", - "i": "index", - "c": "channel", - "x": "space", - "y": "space", - "z": "space", - } - AXIS_ID_MAP = { - "b": "batch", - "t": "time", - "i": "index", - "c": "channel", - } - if isinstance(axes, str): - return tuple( - Axis(id=AxisId(AXIS_ID_MAP.get(a, a)), type=AXIS_TYPE_MAP[a]) for a in axes - ) - else: - return tuple(Axis(id=AxisId(a.id), type=a.type) for a in axes) - - -def _interprete_array_wo_known_axes(array: NDArray[Any]): - ndim = array.ndim - if ndim == 2: - current_axes = ( - SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), - SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), - ) - elif ndim == 3 and any(s <= 3 for s in array.shape): - current_axes = ( - ChannelAxis( - channel_names=[Identifier(f"channel{i}") for i in range(array.shape[0])] - ), - SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), - SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), - ) - elif ndim == 3: - current_axes = ( - SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), - SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), - SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), - ) - elif ndim == 4: - current_axes = ( - ChannelAxis( - channel_names=[Identifier(f"channel{i}") for i in range(array.shape[0])] - ), - SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), - SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), - SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), - ) - elif ndim == 5: - current_axes = ( - BatchAxis(), - ChannelAxis( - channel_names=[Identifier(f"channel{i}") for i in range(array.shape[1])] - ), - SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), - SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), - SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), - ) - else: - raise ValueError(f"Could not guess an axis mapping for {array.shape}") - - return Tensor(array, dims=tuple(a.id for a in current_axes)) - - -def interprete_array( - array: NDArray[Any], - axes: Optional[Union[v0_4.AxesStr, Sequence[AnyAxis]]], -) -> Tensor: - if axes is None: - return _interprete_array_wo_known_axes(array) - - original_shape = tuple(array.shape) - if len(array.shape) > len(axes): - # remove singletons - for i, s in enumerate(array.shape): - if s == 1: - array = np.take(array, 0, axis=i) - if len(array.shape) == len(axes): - break - - # add singletons if nececsary - for a in axes: - if len(array.shape) >= len(axes): - break - - if isinstance(a, str) or a.size is None: - array = array[None] - continue - - if isinstance(a.size, int): - if a.size == 1: - array = array[None] - - continue - - if isinstance(a.size, SizeReference): - continue # TODO: check if singleton is ok for a `SizeReference` - - try: - maybe_size_one = a.size.validate_size( - 1 - ) # TODO: refactor validate_size() to have boolean func here - except ValueError: - continue - - if maybe_size_one == 1: - array = array[None] - - if len(array.shape) != len(axes): - raise ValueError(f"Array shape {original_shape} does not map to axes {axes}") - - normalized_axes = normalize_axes(axes) - assert len(normalized_axes) == len(axes) - return Tensor(array, dims=tuple(a.id for a in normalized_axes)) - - -def transpose_tensor( - tensor: Tensor, - axes: Sequence[AxisId], -) -> Tensor: - """Transpose `array` to `axes` order. - - Args: - tensor: the input array - axes: the desired array axes - """ - # expand the missing image axes - current_axes = tuple(d if isinstance(d, AxisId) else AxisId(d) for d in tensor.dims) - missing_axes = tuple(a for a in axes if a not in current_axes) - tensor = tensor.expand_dims(missing_axes) - # transpose to the correct axis order - return tensor.transpose(*map(str, axes)) - - -def convert_v0_4_axes_for_known_shape(axes: v0_4.AxesStr, shape: Sequence[int]): - return convert_axes(axes, shape=shape, tensor_type="input", halo=None, size_refs={}) - - -def load_tensor( - path: Path, - axes: Optional[Sequence[AnyAxis]] = None, -) -> Tensor: - - ext = path.suffix - if ext == ".npy": - array = load_array(path) - else: - is_volume = True if axes is None else sum(a.type != "channel" for a in axes) > 2 - array = imageio.volread(path) if is_volume else imageio.imread(path) - - return interprete_array(array, axes) diff --git a/bioimageio/core/utils/testing.py b/bioimageio/core/utils/testing.py index 2659a2e7..acd65d95 100644 --- a/bioimageio/core/utils/testing.py +++ b/bioimageio/core/utils/testing.py @@ -1,3 +1,4 @@ +# TODO: move to tests/ from functools import wraps from typing import Any, Protocol, Type diff --git a/bioimageio/core/utils/tiling.py b/bioimageio/core/utils/tiling.py deleted file mode 100644 index fb89a2d2..00000000 --- a/bioimageio/core/utils/tiling.py +++ /dev/null @@ -1,150 +0,0 @@ -import warnings -from typing import Dict, Literal, Mapping, Union - -from typing_extensions import assert_never - -from bioimageio.spec.model.v0_5 import AxisId - -from ..common import LeftRight, Tensor - -PadMode = Literal["edge", "reflect", "symmetric"] - - -def pad( - tensor: Tensor, - pad_width: Mapping[AxisId, Union[int, LeftRight]], - mode: PadMode = "symmetric", -): - return tensor.pad(pad_width={str(k): v for k, v in pad_width.items()}, mode=mode) - - -def pad_to( - tensor: Tensor, - sizes: Mapping[AxisId, int], - pad_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", - mode: PadMode = "symmetric", -): - """pad `tensor` to match `sizes`""" - axes = [AxisId(str(a)) for a in tensor.dims] - if pad_where in ("before", "center", "after"): - pad_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { - a: pad_where for a in axes - } - else: - pad_axis_where = pad_where - - pad_width: Dict[AxisId, Union[int, LeftRight]] = {} - for a, s_is in tensor.sizes.items(): - a = AxisId(str(a)) - if a not in sizes or sizes[a] == s_is: - pad_width[a] = 0 - elif s_is > sizes[a]: - pad_width[a] = 0 - warnings.warn( - f"Cannot pad axis {a} of size {s_is} to smaller size {sizes[a]}" - ) - elif a not in pad_axis_where: - raise ValueError( - f"Don't know where to pad axis {a}, `pad_where`={pad_where}" - ) - else: - pad_this_axis_where = pad_axis_where[a] - p = sizes[a] - s_is - if pad_this_axis_where == "before": - pad_width[a] = LeftRight(p, 0) - elif pad_this_axis_where == "after": - pad_width[a] = LeftRight(0, p) - elif pad_this_axis_where == "center": - pad_width[a] = LeftRight(left := p // 2, p - left) - else: - assert_never(pad_this_axis_where) - - return pad(tensor, pad_width, mode) - - -def crop_to( - tensor: Tensor, - sizes: Mapping[AxisId, int], - crop_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", -): - """crop `tensor` to match `sizes`""" - axes = [AxisId(str(a)) for a in tensor.dims] - if crop_where in ("before", "center", "after"): - crop_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { - a: crop_where for a in axes - } - else: - crop_axis_where = crop_where - - slices: Dict[AxisId, slice] = {} - - for a, s_is in tensor.sizes.items(): - a = AxisId(str(a)) - if a not in sizes or sizes[a] == s_is: - pass - elif sizes[a] > s_is: - warnings.warn( - f"Cannot crop axis {a} of size {s_is} to larger size {sizes[a]}" - ) - elif a not in crop_axis_where: - raise ValueError( - f"Don't know where to crop axis {a}, `crop_where`={crop_where}" - ) - else: - crop_this_axis_where = crop_axis_where[a] - if crop_this_axis_where == "before": - slices[a] = slice(s_is - sizes[a], s_is) - elif crop_this_axis_where == "after": - slices[a] = slice(0, sizes[a]) - elif crop_this_axis_where == "center": - slices[a] = slice(start := (s_is - sizes[a]) // 2, sizes[a] + start) - else: - assert_never(crop_this_axis_where) - - return tensor.isel({str(a): s for a, s in slices.items()}) - - -def resize_to( - tensor: Tensor, - sizes: Mapping[AxisId, int], - *, - pad_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", - crop_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", - pad_mode: PadMode = "symmetric", -): - """crop and pad `tensor` to match `sizes`""" - crop_to_sizes: Dict[AxisId, int] = {} - pad_to_sizes: Dict[AxisId, int] = {} - new_axes = dict(sizes) - for a, s_is in tensor.sizes.items(): - a = AxisId(str(a)) - _ = new_axes.pop(a, None) - if a not in sizes or sizes[a] == s_is: - pass - elif s_is > sizes[a]: - crop_to_sizes[a] = sizes[a] - else: - pad_to_sizes[a] = sizes[a] - - if crop_to_sizes: - tensor = crop_to(tensor, crop_to_sizes, crop_where=crop_where) - - if pad_to_sizes: - tensor = pad_to(tensor, pad_to_sizes, pad_where=pad_where, mode=pad_mode) - - if new_axes: - tensor = tensor.expand_dims({str(k): v for k, v in new_axes}) - - return tensor diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py index 6c43df8d..e8bfa427 100644 --- a/tests/test_proc_ops.py +++ b/tests/test_proc_ops.py @@ -5,10 +5,11 @@ import xarray as xr from typing_extensions import TypeGuard -from bioimageio.core.common import AxisId, TensorId +from bioimageio.core.axis import AxisId from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import compute_measures from bioimageio.core.stat_measures import SampleMean, SamplePercentile, SampleStd +from bioimageio.core.Tensor import TensorId @pytest.fixture(scope="module") diff --git a/tests/test_stat_calculators.py b/tests/test_stat_calculators.py index 68507acb..0e023ba4 100644 --- a/tests/test_stat_calculators.py +++ b/tests/test_stat_calculators.py @@ -4,7 +4,7 @@ import pytest from xarray.testing import assert_allclose # pyright: ignore[reportUnknownVariableType] -from bioimageio.core.common import AxisId, Tensor, TensorId +from bioimageio.core.axis import AxisId from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import MeanVarStdCalculator from bioimageio.core.stat_measures import ( @@ -12,6 +12,7 @@ DatasetStd, DatasetVar, ) +from bioimageio.core.Tensor import Tensor, TensorId def create_random_dataset(tid: TensorId, axes: Tuple[str, ...], n: int = 3): diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index 762c27d7..efddd03f 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -6,13 +6,14 @@ import xarray as xr from bioimageio.core import stat_measures -from bioimageio.core.common import AxisId, Tensor, TensorId +from bioimageio.core.axis import AxisId from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import ( SamplePercentilesCalculator, get_measure_calculators, ) from bioimageio.core.stat_measures import SamplePercentile +from bioimageio.core.Tensor import Tensor, TensorId @pytest.mark.parametrize( diff --git a/tests/utils/test_image_helper.py b/tests/utils/test_image_helper.py index ea3b4f24..96176f88 100644 --- a/tests/utils/test_image_helper.py +++ b/tests/utils/test_image_helper.py @@ -3,8 +3,8 @@ import xarray as xr from xarray.testing import assert_equal # pyright: ignore[reportUnknownVariableType] -from bioimageio.core.common import AxisId -from bioimageio.core.utils.image_helper import ( +from bioimageio.core.axis import AxisId +from bioimageio.core.io import ( interprete_array, transpose_tensor, ) From 4fa24ce308261b5a3f14dc238fc6c66e79fc0430 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 29 Mar 2024 23:50:21 +0100 Subject: [PATCH 171/244] WIP use Tensor --- bioimageio/core/__init__.py | 8 +- bioimageio/core/_magic_tensor_ops.py | 235 +++++++++++++++ bioimageio/core/_prediction_pipeline.py | 24 +- bioimageio/core/_resource_tests.py | 10 +- bioimageio/core/common.py | 33 +- bioimageio/core/io.py | 9 +- .../core/model_adapters/_model_adapter.py | 3 +- bioimageio/core/proc_ops.py | 49 ++- bioimageio/core/sample.py | 132 ++++---- bioimageio/core/stat_calculators.py | 32 +- bioimageio/core/stat_measures.py | 33 +- bioimageio/core/tensor.py | 284 +++++++++++++----- bioimageio/core/tile.py | 6 +- bioimageio/core/utils/_digest_spec.py | 4 +- tests/test_tensor.py | 41 +++ tests/utils/test_image_helper.py | 52 ---- 16 files changed, 667 insertions(+), 288 deletions(-) create mode 100644 bioimageio/core/_magic_tensor_ops.py create mode 100644 tests/test_tensor.py delete mode 100644 tests/utils/test_image_helper.py diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index 2692cc70..c3cb1db6 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -23,11 +23,13 @@ from ._resource_tests import test_description as test_description from ._resource_tests import test_model as test_model from ._settings import settings as settings -from .utils import VERSION +from .axis import Axis as Axis +from .axis import AxisId as AxisId +from .sample import Sample as Sample from .tensor import Tensor as Tensor +from .tensor import TensorId as TensorId from .tile import Tile as Tile -from .sample import Sample as Sample - +from .utils import VERSION __version__ = VERSION diff --git a/bioimageio/core/_magic_tensor_ops.py b/bioimageio/core/_magic_tensor_ops.py new file mode 100644 index 00000000..c1526fef --- /dev/null +++ b/bioimageio/core/_magic_tensor_ops.py @@ -0,0 +1,235 @@ +# this file was modified from the generated +# https://github.com/pydata/xarray/blob/cf3655968b8b12cc0ecd28fb324e63fb94d5e7e2/xarray/core/_typed_ops.py +# TODO: should we generate this ourselves? +# TODO: test these magic methods +import operator +from typing import Any, Callable + +from typing_extensions import Self +from xarray.core import nputils, ops + + +class MagicTensorOpsMixin: + __slots__ = () + _Compatible = Any + + def _binary_op( + self, + other: _Compatible, + f: Callable[[Any, Any], Any], + reflexive: bool = False, + ) -> Self: + raise NotImplementedError + + def __add__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.add) + + def __sub__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.sub) + + def __mul__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.mul) + + def __pow__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.pow) + + def __truediv__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.truediv) + + def __floordiv__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.floordiv) + + def __mod__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.mod) + + def __and__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.and_) + + def __xor__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.xor) + + def __or__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.or_) + + def __lshift__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.lshift) + + def __rshift__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.rshift) + + def __lt__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.lt) + + def __le__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.le) + + def __gt__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.gt) + + def __ge__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.ge) + + def __eq__(self, other: _Compatible) -> Self: # type: ignore[override] + return self._binary_op( + other, nputils.array_eq # pyright: ignore[reportUnknownArgumentType] + ) + + def __ne__(self, other: _Compatible) -> Self: # type: ignore[override] + return self._binary_op( + other, nputils.array_ne # pyright: ignore[reportUnknownArgumentType] + ) + + # When __eq__ is defined but __hash__ is not, then an object is unhashable, + # and it should be declared as follows: + __hash__: None # type:ignore[assignment] + + def __radd__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.add, reflexive=True) + + def __rsub__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.sub, reflexive=True) + + def __rmul__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.mul, reflexive=True) + + def __rpow__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.pow, reflexive=True) + + def __rtruediv__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.truediv, reflexive=True) + + def __rfloordiv__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.floordiv, reflexive=True) + + def __rmod__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.mod, reflexive=True) + + def __rand__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.and_, reflexive=True) + + def __rxor__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.xor, reflexive=True) + + def __ror__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.or_, reflexive=True) + + def _inplace_binary_op( + self, other: _Compatible, f: Callable[[Any, Any], Any] + ) -> Self: + raise NotImplementedError + + def __iadd__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.iadd) + + def __isub__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.isub) + + def __imul__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.imul) + + def __ipow__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ipow) + + def __itruediv__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.itruediv) + + def __ifloordiv__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ifloordiv) + + def __imod__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.imod) + + def __iand__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.iand) + + def __ixor__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ixor) + + def __ior__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ior) + + def __ilshift__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ilshift) + + def __irshift__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.irshift) + + def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: + raise NotImplementedError + + def __neg__(self) -> Self: + return self._unary_op(operator.neg) + + def __pos__(self) -> Self: + return self._unary_op(operator.pos) + + def __abs__(self) -> Self: + return self._unary_op(operator.abs) + + def __invert__(self) -> Self: + return self._unary_op(operator.invert) + + def round(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op( + ops.round_, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ) + + def argsort(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op( + ops.argsort, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ) + + def conj(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op( + ops.conj, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ) + + def conjugate(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op( + ops.conjugate, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ) + + __add__.__doc__ = operator.add.__doc__ + __sub__.__doc__ = operator.sub.__doc__ + __mul__.__doc__ = operator.mul.__doc__ + __pow__.__doc__ = operator.pow.__doc__ + __truediv__.__doc__ = operator.truediv.__doc__ + __floordiv__.__doc__ = operator.floordiv.__doc__ + __mod__.__doc__ = operator.mod.__doc__ + __and__.__doc__ = operator.and_.__doc__ + __xor__.__doc__ = operator.xor.__doc__ + __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ + __lt__.__doc__ = operator.lt.__doc__ + __le__.__doc__ = operator.le.__doc__ + __gt__.__doc__ = operator.gt.__doc__ + __ge__.__doc__ = operator.ge.__doc__ + __eq__.__doc__ = nputils.array_eq.__doc__ + __ne__.__doc__ = nputils.array_ne.__doc__ + __radd__.__doc__ = operator.add.__doc__ + __rsub__.__doc__ = operator.sub.__doc__ + __rmul__.__doc__ = operator.mul.__doc__ + __rpow__.__doc__ = operator.pow.__doc__ + __rtruediv__.__doc__ = operator.truediv.__doc__ + __rfloordiv__.__doc__ = operator.floordiv.__doc__ + __rmod__.__doc__ = operator.mod.__doc__ + __rand__.__doc__ = operator.and_.__doc__ + __rxor__.__doc__ = operator.xor.__doc__ + __ror__.__doc__ = operator.or_.__doc__ + __iadd__.__doc__ = operator.iadd.__doc__ + __isub__.__doc__ = operator.isub.__doc__ + __imul__.__doc__ = operator.imul.__doc__ + __ipow__.__doc__ = operator.ipow.__doc__ + __itruediv__.__doc__ = operator.itruediv.__doc__ + __ifloordiv__.__doc__ = operator.ifloordiv.__doc__ + __imod__.__doc__ = operator.imod.__doc__ + __iand__.__doc__ = operator.iand.__doc__ + __ixor__.__doc__ = operator.ixor.__doc__ + __ior__.__doc__ = operator.ior.__doc__ + __ilshift__.__doc__ = operator.ilshift.__doc__ + __irshift__.__doc__ = operator.irshift.__doc__ + __neg__.__doc__ = operator.neg.__doc__ + __pos__.__doc__ = operator.pos.__doc__ + __abs__.__doc__ = operator.abs.__doc__ + __invert__.__doc__ = operator.invert.__doc__ diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 33c83eb3..56a75407 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -1,17 +1,18 @@ import warnings from types import MappingProxyType -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union - -from bioimageio.core.model_adapters import ModelAdapter, create_model_adapter -from bioimageio.core.model_adapters import get_weight_formats as get_weight_formats -from bioimageio.core.proc_ops import Processing -from bioimageio.core.proc_setup import setup_pre_and_postprocessing -from bioimageio.core.sample import Sample -from bioimageio.core.stat_measures import DatasetMeasure, MeasureValue -from bioimageio.core.Tensor import Tensor, TensorId +from typing import Any, Iterable, List, Mapping, Optional, Sequence, Union + from bioimageio.spec.model import AnyModelDescr, v0_4 from bioimageio.spec.model.v0_5 import WeightsFormat +from .model_adapters import ModelAdapter, create_model_adapter +from .model_adapters import get_weight_formats as get_weight_formats +from .proc_ops import Processing +from .proc_setup import setup_pre_and_postprocessing +from .sample import Sample +from .stat_measures import DatasetMeasure, MeasureValue +from .tensor import PerTensor, Tensor, TensorId + class PredictionPipeline: """ @@ -64,8 +65,7 @@ def predict( ) -> List[Optional[Tensor]]: """Predict input_tensor with the model without applying pre/postprocessing.""" named_tensors = [ - named_input_tensors.get(str(k)) - for k in self.input_ids[len(input_tensors) :] + named_input_tensors.get(k) for k in self.input_ids[len(input_tensors) :] ] return self._adapter.forward(*input_tensors, *named_tensors) @@ -99,7 +99,7 @@ def forward_sample(self, input_sample: Sample) -> Sample: def forward_tensors( self, *input_tensors: Optional[Tensor], **named_input_tensors: Optional[Tensor] - ) -> Dict[TensorId, Tensor]: + ) -> PerTensor[Tensor]: """Apply preprocessing, run prediction and apply postprocessing.""" assert all(TensorId(k) in self.input_ids for k in named_input_tensors) input_sample = Sample( diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index 82bd316b..d14e836e 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -1,13 +1,12 @@ import traceback import warnings -from typing import Dict, Hashable, List, Literal, Optional, Sequence, Set, Tuple, Union +from typing import Dict, Hashable, List, Literal, Optional, Set, Tuple, Union import numpy as np from bioimageio.core._prediction_pipeline import create_prediction_pipeline from bioimageio.core.axis import AxisId, BatchSize from bioimageio.core.utils import VERSION, get_test_inputs, get_test_outputs -from bioimageio.core.utils.tiling import resize_to from bioimageio.spec import ( InvalidDescr, ResourceDescr, @@ -135,7 +134,9 @@ def _test_model_inference( error = "Output tensors for test case may not be None" break try: - np.testing.assert_array_almost_equal(res, exp, decimal=decimal) + np.testing.assert_array_almost_equal( + res.data, exp.data, decimal=decimal + ) except AssertionError as e: error = f"Output and expected output disagree:\n {e}" break @@ -217,8 +218,7 @@ def get_ns(n: int): tested.add(hashable_target_size) resized_test_inputs = [ - resize_to( - t, + t.resize_to( { aid: s for (tid, aid), s in input_target_sizes.items() diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 1c94c77f..5542e897 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -19,16 +19,16 @@ ] -LeftRight_T = TypeVar("LeftRight_T", bound="LeftRight") -LeftRightLike = Union[int, Tuple[int, int], LeftRight_T] +_LeftRight_T = TypeVar("_LeftRight_T", bound="_LeftRight") +_LeftRightLike = Union[int, Tuple[int, int], _LeftRight_T] -class LeftRight(NamedTuple): +class _LeftRight(NamedTuple): left: int right: int @classmethod - def create(cls, like: LeftRightLike[Self]) -> Self: + def create(cls, like: _LeftRightLike[Self]) -> Self: if isinstance(like, cls): return like elif isinstance(like, tuple): @@ -39,20 +39,35 @@ def create(cls, like: LeftRightLike[Self]) -> Self: assert_never(like) -class Halo(LeftRight): +_Where = Literal["left", "right", "left_and_right"] + + +class CropWidth(_LeftRight): + pass + + +CropWidthLike = _LeftRightLike[CropWidth] +CropWhere = _Where + + +class Halo(_LeftRight): pass -HaloLike = LeftRightLike[Halo] +HaloLike = _LeftRightLike[Halo] + + +class OverlapWidth(_LeftRight): + pass -class PadWidth(LeftRight): +class PadWidth(_LeftRight): pass -PadWidthLike = LeftRightLike[PadWidth] +PadWidthLike = _LeftRightLike[PadWidth] PadMode = Literal["edge", "reflect", "symmetric"] -PadWhere = Literal["before", "center", "after"] +PadWhere = _Where class SliceInfo(NamedTuple): diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index 557e61bb..9bcab722 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -1,12 +1,9 @@ from pathlib import Path -from typing import Optional, Sequence, Union +from typing import Optional, Sequence import imageio from bioimageio.core.axis import Axis, AxisLike -from bioimageio.spec.model import v0_5 -from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr04 -from bioimageio.spec.model.v0_4 import OutputTensorDescr as OutputTensorDescr04 from bioimageio.spec.utils import load_array from .tensor import Tensor, TensorId @@ -27,4 +24,6 @@ def load_tensor( ) array = imageio.volread(path) if is_volume else imageio.imread(path) - return Tensor.from_numpy(array, axes, id=TensorId(path.stem) if id is None else id) + return Tensor.from_numpy( + array, dims=axes, id=TensorId(path.stem) if id is None else id + ) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index 3560d61d..633ee342 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod from typing import List, Optional, Sequence, Tuple, Union, final -from bioimageio.core.Tensor import Tensor from bioimageio.spec.model import v0_4, v0_5 +from ..tensor import Tensor + WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat] # Known weight formats in order of priority diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 8523f991..984d53e8 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -14,16 +14,16 @@ import numpy as np import xarray as xr -from numpy.typing import DTypeLike from typing_extensions import Self, assert_never -from bioimageio.core._op_base import Operator -from bioimageio.core.axis import ( - AxisId, -) -from bioimageio.core.sample import Sample -from bioimageio.core.stat_calculators import StatsCalculator -from bioimageio.core.stat_measures import ( +from bioimageio.core.common import DTypeStr +from bioimageio.spec.model import v0_4, v0_5 + +from ._op_base import Operator +from .axis import AxisId +from .sample import Sample +from .stat_calculators import StatsCalculator +from .stat_measures import ( DatasetMean, DatasetMeasure, DatasetPercentile, @@ -37,8 +37,7 @@ Stat, StdMeasure, ) -from bioimageio.core.Tensor import Tensor, TensorId -from bioimageio.spec.model import v0_4, v0_5 +from .tensor import Tensor, TensorId def convert_axis_ids( @@ -169,7 +168,7 @@ class Binarize(_SimpleOperator): threshold: Union[float, Sequence[float]] axis: Optional[AxisId] = None - def _apply(self, input: Tensor, stat: Stat) -> xr.DataArray: + def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input > self.threshold @classmethod @@ -221,18 +220,14 @@ def from_proc_descr( @dataclass class EnsureDtype(_SimpleOperator): - dtype: DTypeLike + dtype: DTypeStr @classmethod def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, tensor_id: TensorId): return cls(input=tensor_id, output=tensor_id, dtype=descr.kwargs.dtype) def get_descr(self): - return v0_5.EnsureDtypeDescr( - kwargs=v0_5.EnsureDtypeKwargs( - dtype=str(self.dtype) # pyright: ignore[reportArgumentType] - ) - ) + return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input.astype(self.dtype) @@ -377,17 +372,17 @@ def __post_init__( ): if lower_percentile is None: tid = self.input if upper_percentile is None else upper_percentile.tensor_id - self.lower = DatasetPercentile(n=0, tensor_id=tid) + self.lower = DatasetPercentile(q=0.0, tensor_id=tid) else: self.lower = lower_percentile if upper_percentile is None: - self.upper = DatasetPercentile(n=100, tensor_id=self.lower.tensor_id) + self.upper = DatasetPercentile(q=1.0, tensor_id=self.lower.tensor_id) else: self.upper = upper_percentile assert self.lower.tensor_id == self.upper.tensor_id - assert self.lower.n < self.upper.n + assert self.lower.q < self.upper.q assert self.lower.axes == self.upper.axes @property @@ -416,14 +411,14 @@ def from_proc_descr( input=tensor_id, output=tensor_id, lower_percentile=Percentile( - n=kwargs.min_percentile, axes=axes, tensor_id=ref_tensor + q=kwargs.min_percentile / 100, axes=axes, tensor_id=ref_tensor ), upper_percentile=Percentile( - n=kwargs.max_percentile, axes=axes, tensor_id=ref_tensor + q=kwargs.max_percentile / 100, axes=axes, tensor_id=ref_tensor ), ) - def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: + def _apply(self, input: Tensor, stat: Stat) -> Tensor: lower = stat[self.lower] upper = stat[self.upper] return (input - lower) / (upper - lower + self.eps) @@ -435,8 +430,8 @@ def get_descr(self): return v0_5.ScaleRangeDescr( kwargs=v0_5.ScaleRangeKwargs( axes=self.lower.axes, - min_percentile=self.lower.n, - max_percentile=self.upper.n, + min_percentile=self.lower.q * 100, + max_percentile=self.upper.q * 100, eps=self.eps, reference_tensor=self.lower.tensor_id, ) @@ -503,7 +498,7 @@ def from_proc_descr( std=Std(axes=axes, tensor_id=tensor_id), ) - def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: + def _apply(self, input: Tensor, stat: Stat) -> Tensor: mean = stat[self.mean] std = stat[self.std] return (input - mean) / (std + self.eps) @@ -565,7 +560,7 @@ def get_descr(self): return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) - def _apply(self, input: xr.DataArray, stat: Stat) -> xr.DataArray: + def _apply(self, input: Tensor, stat: Stat) -> Tensor: return (input - self.mean) / (self.std + self.eps) diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 82f2fc75..aed8b633 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -1,15 +1,16 @@ from dataclasses import dataclass, field -from typing import Dict, Iterable, Iterator, Mapping, Optional, Tuple, Union, cast +from pprint import pformat +from typing import Dict, Iterable, Iterator, Optional, Tuple, cast import numpy +import xarray as xr from typing_extensions import Self -from xarray.core.utils import Frozen from .axis import AxisId, PerAxis -from .common import Halo, HaloLike, PadMode, PadWidth, SliceInfo, TileNumber +from .common import Halo, HaloLike, PadMode, SliceInfo, TileNumber from .stat_measures import Stat from .tensor import PerTensor, Tensor, TensorId -from .tile import Tile, tile_tensor +from .tile import Tile TiledSample = Iterable[Tile] """A dataset sample split into tiles""" @@ -19,7 +20,7 @@ class Sample: """A dataset sample""" - data: PerTensor[Tensor] + data: Dict[TensorId, Tensor] """the sample's tensors""" stat: Stat = field(default_factory=dict) @@ -32,79 +33,90 @@ def sizes(self) -> PerTensor[PerAxis[int]]: def tile( self, tile_sizes: PerTensor[PerAxis[int]], - minimum_halo: PerTensor[PerAxis[HaloLike]], + halo: PerTensor[PerAxis[HaloLike]], + pad_mode: PadMode, ) -> TiledSample: assert not ( missing := [t for t in tile_sizes if t not in self.data] ), f"`tile_sizes` specified for missing tensors: {missing}" assert not ( - missing := [t for t in minimum_halo if t not in tile_sizes] - ), f"`minimum_halo` specified for tensors without `tile_sizes`: {missing}" - - tensor_ids = list(tile_sizes) + missing := [t for t in halo if t not in tile_sizes] + ), f"`halo` specified for tensors without `tile_sizes`: {missing}" + + # any axis not given in `tile_sizes` is treated + # as tile size equal to the tensor axis' size + explicit_tile_sizes = { + t: {a: tile_sizes.get(t, {}).get(a, s) for a, s in tdata.sizes.items()} + for t, tdata in self.data.items() + } + + tensor_ids = tuple(self.data) + broadcasted_tensors = { + t: Tensor.from_xarray(d) + for t, d in zip( + tensor_ids, xr.broadcast(*(self.data[tt].data for tt in tensor_ids)) + ) + } - tensor_tile_generators: Dict[ - TensorId, Iterable[Tuple[TileNumber, Tensor, PerAxis[SliceInfo]]] + tile_iterators: Dict[ + TensorId, Iterator[Tuple[TileNumber, Tensor, PerAxis[SliceInfo]]] ] = {} - n_tiles: Dict[TensorId, int] = {} + + n_tiles_common = 1 + last_non_trivial: Optional[TensorId] = None for t in tensor_ids: - n_tiles[t], tensor_tile_generators[t] = tile_tensor( - self.data[t], - tile_sizes=tile_sizes.get(t, self.data[t].sizes), - minimum_halo=minimum_halo.get(t, {a: 0 for a in self.data[t].dims}), + n_tiles, generator = broadcasted_tensors[t].tile( + tile_size=explicit_tile_sizes[t], + halo=halo.get(t, {}), pad_mode=pad_mode, ) - - n_tiles_common: Optional[int] = None - single_tile_tensors: Dict[TensorId, Tuple[TensorTilePos, Tensor]] = {} - tile_iterators: Dict[TensorId, Iterator[Tuple[int, TensorTilePos, Tensor]]] = {} - for t, n in n_tiles.items(): - tile_iterator = iter(tensor_tile_generators[t]) - if n == 1: - t0, pos, tensor_tile = next(tile_iterator) - assert t0 == 0 - single_tile_tensors[t] = (pos, tensor_tile) - continue - - if n_tiles_common is None: - n_tiles_common = n - elif n != n_tiles_common: + tile_iterators[t] = iter(generator) + if n_tiles in (1, n_tiles_common): + pass + elif n_tiles_common == 1: + last_non_trivial = t + n_tiles_common = n_tiles + else: + assert last_non_trivial is not None + mismatch = { + last_non_trivial: { + "original sizes": self.data[last_non_trivial].sizes, + "broadcasted sizes": broadcasted_tensors[ + last_non_trivial + ].sizes, + "n_tiles": n_tiles_common, + }, + t: { + "original sizes": self.data[t].sizes, + "broadcasted sizes": broadcasted_tensors[t].sizes, + "n_tiles": n_tiles, + }, + } raise ValueError( - f"{self} tiled by {tile_sizes} yields different numbers of tiles: {n_tiles}" + f"broadcasted tensors {last_non_trivial, t} do not tile to the same" + + f" number of tiles {n_tiles_common, n_tiles}. Details\n" + + pformat(mismatch) ) - tile_iterators[t] = tile_iterator - - if n_tiles_common is None: - assert not tile_iterators - n_tiles_common = 1 - - for t in range(n_tiles_common): + for i in range(n_tiles_common): data: Dict[TensorId, Tensor] = {} - tile_pos: TilePos = {} - inner_slice: TileSlice = {} - outer_slice: TileSlice = {} - for t, (tensor_tile, tensor_pos) in single_tile_tensors.items(): - data[t] = tensor_tile - tile_pos[t] = tensor_pos - inner_slice[t] = inner_tensor_slice - outer_slice[t] = outer_tensor_slice - - for t, tile_iterator in tile_iterators.items(): - assert t not in data - assert t not in tile_pos - _t, tensor_pos, tensor_tile = next(tile_iterator) - assert _t == t, (_t, t) + inner_slice: Dict[TensorId, PerAxis[SliceInfo]] = {} + for t, iterator in tile_iterators.items(): + tn, tensor_tile, tensor_slice = next(iterator) + assert tn == i, f"expected tile number {i}, but got {tn}" data[t] = tensor_tile - tile_pos[t] = tensor_pos + inner_slice[t] = tensor_slice yield Tile( data=data, - pos=tile_pos, inner_slice=inner_slice, - outer_slice=outer_slice, - tile_number=t, - tiles_in_self=n_tiles_common, + halo={ + t: {a: Halo.create(h) for a, h in th.items()} + for t, th in halo.items() + }, + sample_sizes=self.sizes, + tile_number=i, + tiles_in_sample=n_tiles_common, stat=self.stat, ) @@ -113,7 +125,7 @@ def from_tiles( cls, tiles: Iterable[Tile], *, fill_value: float = float("nan") ) -> Self: # TODO: add `mode: Literal['in-memory', 'to-disk']` or similar to save out of mem samples - data: TileData = {} + data: PerTensor[Tensor] = {} stat: Stat = {} for tile in tiles: for t, tile_data in tile.inner_data.items(): diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 851dfba6..380c41b1 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -26,11 +26,9 @@ from numpy.typing import NDArray from typing_extensions import assert_never -from bioimageio.core.axis import ( - AxisId, -) -from bioimageio.core.sample import Sample -from bioimageio.core.stat_measures import ( +from .axis import AxisId +from .sample import Sample +from .stat_measures import ( DatasetMean, DatasetMeasure, DatasetMeasureBase, @@ -45,7 +43,7 @@ SampleStd, SampleVar, ) -from bioimageio.core.Tensor import TensorId +from .tensor import Tensor, TensorId try: import crick @@ -70,7 +68,7 @@ class MeanCalculator: def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): super().__init__() self._n: int = 0 - self._mean: Optional[xr.DataArray] = None + self._mean: Optional[Tensor] = None self._axes = None if axes is None else tuple(axes) self._tensor_id = tensor_id self._sample_mean = SampleMean(tensor_id=self._tensor_id, axes=self._axes) @@ -79,8 +77,8 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: return {self._sample_mean: self._compute_impl(sample)} - def _compute_impl(self, sample: Sample) -> xr.DataArray: - tensor = sample.data[self._tensor_id].astype(np.float64, copy=False) + def _compute_impl(self, sample: Sample) -> Tensor: + tensor = sample.data[self._tensor_id].astype("float64", copy=False) return tensor.mean(dim=self._axes) def update(self, sample: Sample) -> None: @@ -92,8 +90,8 @@ def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: self._update_impl(sample.data[self._tensor_id], mean) return {self._sample_mean: mean} - def _update_impl(self, tensor: xr.DataArray, tensor_mean: xr.DataArray): - assert tensor_mean.dtype == np.float64 + def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): + assert tensor_mean.dtype == "float64" # reduced voxel count n_b = int(np.prod(tensor.shape) / np.prod(tensor_mean.shape)) @@ -132,7 +130,7 @@ def compute( ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: tensor = sample.data[self._tensor_id] mean = tensor.mean(dim=self._axes) - c = tensor - mean + c = (tensor - mean).data if self._axes is None: n = tensor.size else: @@ -144,12 +142,16 @@ def compute( assert isinstance(std, xr.DataArray) return { SampleMean(axes=self._axes, tensor_id=self._tensor_id): mean, - SampleVar(axes=self._axes, tensor_id=self._tensor_id): var, - SampleStd(axes=self._axes, tensor_id=self._tensor_id): std, + SampleVar(axes=self._axes, tensor_id=self._tensor_id): Tensor.from_xarray( + var + ), + SampleStd(axes=self._axes, tensor_id=self._tensor_id): Tensor.from_xarray( + std + ), } def update(self, sample: Sample): - tensor = sample.data[self._tensor_id].astype(np.float64, copy=False) + tensor = sample.data[self._tensor_id].astype("float64", copy=False) mean_b = tensor.mean(dim=self._axes) assert mean_b.dtype == np.float64 # reduced voxel count diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index 83775fc9..fa928eae 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -2,15 +2,18 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, Optional, Tuple, TypeVar, Union +from typing import Dict, Optional, Protocol, Tuple, TypeVar, Union -import xarray as xr +from .axis import AxisId +from .tensor import PerTensor, Tensor, TensorId -from bioimageio.core.axis import AxisId -from bioimageio.core.sample import Sample -from bioimageio.core.Tensor import TensorId +MeasureValue = Union[float, Tensor] -MeasureValue = Union[float, xr.DataArray] + +# using Sample Protocol really only to avoid circular imports +class SampleLike(Protocol): + @property + def data(self) -> PerTensor[Tensor]: ... @dataclass(frozen=True) @@ -21,7 +24,7 @@ class MeasureBase: @dataclass(frozen=True) class SampleMeasureBase(MeasureBase, ABC): @abstractmethod - def compute(self, sample: Sample) -> MeasureValue: + def compute(self, sample: SampleLike) -> MeasureValue: """compute the measure""" ... @@ -41,7 +44,7 @@ class _Mean: class SampleMean(_Mean, SampleMeasureBase): """The mean value of a single tensor""" - def compute(self, sample: Sample) -> MeasureValue: + def compute(self, sample: SampleLike) -> MeasureValue: tensor = sample.data[self.tensor_id] return tensor.mean(dim=self.axes) @@ -67,7 +70,7 @@ class _Std: class SampleStd(_Std, SampleMeasureBase): """The standard deviation of a single tensor""" - def compute(self, sample: Sample) -> MeasureValue: + def compute(self, sample: SampleLike) -> MeasureValue: tensor = sample.data[self.tensor_id] return tensor.std(dim=self.axes) @@ -93,7 +96,7 @@ class _Var: class SampleVar(_Var, SampleMeasureBase): """The variance of a single tensor""" - def compute(self, sample: Sample) -> MeasureValue: + def compute(self, sample: SampleLike) -> MeasureValue: tensor = sample.data[self.tensor_id] return tensor.var(dim=self.axes) @@ -111,22 +114,22 @@ def __post_init__(self): @dataclass(frozen=True) class _Percentile: - n: float + q: float axes: Optional[Tuple[AxisId, ...]] = None """`axes` to reduce""" def __post_init__(self): - assert self.n >= 0 - assert self.n <= 100 + assert self.q >= 0.0 + assert self.q <= 1.0 @dataclass(frozen=True) class SamplePercentile(_Percentile, SampleMeasureBase): """The `n`th percentile of a single tensor""" - def compute(self, sample: Sample) -> MeasureValue: + def compute(self, sample: SampleLike) -> MeasureValue: tensor = sample.data[self.tensor_id] - return tensor.quantile(self.n / 100.0, dim=self.axes) + return tensor.quantile(self.q, dim=self.axes) def __post_init__(self): super().__post_init__() diff --git a/bioimageio/core/tensor.py b/bioimageio/core/tensor.py index 26384f0e..e63380ea 100644 --- a/bioimageio/core/tensor.py +++ b/bioimageio/core/tensor.py @@ -3,7 +3,9 @@ import itertools from math import prod from typing import ( + TYPE_CHECKING, Any, + Callable, Dict, Generator, List, @@ -25,31 +27,48 @@ from bioimageio.core.axis import PerAxis from bioimageio.core.common import PadMode, PadWhere -from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model import v0_5 +from ._magic_tensor_ops import MagicTensorOpsMixin from .axis import Axis, AxisId, AxisInfo, AxisLike from .common import ( + CropWhere, DTypeStr, Halo, HaloLike, PadWidth, + PadWidthLike, SliceInfo, TileNumber, TotalNumberOfTiles, ) +if TYPE_CHECKING: + from numpy.typing import ArrayLike, NDArray TensorId = v0_5.TensorId T = TypeVar("T") + PerTensor = Mapping[TensorId, T] -class Tensor: +_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" + + +# TODO: make Tensor a numpy compatible array type, to use e.g. \ +# with `np.testing.assert_array_almost_equal`. +# TODO: complete docstrings +class Tensor(MagicTensorOpsMixin): + """A wrapper around an xr.DataArray for better integration with bioimageio.spec + and improved type annotations.""" + + _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] + def __init__( self, array: NDArray[Any], dims: Union[AxisId, Sequence[AxisId]], - id: TensorId, + id: Optional[TensorId] = None, ) -> None: super().__init__() self._data = xr.DataArray(array, dims=dims, name=id) @@ -63,61 +82,139 @@ def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> N key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} self._data[key] = value._data + def _binary_op( + self, + other: _Compatible, + f: Callable[[Any, Any], Any], + reflexive: bool = False, + ) -> Self: + data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] + (other._data if isinstance(other, Tensor) else other), + f, + reflexive, + ) + return self.__class__.from_xarray(data) + + def _inplace_binary_op( + self, + other: _Compatible, + f: Callable[[Any, Any], Any], + ) -> Self: + _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] + ( + other_d + if (other_d := getattr(other, "data")) is not None + and isinstance( + other_d, + xr.DataArray, + ) + else other + ), + f, + ) + return self + + def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: + data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] + f, *args, **kwargs + ) + return self.__class__.from_xarray(data) + @classmethod def from_xarray(cls, data_array: xr.DataArray) -> Self: - if data_array.name is None: - raise ValueError( - "Expected a named `data_array` to use `data_array.name` as tensor id" - ) + """create a `Tensor` from an xarray data array + note for internal use: this factory method is round-trip save + for any `Tensor`'s `data` property (an xarray.DataArray). + """ return cls( array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims), - id=TensorId(data_array.name), + id=None if data_array.name is None else TensorId(data_array.name), ) @classmethod def from_numpy( - cls, array: NDArray[Any], axes: Optional[Sequence[AxisLike]], id: TensorId + cls, + array: NDArray[Any], + *, + dims: Optional[Union[AxisLike, Sequence[AxisLike]]], + id: TensorId, ) -> Tensor: - if axes is None: + """create a `Tensor` from a numpy array + + Args: + array: the nd numpy array + axes: A description of the array's axes, + if None axes are guessed (which might fail and raise a ValueError.) + id: the created tensor's identifier + + Raises: + ValueError: if `axes` is None and axes guessing fails. + """ + + if dims is None: return cls._interprete_array_wo_known_axes(array, id=id) + elif isinstance(dims, (str, Axis, v0_5.AxisBase)): + dims = [dims] + axis_infos = [AxisInfo.create(a) for a in dims] original_shape = tuple(array.shape) - if len(array.shape) > len(axes): + if len(array.shape) > len(dims): # remove singletons for i, s in enumerate(array.shape): if s == 1: array = np.take(array, 0, axis=i) - if len(array.shape) == len(axes): + if len(array.shape) == len(dims): break # add singletons if nececsary - for a in axes: - a = AxisInfo.create(a) - if len(array.shape) >= len(axes): + for a in axis_infos: + + if len(array.shape) >= len(dims): break if a.maybe_singleton: array = array[None] - if len(array.shape) != len(axes): + if len(array.shape) != len(dims): raise ValueError( - f"Array shape {original_shape} does not map to axes {axes}" + f"Array shape {original_shape} does not map to axes {dims}" ) - normalized_axes = normalize_axes(axes) - assert len(normalized_axes) == len(axes) - return Tensor(array, dims=tuple(a.id for a in normalized_axes)) + return Tensor(array, dims=tuple(a.id for a in axis_infos), id=id) @property def data(self): return self._data @property - def dims(self): + def dims(self): # TODO: rename to `axes`? + """Tuple of dimension names associated with this tensor.""" return cast(Tuple[AxisId, ...], self._data.dims) + @property + def shape(self): + """Tuple of tensor dimension lenghts""" + return self._data.shape + + @property + def size(self): + """Number of elements in the tensor. + + Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. + """ + return self._data.size + + def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: + """Reduce this Tensor's data by applying sum along some dimension(s).""" + return self.__class__.from_xarray(self._data.sum(dim=dim)) + + @property + def ndim(self): + """Number of tensor dimensions.""" + return self._data.ndim + @property def dtype(self) -> DTypeStr: dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] @@ -126,38 +223,50 @@ def dtype(self) -> DTypeStr: @property def id(self): + """the tensor's identifier""" return self._id @property def sizes(self): + """Ordered, immutable mapping from axis ids to lengths.""" return cast(Mapping[AxisId, int], self.data.sizes) + def astype(self, dtype: DTypeStr, *, copy: bool = False): + """Return tensor cast to `dtype` + + note: if dtype is already satisfied copy if `copy`""" + return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) + + def clip(self, min: Optional[float] = None, max: Optional[float] = None): + """Return a tensor whose values are limited to [min, max]. + At least one of max or min must be given.""" + return self.__class__.from_xarray(self._data.clip(min, max)) + def crop_to( - tensor: Tensor, - sizes: Mapping[AxisId, int], + self, + sizes: PerAxis[int], crop_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", - ): - """crop `tensor` to match `sizes`""" - axes = [AxisId(str(a)) for a in tensor.dims] - if crop_where in ("before", "center", "after"): - crop_axis_where: Mapping[AxisId, Literal["before", "center", "after"]] = { - a: crop_where for a in axes - } + CropWhere, + PerAxis[CropWhere], + ] = "left_and_right", + ) -> Self: + """crop to match `sizes`""" + if isinstance(crop_where, str): + crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} else: crop_axis_where = crop_where - slices: Dict[AxisId, slice] = {} + slices: Dict[AxisId, SliceInfo] = {} - for a, s_is in tensor.sizes.items(): - a = AxisId(str(a)) + for a, s_is in self.sizes.items(): if a not in sizes or sizes[a] == s_is: pass elif sizes[a] > s_is: - warnings.warn( - f"Cannot crop axis {a} of size {s_is} to larger size {sizes[a]}" + logger.warning( + "Cannot crop axis {} of size {} to larger size {}", + a, + s_is, + sizes[a], ) elif a not in crop_axis_where: raise ValueError( @@ -165,31 +274,37 @@ def crop_to( ) else: crop_this_axis_where = crop_axis_where[a] - if crop_this_axis_where == "before": - slices[a] = slice(s_is - sizes[a], s_is) - elif crop_this_axis_where == "after": - slices[a] = slice(0, sizes[a]) - elif crop_this_axis_where == "center": - slices[a] = slice(start := (s_is - sizes[a]) // 2, sizes[a] + start) + if crop_this_axis_where == "left": + slices[a] = SliceInfo(s_is - sizes[a], s_is) + elif crop_this_axis_where == "right": + slices[a] = SliceInfo(0, sizes[a]) + elif crop_this_axis_where == "left_and_right": + slices[a] = SliceInfo( + start := (s_is - sizes[a]) // 2, sizes[a] + start + ) else: assert_never(crop_this_axis_where) - return tensor.isel({str(a): s for a, s in slices.items()}) + return self[slices] - def mean(self, dim: Union[AxisId, Sequence[AxisId]]) -> Self: + def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: + return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) + + def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: return self.__class__.from_xarray(self._data.mean(dims=dim)) - def std(self, dim: Union[AxisId, Sequence[AxisId]]) -> Self: + def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: return self.__class__.from_xarray(self._data.std(dims=dim)) - def var(self, dim: Union[AxisId, Sequence[AxisId]]) -> Self: + def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: return self.__class__.from_xarray(self._data.var(dims=dim)) def pad( self, - pad_width: PerAxis[PadWidth], + pad_width: PerAxis[PadWidthLike], mode: PadMode = "symmetric", ) -> Self: + pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} return self.__class__.from_xarray( self._data.pad(pad_width=pad_width, mode=mode) ) @@ -197,7 +312,7 @@ def pad( def pad_to( self, sizes: PerAxis[int], - pad_where: Union[PadWhere, PerAxis[PadWhere]] = "center", + pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", mode: PadMode = "symmetric", ) -> Self: """pad `tensor` to match `sizes`""" @@ -224,37 +339,44 @@ def pad_to( ) else: pad_this_axis_where = pad_axis_where[a] - p = sizes[a] - s_is - if pad_this_axis_where == "before": - pad_width[a] = PadWidth(p, 0) - elif pad_this_axis_where == "after": - pad_width[a] = PadWidth(0, p) - elif pad_this_axis_where == "center": - pad_width[a] = PadWidth(left := p // 2, p - left) + d = sizes[a] - s_is + if pad_this_axis_where == "left": + pad_width[a] = PadWidth(d, 0) + elif pad_this_axis_where == "right": + pad_width[a] = PadWidth(0, d) + elif pad_this_axis_where == "left_and_right": + pad_width[a] = PadWidth(left := d // 2, d - left) else: assert_never(pad_this_axis_where) return self.pad(pad_width, mode) + def quantile( + self, q: float, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None + ) -> Self: + assert q >= 0.0 + assert q <= 1.0 + return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) + def resize_to( - tensor: Tensor, - sizes: Mapping[AxisId, int], + self, + sizes: PerAxis[int], *, pad_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", + PadWhere, + PerAxis[PadWhere], + ] = "left_and_right", crop_where: Union[ - Literal["before", "center", "after"], - Mapping[AxisId, Literal["before", "center", "after"]], - ] = "center", + CropWhere, + PerAxis[CropWhere], + ] = "left_and_right", pad_mode: PadMode = "symmetric", ): - """crop and pad `tensor` to match `sizes`""" + """return cropped/padded tensor with `sizes`""" crop_to_sizes: Dict[AxisId, int] = {} pad_to_sizes: Dict[AxisId, int] = {} new_axes = dict(sizes) - for a, s_is in tensor.sizes.items(): + for a, s_is in self.sizes.items(): a = AxisId(str(a)) _ = new_axes.pop(a, None) if a not in sizes or sizes[a] == s_is: @@ -264,14 +386,15 @@ def resize_to( else: pad_to_sizes[a] = sizes[a] + tensor = self if crop_to_sizes: - tensor = crop_to(tensor, crop_to_sizes, crop_where=crop_where) + tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) if pad_to_sizes: - tensor = pad_to(tensor, pad_to_sizes, pad_where=pad_where, mode=pad_mode) + tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) if new_axes: - tensor = tensor.expand_dims({str(k): v for k, v in new_axes}) + tensor = tensor.expand_dims(new_axes) return tensor @@ -282,7 +405,7 @@ def tile( pad_mode: PadMode, ) -> Tuple[ TotalNumberOfTiles, - Generator[Tuple[TileNumber, Tensor, PerAxis[SliceInfo]], Any, None], + Generator[Tuple[TileNumber, Tensor, PedrAxis[SliceInfo]], Any, None], ]: """tile this tensor into `tile_size` tiles that overlap by `halo`. At the tensor's edge the `halo` is padded with `pad_mode`. @@ -298,8 +421,11 @@ def tile( assert all(a in self.dims for a in tile_size), (self.dims, set(tile_size)) assert all(a in self.dims for a in halo), (self.dims, set(halo)) + # fill in default halo (0) and tile_size (tensor size) + halo = {a: Halo.create(halo.get(a, 0)) for a in self.dims} + tile_size = {a: tile_size.get(a, s) for a, s in self.sizes.items()} + inner_1d_tiles: List[List[SliceInfo]] = [] - halo = {a: Halo.create(h) for a, h in halo.items()} for a, s in self.sizes.items(): stride = tile_size[a] - sum(halo[a]) tiles_1d = [SliceInfo(p, min(s, p + stride)) for p in range(0, s, stride)] @@ -320,14 +446,14 @@ def transpose( Args: axes: the desired tensor axes """ - # expand the missing image axes - current_axes = tuple( - d if isinstance(d, AxisId) else AxisId(d) for d in tensor.dims - ) - missing_axes = tuple(a for a in axes if a not in current_axes) - tensor = tensor.expand_dims(missing_axes) + # expand missing tensor axes + missing_axes = tuple(a for a in axes if a not in self.dims) + array = self._data + if missing_axes: + array = array.expand_dims(missing_axes) + # transpose to the correct axis order - return tensor.transpose(*map(str, axes)) + return self.__class__.from_xarray(array.transpose(*axes)) @classmethod def _interprete_array_wo_known_axes(cls, array: NDArray[Any], id: TensorId): diff --git a/bioimageio/core/tile.py b/bioimageio/core/tile.py index 03703ce3..d8180af4 100644 --- a/bioimageio/core/tile.py +++ b/bioimageio/core/tile.py @@ -3,7 +3,7 @@ from bioimageio.core.common import TileNumber, TotalNumberOfTiles from .axis import PerAxis -from .common import Halo, LeftRight, PadWidth, SliceInfo +from .common import Halo, OverlapWidth, PadWidth, SliceInfo from .stat_measures import Stat from .tensor import PerTensor, Tensor @@ -36,7 +36,7 @@ class AbstractTile: local_slice: PerTensor[PerAxis[SliceInfo]] = field(init=False) """slice to extract the inner tile from the outer tile""" - overlap: PerTensor[PerAxis[LeftRight]] = field(init=False) + overlap: PerTensor[PerAxis[OverlapWidth]] = field(init=False) """overlap 'into a neighboring tile'""" padding: PerTensor[PerAxis[PadWidth]] = field(init=False) @@ -68,7 +68,7 @@ def __post_init__(self): } self.overlap = { t: { - a: LeftRight( + a: OverlapWidth( self.inner_slice[t][a].start - self.outer_slice[t][a].start, self.outer_slice[t][a].stop - self.inner_slice[t][a].stop, ) diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py index d88ea113..3fe41c02 100644 --- a/bioimageio/core/utils/_digest_spec.py +++ b/bioimageio/core/utils/_digest_spec.py @@ -27,7 +27,7 @@ def get_test_inputs(model: AnyModelDescr) -> List[Tensor]: tensor_ids = [ipt.id for ipt in model.inputs] return [ - Tensor.from_numpy(arr, ax, t) + Tensor.from_numpy(arr, dims=ax, id=t) for arr, ax, t in zip(arrays, core_axes, tensor_ids) ] @@ -52,6 +52,6 @@ def get_test_outputs(model: AnyModelDescr) -> List[Tensor]: tensor_ids = [ipt.id for ipt in model.inputs] return [ - Tensor.from_numpy(arr, ax, t) + Tensor.from_numpy(arr, dims=ax, id=t) for arr, ax, t in zip(arrays, core_axes, tensor_ids) ] diff --git a/tests/test_tensor.py b/tests/test_tensor.py new file mode 100644 index 00000000..076e0961 --- /dev/null +++ b/tests/test_tensor.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest +import xarray as xr +from xarray.testing import assert_equal # pyright: ignore[reportUnknownVariableType] + +from bioimageio.core import AxisId, Tensor, TensorId + + +@pytest.mark.parametrize( + "axes", + ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"], +) +def test_transpose_tensor_2d(axes: str): + + tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None, id=TensorId("id")) + transposed = tensor.transpose([AxisId(a) for a in axes]) + assert transposed.ndim == len(axes) + + +@pytest.mark.parametrize( + "axes", + ["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"], +) +def test_transpose_tensor_3d(axes: str): + tensor = Tensor.from_numpy(np.random.rand(64, 64, 64), dims=None, id=TensorId("id")) + transposed = tensor.transpose([AxisId(a) for a in axes]) + assert transposed.ndim == len(axes) + + +def test_crop_and_pad(): + tensor = Tensor.from_xarray( + xr.DataArray(np.random.rand(10, 20), dims=("x", "y"), name="id") + ) + padded = tensor.pad({AxisId("x"): 7, AxisId("y"): (3, 3)}) + cropped = padded.crop_to(tensor.sizes) + assert_equal(tensor, cropped) + + +def test_some_magic_ops(): + tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None, id=TensorId("id")) + assert tensor + 2 == 2 + tensor diff --git a/tests/utils/test_image_helper.py b/tests/utils/test_image_helper.py deleted file mode 100644 index 96176f88..00000000 --- a/tests/utils/test_image_helper.py +++ /dev/null @@ -1,52 +0,0 @@ -import numpy as np -import pytest -import xarray as xr -from xarray.testing import assert_equal # pyright: ignore[reportUnknownVariableType] - -from bioimageio.core.axis import AxisId -from bioimageio.core.io import ( - interprete_array, - transpose_tensor, -) -from bioimageio.core.utils.tiling import crop_to, pad - - -@pytest.mark.parametrize( - "axes", - ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"], -) -def test_transpose_tensor_2d(axes: str): - - tensor = interprete_array(np.random.rand(256, 256), None) - transposed = transpose_tensor(tensor, [AxisId(a) for a in axes]) - assert transposed.ndim == len(axes) - - -@pytest.mark.parametrize( - "axes", - ["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"], -) -def test_transpose_tensor_3d(axes: str): - tensor = interprete_array(np.random.rand(64, 64, 64), None) - transposed = transpose_tensor(tensor, [AxisId(a) for a in axes]) - assert transposed.ndim == len(axes) - - -def test_crop_and_pad(): - tensor = xr.DataArray(np.random.rand(10, 20), dims=("x", "y")) - sizes = {AxisId(str(k)): v for k, v in tensor.sizes.items()} - padded = pad(tensor, {AxisId("x"): 7, AxisId("y"): (3, 3)}) - cropped = crop_to(padded, sizes) - assert_equal(tensor, cropped) - - -# def test_transform_output_tensor(): -# from bioimageio.core.utils.image_helper import transform_output_tensor - -# tensor = np.random.rand(1, 3, 64, 64, 64) -# tensor_axes = "bczyx" - -# out_ax_list = ["bczyx", "cyx", "xyc", "byxc", "zyx", "xyz"] -# for out_axes in out_ax_list: -# out = transform_output_tensor(tensor, tensor_axes, out_axes) -# assert out.ndim == len(out_axes) From 96cc0ef456ecaf325347efa237f34520b1d217d7 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 30 Mar 2024 13:40:29 +0100 Subject: [PATCH 172/244] make Tensor numpy arraylike --- bioimageio/core/tensor.py | 7 ++++--- tests/test_prediction_pipeline_device_management.py | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/bioimageio/core/tensor.py b/bioimageio/core/tensor.py index e63380ea..67d8f8cf 100644 --- a/bioimageio/core/tensor.py +++ b/bioimageio/core/tensor.py @@ -22,7 +22,7 @@ import numpy as np import xarray as xr from loguru import logger -from numpy.typing import NDArray +from numpy.typing import DTypeLike, NDArray from typing_extensions import Self, assert_never from bioimageio.core.axis import PerAxis @@ -55,8 +55,6 @@ _ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" -# TODO: make Tensor a numpy compatible array type, to use e.g. \ -# with `np.testing.assert_array_almost_equal`. # TODO: complete docstrings class Tensor(MagicTensorOpsMixin): """A wrapper around an xr.DataArray for better integration with bioimageio.spec @@ -74,6 +72,9 @@ def __init__( self._data = xr.DataArray(array, dims=dims, name=id) self._id = id + def __array__(self, dtype: DTypeLike = None): + return np.asarray(self._data, dtype=dtype) + def __getitem__(self, key: PerAxis[Union[SliceInfo, slice]]) -> Self: key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} return self.__class__.from_xarray(self._data[key]) diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py index 4c7ac1a0..b244ee26 100644 --- a/tests/test_prediction_pipeline_device_management.py +++ b/tests/test_prediction_pipeline_device_management.py @@ -37,6 +37,7 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): assert len(outputs) == len(expected_outputs) for out, exp in zip(outputs, expected_outputs): + assert out is not None assert_array_almost_equal(out, exp, decimal=4) # repeat inference with context manager to test load/unload/load/forward @@ -45,6 +46,7 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): assert len(outputs) == len(expected_outputs) for out, exp in zip(outputs, expected_outputs): + assert out is not None assert_array_almost_equal(out, exp, decimal=4) From 9daffe9b202f4d6f36861c921f69be1183bf8eda Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 30 Mar 2024 14:24:50 +0100 Subject: [PATCH 173/244] percentile`s n -> q and "use tensor" --- bioimageio/core/stat_calculators.py | 86 ++++++++++++++--------------- bioimageio/core/tensor.py | 37 +++++++++++-- 2 files changed, 74 insertions(+), 49 deletions(-) diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 380c41b1..9319443d 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -3,11 +3,11 @@ import collections.abc import warnings from itertools import product +from math import prod from typing import ( Any, Collection, Dict, - Hashable, Iterable, Iterator, List, @@ -26,7 +26,7 @@ from numpy.typing import NDArray from typing_extensions import assert_never -from .axis import AxisId +from .axis import AxisId, PerAxis from .sample import Sample from .stat_measures import ( DatasetMean, @@ -122,8 +122,8 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): self._axes = None if axes is None else tuple(axes) self._tensor_id = tensor_id self._n: int = 0 - self._mean: Optional[xr.DataArray] = None - self._m2: Optional[xr.DataArray] = None + self._mean: Optional[Tensor] = None + self._m2: Optional[Tensor] = None def compute( self, sample: Sample @@ -136,9 +136,9 @@ def compute( else: n = int(np.prod([tensor.sizes[d] for d in self._axes])) - var: xr.DataArray = xr.dot(c, c, dims=self._axes) / n + var = xr.dot(c, c, dims=self._axes) / n assert isinstance(var, xr.DataArray) - std: xr.DataArray = np.sqrt(var) # type: ignore + std = np.sqrt(var) assert isinstance(std, xr.DataArray) return { SampleMean(axes=self._axes, tensor_id=self._tensor_id): mean, @@ -153,11 +153,11 @@ def compute( def update(self, sample: Sample): tensor = sample.data[self._tensor_id].astype("float64", copy=False) mean_b = tensor.mean(dim=self._axes) - assert mean_b.dtype == np.float64 + assert mean_b.dtype == "float64" # reduced voxel count - n_b = int(np.prod(tensor.shape) / np.prod(mean_b.shape)) + n_b = int(prod(tensor.shape) / prod(mean_b.shape)) m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) - assert m2_b.dtype == np.float64 + assert m2_b.dtype == "float64" if self._mean is None: assert self._m2 is None self._n = n_b @@ -182,11 +182,14 @@ def finalize( else: assert self._m2 is not None var = self._m2 / self._n - sqrt: xr.DataArray = np.sqrt(var) # type: ignore + sqrt = np.sqrt(var) + assert isinstance(sqrt, xr.DataArray) return { DatasetMean(tensor_id=self._tensor_id, axes=self._axes): self._mean, DatasetVar(tensor_id=self._tensor_id, axes=self._axes): var, - DatasetStd(tensor_id=self._tensor_id, axes=self._axes): sqrt, + DatasetStd( + tensor_id=self._tensor_id, axes=self._axes + ): Tensor.from_xarray(sqrt), } @@ -197,12 +200,11 @@ def __init__( self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], - ns: Collection[float], + qs: Collection[float], ): super().__init__() - assert all(0 <= n <= 100 for n in ns) - self.ns = ns - self._qs = [n / 100 for n in ns] + assert all(0.0 <= q <= 1.0 for q in qs) + self._qs = sorted(set(qs)) self._axes = None if axes is None else tuple(axes) self._tensor_id = tensor_id @@ -210,8 +212,8 @@ def compute(self, sample: Sample) -> Dict[SamplePercentile, MeasureValue]: tensor = sample.data[self._tensor_id] ps = tensor.quantile(self._qs, dim=self._axes) return { - SamplePercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): p - for n, p in zip(self.ns, ps) + SamplePercentile(q=q, axes=self._axes, tensor_id=self._tensor_id): p + for q, p in zip(self._qs, ps) } @@ -224,21 +226,20 @@ def __init__( self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], - ns: Collection[float], + qs: Collection[float], ): super().__init__() - assert all(0 <= n <= 100 for n in ns) - self._ns = ns - self._qs = np.asarray([n / 100 for n in ns]) + assert all(0.0 <= q <= 1.0 for q in qs) + self._qs = sorted(set(qs)) self._axes = None if axes is None else tuple(axes) self._tensor_id = tensor_id self._n: int = 0 - self._estimates: Optional[xr.DataArray] = None + self._estimates: Optional[Tensor] = None def update(self, sample: Sample): tensor = sample.data[self._tensor_id] sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( - np.float64, copy=False + "float64", copy=False ) # reduced voxel count @@ -263,8 +264,8 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: "Computed dataset percentiles naively by averaging percentiles of samples." ) return { - DatasetPercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): e - for n, e in zip(self._ns, self._estimates) + DatasetPercentile(q=q, axes=self._axes, tensor_id=self._tensor_id): e + for q, e in zip(self._qs, self._estimates) } @@ -275,27 +276,26 @@ def __init__( self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], - ns: Collection[float], + qs: Collection[float], ): warnings.warn( "Computing dataset percentiles with experimental 'crick' library." ) super().__init__() - assert all(0 <= n <= 100 for n in ns) + assert all(0.0 <= q <= 1.0 for q in qs) assert axes is None or "_percentiles" not in axes - self._ns = ns - self._qs = [n / 100 for n in ns] + self._qs = sorted(set(qs)) self._axes = None if axes is None else tuple(axes) self._tensor_id = tensor_id self._digest: Optional[List[TDigest]] = None - self._dims: Optional[Tuple[Hashable, ...]] = None + self._dims: Optional[Tuple[AxisId, ...]] = None self._indices: Optional[Iterator[Tuple[int, ...]]] = None self._shape: Optional[Tuple[int, ...]] = None - def _initialize(self, tensor_sizes: Mapping[Hashable, int]): + def _initialize(self, tensor_sizes: PerAxis[int]): assert crick is not None - out_sizes: OrderedDict[Hashable, int] = collections.OrderedDict( - _percentiles=len(self._ns) + out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict( + _percentiles=len(self._qs) ) if self._axes is not None: for d, s in tensor_sizes.items(): @@ -317,7 +317,7 @@ def update(self, sample: Sample): assert self._indices is not None assert self._dims is not None for i, idx in enumerate(self._indices): - self._digest[i].update(tensor.isel(dict(zip(self._dims[1:], idx)))) + self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))]) def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: if self._digest is None: @@ -331,9 +331,9 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: ).reshape(self._shape) return { DatasetPercentile( - n=n, axes=self._axes, tensor_id=self._tensor_id - ): xr.DataArray(v, dims=self._dims[1:]) - for n, v in zip(self._ns, vs) + q=q, axes=self._axes, tensor_id=self._tensor_id + ): Tensor(v, dims=self._dims[1:]) + for q, v in zip(self._qs, vs) } @@ -499,11 +499,11 @@ def get_measure_calculators( assert rm in required_dataset_mean_var_std elif isinstance(rm, SamplePercentile): required_sample_percentiles.setdefault((rm.tensor_id, rm.axes), set()).add( - rm.n + rm.q ) elif isinstance(rm, DatasetPercentile): required_dataset_percentiles.setdefault((rm.tensor_id, rm.axes), set()).add( - rm.n + rm.q ) else: assert_never(rm) @@ -532,14 +532,14 @@ def get_measure_calculators( MeanVarStdCalculator(tensor_id=rm.tensor_id, axes=rm.axes) ) - for (tid, axes), ns in required_sample_percentiles.items(): + for (tid, axes), qs in required_sample_percentiles.items(): sample_calculators.append( - SamplePercentilesCalculator(tensor_id=tid, axes=axes, ns=ns) + SamplePercentilesCalculator(tensor_id=tid, axes=axes, qs=qs) ) - for (tid, axes), ns in required_dataset_percentiles.items(): + for (tid, axes), qs in required_dataset_percentiles.items(): dataset_calculators.append( - DatasetPercentilesCalculator(tensor_id=tid, axes=axes, ns=ns) + DatasetPercentilesCalculator(tensor_id=tid, axes=axes, qs=qs) ) return sample_calculators, dataset_calculators diff --git a/bioimageio/core/tensor.py b/bioimageio/core/tensor.py index 67d8f8cf..39162413 100644 --- a/bioimageio/core/tensor.py +++ b/bioimageio/core/tensor.py @@ -8,6 +8,7 @@ Callable, Dict, Generator, + Iterator, List, Mapping, Optional, @@ -75,14 +76,26 @@ def __init__( def __array__(self, dtype: DTypeLike = None): return np.asarray(self._data, dtype=dtype) - def __getitem__(self, key: PerAxis[Union[SliceInfo, slice]]) -> Self: - key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} + def __getitem__(self, key: PerAxis[Union[SliceInfo, slice, int]]) -> Self: + key = { + a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) + for a, s in key.items() + } return self.__class__.from_xarray(self._data[key]) def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None: key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} self._data[key] = value._data + def _iter(self: Any) -> Iterator[Any]: + for n in range(len(self)): + yield self[n] + + def __iter__(self: Any) -> Iterator[Any]: + if self.ndim == 0: + raise TypeError("iteration over a 0-d array") + return self._iter() + def _binary_op( self, other: _Compatible, @@ -353,10 +366,22 @@ def pad_to( return self.pad(pad_width, mode) def quantile( - self, q: float, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None + self, + q: Union[float, Sequence[float]], + dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, ) -> Self: - assert q >= 0.0 - assert q <= 1.0 + assert ( + isinstance(q, (float, int)) + and q >= 0.0 + or not isinstance(q, (float, int)) + and all(qq >= 0.0 for qq in q) + ) + assert ( + isinstance(q, (float, int)) + and q <= 1.0 + or not isinstance(q, (float, int)) + and all(qq <= 1.0 for qq in q) + ) return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) def resize_to( @@ -406,7 +431,7 @@ def tile( pad_mode: PadMode, ) -> Tuple[ TotalNumberOfTiles, - Generator[Tuple[TileNumber, Tensor, PedrAxis[SliceInfo]], Any, None], + Generator[Tuple[TileNumber, Tensor, PerAxis[SliceInfo]], Any, None], ]: """tile this tensor into `tile_size` tiles that overlap by `halo`. At the tensor's edge the `halo` is padded with `pad_mode`. From 860ff2582c4413c0571878a5607a471b2086f281 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 30 Mar 2024 23:58:23 +0100 Subject: [PATCH 174/244] fix imports in model_adapters and update tests --- .../core/model_adapters/_keras_model_adapter.py | 2 +- .../core/model_adapters/_onnx_model_adapter.py | 2 +- .../model_adapters/_pytorch_model_adapter.py | 2 +- .../model_adapters/_tensorflow_model_adapter.py | 2 +- .../_torchscript_model_adapter.py | 2 +- tests/test_prediction.py | 2 +- tests/test_proc_ops.py | 2 +- tests/test_stat_calculators.py | 17 +++++++++-------- tests/test_stat_measures.py | 16 ++++++++++------ 9 files changed, 26 insertions(+), 21 deletions(-) diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py index 5d956807..5e74b084 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -4,7 +4,7 @@ from loguru import logger from numpy.typing import NDArray -from bioimageio.core.Tensor import Tensor +from bioimageio.core.tensor import Tensor from bioimageio.spec._internal.io_utils import download from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import Version diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py index fb5c6648..e42b8912 100644 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -3,7 +3,7 @@ from numpy.typing import NDArray -from bioimageio.core.Tensor import Tensor +from bioimageio.core.tensor import Tensor from bioimageio.spec.model import v0_4, v0_5 from ._model_adapter import ModelAdapter diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 839919f6..b3454582 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -2,7 +2,7 @@ import warnings from typing import Any, List, Optional, Sequence, Tuple, Union -from bioimageio.core.Tensor import Tensor +from bioimageio.core.tensor import Tensor from bioimageio.core.utils import import_callable from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index eecc9b45..390b1b05 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -4,7 +4,7 @@ import numpy as np -from bioimageio.core.Tensor import Tensor +from bioimageio.core.tensor import Tensor from bioimageio.spec.common import FileSource from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index 4f0a50ba..d9454854 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -5,7 +5,7 @@ import numpy as np from numpy.typing import NDArray -from bioimageio.core.Tensor import Tensor +from bioimageio.core.tensor import Tensor from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 2a6c4487..b2547171 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -8,7 +8,7 @@ from bioimageio.spec import load_description from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr_v0_4 from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr_v0_4 -from bioimageio.spec.model.v0_5 import InputTensorDescr, ModelDescr +from bioimageio.spec.model.v0_5 import ModelDescr def test_predict_image(any_model: Path, tmpdir: Path): diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py index e8bfa427..94ce3320 100644 --- a/tests/test_proc_ops.py +++ b/tests/test_proc_ops.py @@ -9,7 +9,7 @@ from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import compute_measures from bioimageio.core.stat_measures import SampleMean, SamplePercentile, SampleStd -from bioimageio.core.Tensor import TensorId +from bioimageio.core.tensor import TensorId @pytest.fixture(scope="module") diff --git a/tests/test_stat_calculators.py b/tests/test_stat_calculators.py index 0e023ba4..6e963272 100644 --- a/tests/test_stat_calculators.py +++ b/tests/test_stat_calculators.py @@ -12,16 +12,17 @@ DatasetStd, DatasetVar, ) -from bioimageio.core.Tensor import Tensor, TensorId +from bioimageio.core.tensor import Tensor, TensorId -def create_random_dataset(tid: TensorId, axes: Tuple[str, ...], n: int = 3): - assert axes[0] == "batch" - sizes = list(range(1, len(axes) + 1)) - b = sizes[0] - ds_array = Tensor(np.random.rand(n * b, *sizes[1:]), dims=axes) - ds = [Sample(data={tid: ds_array[i * b : (i + 1) * b]}) for i in range(n)] - return ds_array, ds +def create_random_dataset(tid: TensorId, axes: Tuple[AxisId, ...]): + n = 3 + sizes = list(range(n, len(axes) + 1)) + data = np.asarray(np.random.rand(*sizes)) + ds = [ + Sample(data={tid: Tensor(data[i : i + 1], dims=axes, id=tid)}) for i in range(n) + ] + return Tensor(data, dims=axes), ds @pytest.mark.parametrize( diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index efddd03f..04b9ed3f 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -13,7 +13,7 @@ get_measure_calculators, ) from bioimageio.core.stat_measures import SamplePercentile -from bioimageio.core.Tensor import Tensor, TensorId +from bioimageio.core.tensor import Tensor, TensorId @pytest.mark.parametrize( @@ -31,7 +31,9 @@ def test_individual_normal_measure( measure = getattr(stat_measures, "Sample" + name.title())( axes=axes, tensor_id=data_id ) - data = Tensor(np.random.random((5, 6, 3)), dims=("x", "y", "c")) + data = Tensor( + np.random.random((5, 6, 3)), dims=(AxisId("x"), AxisId("y"), AxisId("c")) + ) expected = getattr(data, name)(dim=axes) sample = Sample(data={data_id: data}) @@ -41,17 +43,19 @@ def test_individual_normal_measure( @pytest.mark.parametrize("axes", [None, (AxisId("x"), AxisId("y"))]) def test_individual_percentile_measure(axes: Optional[Tuple[AxisId, ...]]): - ns = [0, 10, 50, 100] + qs = [0, 0.1, 0.5, 1.0] tid = TensorId("tensor") - measures = [SamplePercentile(tensor_id=tid, axes=axes, n=n) for n in ns] + measures = [SamplePercentile(tensor_id=tid, axes=axes, q=q) for q in qs] calcs, _ = get_measure_calculators(measures) assert len(calcs) == 1 calc = calcs[0] assert isinstance(calc, SamplePercentilesCalculator) - data = Tensor(np.random.random((5, 6, 3)), dims=("x", "y", "c")) + data = Tensor( + np.random.random((5, 6, 3)), dims=(AxisId("x"), AxisId("y"), AxisId("c")) + ) actual = calc.compute(Sample(data={tid: data})) for m in measures: - expected = data.quantile(q=m.n / 100, dim=m.axes) + expected = data.quantile(q=m.q, dim=m.axes) xr.testing.assert_allclose(expected, actual[m]) From 987f5ed1ec853511f2ecd5c1c7566c7fca053d34 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sun, 31 Mar 2024 10:40:47 +0200 Subject: [PATCH 175/244] add get_sample_axes --- bioimageio/core/_prediction_pipeline.py | 4 + bioimageio/core/axis.py | 2 +- bioimageio/core/prediction.py | 171 +++++++++--------------- bioimageio/core/utils/__init__.py | 1 + bioimageio/core/utils/_digest_spec.py | 51 ++++--- 5 files changed, 100 insertions(+), 129 deletions(-) diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 56a75407..5c5f648d 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -12,6 +12,7 @@ from .sample import Sample from .stat_measures import DatasetMeasure, MeasureValue from .tensor import PerTensor, Tensor, TensorId +from .utils import get_sample_axes class PredictionPipeline: @@ -45,6 +46,9 @@ def __init__( self.input_ids = [d.id for d in bioimageio_model.inputs] self.output_ids = [d.id for d in bioimageio_model.outputs] + self.input_axes = get_sample_axes(bioimageio_model.inputs) + self.output_axes = get_sample_axes(bioimageio_model.outputs) + self._adapter: ModelAdapter = model def __call__( diff --git a/bioimageio/core/axis.py b/bioimageio/core/axis.py index e86c7911..033b68d7 100644 --- a/bioimageio/core/axis.py +++ b/bioimageio/core/axis.py @@ -20,7 +20,7 @@ def _get_axis_type(a: Literal["b", "t", "i", "c", "x", "y", "z"]): elif a in ("x", "y", "z"): return "space" else: - assert_never(a) + return "index" # return most unspecific axis S = TypeVar("S", bound=str) diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index e9ec7256..d6f0ec30 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -1,92 +1,38 @@ """coming soon""" # TODO: update -# import collections -# import os -# from fractions import Fraction -# from itertools import product -# from pathlib import Path -# from typing import Any, Dict, Hashable, Iterator, List, NamedTuple, Optional, OrderedDict, Sequence, Tuple, Union - -# import numpy as np -# import xarray as xr -# from bioimageio.spec import ResourceDescr -# from bioimageio.spec.model.v0_5 import AxisType -# from numpy.typing import NDArray -# from pydantic import HttpUrl -# from tqdm import tqdm - -# from bioimageio.core import image_helper, load_description -# from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline -# from bioimageio.core.resource_io.nodes import ImplicitOutputShape, Model, ResourceDescr - -# Axis = Hashable - - -# class TileDef(NamedTuple): -# outer: Dict[Axis, slice] -# inner: Dict[Axis, slice] -# local: Dict[Axis, slice] - - -# def get_tiling( -# shape: Sequence[int], -# tile_shape: Dict[Axis, int], -# halo: Dict[Axis, int], -# input_axes: Sequence[Axis], -# axis_types: Dict[Axis, AxisType], -# scaling: Dict[Axis, float], -# ) -> Iterator[TileDef]: -# # outer_tile is the "input" tile, inner_tile is the "output" tile with the halo removed -# # tile_shape is the shape of the outer_tile -# assert len(shape) == len(input_axes) -# scaling_fractions = {ax: Fraction(sc).limit_denominator() for ax, sc in scaling.items()} - -# shape_ = [sh for sh, ax in zip(shape, input_axes) if axis_types[ax] == "space"] -# spatial_axes = [ax for ax in input_axes if axis_types[ax] == "space"] -# inner_tile_shape_ = [tile_shape[ax] - 2 * halo[ax] for ax in spatial_axes] -# scaling_ = [scaling_fractions[ax] for ax in spatial_axes] -# assert all([sh % fr.denominator == 0 for sh, fr in zip(shape_, scaling_)]) -# assert all([ish % fr.denominator == 0 for ish, fr in zip(inner_tile_shape_, scaling_)]) -# halo_ = [halo[ax] for ax in spatial_axes] -# assert len(shape_) == len(inner_tile_shape_) == len(spatial_axes) == len(halo_) - -# ranges = [range(sh // tsh if sh % tsh == 0 else sh // tsh + 1) for sh, tsh in zip(shape_, inner_tile_shape_)] -# start_points = product(*ranges) - -# for start_point in start_points: -# positions = [sp * tsh for sp, tsh in zip(start_point, inner_tile_shape_)] - -# inner_tile = { -# ax: slice(int(pos * fr), int(min(pos + tsh, sh) * fr)) -# for ax, pos, tsh, sh, fr in zip(spatial_axes, positions, inner_tile_shape_, shape_, scaling_) -# } -# # inner_tile["b"] = slice(None) -# # inner_tile["c"] = slice(None) - -# outer_tile = { -# ax: slice(max(pos - ha, 0), min(pos + tsh + ha, sh)) -# for ax, pos, tsh, sh, ha in zip(spatial_axes, positions, inner_tile_shape_, shape_, halo_) -# } -# # outer_tile["b"] = slice(None) -# # outer_tile["c"] = slice(None) - -# local_tile = { -# ax: slice( -# inner_tile[ax].start - int(outer_tile[ax].start * scaling[ax]), -# ( -# -(int(outer_tile[ax].stop * scaling[ax]) - inner_tile[ax].stop) -# if int(outer_tile[ax].stop * scaling[ax]) != inner_tile[ax].stop -# else None -# ), -# ) -# for ax in spatial_axes -# } -# # local_tile["b"] = slice(None) -# # local_tile["c"] = slice(None) - -# yield TileDef(outer_tile, inner_tile, local_tile) - +import collections.abc +import os +from fractions import Fraction +from itertools import product +from pathlib import Path +from typing import ( + Any, + Dict, + Hashable, + Iterator, + List, + Mapping, + NamedTuple, + Optional, + OrderedDict, + Sequence, + Tuple, + Union, +) + +import numpy as np +import xarray as xr +from numpy.typing import NDArray +from pydantic import HttpUrl +from tqdm import tqdm + +from bioimageio.core.tensor import Tensor, TensorId +from bioimageio.spec import ResourceDescr, load_description +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import AxisType + +from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline # def _predict_with_tiling_impl( # prediction_pipeline: PredictionPipeline, @@ -138,27 +84,38 @@ # output[inner_tile] = out[local_tile] -# def predict( -# prediction_pipeline: PredictionPipeline, -# inputs: Union[ -# xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray], NDArray[Any], List[NDArray[Any]], Tuple[NDArray[Any]] -# ], -# ) -> List[xr.DataArray]: -# """Run prediction for a single set of input(s) with a bioimage.io model - -# Args: -# prediction_pipeline: the prediction pipeline for the input model. -# inputs: the input(s) for this model represented as xarray data or numpy nd array. -# """ -# if not isinstance(inputs, (tuple, list)): -# inputs = [inputs] - -# assert len(inputs) == len(prediction_pipeline.input_specs) -# tagged_data = [ -# ipt if isinstance(ipt, xr.DataArray) else xr.DataArray(ipt, dims=ipt_spec.axes) -# for ipt, ipt_spec in zip(inputs, prediction_pipeline.input_specs) -# ] -# return prediction_pipeline.forward(*tagged_data) +def predict( + prediction_pipeline: PredictionPipeline, + inputs: Union[ + Tensor, + NDArray[Any], + Sequence[Union[Tensor, NDArray[Any]]], + Mapping[Union[TensorId, str], Union[Tensor, NDArray[Any]]], + ], +) -> List[xr.DataArray]: + """Run prediction for a single set of input(s) with a bioimage.io model + + Args: + prediction_pipeline: the prediction pipeline for the input model. + inputs: the input(s) for this model represented as xarray data or numpy nd array. + """ + if isinstance(inputs, collections.abc.Mapping): + inputs_seq = [ + inputs.get(str(tid), inputs[tid]) for tid in prediction_pipeline.input_ids + ] + else: + if isinstance(inputs, (Tensor, np.ndarray)): + inputs_seq = [inputs] + else: + inputs_seq = inputs + + assert len(inputs_seq) == len(prediction_pipeline.input_ids) + + tagged_data = [ + ipt if isinstance(ipt, Tensor) else Tensor.from_numpy(ipt, dims=ipt_spec.axes) + for ipt, ipt_spec in zip(inputs, prediction_pipeline.input_axes) + ] + return prediction_pipeline.forward(*tagged_data) # def _parse_padding(padding, input_specs): diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py index 4037be8a..bbfd90c4 100644 --- a/bioimageio/core/utils/__init__.py +++ b/bioimageio/core/utils/__init__.py @@ -2,6 +2,7 @@ import sys from pathlib import Path +from ._digest_spec import get_sample_axes as get_sample_axes from ._digest_spec import get_test_inputs as get_test_inputs from ._digest_spec import get_test_outputs as get_test_outputs from ._import_callable import import_callable as import_callable diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py index 3fe41c02..42ddd15e 100644 --- a/bioimageio/core/utils/_digest_spec.py +++ b/bioimageio/core/utils/_digest_spec.py @@ -1,21 +1,37 @@ -from typing import List, Sequence, get_args +from typing import List, Sequence, Union -from bioimageio.core.axis import AxisLetter, AxisLike -from bioimageio.spec.model import AnyModelDescr, v0_4 +from bioimageio.core.axis import AxisInfo +from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.utils import load_array from ..tensor import Tensor, TensorId -def get_test_inputs(model: AnyModelDescr) -> List[Tensor]: - axes = [d.axes for d in model.inputs] - if isinstance(axes, str): - core_axes: List[Sequence[AxisLike]] = [ - a if a in get_args(AxisLetter) else "i" for a in axes - ] # pyright: ignore[reportAssignmentType] - else: - core_axes = axes # pyright: ignore[reportAssignmentType] +def get_sample_axes( + io_descr: Sequence[ + Union[ + v0_4.InputTensorDescr, + v0_4.OutputTensorDescr, + v0_5.InputTensorDescr, + v0_5.OutputTensorDescr, + ] + ] +): + return [ + [ + ( + AxisInfo.create("i") + if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x") + else AxisInfo.create(a) + ) + for a in d.axes + ] + for d in io_descr + ] + +def get_test_inputs(model: AnyModelDescr) -> List[Tensor]: + axes = get_sample_axes(model.inputs) if isinstance(model, v0_4.ModelDescr): arrays = [load_array(tt) for tt in model.test_inputs] else: @@ -28,19 +44,12 @@ def get_test_inputs(model: AnyModelDescr) -> List[Tensor]: return [ Tensor.from_numpy(arr, dims=ax, id=t) - for arr, ax, t in zip(arrays, core_axes, tensor_ids) + for arr, ax, t in zip(arrays, axes, tensor_ids) ] def get_test_outputs(model: AnyModelDescr) -> List[Tensor]: - axes = [d.axes for d in model.outputs] - if isinstance(axes, str): - core_axes: List[Sequence[AxisLike]] = [ - a if a in get_args(AxisLetter) else "i" for a in axes - ] # pyright: ignore[reportAssignmentType] - else: - core_axes = axes # pyright: ignore[reportAssignmentType] - + axes = get_sample_axes(model.outputs) if isinstance(model, v0_4.ModelDescr): arrays = [load_array(tt) for tt in model.test_outputs] else: @@ -53,5 +62,5 @@ def get_test_outputs(model: AnyModelDescr) -> List[Tensor]: return [ Tensor.from_numpy(arr, dims=ax, id=t) - for arr, ax, t in zip(arrays, core_axes, tensor_ids) + for arr, ax, t in zip(arrays, axes, tensor_ids) ] From c63081722e7b6479db0bd9772d1d7ff350b597ad Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sun, 31 Mar 2024 14:58:49 +0200 Subject: [PATCH 176/244] simplify PredictionPipeline --- bioimageio/core/_prediction_pipeline.py | 125 +++++++++++------------- bioimageio/core/prediction.py | 9 +- bioimageio/core/tensor.py | 30 +++--- bioimageio/core/utils/__init__.py | 2 +- bioimageio/core/utils/_digest_spec.py | 77 ++++++++------- 5 files changed, 118 insertions(+), 125 deletions(-) diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 5c5f648d..80a89877 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -1,7 +1,10 @@ +import collections import warnings +from dataclasses import dataclass from types import MappingProxyType from typing import Any, Iterable, List, Mapping, Optional, Sequence, Union +from bioimageio.core.axis import AxisInfo from bioimageio.spec.model import AnyModelDescr, v0_4 from bioimageio.spec.model.v0_5 import WeightsFormat @@ -11,8 +14,15 @@ from .proc_setup import setup_pre_and_postprocessing from .sample import Sample from .stat_measures import DatasetMeasure, MeasureValue -from .tensor import PerTensor, Tensor, TensorId -from .utils import get_sample_axes +from .tensor import Tensor, TensorId +from .utils import get_axes_infos + + +@dataclass +class CoreTensorDescr: + id: TensorId + axes: Sequence[AxisInfo] + optional: bool class PredictionPipeline: @@ -39,22 +49,42 @@ def __init__( self.name = name self._preprocessing = preprocessing self._postprocessing = postprocessing - if isinstance(bioimageio_model, v0_4.ModelDescr): - self.input_ids = [TensorId(str(d.name)) for d in bioimageio_model.inputs] - self.output_ids = [TensorId(str(d.name)) for d in bioimageio_model.outputs] - else: - self.input_ids = [d.id for d in bioimageio_model.inputs] - self.output_ids = [d.id for d in bioimageio_model.outputs] - self.input_axes = get_sample_axes(bioimageio_model.inputs) - self.output_axes = get_sample_axes(bioimageio_model.outputs) + self.input_ids = tuple( + (TensorId(str(t.name)) if isinstance(t, v0_4.InputTensorDescr) else t.id) + for t in bioimageio_model.inputs + ) + self.inputs = collections.OrderedDict( + ( + tid, + CoreTensorDescr( + id=tid, + axes=get_axes_infos(t), + optional=not isinstance(t, v0_4.InputTensorDescr) and t.optional, + ), + ) + for tid, t in zip(self.input_ids, bioimageio_model.inputs) + ) + self.output_ids = tuple( + (TensorId(str(t.name)) if isinstance(t, v0_4.OutputTensorDescr) else t.id) + for t in bioimageio_model.outputs + ) + self.outputs = collections.OrderedDict( + ( + tid, + CoreTensorDescr( + id=tid, + axes=get_axes_infos(t), + optional=False, + ), + ) + for tid, t in zip(self.output_ids, bioimageio_model.outputs) + ) self._adapter: ModelAdapter = model - def __call__( - self, *input_tensors: Optional[Tensor], **named_input_tensors: Optional[Tensor] - ) -> List[Optional[Tensor]]: - return self.forward(*input_tensors, **named_input_tensors) + def __call__(self, sample: Sample) -> Sample: + return self.predict(sample) def __enter__(self): self.load() @@ -64,14 +94,21 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore self.unload() return False - def predict( - self, *input_tensors: Optional[Tensor], **named_input_tensors: Optional[Tensor] - ) -> List[Optional[Tensor]]: - """Predict input_tensor with the model without applying pre/postprocessing.""" - named_tensors = [ - named_input_tensors.get(k) for k in self.input_ids[len(input_tensors) :] - ] - return self._adapter.forward(*input_tensors, *named_tensors) + def predict(self, sample: Sample) -> Sample: + """Run model prediction **including** pre/postprocessing.""" + self.apply_preprocessing(sample) + output = Sample( + data={ + tid: out + for tid, out in zip( + self.output_ids, + self._adapter.forward(*(sample.data[t] for t in self.input_ids)), + ) + if out is not None + } + ) + self.apply_postprocessing(output) + return output def apply_preprocessing(self, sample: Sample) -> None: """apply preprocessing in-place, also updates sample stats""" @@ -83,50 +120,6 @@ def apply_postprocessing(self, sample: Sample) -> None: for op in self._postprocessing: op(sample) - def forward_sample(self, input_sample: Sample) -> Sample: - """Apply preprocessing, run prediction and apply postprocessing.""" - self.apply_preprocessing(input_sample) - - prediction_tensors = self.predict( - **{str(k): v for k, v in input_sample.data.items()} - ) - prediction = Sample( - data={ - tid: t - for tid, t in zip(self.output_ids, prediction_tensors) - if t is not None - }, - stat=input_sample.stat, - ) - self.apply_postprocessing(prediction) - return prediction - - def forward_tensors( - self, *input_tensors: Optional[Tensor], **named_input_tensors: Optional[Tensor] - ) -> PerTensor[Tensor]: - """Apply preprocessing, run prediction and apply postprocessing.""" - assert all(TensorId(k) in self.input_ids for k in named_input_tensors) - input_sample = Sample( - data={ - **{ - k: v for k, v in zip(self.input_ids, input_tensors) if v is not None - }, - **{ - TensorId(k): v - for k, v in named_input_tensors.items() - if v is not None - }, - } - ) - return self.forward_sample(input_sample).data - - def forward( - self, *input_tensors: Optional[Tensor], **named_input_tensors: Optional[Tensor] - ) -> List[Optional[Tensor]]: - """Apply preprocessing, run prediction and apply postprocessing.""" - named_outputs = self.forward_tensors(*input_tensors, **named_input_tensors) - return [named_outputs.get(x) for x in self.output_ids] - def load(self): """ optional step: load model onto devices before calling forward if not using it as context manager diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index d6f0ec30..eda54ff2 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import ( Any, + Collection, Dict, Hashable, Iterator, @@ -27,6 +28,8 @@ from pydantic import HttpUrl from tqdm import tqdm +from bioimageio.core.axis import AxisInfo +from bioimageio.core.sample import Sample from bioimageio.core.tensor import Tensor, TensorId from bioimageio.spec import ResourceDescr, load_description from bioimageio.spec.model import v0_4, v0_5 @@ -112,8 +115,10 @@ def predict( assert len(inputs_seq) == len(prediction_pipeline.input_ids) tagged_data = [ - ipt if isinstance(ipt, Tensor) else Tensor.from_numpy(ipt, dims=ipt_spec.axes) - for ipt, ipt_spec in zip(inputs, prediction_pipeline.input_axes) + ipt if isinstance(ipt, Tensor) else Tensor.from_numpy(ipt, dims=axes, id=tid) + for ipt, axes, tid in zip( + inputs_seq, prediction_pipeline.input_axes, prediction_pipeline.input_ids + ) ] return prediction_pipeline.forward(*tagged_data) diff --git a/bioimageio/core/tensor.py b/bioimageio/core/tensor.py index 39162413..dbd0cc91 100644 --- a/bioimageio/core/tensor.py +++ b/bioimageio/core/tensor.py @@ -66,12 +66,15 @@ class Tensor(MagicTensorOpsMixin): def __init__( self, array: NDArray[Any], - dims: Union[AxisId, Sequence[AxisId]], - id: Optional[TensorId] = None, + dims: Sequence[AxisId], ) -> None: super().__init__() - self._data = xr.DataArray(array, dims=dims, name=id) - self._id = id + if any(not isinstance(d, AxisId) for d in dims): + raise TypeError( + f"Expected sequence of `AxisId`, but got {list(map(type, dims))}" + ) + + self._data = xr.DataArray(array, dims=dims) def __array__(self, dtype: DTypeLike = None): return np.asarray(self._data, dtype=dtype) @@ -142,9 +145,7 @@ def from_xarray(cls, data_array: xr.DataArray) -> Self: for any `Tensor`'s `data` property (an xarray.DataArray). """ return cls( - array=data_array.data, - dims=tuple(AxisId(d) for d in data_array.dims), - id=None if data_array.name is None else TensorId(data_array.name), + array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) ) @classmethod @@ -153,7 +154,6 @@ def from_numpy( array: NDArray[Any], *, dims: Optional[Union[AxisLike, Sequence[AxisLike]]], - id: TensorId, ) -> Tensor: """create a `Tensor` from a numpy array @@ -161,14 +161,13 @@ def from_numpy( array: the nd numpy array axes: A description of the array's axes, if None axes are guessed (which might fail and raise a ValueError.) - id: the created tensor's identifier Raises: ValueError: if `axes` is None and axes guessing fails. """ if dims is None: - return cls._interprete_array_wo_known_axes(array, id=id) + return cls._interprete_array_wo_known_axes(array) elif isinstance(dims, (str, Axis, v0_5.AxisBase)): dims = [dims] @@ -196,7 +195,7 @@ def from_numpy( f"Array shape {original_shape} does not map to axes {dims}" ) - return Tensor(array, dims=tuple(a.id for a in axis_infos), id=id) + return Tensor(array, dims=tuple(a.id for a in axis_infos)) @property def data(self): @@ -235,11 +234,6 @@ def dtype(self) -> DTypeStr: assert dt in get_args(DTypeStr) return dt # pyright: ignore[reportReturnType] - @property - def id(self): - """the tensor's identifier""" - return self._id - @property def sizes(self): """Ordered, immutable mapping from axis ids to lengths.""" @@ -482,7 +476,7 @@ def transpose( return self.__class__.from_xarray(array.transpose(*axes)) @classmethod - def _interprete_array_wo_known_axes(cls, array: NDArray[Any], id: TensorId): + def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): ndim = array.ndim if ndim == 2: current_axes = ( @@ -531,7 +525,7 @@ def _interprete_array_wo_known_axes(cls, array: NDArray[Any], id: TensorId): else: raise ValueError(f"Could not guess an axis mapping for {array.shape}") - return cls(array, dims=tuple(a.id for a in current_axes), id=id) + return cls(array, dims=tuple(a.id for a in current_axes)) def _tile_generator( self, diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py index bbfd90c4..ddc519f7 100644 --- a/bioimageio/core/utils/__init__.py +++ b/bioimageio/core/utils/__init__.py @@ -2,7 +2,7 @@ import sys from pathlib import Path -from ._digest_spec import get_sample_axes as get_sample_axes +from ._digest_spec import get_axes_infos as get_axes_infos from ._digest_spec import get_test_inputs as get_test_inputs from ._digest_spec import get_test_outputs as get_test_outputs from ._import_callable import import_callable as import_callable diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py index 42ddd15e..7480a255 100644 --- a/bioimageio/core/utils/_digest_spec.py +++ b/bioimageio/core/utils/_digest_spec.py @@ -1,66 +1,67 @@ -from typing import List, Sequence, Union +from typing import Union from bioimageio.core.axis import AxisInfo +from bioimageio.core.sample import Sample from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.utils import load_array from ..tensor import Tensor, TensorId -def get_sample_axes( - io_descr: Sequence[ - Union[ - v0_4.InputTensorDescr, - v0_4.OutputTensorDescr, - v0_5.InputTensorDescr, - v0_5.OutputTensorDescr, - ] +def get_axes_infos( + io_descr: Union[ + v0_4.InputTensorDescr, + v0_4.OutputTensorDescr, + v0_5.InputTensorDescr, + v0_5.OutputTensorDescr, ] ): return [ - [ - ( - AxisInfo.create("i") - if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x") - else AxisInfo.create(a) - ) - for a in d.axes - ] - for d in io_descr + ( + AxisInfo.create("i") + if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x") + else AxisInfo.create(a) + ) + for a in io_descr.axes ] -def get_test_inputs(model: AnyModelDescr) -> List[Tensor]: - axes = get_sample_axes(model.inputs) +def get_test_inputs(model: AnyModelDescr) -> Sample: if isinstance(model, v0_4.ModelDescr): - arrays = [load_array(tt) for tt in model.test_inputs] + tensor_ids = [TensorId(t.name) for t in model.inputs] else: - arrays = [load_array(d.test_tensor) for d in model.inputs] + tensor_ids = [t.id for t in model.inputs] if isinstance(model, v0_4.ModelDescr): - tensor_ids = [TensorId(ipt.name) for ipt in model.inputs] + arrays = [load_array(tt) for tt in model.test_inputs] else: - tensor_ids = [ipt.id for ipt in model.inputs] + arrays = [load_array(d.test_tensor) for d in model.inputs] - return [ - Tensor.from_numpy(arr, dims=ax, id=t) - for arr, ax, t in zip(arrays, axes, tensor_ids) - ] + axes = [get_axes_infos(t) for t in model.inputs] + return Sample( + data={ + tid: Tensor.from_numpy(arr, dims=ax) + for tid, arr, ax in zip(tensor_ids, arrays, axes) + } + ) -def get_test_outputs(model: AnyModelDescr) -> List[Tensor]: - axes = get_sample_axes(model.outputs) +def get_test_outputs(model: AnyModelDescr) -> Sample: if isinstance(model, v0_4.ModelDescr): - arrays = [load_array(tt) for tt in model.test_outputs] + tensor_ids = [TensorId(t.name) for t in model.outputs] else: - arrays = [load_array(d.test_tensor) for d in model.outputs] + tensor_ids = [t.id for t in model.outputs] if isinstance(model, v0_4.ModelDescr): - tensor_ids = [TensorId(ipt.name) for ipt in model.inputs] + arrays = [load_array(tt) for tt in model.test_outputs] else: - tensor_ids = [ipt.id for ipt in model.inputs] + arrays = [load_array(d.test_tensor) for d in model.outputs] - return [ - Tensor.from_numpy(arr, dims=ax, id=t) - for arr, ax, t in zip(arrays, axes, tensor_ids) - ] + axes = [get_axes_infos(t) for t in model.outputs] + + return Sample( + data={ + tid: Tensor.from_numpy(arr, dims=ax) + for tid, arr, ax in zip(tensor_ids, arrays, axes) + } + ) From 946fa5025b0149c5745aca3ebfda101e2b406708 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 3 Apr 2024 01:10:09 +0200 Subject: [PATCH 177/244] WIP block --- bioimageio/core/__init__.py | 2 +- bioimageio/core/_op_base.py | 4 +- bioimageio/core/_prediction_pipeline.py | 115 ++++++++++---- bioimageio/core/block.py | 193 ++++++++++++++++++++++++ bioimageio/core/common.py | 4 +- bioimageio/core/prediction.py | 34 +---- bioimageio/core/proc_ops.py | 8 +- bioimageio/core/proc_setup.py | 4 +- bioimageio/core/sample.py | 101 +++++++++++-- bioimageio/core/stat_calculators.py | 46 +++--- bioimageio/core/tensor.py | 92 ++--------- bioimageio/core/tensor_block.py | 106 +++++++++++++ bioimageio/core/tile.py | 113 -------------- bioimageio/core/utils/_digest_spec.py | 44 +++++- tests/test_proc_ops.py | 32 ++-- tests/test_stat_calculators.py | 5 +- tests/test_stat_measures.py | 6 +- 17 files changed, 589 insertions(+), 320 deletions(-) create mode 100644 bioimageio/core/block.py create mode 100644 bioimageio/core/tensor_block.py delete mode 100644 bioimageio/core/tile.py diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index c3cb1db6..de4ffe29 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -25,7 +25,7 @@ from ._settings import settings as settings from .axis import Axis as Axis from .axis import AxisId as AxisId -from .sample import Sample as Sample +from .sample import UntiledSample as UntiledSample from .tensor import Tensor as Tensor from .tensor import TensorId as TensorId from .tile import Tile as Tile diff --git a/bioimageio/core/_op_base.py b/bioimageio/core/_op_base.py index a0ca7ae1..e3e420d1 100644 --- a/bioimageio/core/_op_base.py +++ b/bioimageio/core/_op_base.py @@ -2,14 +2,14 @@ from dataclasses import dataclass from typing import Collection -from bioimageio.core.sample import Sample +from bioimageio.core.sample import UntiledSample from bioimageio.core.stat_measures import Measure @dataclass class Operator(ABC): @abstractmethod - def __call__(self, sample: Sample) -> None: ... + def __call__(self, sample: UntiledSample) -> None: ... @property @abstractmethod diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 80a89877..a95e0932 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -1,20 +1,33 @@ -import collections +import collections.abc import warnings from dataclasses import dataclass from types import MappingProxyType -from typing import Any, Iterable, List, Mapping, Optional, Sequence, Union +from typing import ( + Any, + Iterable, + List, + Mapping, + Optional, + Sequence, + TypeVar, + Union, +) + +from numpy.typing import NDArray +from typing_extensions import assert_never -from bioimageio.core.axis import AxisInfo from bioimageio.spec.model import AnyModelDescr, v0_4 from bioimageio.spec.model.v0_5 import WeightsFormat +from .axis import AxisInfo from .model_adapters import ModelAdapter, create_model_adapter from .model_adapters import get_weight_formats as get_weight_formats from .proc_ops import Processing from .proc_setup import setup_pre_and_postprocessing -from .sample import Sample +from .sample import TiledSample, UntiledSample from .stat_measures import DatasetMeasure, MeasureValue from .tensor import Tensor, TensorId +from .tile import Tile from .utils import get_axes_infos @@ -25,6 +38,19 @@ class CoreTensorDescr: optional: bool +Data = TypeVar( + "Data", + TiledSample, + UntiledSample, + Tile, + Iterable[TiledSample], + Iterable[UntiledSample], + NDArray[Any], + Sequence[Optional[NDArray[Any]]], + Mapping[Union[TensorId, str], Optional[NDArray[Any]]], +) + + class PredictionPipeline: """ Represents model computation including preprocessing and postprocessing @@ -83,8 +109,8 @@ def __init__( self._adapter: ModelAdapter = model - def __call__(self, sample: Sample) -> Sample: - return self.predict(sample) + def __call__(self, data: Data) -> Data: + return self.predict(data) def __enter__(self): self.load() @@ -94,31 +120,66 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore self.unload() return False - def predict(self, sample: Sample) -> Sample: + def predict(self, inputs: Data) -> Data: """Run model prediction **including** pre/postprocessing.""" - self.apply_preprocessing(sample) - output = Sample( - data={ - tid: out - for tid, out in zip( - self.output_ids, - self._adapter.forward(*(sample.data[t] for t in self.input_ids)), - ) - if out is not None - } - ) - self.apply_postprocessing(output) + + if isinstance(inputs, Tile): + self.apply_preprocessing(inputs) + output_tile = Tile( + data={ + tid: out + for tid, out in zip( + self.output_ids, + self._adapter.forward( + *(inputs.data[t] for t in self.input_ids) + ), + ) + if out is not None + } + ) + self.apply_postprocessing(output_tile) + return output_tile + + else: + assert_never(inputs) + return output - def apply_preprocessing(self, sample: Sample) -> None: + # if isinstance(inputs, collections.abc.Mapping): + # data = { + # tid: d + # for tid in self.input_ids + # if (d := inputs.get(tid, inputs.get(str(tid)))) is not None + # } + # else: + # if isinstance(inputs, (Tensor, np.ndarray)): + # inputs_seq = [inputs] + # else: + # inputs_seq = inputs + + # assert len(inputs_seq) == len(self.input_ids) + # data = { + # tid: d for tid, d in zip(self.input_ids, inputs_seq) if d is not None + # } + + # sample = UntiledSample( + # data={ + # tid: Tensor.from_numpy(d, dims=self.inputs[tid].axes) + # for tid, d in data.items() + # } + # ) + # output = self.predict(sample) + # return {tid: out.data.data for } + + def apply_preprocessing(self, tile: Tile) -> None: """apply preprocessing in-place, also updates sample stats""" for op in self._preprocessing: - op(sample) + op(tile) - def apply_postprocessing(self, sample: Sample) -> None: + def apply_postprocessing(self, tile: Tile) -> None: """apply postprocessing in-place, also updates samples stats""" for op in self._postprocessing: - op(sample) + op(tile) def load(self): """ @@ -139,7 +200,9 @@ def create_prediction_pipeline( devices: Optional[Sequence[str]] = None, weight_format: Optional[WeightsFormat] = None, weights_format: Optional[WeightsFormat] = None, - dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(), + dataset_for_initial_statistics: Iterable[ + Union[UntiledSample, Sequence[Tensor]] + ] = tuple(), keep_updating_initial_dataset_statistics: bool = False, fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType( {} @@ -175,10 +238,10 @@ def create_prediction_pipeline( def dataset(): for x in dataset_for_initial_statistics: - if isinstance(x, Sample): + if isinstance(x, UntiledSample): yield x else: - yield Sample(data=dict(zip(input_ids, x))) + yield UntiledSample(data=dict(zip(input_ids, x))) preprocessing, postprocessing = setup_pre_and_postprocessing( bioimageio_model, diff --git a/bioimageio/core/block.py b/bioimageio/core/block.py new file mode 100644 index 00000000..b678d82a --- /dev/null +++ b/bioimageio/core/block.py @@ -0,0 +1,193 @@ +import itertools +from dataclasses import dataclass, field +from math import prod +from typing import Any, Dict, Generator, List, Optional, Tuple + +from .axis import AxisId, PerAxis +from .common import ( + BlockNumber, + Halo, + HaloLike, + PadWidth, + SliceInfo, + TotalNumberOfBlocks, +) + + +@dataclass +class Block: + """Block of a sample + + Figure for illustration: + The first 2d block (dashed) of a sample (**bold**). + The inner slice (thin) is expanded by a halo in both dimensions on both sides. + The outer slice reaches from the sample origin (0, 0) to the right halo point. + + ```terminal + ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ + ╷ halo(left) ╷ + ╷ ╷ + ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔ + ╷ ┃ │ ╷ sample + ╷ ┃ inner │ ╷ + ╷ ┃ (and outer) │ outer ╷ + ╷ ┃ slice │ slice ╷ + ╷ ┃ │ ╷ + ╷ ┣─────────────────┘ ╷ + ╷ ┃ outer slice ╷ + ╷ ┃ halo(right) ╷ + └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ + ⬇ + ``` + + note: + - Inner and outer slices are specified in sample coordinates. + - The outer_slice of a block at the sample edge may overlap by more than the + halo with the neighboring block (the inner slices will not overlap though). + + """ + + sample_shape: PerAxis[int] + """the axis sizes of the whole (unblocked) sample""" + + inner_slice: PerAxis[SliceInfo] + """inner region (without halo) wrt the sample""" + + halo: PerAxis[Halo] + """halo enlarging the inner region to the block's sizes""" + + block_number: BlockNumber + """the n-th block of the sample""" + + blocks_in_sample: TotalNumberOfBlocks + """total number of blocks in the sample""" + + shape: PerAxis[int] = field(init=False) + """axis lengths of the block""" + + padding: PerAxis[PadWidth] = field(init=False) + """padding to realize the halo at the sample edge + where we cannot simply enlarge the inner slice""" + + outer_slice: PerAxis[SliceInfo] = field(init=False) + """slice of the outer block (without padding) wrt the sample""" + + inner_shape: PerAxis[int] = field(init=False) + """axis lengths of the inner region (without halo)""" + + local_slice: PerAxis[SliceInfo] = field(init=False) + """inner slice wrt the block, **not** the sample""" + + def __post_init__(self): + assert all( + a in self.sample_shape for a in self.inner_slice + ), "block has axes not present in sample" + assert all( + a in self.inner_slice for a in self.halo + ), "halo has axes not present in block" + + self.shape = { + a: s.stop - s.start + sum(self.halo[a]) for a, s in self.inner_slice.items() + } + assert all( + s <= self.sample_shape[a] for a, s in self.shape.items() + ), "block larger than sample" + + self.inner_shape = {a: s.stop - s.start for a, s in self.inner_slice.items()} + self.outer_slice = { + a: SliceInfo( + max( + 0, + min( + self.inner_slice[a].start - self.halo[a].left, + self.sample_shape[a] - self.inner_shape[a] - self.halo[a].left, + ), + ), + min( + self.sample_shape[a], + self.inner_slice[a].stop + self.halo[a].right, + ), + ) + for a in self.inner_slice + } + self.padding = { + a: PadWidth( + max( + 0, + self.halo[a].left + - (self.inner_slice[a].start + self.outer_slice[a].start), + ), + max( + 0, + self.halo[a].right + - (self.outer_slice[a].stop + self.inner_slice[a].stop), + ), + ) + for a in self.inner_slice + } + self.local_slice = { + a: SliceInfo( + self.padding[a].left, + self.padding[a].left + self.inner_shape[a], + ) + for a in self.inner_slice + } + + +def split_shape_into_blocks( + shape: PerAxis[int], + block_shape: PerAxis[int], + halo: PerAxis[HaloLike], + stride: Optional[PerAxis[int]] = None, +) -> Tuple[TotalNumberOfBlocks, Generator[Block, Any, None]]: + assert all(a in shape for a in block_shape), ( + tuple(shape), + set(block_shape), + ) + assert all(a in shape for a in halo), (tuple(shape), set(halo)) + + # fill in default halo (0) and tile_size (tensor size) + halo = {a: Halo.create(h) for a, h in halo.items()} + block_shape = {a: block_shape.get(a, s) for a, s in shape.items()} + if stride is None: + stride = {} + + inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {} + for a, s in shape.items(): + inner_size = block_shape[a] - sum(halo[a]) + stride_1d = stride.get(a, inner_size) + inner_1d_slices[a] = [ + SliceInfo(min(p, s - inner_size), min(p + inner_size, s)) + for p in range(0, s, stride_1d) + ] + + n_blocks = prod(map(len, inner_1d_slices.values())) + + return n_blocks, _block_generator( + shape, + blocks_in_sample=n_blocks, + inner_1d_slices=inner_1d_slices, + halo=halo, + ) + + +def _block_generator( + sample_shape: PerAxis[int], + *, + blocks_in_sample: int, + inner_1d_slices: Dict[AxisId, List[SliceInfo]], + halo: PerAxis[HaloLike], +): + assert all(a in sample_shape for a in halo) + + halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices} + for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())): + inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile)) + + yield Block( + sample_shape=sample_shape, + inner_slice=inner_slice, + halo=halo, + block_number=i, + blocks_in_sample=blocks_in_sample, + ) diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 5542e897..faff0de0 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -75,5 +75,5 @@ class SliceInfo(NamedTuple): stop: int -TileNumber = int -TotalNumberOfTiles = int +BlockNumber = int +TotalNumberOfBlocks = int diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index eda54ff2..c3562d82 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -28,14 +28,14 @@ from pydantic import HttpUrl from tqdm import tqdm -from bioimageio.core.axis import AxisInfo -from bioimageio.core.sample import Sample -from bioimageio.core.tensor import Tensor, TensorId from bioimageio.spec import ResourceDescr, load_description from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import AxisType from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline +from .axis import AxisInfo +from .sample import UntiledSample +from .tensor import Tensor, TensorId # def _predict_with_tiling_impl( # prediction_pipeline: PredictionPipeline, @@ -87,39 +87,15 @@ # output[inner_tile] = out[local_tile] -def predict( +def predict_numpy( prediction_pipeline: PredictionPipeline, - inputs: Union[ - Tensor, - NDArray[Any], - Sequence[Union[Tensor, NDArray[Any]]], - Mapping[Union[TensorId, str], Union[Tensor, NDArray[Any]]], - ], -) -> List[xr.DataArray]: + """Run prediction for a single set of input(s) with a bioimage.io model Args: prediction_pipeline: the prediction pipeline for the input model. inputs: the input(s) for this model represented as xarray data or numpy nd array. """ - if isinstance(inputs, collections.abc.Mapping): - inputs_seq = [ - inputs.get(str(tid), inputs[tid]) for tid in prediction_pipeline.input_ids - ] - else: - if isinstance(inputs, (Tensor, np.ndarray)): - inputs_seq = [inputs] - else: - inputs_seq = inputs - - assert len(inputs_seq) == len(prediction_pipeline.input_ids) - - tagged_data = [ - ipt if isinstance(ipt, Tensor) else Tensor.from_numpy(ipt, dims=axes, id=tid) - for ipt, axes, tid in zip( - inputs_seq, prediction_pipeline.input_axes, prediction_pipeline.input_ids - ) - ] return prediction_pipeline.forward(*tagged_data) diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 984d53e8..0dc9ee20 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -21,7 +21,7 @@ from ._op_base import Operator from .axis import AxisId -from .sample import Sample +from .sample import UntiledSample from .stat_calculators import StatsCalculator from .stat_measures import ( DatasetMean, @@ -76,7 +76,7 @@ def required_measures(self) -> Collection[Measure]: # def produced_tensors(self) -> Set[TensorId]: # return {self.output} - def __call__(self, sample: Sample) -> None: + def __call__(self, sample: UntiledSample) -> None: sample.data[self.output] = self._apply(sample.data[self.input], sample.stat) @abstractmethod @@ -91,7 +91,7 @@ class AddKnownDatasetStats(Operator): def required_measures(self) -> Set[Measure]: return set() - def __call__(self, sample: Sample) -> None: + def __call__(self, sample: UntiledSample) -> None: sample.stat.update(self.dataset_stats.items()) @@ -154,7 +154,7 @@ def __post_init__(self): or not self.stats_calculator.has_dataset_measures ) - def __call__(self, sample: Sample) -> None: + def __call__(self, sample: UntiledSample) -> None: if self._keep_updating_dataset_stats: sample.stat.update(self.stats_calculator.update_and_get_all(sample)) else: diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index 8202bc97..0979773f 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -17,7 +17,7 @@ UpdateStats, get_proc_class, ) -from bioimageio.core.sample import Sample +from bioimageio.core.sample import UntiledSample from bioimageio.core.stat_calculators import StatsCalculator from bioimageio.core.stat_measures import DatasetMeasure, Measure, MeasureValue from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 @@ -45,7 +45,7 @@ class _SetupProcessing(NamedTuple): def setup_pre_and_postprocessing( model: AnyModelDescr, - dataset_for_initial_statistics: Iterable[Sample], + dataset_for_initial_statistics: Iterable[UntiledSample], keep_updating_initial_dataset_stats: bool = False, fixed_dataset_stats: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType({}), ) -> PreAndPostprocessing: diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index aed8b633..e3d003e2 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -1,19 +1,92 @@ from dataclasses import dataclass, field from pprint import pformat -from typing import Dict, Iterable, Iterator, Optional, Tuple, cast +from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union, cast -import numpy +import numpy as np import xarray as xr from typing_extensions import Self +from bioimageio.core.tensor_block import TensorBlock + from .axis import AxisId, PerAxis -from .common import Halo, HaloLike, PadMode, SliceInfo, TileNumber +from .block import Block, BlockNumber, TotalNumberOfBlocks, split_shape_into_blocks +from .common import BlockNumber, Halo, HaloLike, PadMode, SliceInfo from .stat_measures import Stat from .tensor import PerTensor, Tensor, TensorId -from .tile import Tile -TiledSample = Iterable[Tile] -"""A dataset sample split into tiles""" + +def split_multiple_shapes_into_blocks( + shapes: PerTensor[PerAxis[int]], + block_shapes: PerTensor[PerAxis[int]], + *, + strides: Optional[PerTensor[PerAxis[int]]] = None, + halo: PerTensor[PerAxis[HaloLike]], + pad_mode: PadMode, + broadcast: bool = False, +) -> Tuple[TotalNumberOfBlocks, Iterable[PerTensor[Block]]]: + assert not ( + missing := [t for t in block_shapes if t not in shapes] + ), f"block shape specified for unknown tensors: {missing}" + assert broadcast or not ( + missing := [t for t in shapes if t not in block_shapes] + ), f"no block shape specified for {missing} (set `broadcast` to True if these tensors should be repeated for each block)" + assert not ( + missing := [t for t in halo if t not in block_shapes] + ), f"`halo` specified for tensors without block shape: {missing}" + + if strides is None: + strides = {} + + assert not ( + missing := [t for t in strides if t not in block_shapes] + ), f"`stride` specified for tensors without block shape: {missing}" + + blocks: Dict[TensorId, Iterable[Block]] = {} + n_blocks: Dict[TensorId, TotalNumberOfBlocks] = {} + for t in block_shapes: + n_blocks[t], blocks[t] = split_shape_into_blocks( + shape=shapes[t], + block_shape=block_shapes[t], + halo=halo.get(t, {}), + stride=strides.get(t), + ) + assert n_blocks[t] > 0 + + unique_n_blocks = set(n_blocks.values()) + n = max(unique_n_blocks) + if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: + if not broadcast: + raise ValueError( + f"Mismatch for total number of blocks due to unsplit (single block) tensors: {n_blocks}." + + " Set `broadcast` to True if you want to repeat unsplit (single block) tensors." + ) + + blocks = { + t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen + for t, block_gen in blocks.items() + } + elif len(unique_n_blocks) != 1: + raise ValueError(f"Mismatch for total number of blocks: {n_blocks}") + + return n, _aligned_blocks_generator(n, blocks) + + +def _aligned_blocks_generator( + n: TotalNumberOfBlocks, blocks: Dict[TensorId, Iterable[Block]] +): + iterators = {t: iter(gen) for t, gen in blocks.items()} + for _ in range(n): + yield {t: next(it) for t, it in iterators.items()} + + +def _repeat_single_block(block_generator: Iterable[Block], n: TotalNumberOfBlocks): + round_two = False + for block in block_generator: + assert not round_two + for _ in range(n): + yield block + + round_two = True @dataclass @@ -30,7 +103,7 @@ class Sample: def sizes(self) -> PerTensor[PerAxis[int]]: return {tid: t.sizes for tid, t in self.data.items()} - def tile( + def split_into_blocks( self, tile_sizes: PerTensor[PerAxis[int]], halo: PerTensor[PerAxis[HaloLike]], @@ -59,15 +132,15 @@ def tile( } tile_iterators: Dict[ - TensorId, Iterator[Tuple[TileNumber, Tensor, PerAxis[SliceInfo]]] + TensorId, Iterator[Tuple[BlockNumber, Tensor, PerAxis[SliceInfo]]] ] = {} n_tiles_common = 1 last_non_trivial: Optional[TensorId] = None for t in tensor_ids: - n_tiles, generator = broadcasted_tensors[t].tile( - tile_size=explicit_tile_sizes[t], - halo=halo.get(t, {}), + n_tiles, generator = broadcasted_tensors[t].block( + block_size=explicit_tile_sizes[t], + explicit_halo=halo.get(t, {}), pad_mode=pad_mode, ) tile_iterators[t] = iter(generator) @@ -132,13 +205,12 @@ def from_tiles( if t not in data: axes = cast(Tuple[AxisId], tile_data.dims) data[t] = Tensor( - numpy.full( + np.full( tuple(tile.sample_sizes[t][a] for a in axes), fill_value, dtype=tile_data.dtype, ), dims=axes, - id=t, ) data[t][tile.inner_slice[t]] = tile_data @@ -146,3 +218,6 @@ def from_tiles( stat = tile.stat return cls(data=data, stat=stat) + + +Sample = Union[UntiledSample, TiledSample] diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 9319443d..643c697a 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -27,7 +27,7 @@ from typing_extensions import assert_never from .axis import AxisId, PerAxis -from .sample import Sample +from .sample import UntiledSample from .stat_measures import ( DatasetMean, DatasetMeasure, @@ -74,18 +74,20 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): self._sample_mean = SampleMean(tensor_id=self._tensor_id, axes=self._axes) self._dataset_mean = DatasetMean(tensor_id=self._tensor_id, axes=self._axes) - def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: + def compute(self, sample: UntiledSample) -> Dict[SampleMean, MeasureValue]: return {self._sample_mean: self._compute_impl(sample)} - def _compute_impl(self, sample: Sample) -> Tensor: + def _compute_impl(self, sample: UntiledSample) -> Tensor: tensor = sample.data[self._tensor_id].astype("float64", copy=False) return tensor.mean(dim=self._axes) - def update(self, sample: Sample) -> None: + def update(self, sample: UntiledSample) -> None: mean = self._compute_impl(sample) self._update_impl(sample.data[self._tensor_id], mean) - def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: + def compute_and_update( + self, sample: UntiledSample + ) -> Dict[SampleMean, MeasureValue]: mean = self._compute_impl(sample) self._update_impl(sample.data[self._tensor_id], mean) return {self._sample_mean: mean} @@ -126,7 +128,7 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): self._m2: Optional[Tensor] = None def compute( - self, sample: Sample + self, sample: UntiledSample ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: tensor = sample.data[self._tensor_id] mean = tensor.mean(dim=self._axes) @@ -150,7 +152,7 @@ def compute( ), } - def update(self, sample: Sample): + def update(self, sample: UntiledSample): tensor = sample.data[self._tensor_id].astype("float64", copy=False) mean_b = tensor.mean(dim=self._axes) assert mean_b.dtype == "float64" @@ -208,7 +210,7 @@ def __init__( self._axes = None if axes is None else tuple(axes) self._tensor_id = tensor_id - def compute(self, sample: Sample) -> Dict[SamplePercentile, MeasureValue]: + def compute(self, sample: UntiledSample) -> Dict[SamplePercentile, MeasureValue]: tensor = sample.data[self._tensor_id] ps = tensor.quantile(self._qs, dim=self._axes) return { @@ -236,7 +238,7 @@ def __init__( self._n: int = 0 self._estimates: Optional[Tensor] = None - def update(self, sample: Sample): + def update(self, sample: UntiledSample): tensor = sample.data[self._tensor_id] sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( "float64", copy=False @@ -307,7 +309,7 @@ def _initialize(self, tensor_sizes: PerAxis[int]): self._digest = [TDigest() for _ in range(d)] self._indices = product(*map(range, self._shape[1:])) - def update(self, sample: Sample): + def update(self, sample: UntiledSample): tensor = sample.data[self._tensor_id] assert "_percentiles" not in tensor.dims if self._digest is None: @@ -353,7 +355,7 @@ def __init__(self, tensor_id: TensorId, measure: SampleMeasure): self.tensor_name = tensor_id self.measure = measure - def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: + def compute(self, sample: UntiledSample) -> Dict[SampleMeasure, MeasureValue]: return {self.measure: self.measure.compute(sample)} @@ -406,7 +408,7 @@ def __init__( def has_dataset_measures(self): return self._current_dataset_measures is not None - def update(self, sample: Union[Sample, Iterable[Sample]]) -> None: + def update(self, sample: Union[UntiledSample, Iterable[UntiledSample]]) -> None: _ = self._update(sample) def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: @@ -420,7 +422,7 @@ def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: return self._current_dataset_measures def update_and_get_all( - self, sample: Union[Sample, Iterable[Sample]] + self, sample: Union[UntiledSample, Iterable[UntiledSample]] ) -> Dict[Measure, MeasureValue]: """Returns sample as well as updated dataset statistics""" last_sample = self._update(sample) @@ -429,11 +431,13 @@ def update_and_get_all( return {**self._compute(last_sample), **self.finalize()} - def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: + def skip_update_and_get_all( + self, sample: UntiledSample + ) -> Dict[Measure, MeasureValue]: """Returns sample as well as previously computed dataset statistics""" return {**self._compute(sample), **self.finalize()} - def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: + def _compute(self, sample: UntiledSample) -> Dict[SampleMeasure, MeasureValue]: ret: Dict[SampleMeasure, MeasureValue] = {} for calc in self.sample_calculators: values = calc.compute(sample) @@ -441,9 +445,11 @@ def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: return ret - def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: + def _update( + self, sample: Union[UntiledSample, Iterable[UntiledSample]] + ) -> Optional[UntiledSample]: self.sample_count += 1 - samples = [sample] if isinstance(sample, Sample) else sample + samples = [sample] if isinstance(sample, UntiledSample) else sample last_sample = None for s in samples: last_sample = s @@ -546,7 +552,7 @@ def get_measure_calculators( def compute_dataset_measures( - measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] + measures: Iterable[DatasetMeasure], dataset: Iterable[UntiledSample] ) -> Dict[DatasetMeasure, MeasureValue]: """compute all dataset `measures` for the given `dataset`""" sample_calculators, calculators = get_measure_calculators(measures) @@ -565,7 +571,7 @@ def compute_dataset_measures( def compute_sample_measures( - measures: Iterable[SampleMeasure], sample: Sample + measures: Iterable[SampleMeasure], sample: UntiledSample ) -> Dict[SampleMeasure, MeasureValue]: """compute all sample `measures` for the given `sample`""" calculators, dataset_calculators = get_measure_calculators(measures) @@ -579,7 +585,7 @@ def compute_sample_measures( def compute_measures( - measures: Iterable[Measure], dataset: Iterable[Sample] + measures: Iterable[Measure], dataset: Iterable[UntiledSample] ) -> Dict[Measure, MeasureValue]: """compute all `measures` for the given `dataset` sample measures are computed for the last sample in `dataset`""" diff --git a/bioimageio/core/tensor.py b/bioimageio/core/tensor.py index dbd0cc91..910aad01 100644 --- a/bioimageio/core/tensor.py +++ b/bioimageio/core/tensor.py @@ -1,15 +1,11 @@ from __future__ import annotations -import itertools -from math import prod from typing import ( TYPE_CHECKING, Any, Callable, Dict, - Generator, Iterator, - List, Mapping, Optional, Sequence, @@ -26,22 +22,18 @@ from numpy.typing import DTypeLike, NDArray from typing_extensions import Self, assert_never -from bioimageio.core.axis import PerAxis -from bioimageio.core.common import PadMode, PadWhere from bioimageio.spec.model import v0_5 from ._magic_tensor_ops import MagicTensorOpsMixin -from .axis import Axis, AxisId, AxisInfo, AxisLike +from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis from .common import ( CropWhere, DTypeStr, - Halo, - HaloLike, + PadMode, + PadWhere, PadWidth, PadWidthLike, SliceInfo, - TileNumber, - TotalNumberOfTiles, ) if TYPE_CHECKING: @@ -207,9 +199,9 @@ def dims(self): # TODO: rename to `axes`? return cast(Tuple[AxisId, ...], self._data.dims) @property - def shape(self): - """Tuple of tensor dimension lenghts""" - return self._data.shape + def tagged_shape(self): + """alias for `sizes`""" + return self.sizes @property def size(self): @@ -236,9 +228,14 @@ def dtype(self) -> DTypeStr: @property def sizes(self): - """Ordered, immutable mapping from axis ids to lengths.""" + """Ordered, immutable mapping from axis ids to axis lengths.""" return cast(Mapping[AxisId, int], self.data.sizes) + # @property + # def tagged_shape(self): + # """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" + # return cast(Mapping[AxisId, int], self.data.sizes) + def astype(self, dtype: DTypeStr, *, copy: bool = False): """Return tensor cast to `dtype` @@ -418,45 +415,6 @@ def resize_to( return tensor - def tile( - self, - tile_size: PerAxis[int], - halo: PerAxis[HaloLike], - pad_mode: PadMode, - ) -> Tuple[ - TotalNumberOfTiles, - Generator[Tuple[TileNumber, Tensor, PerAxis[SliceInfo]], Any, None], - ]: - """tile this tensor into `tile_size` tiles that overlap by `halo`. - At the tensor's edge the `halo` is padded with `pad_mode`. - - Args: - tile_sizes: (Outer) output tile shape. - halo: padding At the tensor's edge, overlap with neighboring tiles within - the tensor; additional padding at the end of dimensions that do not - evenly divide by the tile shape may result in larger halos for edge - tiles. - pad_mode: How to pad at the tensor's edge. - """ - assert all(a in self.dims for a in tile_size), (self.dims, set(tile_size)) - assert all(a in self.dims for a in halo), (self.dims, set(halo)) - - # fill in default halo (0) and tile_size (tensor size) - halo = {a: Halo.create(halo.get(a, 0)) for a in self.dims} - tile_size = {a: tile_size.get(a, s) for a, s in self.sizes.items()} - - inner_1d_tiles: List[List[SliceInfo]] = [] - for a, s in self.sizes.items(): - stride = tile_size[a] - sum(halo[a]) - tiles_1d = [SliceInfo(p, min(s, p + stride)) for p in range(0, s, stride)] - inner_1d_tiles.append(tiles_1d) - - n_tiles = prod(map(len, inner_1d_tiles)) - - return n_tiles, self._tile_generator( - inner_1d_tiles=inner_1d_tiles, halo=halo, pad_mode=pad_mode - ) - def transpose( self, axes: Sequence[AxisId], @@ -526,29 +484,3 @@ def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): raise ValueError(f"Could not guess an axis mapping for {array.shape}") return cls(array, dims=tuple(a.id for a in current_axes)) - - def _tile_generator( - self, - *, - inner_1d_tiles: List[List[SliceInfo]], - halo: PerAxis[Halo], - pad_mode: PadMode, - ): - for i, nd_tile in enumerate(itertools.product(*inner_1d_tiles)): - inner_slice: PerAxis[SliceInfo] = dict(zip(self.dims, nd_tile)) - outer_slice = { - a: SliceInfo( - max(0, inner.start - halo[a].left), - min(self.sizes[a], inner.stop + halo[a].right), - ) - for a, inner in inner_slice.items() - } - pad_width: PerAxis[PadWidth] = { - a: PadWidth( - max(0, halo[a].left - inner.start), - max(0, inner.stop + halo[a].right - self.sizes[a]), - ) - for a, inner in inner_slice.items() - } - - yield i, self[outer_slice].pad(pad_width, pad_mode), inner_slice diff --git a/bioimageio/core/tensor_block.py b/bioimageio/core/tensor_block.py new file mode 100644 index 00000000..449dc596 --- /dev/null +++ b/bioimageio/core/tensor_block.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +from typing import Any, Generator, Iterable, Optional, Tuple + +from typing_extensions import Self + +from bioimageio.core.axis import PerAxis +from bioimageio.core.common import ( + Halo, + HaloLike, + PadMode, + SliceInfo, + TotalNumberOfBlocks, +) + +from .block import Block, split_shape_into_blocks +from .stat_measures import Stat +from .tensor import Tensor + + +@dataclass(init=False) +class TensorBlock(Block): + """A block with data""" + + stat: Stat + """sample and dataset statistics""" + + data: Tensor + """the block's tensor""" + + def __init__( + self, + data: Tensor, + *, + inner_slice: PerAxis[SliceInfo], + halo: PerAxis[Halo], + block_number: int, + blocks_in_sample: int, + stat: Stat, + ): + super().__init__( + sample_shape=data.tagged_shape, + inner_slice=inner_slice, + halo=halo, + block_number=block_number, + blocks_in_sample=blocks_in_sample, + ) + self.data = data + self.stat = stat + + @property + def inner_data(self): + return {t: self.data[self.local_slice] for t in self.data} + + def __post_init__(self): + super().__post_init__() + for a, s in self.data.sizes.items(): + slice_ = self.inner_slice[a] + halo = self.halo[a] + assert s == slice_.stop - slice_.start + halo.left + halo.right, ( + s, + slice_, + halo, + ) + + @classmethod + def from_sample( + cls, + sample: Tensor, + block: Block, + *, + pad_mode: PadMode, + stat: Stat, + ) -> Self: + return cls( + data=sample[block.outer_slice].pad(block.padding, pad_mode), + inner_slice=block.inner_slice, + halo=block.halo, + block_number=block.block_number, + blocks_in_sample=block.blocks_in_sample, + stat=stat, + ) + + +def split_tensor_into_blocks( + sample: Tensor, + block_shape: PerAxis[int], + *, + halo: PerAxis[HaloLike], + stride: Optional[PerAxis[int]] = None, + pad_mode: PadMode, + stat: Stat, +) -> Tuple[TotalNumberOfBlocks, Generator[TensorBlock, Any, None]]: + """divide a sample tensor into tensor blocks.""" + n_blocks, block_gen = split_shape_into_blocks( + sample.tagged_shape, block_shape=block_shape, halo=halo + ) + return n_blocks, _tensor_block_generator( + sample, block_gen, pad_mode=pad_mode, stat=stat + ) + + +def _tensor_block_generator( + sample: Tensor, blocks: Iterable[Block], *, pad_mode: PadMode, stat: Stat +): + for block in blocks: + yield TensorBlock.from_sample(sample, block, pad_mode=pad_mode, stat=stat) diff --git a/bioimageio/core/tile.py b/bioimageio/core/tile.py deleted file mode 100644 index d8180af4..00000000 --- a/bioimageio/core/tile.py +++ /dev/null @@ -1,113 +0,0 @@ -from dataclasses import dataclass, field - -from bioimageio.core.common import TileNumber, TotalNumberOfTiles - -from .axis import PerAxis -from .common import Halo, OverlapWidth, PadWidth, SliceInfo -from .stat_measures import Stat -from .tensor import PerTensor, Tensor - - -@dataclass -class AbstractTile: - """A tile of a dataset sample without any data""" - - inner_slice: PerTensor[PerAxis[SliceInfo]] - """slice of the inner tile (without padding and overlap) of the sample""" - - halo: PerTensor[PerAxis[Halo]] - """pad/overlap to extend the (inner) tile (to the outer tile)""" - - tile_number: TileNumber - """the n-th tile of the sample""" - - tiles_in_sample: TotalNumberOfTiles - """total number of tiles of the sample""" - - sample_sizes: PerTensor[PerAxis[int]] - """the axis sizes of the sample""" - - stat: Stat - """sample and dataset statistics""" - - outer_slice: PerTensor[PerAxis[SliceInfo]] = field(init=False) - """slice of the outer tile (including overlap, but not padding) in the sample""" - - local_slice: PerTensor[PerAxis[SliceInfo]] = field(init=False) - """slice to extract the inner tile from the outer tile""" - - overlap: PerTensor[PerAxis[OverlapWidth]] = field(init=False) - """overlap 'into a neighboring tile'""" - - padding: PerTensor[PerAxis[PadWidth]] = field(init=False) - """pad (at sample edges where we cannot overlap to realize `halo`""" - - def __post_init__(self): - self.outer_slice = { - t: { - a: SliceInfo( - max(0, self.inner_slice[t][a].start - self.halo[t][a].left), - min( - self.sample_sizes[t][a], - self.inner_slice[t][a].stop + self.halo[t][a].right, - ), - ) - for a in self.inner_slice[t] - } - for t in self.inner_slice - } - self.local_slice = { - t: { - a: SliceInfo( - self.inner_slice[t][a].start - self.outer_slice[t][a].start, - self.inner_slice[t][a].stop - self.outer_slice[t][a].start, - ) - for a in self.inner_slice[t] - } - for t in self.inner_slice - } - self.overlap = { - t: { - a: OverlapWidth( - self.inner_slice[t][a].start - self.outer_slice[t][a].start, - self.outer_slice[t][a].stop - self.inner_slice[t][a].stop, - ) - for a in self.inner_slice[t] - } - for t in self.inner_slice - } - self.padding = { - t: { - a: PadWidth( - self.halo[t][a].left - self.overlap[t][a].left, - self.halo[t][a].right - self.overlap[t][a].right, - ) - for a in self.inner_slice[t] - } - for t in self.inner_slice - } - - -@dataclass -class Tile(AbstractTile): - """A tile of a dataset sample""" - - data: PerTensor[Tensor] - """the tile's tensors""" - - @property - def inner_data(self): - return {t: self.data[t][self.local_slice[t]] for t in self.data} - - def __post_init__(self): - super().__post_init__() - for t, d in self.data.items(): - assert t == d.id, f"tensor id mismatch: {t} != {d.id}" - for a, s in d.sizes.items(): - slice_ = self.inner_slice[t][a] - halo = self.halo[t][a] - assert s == slice_.stop - slice_.start + halo.left + halo.right, ( - s, - slice_, - halo, - ) diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py index 7480a255..09a290f8 100644 --- a/bioimageio/core/utils/_digest_spec.py +++ b/bioimageio/core/utils/_digest_spec.py @@ -1,10 +1,11 @@ -from typing import Union +from typing import Iterable, Union -from bioimageio.core.axis import AxisInfo -from bioimageio.core.sample import Sample +from bioimageio.core.tile import AbstractTile from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.utils import load_array +from ..axis import AxisInfo +from ..sample import UntiledSample from ..tensor import Tensor, TensorId @@ -26,7 +27,7 @@ def get_axes_infos( ] -def get_test_inputs(model: AnyModelDescr) -> Sample: +def get_test_inputs(model: AnyModelDescr) -> UntiledSample: if isinstance(model, v0_4.ModelDescr): tensor_ids = [TensorId(t.name) for t in model.inputs] else: @@ -38,7 +39,7 @@ def get_test_inputs(model: AnyModelDescr) -> Sample: arrays = [load_array(d.test_tensor) for d in model.inputs] axes = [get_axes_infos(t) for t in model.inputs] - return Sample( + return UntiledSample( data={ tid: Tensor.from_numpy(arr, dims=ax) for tid, arr, ax in zip(tensor_ids, arrays, axes) @@ -46,7 +47,7 @@ def get_test_inputs(model: AnyModelDescr) -> Sample: ) -def get_test_outputs(model: AnyModelDescr) -> Sample: +def get_test_outputs(model: AnyModelDescr) -> UntiledSample: if isinstance(model, v0_4.ModelDescr): tensor_ids = [TensorId(t.name) for t in model.outputs] else: @@ -59,9 +60,38 @@ def get_test_outputs(model: AnyModelDescr) -> Sample: axes = [get_axes_infos(t) for t in model.outputs] - return Sample( + return UntiledSample( data={ tid: Tensor.from_numpy(arr, dims=ax) for tid, arr, ax in zip(tensor_ids, arrays, axes) } ) + + +def get_abstract_output_tiles( + input_tiles: Iterable[AbstractTile], model: v0_5.ModelDescr +): + if not isinstance(model, v0_5.ModelDescr): + raise TypeError(f"get_abstract_output_tile() not implemented for {type(model)}") + + sample_sizes = model.get_output_tensor_sizes(input_tile.sample_sizes) + outer_sizes = model.get_output_tensor_sizes(input_tile.outer_sizes) + UntiledSample() + halo = { + t.id: {a.id: a.halo for a in t.axes if isinstance(a, v0_5.WithHalo)} + for t in model.outputs + if t.id in outer_sizes + } + inner_sizes = { + t: { + a: outer_sizes[t][a] - 2 * halo.get(t, {}).get(a, 0) for a in outer_sizes[t] + } + for t in outer_sizes + } + + return AbstractTile( + halo=halo, + tile_number=input_tile.tile_number, + tiles_in_sample=input_tile.tiles_in_sample, + stat={}, + ) diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py index 94ce3320..9ec34ec6 100644 --- a/tests/test_proc_ops.py +++ b/tests/test_proc_ops.py @@ -6,7 +6,7 @@ from typing_extensions import TypeGuard from bioimageio.core.axis import AxisId -from bioimageio.core.sample import Sample +from bioimageio.core.sample import UntiledSample from bioimageio.core.stat_calculators import compute_measures from bioimageio.core.stat_measures import SampleMean, SamplePercentile, SampleStd from bioimageio.core.tensor import TensorId @@ -23,7 +23,7 @@ def test_scale_linear(tid: TensorId): offset = xr.DataArray([1, 2, 42], dims=("c")) gain = xr.DataArray([1, 2, 3], dims=("c")) data = xr.DataArray(np.arange(6).reshape((1, 2, 3)), dims=("x", "y", "c")) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) op = ScaleLinear(input=tid, output=tid, offset=offset, gain=gain) op(sample) @@ -37,7 +37,7 @@ def test_scale_linear_no_channel(tid: TensorId): op = ScaleLinear(tid, tid, offset=1, gain=2) data = xr.DataArray(np.arange(6).reshape(2, 3), dims=("x", "y")) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) op(sample) expected = xr.DataArray(np.array([[1, 3, 5], [7, 9, 11]]), dims=("x", "y")) @@ -56,7 +56,7 @@ def test_zero_mean_unit_variance(tid: TensorId): from bioimageio.core.proc_ops import ZeroMeanUnitVariance data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) m = SampleMean(tid) std = SampleStd(tid) op = ZeroMeanUnitVariance(tid, tid, m, std) @@ -99,7 +99,7 @@ def test_zero_mean_unit_variance_fixed(tid: TensorId): ), dims=("b", "c", "x"), ) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) op(sample) xr.testing.assert_allclose(expected, sample.data[tid]) @@ -115,7 +115,7 @@ def test_zero_mean_unit_across_axes(tid: TensorId): SampleMean(tid, (AxisId("x"), AxisId("y"))), SampleStd(tid, (AxisId("x"), AxisId("y"))), ) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) sample.stat = compute_measures(op.required_measures, [sample]) expected = xr.concat( @@ -135,7 +135,7 @@ def test_zero_mean_unit_variance_fixed2(tid: TensorId): op = FixedZeroMeanUnitVariance(tid, tid, mean=mean, std=std, eps=eps) data = xr.DataArray(np_data, dims=("x", "y")) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y")) op(sample) xr.testing.assert_allclose(expected, sample.data[tid]) @@ -146,7 +146,7 @@ def test_binarize(tid: TensorId): op = Binarize(tid, tid, threshold=14) data = xr.DataArray(np.arange(30).reshape((2, 3, 5)), dims=("x", "y", "c")) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) expected = xr.zeros_like(data) expected[{"x": slice(1, None)}] = 1 op(sample) @@ -164,7 +164,7 @@ def test_binarize2(tid: TensorId): threshold = 0.5 exp = xr.DataArray(np_data > threshold, dims=axes) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) binarize = Binarize(tid, tid, threshold=threshold) binarize(sample) xr.testing.assert_allclose(exp, sample.data[tid]) @@ -175,7 +175,7 @@ def test_clip(tid: TensorId): op = Clip(tid, tid, min=3, max=5) data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) expected = xr.DataArray( np.array([[3, 3, 3], [3, 4, 5], [5, 5, 5]]), dims=("x", "y") @@ -188,7 +188,7 @@ def test_combination_of_op_steps_with_dims_specified(tid: TensorId): from bioimageio.core.proc_ops import ZeroMeanUnitVariance data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) op = ZeroMeanUnitVariance( tid, tid, @@ -244,7 +244,7 @@ def test_scale_mean_variance(tid: TensorId, axes: Optional[Tuple[AxisId, ...]]): ref_data = xr.DataArray((np_data * 2) + 3, dims=ipt_axes) op = ScaleMeanVariance(tid, tid, reference_tensor=TensorId("ref_name"), axes=axes) - sample = Sample(data={tid: ipt_data, TensorId("ref_name"): ref_data}) + sample = UntiledSample(data={tid: ipt_data, TensorId("ref_name"): ref_data}) sample.stat = compute_measures(op.required_measures, [sample]) op(sample) xr.testing.assert_allclose(ref_data, sample.data[tid]) @@ -269,7 +269,7 @@ def test_scale_mean_variance_per_channel(tid: TensorId, axes_str: Optional[str]) ref_data = xr.DataArray(np_ref_data, dims=ipt_axes) op = ScaleMeanVariance(tid, tid, reference_tensor=TensorId("ref_name"), axes=axes) - sample = Sample(data={tid: ipt_data, TensorId("ref_name"): ref_data}) + sample = UntiledSample(data={tid: ipt_data, TensorId("ref_name"): ref_data}) sample.stat = compute_measures(op.required_measures, [sample]) op(sample) @@ -288,7 +288,7 @@ def test_scale_range(tid: TensorId): op = ScaleRange(tid, tid) np_data = np.arange(9).reshape(3, 3).astype("float32") data = xr.DataArray(np_data, dims=("x", "y")) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) sample.stat = compute_measures(op.required_measures, [sample]) eps = 1.0e-6 @@ -310,7 +310,7 @@ def test_scale_range_axes(tid: TensorId): np_data = np.arange(18).reshape((2, 3, 3)).astype("float32") data = xr.DataArray(np_data, dims=("c", "x", "y")) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) sample.stat = compute_measures(op.required_measures, [sample]) eps = 1.0e-6 @@ -331,7 +331,7 @@ def test_sigmoid(tid: TensorId): axes = ("c", "y", "x") np_data = np.random.rand(*shape) data = xr.DataArray(np_data, dims=axes) - sample = Sample(data={tid: data}) + sample = UntiledSample(data={tid: data}) sigmoid = Sigmoid(tid, tid) sigmoid(sample) diff --git a/tests/test_stat_calculators.py b/tests/test_stat_calculators.py index 6e963272..b1468609 100644 --- a/tests/test_stat_calculators.py +++ b/tests/test_stat_calculators.py @@ -5,7 +5,7 @@ from xarray.testing import assert_allclose # pyright: ignore[reportUnknownVariableType] from bioimageio.core.axis import AxisId -from bioimageio.core.sample import Sample +from bioimageio.core.sample import UntiledSample from bioimageio.core.stat_calculators import MeanVarStdCalculator from bioimageio.core.stat_measures import ( DatasetMean, @@ -20,7 +20,8 @@ def create_random_dataset(tid: TensorId, axes: Tuple[AxisId, ...]): sizes = list(range(n, len(axes) + 1)) data = np.asarray(np.random.rand(*sizes)) ds = [ - Sample(data={tid: Tensor(data[i : i + 1], dims=axes, id=tid)}) for i in range(n) + UntiledSample(data={tid: Tensor(data[i : i + 1], dims=axes, id=tid)}) + for i in range(n) ] return Tensor(data, dims=axes), ds diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index 04b9ed3f..2fa69c2a 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -7,7 +7,7 @@ from bioimageio.core import stat_measures from bioimageio.core.axis import AxisId -from bioimageio.core.sample import Sample +from bioimageio.core.sample import UntiledSample from bioimageio.core.stat_calculators import ( SamplePercentilesCalculator, get_measure_calculators, @@ -36,7 +36,7 @@ def test_individual_normal_measure( ) expected = getattr(data, name)(dim=axes) - sample = Sample(data={data_id: data}) + sample = UntiledSample(data={data_id: data}) actual = measure.compute(sample) xr.testing.assert_allclose(expected, actual) @@ -55,7 +55,7 @@ def test_individual_percentile_measure(axes: Optional[Tuple[AxisId, ...]]): data = Tensor( np.random.random((5, 6, 3)), dims=(AxisId("x"), AxisId("y"), AxisId("c")) ) - actual = calc.compute(Sample(data={tid: data})) + actual = calc.compute(UntiledSample(data={tid: data})) for m in measures: expected = data.quantile(q=m.q, dim=m.axes) xr.testing.assert_allclose(expected, actual[m]) From 9e806c7701d119b69f4000f2c7d5978eb6f316b2 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 5 Apr 2024 16:24:57 +0200 Subject: [PATCH 178/244] WIP sample_block --- bioimageio/core/__init__.py | 6 +- bioimageio/core/_op_base.py | 6 +- bioimageio/core/_prediction_pipeline.py | 218 +++++++++----- bioimageio/core/block.py | 251 +++++----------- bioimageio/core/block_meta.py | 296 +++++++++++++++++++ bioimageio/core/common.py | 9 +- bioimageio/core/digest_spec.py | 215 ++++++++++++++ bioimageio/core/proc_ops.py | 137 +++++---- bioimageio/core/proc_setup.py | 32 +-- bioimageio/core/sample.py | 335 ++++++++++------------ bioimageio/core/stat_calculators.py | 160 ++++++----- bioimageio/core/stat_measures.py | 15 +- bioimageio/core/tensor.py | 21 +- bioimageio/core/tensor_block.py | 106 ------- bioimageio/core/utils/__init__.py | 5 - bioimageio/core/utils/_digest_spec.py | 97 ------- bioimageio/core/utils/_import_callable.py | 66 ----- tests/test_stat_measures.py | 4 +- 18 files changed, 1098 insertions(+), 881 deletions(-) create mode 100644 bioimageio/core/block_meta.py create mode 100644 bioimageio/core/digest_spec.py delete mode 100644 bioimageio/core/tensor_block.py delete mode 100644 bioimageio/core/utils/_digest_spec.py delete mode 100644 bioimageio/core/utils/_import_callable.py diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index de4ffe29..27794eb6 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -25,10 +25,10 @@ from ._settings import settings as settings from .axis import Axis as Axis from .axis import AxisId as AxisId -from .sample import UntiledSample as UntiledSample +from .block_meta import BlockMeta as BlockMeta +from .common import MemberId as MemberId +from .sample import Sample as Sample from .tensor import Tensor as Tensor -from .tensor import TensorId as TensorId -from .tile import Tile as Tile from .utils import VERSION __version__ = VERSION diff --git a/bioimageio/core/_op_base.py b/bioimageio/core/_op_base.py index e3e420d1..78d13b52 100644 --- a/bioimageio/core/_op_base.py +++ b/bioimageio/core/_op_base.py @@ -2,14 +2,14 @@ from dataclasses import dataclass from typing import Collection -from bioimageio.core.sample import UntiledSample -from bioimageio.core.stat_measures import Measure +from .sample import SampleBlock +from .stat_measures import Measure @dataclass class Operator(ABC): @abstractmethod - def __call__(self, sample: UntiledSample) -> None: ... + def __call__(self, sample_block: SampleBlock) -> None: ... @property @abstractmethod diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index a95e0932..b9498095 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -9,47 +9,49 @@ Mapping, Optional, Sequence, + Tuple, TypeVar, Union, ) -from numpy.typing import NDArray +from tqdm import tqdm from typing_extensions import assert_never -from bioimageio.spec.model import AnyModelDescr, v0_4 +from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.model.v0_5 import WeightsFormat -from .axis import AxisInfo +from .axis import AxisId, AxisInfo +from .block import Block +from .common import MemberId, PadMode, PerMember +from .digest_spec import get_axes_infos, get_block_meta from .model_adapters import ModelAdapter, create_model_adapter from .model_adapters import get_weight_formats as get_weight_formats from .proc_ops import Processing from .proc_setup import setup_pre_and_postprocessing -from .sample import TiledSample, UntiledSample +from .sample import Sample, SampleBlock from .stat_measures import DatasetMeasure, MeasureValue -from .tensor import Tensor, TensorId -from .tile import Tile -from .utils import get_axes_infos +from .tensor import Tensor @dataclass -class CoreTensorDescr: - id: TensorId +class MemberDescr: + id: MemberId axes: Sequence[AxisInfo] optional: bool -Data = TypeVar( - "Data", - TiledSample, - UntiledSample, - Tile, - Iterable[TiledSample], - Iterable[UntiledSample], - NDArray[Any], - Sequence[Optional[NDArray[Any]]], - Mapping[Union[TensorId, str], Optional[NDArray[Any]]], +Predict_IO = TypeVar( + "Predict_IO", + Sample, + SampleBlock, + Iterable[Sample], + Iterable[SampleBlock], ) +# NDArray[Any], +# Sequence[Optional[NDArray[Any]]], +# Mapping[Union[MemberId, str], Optional[NDArray[Any]]], + class PredictionPipeline: """ @@ -61,55 +63,80 @@ def __init__( self, *, name: str, - bioimageio_model: AnyModelDescr, + model_description: AnyModelDescr, preprocessing: List[Processing], postprocessing: List[Processing], - model: ModelAdapter, + model_adapter: ModelAdapter, + ns: Union[ + v0_5.ParameterizedSize.N, + Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N], + ], ) -> None: super().__init__() - if bioimageio_model.run_mode: + if model_description.run_mode: warnings.warn( - f"Not yet implemented inference for run mode '{bioimageio_model.run_mode.name}'" + f"Not yet implemented inference for run mode '{model_description.run_mode.name}'" ) self.name = name self._preprocessing = preprocessing self._postprocessing = postprocessing + self.model_description = model_description + if isinstance(ns, int): + if isinstance(model_description, v0_4.ModelDescr): + self.ns = None + else: + self.ns = { + (ipt.id, a.id): ns + for ipt in model_description.inputs + for a in ipt.axes + if isinstance(a.size, v0_5.ParameterizedSize) + } + else: + self.ns = ns + # if isinstance(model_description, v0_4.ModelDescr): + # self.default_sample_block_shape = None + # else: + + # self.default_sample_block_shape = model_description.get_tensor_sizes( + # ns, 1 + # ).inputs + self.input_ids = tuple( - (TensorId(str(t.name)) if isinstance(t, v0_4.InputTensorDescr) else t.id) - for t in bioimageio_model.inputs + (MemberId(str(t.name)) if isinstance(t, v0_4.InputTensorDescr) else t.id) + for t in model_description.inputs ) self.inputs = collections.OrderedDict( ( tid, - CoreTensorDescr( + MemberDescr( id=tid, axes=get_axes_infos(t), optional=not isinstance(t, v0_4.InputTensorDescr) and t.optional, ), ) - for tid, t in zip(self.input_ids, bioimageio_model.inputs) + for tid, t in zip(self.input_ids, model_description.inputs) ) self.output_ids = tuple( - (TensorId(str(t.name)) if isinstance(t, v0_4.OutputTensorDescr) else t.id) - for t in bioimageio_model.outputs + (MemberId(str(t.name)) if isinstance(t, v0_4.OutputTensorDescr) else t.id) + for t in model_description.outputs ) self.outputs = collections.OrderedDict( ( tid, - CoreTensorDescr( + MemberDescr( id=tid, axes=get_axes_infos(t), optional=False, ), ) - for tid, t in zip(self.output_ids, bioimageio_model.outputs) + for tid, t in zip(self.output_ids, model_description.outputs) ) - self._adapter: ModelAdapter = model + self._adapter: ModelAdapter = model_adapter - def __call__(self, data: Data) -> Data: + def __call__(self, data: Predict_IO) -> Predict_IO: return self.predict(data) def __enter__(self): @@ -120,30 +147,88 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore self.unload() return False - def predict(self, inputs: Data) -> Data: + # def predict_sample( + # self, + # sample: Sample, + # parameterized_size_n: Optional[int] = None, + # pad_mode: PadMode = "reflect", + # ) -> Sample: + # if parameterized_size_n is None: + # # TODO guess n + # parameterized_size_n = 10 + + # return Sample.from_blocks( + # map( + # self.predict_sample_block, + # sample.split_into_blocks( + # block_shapes={m: ipt.axes for m, ipt in self.inputs.items()}, + # halo={ + # m: ipt.axes.halo + # for m, ipt in self.inputs.items() + # if isinstance(ipt.axes, v0_5.WithHalo) + # }, + # pad_mode=pad_mode, + # ), + # ) + # ) + + # def predict_sample_block(self, inputs: SampleBlock) -> SampleBlock: + # self.apply_preprocessing(inputs) + # output = Block( + # data={ + # tid: out + # for tid, out in zip( + # self.output_ids, + # self._adapter.forward( + # *(inputs.data[t] for t in self.input_ids) + # ), + # ) + # if out is not None + # } + # ) + # self.apply_postprocessing(output) + # return output + + # else: + # assert_never(inputs) + + # return output + + def predict(self, inputs: Predict_IO) -> Predict_IO: """Run model prediction **including** pre/postprocessing.""" - if isinstance(inputs, Tile): - self.apply_preprocessing(inputs) - output_tile = Tile( - data={ - tid: out - for tid, out in zip( - self.output_ids, - self._adapter.forward( - *(inputs.data[t] for t in self.input_ids) - ), - ) - if out is not None - } + if isinstance(inputs, Sample): + if isinstance(self.model_description, v0_4.ModelDescr): + raise NotImplementedError( + "predicting `Sample`s no implemented for model" + + f" {self.model_description.format_version}." + + " Please divide the sample into block. using `sample.split_into_blocks()`." + ) + + assert self.ns is not None + n_blocks, block_metas = get_block_meta( + self.model_description, input_sample_shape=inputs.shape, ns=self.ns ) - self.apply_postprocessing(output_tile) - return output_tile + # for block_meta in tqdm(block_metas, desc=f"predict sample {inputs.id or ''} with {self.model_description.id or self.model_description.name}", unit="block", total=n_blocks): + input_halo = + Sample.from_blocks(inputs.split_into_blocks()) + # return Sample.from_blocks( + # map( + # self.predict, + n_blocks, blocks = inputs.split_into_blocks( + block_shapes=self.default_sample_block_shape, + halo={ + m: ipt.axes.halo + for m, ipt in self.inputs.items() + if isinstance(ipt.axes, v0_5.WithHalo) + }, + pad_mode="reflect", + ) + # ) + # ) else: - assert_never(inputs) - - return output + return self.predict_sample_block(inputs) # if isinstance(inputs, collections.abc.Mapping): # data = { @@ -162,7 +247,7 @@ def predict(self, inputs: Data) -> Data: # tid: d for tid, d in zip(self.input_ids, inputs_seq) if d is not None # } - # sample = UntiledSample( + # sample = Sample( # data={ # tid: Tensor.from_numpy(d, dims=self.inputs[tid].axes) # for tid, d in data.items() @@ -171,15 +256,15 @@ def predict(self, inputs: Data) -> Data: # output = self.predict(sample) # return {tid: out.data.data for } - def apply_preprocessing(self, tile: Tile) -> None: + def apply_preprocessing(self, sample_block: SampleBlock) -> None: """apply preprocessing in-place, also updates sample stats""" for op in self._preprocessing: - op(tile) + op(sample_block) - def apply_postprocessing(self, tile: Tile) -> None: + def apply_postprocessing(self, sample_block: SampleBlock) -> None: """apply postprocessing in-place, also updates samples stats""" for op in self._postprocessing: - op(tile) + op(sample_block) def load(self): """ @@ -200,14 +285,16 @@ def create_prediction_pipeline( devices: Optional[Sequence[str]] = None, weight_format: Optional[WeightsFormat] = None, weights_format: Optional[WeightsFormat] = None, - dataset_for_initial_statistics: Iterable[ - Union[UntiledSample, Sequence[Tensor]] - ] = tuple(), + dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(), keep_updating_initial_dataset_statistics: bool = False, fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType( {} ), model_adapter: Optional[ModelAdapter] = None, + ns: Union[ + v0_5.ParameterizedSize.N, + Mapping[Tuple[TensorId, AxisId], v0_5.ParameterizedSize.N], + ] = 10, **deprecated_kwargs: Any, ) -> PredictionPipeline: """ @@ -232,16 +319,16 @@ def create_prediction_pipeline( ) if isinstance(bioimageio_model, v0_4.ModelDescr): - input_ids = [TensorId(str(ipt.name)) for ipt in bioimageio_model.inputs] + input_ids = [MemberId(str(ipt.name)) for ipt in bioimageio_model.inputs] else: input_ids = [ipt.id for ipt in bioimageio_model.inputs] def dataset(): for x in dataset_for_initial_statistics: - if isinstance(x, UntiledSample): + if isinstance(x, Sample): yield x else: - yield UntiledSample(data=dict(zip(input_ids, x))) + yield Sample(members=dict(zip(input_ids, x))) preprocessing, postprocessing = setup_pre_and_postprocessing( bioimageio_model, @@ -252,8 +339,9 @@ def dataset(): return PredictionPipeline( name=bioimageio_model.name, - bioimageio_model=bioimageio_model, - model=model_adapter, + model_description=bioimageio_model, + model_adapter=model_adapter, preprocessing=preprocessing, postprocessing=postprocessing, + ns=ns, ) diff --git a/bioimageio/core/block.py b/bioimageio/core/block.py index b678d82a..7d78e56c 100644 --- a/bioimageio/core/block.py +++ b/bioimageio/core/block.py @@ -1,193 +1,98 @@ -import itertools -from dataclasses import dataclass, field -from math import prod -from typing import Any, Dict, Generator, List, Optional, Tuple +from dataclasses import dataclass +from typing import ( + Any, + Generator, + Iterable, + Optional, + Tuple, +) + +from typing_extensions import Self -from .axis import AxisId, PerAxis +from .axis import PerAxis +from .block_meta import BlockMeta, split_shape_into_blocks from .common import ( - BlockNumber, Halo, HaloLike, - PadWidth, + PadMode, SliceInfo, TotalNumberOfBlocks, ) +from .tensor import Tensor + + +@dataclass(init=False) +class Block(BlockMeta): + """A block/tile of a (larger) tensor""" + + data: Tensor + """the block's tensor, e.g. a (padded) slice of some larger, original tensor""" + + def __init__( + self, + data: Tensor, + *, + inner_slice: PerAxis[SliceInfo], + halo: PerAxis[Halo], + block_number: int, + blocks_in_sample: int, + ): + super().__init__( + sample_shape=data.tagged_shape, + inner_slice=inner_slice, + halo=halo, + block_number=block_number, + blocks_in_sample=blocks_in_sample, + ) + self.data = data - -@dataclass -class Block: - """Block of a sample - - Figure for illustration: - The first 2d block (dashed) of a sample (**bold**). - The inner slice (thin) is expanded by a halo in both dimensions on both sides. - The outer slice reaches from the sample origin (0, 0) to the right halo point. - - ```terminal - ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ - ╷ halo(left) ╷ - ╷ ╷ - ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔ - ╷ ┃ │ ╷ sample - ╷ ┃ inner │ ╷ - ╷ ┃ (and outer) │ outer ╷ - ╷ ┃ slice │ slice ╷ - ╷ ┃ │ ╷ - ╷ ┣─────────────────┘ ╷ - ╷ ┃ outer slice ╷ - ╷ ┃ halo(right) ╷ - └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ - ⬇ - ``` - - note: - - Inner and outer slices are specified in sample coordinates. - - The outer_slice of a block at the sample edge may overlap by more than the - halo with the neighboring block (the inner slices will not overlap though). - - """ - - sample_shape: PerAxis[int] - """the axis sizes of the whole (unblocked) sample""" - - inner_slice: PerAxis[SliceInfo] - """inner region (without halo) wrt the sample""" - - halo: PerAxis[Halo] - """halo enlarging the inner region to the block's sizes""" - - block_number: BlockNumber - """the n-th block of the sample""" - - blocks_in_sample: TotalNumberOfBlocks - """total number of blocks in the sample""" - - shape: PerAxis[int] = field(init=False) - """axis lengths of the block""" - - padding: PerAxis[PadWidth] = field(init=False) - """padding to realize the halo at the sample edge - where we cannot simply enlarge the inner slice""" - - outer_slice: PerAxis[SliceInfo] = field(init=False) - """slice of the outer block (without padding) wrt the sample""" - - inner_shape: PerAxis[int] = field(init=False) - """axis lengths of the inner region (without halo)""" - - local_slice: PerAxis[SliceInfo] = field(init=False) - """inner slice wrt the block, **not** the sample""" + @property + def inner_data(self): + return self.data[self.local_slice] def __post_init__(self): - assert all( - a in self.sample_shape for a in self.inner_slice - ), "block has axes not present in sample" - assert all( - a in self.inner_slice for a in self.halo - ), "halo has axes not present in block" - - self.shape = { - a: s.stop - s.start + sum(self.halo[a]) for a, s in self.inner_slice.items() - } - assert all( - s <= self.sample_shape[a] for a, s in self.shape.items() - ), "block larger than sample" - - self.inner_shape = {a: s.stop - s.start for a, s in self.inner_slice.items()} - self.outer_slice = { - a: SliceInfo( - max( - 0, - min( - self.inner_slice[a].start - self.halo[a].left, - self.sample_shape[a] - self.inner_shape[a] - self.halo[a].left, - ), - ), - min( - self.sample_shape[a], - self.inner_slice[a].stop + self.halo[a].right, - ), - ) - for a in self.inner_slice - } - self.padding = { - a: PadWidth( - max( - 0, - self.halo[a].left - - (self.inner_slice[a].start + self.outer_slice[a].start), - ), - max( - 0, - self.halo[a].right - - (self.outer_slice[a].stop + self.inner_slice[a].stop), - ), + super().__post_init__() + for a, s in self.data.sizes.items(): + slice_ = self.inner_slice[a] + halo = self.halo[a] + assert s == slice_.stop - slice_.start + halo.left + halo.right, ( + s, + slice_, + halo, ) - for a in self.inner_slice - } - self.local_slice = { - a: SliceInfo( - self.padding[a].left, - self.padding[a].left + self.inner_shape[a], - ) - for a in self.inner_slice - } + + @classmethod + def from_sample_member( + cls, + sample_member: Tensor, + block: BlockMeta, + *, + pad_mode: PadMode, + ) -> Self: + return cls( + data=sample_member[block.outer_slice].pad(block.padding, pad_mode), + inner_slice=block.inner_slice, + halo=block.halo, + block_number=block.block_number, + blocks_in_sample=block.blocks_in_sample, + ) -def split_shape_into_blocks( - shape: PerAxis[int], +def split_tensor_into_blocks( + tensor: Tensor, block_shape: PerAxis[int], + *, halo: PerAxis[HaloLike], stride: Optional[PerAxis[int]] = None, + pad_mode: PadMode, ) -> Tuple[TotalNumberOfBlocks, Generator[Block, Any, None]]: - assert all(a in shape for a in block_shape), ( - tuple(shape), - set(block_shape), - ) - assert all(a in shape for a in halo), (tuple(shape), set(halo)) - - # fill in default halo (0) and tile_size (tensor size) - halo = {a: Halo.create(h) for a, h in halo.items()} - block_shape = {a: block_shape.get(a, s) for a, s in shape.items()} - if stride is None: - stride = {} - - inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {} - for a, s in shape.items(): - inner_size = block_shape[a] - sum(halo[a]) - stride_1d = stride.get(a, inner_size) - inner_1d_slices[a] = [ - SliceInfo(min(p, s - inner_size), min(p + inner_size, s)) - for p in range(0, s, stride_1d) - ] - - n_blocks = prod(map(len, inner_1d_slices.values())) - - return n_blocks, _block_generator( - shape, - blocks_in_sample=n_blocks, - inner_1d_slices=inner_1d_slices, - halo=halo, + """divide a sample tensor into tensor blocks.""" + n_blocks, block_gen = split_shape_into_blocks( + tensor.tagged_shape, block_shape=block_shape, halo=halo, stride=stride ) + return n_blocks, _block_generator(tensor, block_gen, pad_mode=pad_mode) -def _block_generator( - sample_shape: PerAxis[int], - *, - blocks_in_sample: int, - inner_1d_slices: Dict[AxisId, List[SliceInfo]], - halo: PerAxis[HaloLike], -): - assert all(a in sample_shape for a in halo) - - halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices} - for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())): - inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile)) - - yield Block( - sample_shape=sample_shape, - inner_slice=inner_slice, - halo=halo, - block_number=i, - blocks_in_sample=blocks_in_sample, - ) +def _block_generator(sample: Tensor, blocks: Iterable[BlockMeta], *, pad_mode: PadMode): + for block in blocks: + yield Block.from_sample_member(sample, block, pad_mode=pad_mode) diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py new file mode 100644 index 00000000..785abc6c --- /dev/null +++ b/bioimageio/core/block_meta.py @@ -0,0 +1,296 @@ +import itertools +from dataclasses import dataclass, field +from math import prod +from typing import ( + Any, + Collection, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, +) + +from .axis import AxisId, PerAxis +from .common import ( + BlockNumber, + Halo, + HaloLike, + MemberId, + PadWidth, + PerMember, + SliceInfo, + TotalNumberOfBlocks, +) + + +@dataclass +class BlockMeta: + """Block meta data of a sample member (a tensor in a sample) + + Figure for illustration: + The first 2d block (dashed) of a sample member (**bold**). + The inner slice (thin) is expanded by a halo in both dimensions on both sides. + The outer slice reaches from the sample member origin (0, 0) to the right halo point. + + ```terminal + ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ + ╷ halo(left) ╷ + ╷ ╷ + ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔ + ╷ ┃ │ ╷ sample member + ╷ ┃ inner │ ╷ + ╷ ┃ (and outer) │ outer ╷ + ╷ ┃ slice │ slice ╷ + ╷ ┃ │ ╷ + ╷ ┣─────────────────┘ ╷ + ╷ ┃ outer slice ╷ + ╷ ┃ halo(right) ╷ + └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ + ⬇ + ``` + + note: + - Inner and outer slices are specified in sample member coordinates. + - The outer_slice of a block at the sample edge may overlap by more than the + halo with the neighboring block (the inner slices will not overlap though). + + """ + + sample_shape: PerAxis[int] + """the axis sizes of the whole (unblocked) sample""" + + inner_slice: PerAxis[SliceInfo] + """inner region (without halo) wrt the sample""" + + halo: PerAxis[Halo] + """halo enlarging the inner region to the block's sizes""" + + block_number: BlockNumber + """the n-th block of the sample""" + + blocks_in_sample: TotalNumberOfBlocks + """total number of blocks in the sample""" + + shape: PerAxis[int] = field(init=False) + """axis lengths of the block""" + + padding: PerAxis[PadWidth] = field(init=False) + """padding to realize the halo at the sample edge + where we cannot simply enlarge the inner slice""" + + outer_slice: PerAxis[SliceInfo] = field(init=False) + """slice of the outer block (without padding) wrt the sample""" + + inner_shape: PerAxis[int] = field(init=False) + """axis lengths of the inner region (without halo)""" + + local_slice: PerAxis[SliceInfo] = field(init=False) + """inner slice wrt the block, **not** the sample""" + + @property + def dims(self) -> Collection[AxisId]: + return set(self.inner_shape) + + @property + def tagged_shape(self) -> PerAxis[int]: + """alias for shape""" + return self.shape + + @property + def inner_slice_wo_overlap(self): + """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be + stiched together trivially to form the original sample. + + This can also be used to calculate statistics + without overrepresenting edge regions.""" + # TODO: update inner_slice_wo_overlap when adding block overlap + return self.inner_slice + + def __post_init__(self): + assert all( + a in self.sample_shape for a in self.inner_slice + ), "block has axes not present in sample" + assert all( + a in self.inner_slice for a in self.halo + ), "halo has axes not present in block" + + self.shape = { + a: s.stop - s.start + sum(self.halo[a]) for a, s in self.inner_slice.items() + } + assert all( + s <= self.sample_shape[a] for a, s in self.shape.items() + ), "block larger than sample" + + self.inner_shape = {a: s.stop - s.start for a, s in self.inner_slice.items()} + self.outer_slice = { + a: SliceInfo( + max( + 0, + min( + self.inner_slice[a].start - self.halo[a].left, + self.sample_shape[a] - self.inner_shape[a] - self.halo[a].left, + ), + ), + min( + self.sample_shape[a], + self.inner_slice[a].stop + self.halo[a].right, + ), + ) + for a in self.inner_slice + } + self.padding = { + a: PadWidth( + max( + 0, + self.halo[a].left + - (self.inner_slice[a].start + self.outer_slice[a].start), + ), + max( + 0, + self.halo[a].right + - (self.outer_slice[a].stop + self.inner_slice[a].stop), + ), + ) + for a in self.inner_slice + } + self.local_slice = { + a: SliceInfo( + self.padding[a].left, + self.padding[a].left + self.inner_shape[a], + ) + for a in self.inner_slice + } + + +def split_shape_into_blocks( + shape: PerAxis[int], + block_shape: PerAxis[int], + halo: PerAxis[HaloLike], + stride: Optional[PerAxis[int]] = None, +) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]: + assert all(a in shape for a in block_shape), ( + tuple(shape), + set(block_shape), + ) + assert all(a in shape for a in halo), (tuple(shape), set(halo)) + + # fill in default halo (0) and tile_size (tensor size) + halo = {a: Halo.create(h) for a, h in halo.items()} + block_shape = {a: block_shape.get(a, s) for a, s in shape.items()} + if stride is None: + stride = {} + + inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {} + for a, s in shape.items(): + inner_size = block_shape[a] - sum(halo[a]) + stride_1d = stride.get(a, inner_size) + inner_1d_slices[a] = [ + SliceInfo(min(p, s - inner_size), min(p + inner_size, s)) + for p in range(0, s, stride_1d) + ] + + n_blocks = prod(map(len, inner_1d_slices.values())) + + return n_blocks, _block_meta_generator( + shape, + blocks_in_sample=n_blocks, + inner_1d_slices=inner_1d_slices, + halo=halo, + ) + + +def _block_meta_generator( + sample_shape: PerAxis[int], + *, + blocks_in_sample: int, + inner_1d_slices: Dict[AxisId, List[SliceInfo]], + halo: PerAxis[HaloLike], +): + assert all(a in sample_shape for a in halo) + + halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices} + for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())): + inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile)) + + yield BlockMeta( + sample_shape=sample_shape, + inner_slice=inner_slice, + halo=halo, + block_number=i, + blocks_in_sample=blocks_in_sample, + ) + + +def split_multiple_shapes_into_blocks( + shapes: PerMember[PerAxis[int]], + block_shapes: PerMember[PerAxis[int]], + *, + halo: PerMember[PerAxis[HaloLike]], + strides: Optional[PerMember[PerAxis[int]]] = None, + broadcast: bool = False, +) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]: + assert not ( + missing := [t for t in block_shapes if t not in shapes] + ), f"block shape specified for unknown tensors: {missing}" + assert broadcast or not ( + missing := [t for t in shapes if t not in block_shapes] + ), f"no block shape specified for {missing} (set `broadcast` to True if these tensors should be repeated for each block)" + assert not ( + missing := [t for t in halo if t not in block_shapes] + ), f"`halo` specified for tensors without block shape: {missing}" + + if strides is None: + strides = {} + + assert not ( + missing := [t for t in strides if t not in block_shapes] + ), f"`stride` specified for tensors without block shape: {missing}" + + blocks: Dict[MemberId, Iterable[BlockMeta]] = {} + n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {} + for t in block_shapes: + n_blocks[t], blocks[t] = split_shape_into_blocks( + shape=shapes[t], + block_shape=block_shapes[t], + halo=halo.get(t, {}), + stride=strides.get(t), + ) + assert n_blocks[t] > 0 + + unique_n_blocks = set(n_blocks.values()) + n = max(unique_n_blocks) + if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: + if not broadcast: + raise ValueError( + f"Mismatch for total number of blocks due to unsplit (single block) tensors: {n_blocks}." + + " Set `broadcast` to True if you want to repeat unsplit (single block) tensors." + ) + + blocks = { + t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen + for t, block_gen in blocks.items() + } + elif len(unique_n_blocks) != 1: + raise ValueError(f"Mismatch for total number of blocks: {n_blocks}") + + return n, _aligned_blocks_generator(n, blocks) + + +def _aligned_blocks_generator( + n: TotalNumberOfBlocks, blocks: Dict[MemberId, Iterable[BlockMeta]] +): + iterators = {t: iter(gen) for t, gen in blocks.items()} + for _ in range(n): + yield {t: next(it) for t, it in iterators.items()} + + +def _repeat_single_block(block_generator: Iterable[BlockMeta], n: TotalNumberOfBlocks): + round_two = False + for block in block_generator: + assert not round_two + for _ in range(n): + yield block + + round_two = True diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index faff0de0..d5c825c1 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,9 +1,11 @@ from __future__ import annotations -from typing import Literal, NamedTuple, Tuple, TypeVar, Union +from typing import Hashable, Literal, Mapping, NamedTuple, Tuple, TypeVar, Union from typing_extensions import Self, assert_never +from bioimageio.spec.model import v0_5 + DTypeStr = Literal[ "bool", "float32", @@ -75,5 +77,10 @@ class SliceInfo(NamedTuple): stop: int +SampleId = Hashable +MemberId = v0_5.TensorId +T = TypeVar("T") +PerMember = Mapping[MemberId, T] + BlockNumber = int TotalNumberOfBlocks = int diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py new file mode 100644 index 00000000..7594d0c1 --- /dev/null +++ b/bioimageio/core/digest_spec.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +import importlib.util +from functools import singledispatch +from typing import Any, Callable, Dict, Iterable, Mapping, NamedTuple, Tuple, Union + +from typing_extensions import Unpack + +from bioimageio.spec._internal.io_utils import HashKwargs, download +from bioimageio.spec.common import FileSource +from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 +from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile +from bioimageio.spec.model.v0_5 import ( + ArchitectureFromFileDescr, + ArchitectureFromLibraryDescr, + ParameterizedSize, + TensorId, +) +from bioimageio.spec.utils import load_array + +from .axis import AxisId, AxisInfo, PerAxis +from .block_meta import split_multiple_shapes_into_blocks +from .common import Halo, MemberId, PerMember, TotalNumberOfBlocks +from .sample import Sample, SampleBlockMeta, sample_block_meta_generator +from .tensor import Tensor + + +@singledispatch +def import_callable(node: type, /) -> Callable[..., Any]: + raise TypeError(type(node)) + + +@import_callable.register +def _(node: CallableFromDepencency) -> Callable[..., Any]: + module = importlib.import_module(node.module_name) + c = getattr(module, str(node.callable_name)) + if not callable(c): + raise ValueError(f"{node} (imported: {c}) is not callable") + + return c + + +@import_callable.register +def _(node: ArchitectureFromLibraryDescr) -> Callable[..., Any]: + module = importlib.import_module(node.import_from) + c = getattr(module, str(node.callable)) + if not callable(c): + raise ValueError(f"{node} (imported: {c}) is not callable") + + return c + + +@import_callable.register +def _(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): + return _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) + + +@import_callable.register +def _(node: ArchitectureFromFileDescr, **kwargs: Unpack[HashKwargs]): + return _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) + + +def _import_from_file_impl( + source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] +): + local_file = download(source, **kwargs) + module_name = local_file.path.stem + importlib_spec = importlib.util.spec_from_file_location( + module_name, local_file.path + ) + if importlib_spec is None: + raise ImportError(f"Failed to import {module_name} from {source}.") + + dep = importlib.util.module_from_spec(importlib_spec) + importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? + return getattr(dep, callable_name) + + +def get_axes_infos( + io_descr: Union[ + v0_4.InputTensorDescr, + v0_4.OutputTensorDescr, + v0_5.InputTensorDescr, + v0_5.OutputTensorDescr, + ] +): + """get a unified, simplified axis represenation from spec axes""" + return [ + ( + AxisInfo.create("i") + if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x") + else AxisInfo.create(a) + ) + for a in io_descr.axes + ] + + +def get_test_inputs(model: AnyModelDescr) -> Sample: + """returns a model's test input sample""" + if isinstance(model, v0_4.ModelDescr): + tensor_ids = [TensorId(t.name) for t in model.inputs] + else: + tensor_ids = [t.id for t in model.inputs] + + if isinstance(model, v0_4.ModelDescr): + arrays = [load_array(tt) for tt in model.test_inputs] + else: + arrays = [load_array(d.test_tensor) for d in model.inputs] + + axes = [get_axes_infos(t) for t in model.inputs] + return Sample( + members={ + tid: Tensor.from_numpy(arr, dims=ax) + for tid, arr, ax in zip(tensor_ids, arrays, axes) + } + ) + + +def get_test_outputs(model: AnyModelDescr) -> Sample: + """returns a model's test output sample""" + if isinstance(model, v0_4.ModelDescr): + tensor_ids = [TensorId(t.name) for t in model.outputs] + else: + tensor_ids = [t.id for t in model.outputs] + + if isinstance(model, v0_4.ModelDescr): + arrays = [load_array(tt) for tt in model.test_outputs] + else: + arrays = [load_array(d.test_tensor) for d in model.outputs] + + axes = [get_axes_infos(t) for t in model.outputs] + + return Sample( + members={ + tid: Tensor.from_numpy(arr, dims=ax) + for tid, arr, ax in zip(tensor_ids, arrays, axes) + } + ) + + +class IO_SampleBlockMeta(NamedTuple): + input: SampleBlockMeta + output: SampleBlockMeta + +def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): + halo: Dict[MemberId, Dict[AxisId, Halo]] = {} + outputs = {t.id: t for t in model.outputs} + all_tensors = {**{t.id: t for t in model.inputs}, **outputs} + + for t, th in output_halo.items(): + axes = {a.id: a for a in outputs[t].axes} + + for a, ah in th.items(): + s = axes[a].size + if not isinstance(s, v0_5.SizeReference): + raise ValueError(f"Unable to map output halo for {t}.{a} to an input axis") + + + axis = axes[a] + ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] + + total_output_halo = sum(ah) + total_input_halo = total_output_halo * axis.scale / ref_axis.scale + if total_input_halo != int(total_input_halo): + raise ValueError() + for lr in (ah.left, ah.right): + input_halo = + return halo + +def get_block_meta( + model: v0_5.ModelDescr, + input_sample_shape: PerMember[PerAxis[int]], + ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize.N], +) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: + """returns an iterable yielding meta data for corresponding input and output samples""" + if not isinstance(model, v0_5.ModelDescr): + raise TypeError(f"get_block_meta() not implemented for {type(model)}") + + block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=1) + input_block_shape = { + t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} + for t in {tt for tt, _ in block_axis_sizes.inputs} + } + output_block_shape = { + t: { + aa: s + for (tt, aa), s in block_axis_sizes.outputs.items() + if tt == t and not isinstance(s, tuple) + } + for t in {tt for tt, _ in block_axis_sizes.outputs} + } + output_halo = {t.id: {a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)} for t in model.outputs} + input_halo = get_input_halo(model, output_halo) + output_sample_shape_data_dep = model.get_output_tensor_sizes(input_sample_shape) + output_sample_shape = { + t: { + a: -1 if isinstance(s, tuple) else s + for a, s in output_sample_shape_data_dep[t].items() + } + for t in output_sample_shape_data_dep + } + n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( + input_sample_shape, input_block_shape, halo=input_halo + ) + n_output_blocks, output_blocks = split_multiple_shapes_into_blocks( + output_sample_shape, output_block_shape, halo=output_halo + ) + assert n_input_blocks == n_output_blocks + return n_input_blocks, ( + IO_SampleBlockMeta(ipt, out) + for ipt, out in zip( + sample_block_meta_generator(input_blocks, origin=input_sample_shape), + sample_block_meta_generator(output_blocks, origin=output_sample_shape), + ) + ) diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 0dc9ee20..0676b87b 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -1,6 +1,6 @@ import collections.abc from abc import ABC, abstractmethod -from dataclasses import InitVar, dataclass, field +from dataclasses import InitVar, dataclass, field, replace from typing import ( Collection, Literal, @@ -16,12 +16,12 @@ import xarray as xr from typing_extensions import Self, assert_never -from bioimageio.core.common import DTypeStr +from bioimageio.core.sample import SampleBlock from bioimageio.spec.model import v0_4, v0_5 from ._op_base import Operator from .axis import AxisId -from .sample import UntiledSample +from .common import DTypeStr, MemberId from .stat_calculators import StatsCalculator from .stat_measures import ( DatasetMean, @@ -37,7 +37,7 @@ Stat, StdMeasure, ) -from .tensor import Tensor, TensorId +from .tensor import Tensor def convert_axis_ids( @@ -61,23 +61,35 @@ def convert_axis_ids( @dataclass class _SimpleOperator(Operator, ABC): - input: TensorId - output: TensorId + input: MemberId + output: MemberId @property def required_measures(self) -> Collection[Measure]: return set() # @property - # def required_tensors(self) -> Set[TensorId]: + # def required_tensors(self) -> Set[MemberId]: # return {self.input} # @property - # def produced_tensors(self) -> Set[TensorId]: + # def produced_tensors(self) -> Set[MemberId]: # return {self.output} - def __call__(self, sample: UntiledSample) -> None: - sample.data[self.output] = self._apply(sample.data[self.input], sample.stat) + def __call__(self, sample_block: SampleBlock) -> None: + input_tensor = sample_block.members[self.input] + output_tensor = self._apply(input_tensor, sample_block.stat) + + if self.output in sample_block.blocks: + assert ( + sample_block.blocks[self.output].tagged_shape + == output_tensor.tagged_shape + ) + sample_block.blocks[self.output].data = output_tensor + else: + sample_block.blocks[self.output] = replace( + sample_block.blocks[self.input], data=output_tensor + ) @abstractmethod def _apply(self, input: Tensor, stat: Stat) -> Tensor: ... @@ -91,8 +103,8 @@ class AddKnownDatasetStats(Operator): def required_measures(self) -> Set[Measure]: return set() - def __call__(self, sample: UntiledSample) -> None: - sample.stat.update(self.dataset_stats.items()) + def __call__(self, sample_block: SampleBlock) -> None: + sample_block.stat.update(self.dataset_stats.items()) # @dataclass @@ -124,7 +136,7 @@ def __call__(self, sample: UntiledSample) -> None: # else: # self._keep_updating_dataset_stats = self.keep_updating_dataset_stats -# def __call__(self, sample: Sample) -> None: +# def __call__(self, sample_block: SampleBlock> None: # if self._keep_updating_dataset_stats: # sample.stat.update(self._stats_calculator.update_and_get_all(sample)) # else: @@ -154,11 +166,18 @@ def __post_init__(self): or not self.stats_calculator.has_dataset_measures ) - def __call__(self, sample: UntiledSample) -> None: + def __call__(self, sample_block: SampleBlock) -> None: + if sample_block.block_number != 0: + return # update stats with whole sample on first block + if self._keep_updating_dataset_stats: - sample.stat.update(self.stats_calculator.update_and_get_all(sample)) + sample_block.stat.update( + self.stats_calculator.update_and_get_all(sample_block.origin) + ) else: - sample.stat.update(self.stats_calculator.skip_update_and_get_all(sample)) + sample_block.stat.update( + self.stats_calculator.skip_update_and_get_all(sample_block.origin) + ) @dataclass @@ -173,16 +192,16 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: @classmethod def from_proc_descr( - cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], tensor_id: TensorId + cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId ) -> Self: if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): return cls( - input=tensor_id, output=tensor_id, threshold=descr.kwargs.threshold + input=member_id, output=member_id, threshold=descr.kwargs.threshold ) elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): return cls( - input=tensor_id, - output=tensor_id, + input=member_id, + output=member_id, threshold=descr.kwargs.threshold, axis=descr.kwargs.axis, ) @@ -208,11 +227,11 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: @classmethod def from_proc_descr( - cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], tensor_id: TensorId + cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId ) -> Self: return cls( - input=tensor_id, - output=tensor_id, + input=member_id, + output=member_id, min=descr.kwargs.min, max=descr.kwargs.max, ) @@ -223,8 +242,8 @@ class EnsureDtype(_SimpleOperator): dtype: DTypeStr @classmethod - def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, tensor_id: TensorId): - return cls(input=tensor_id, output=tensor_id, dtype=descr.kwargs.dtype) + def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): + return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) def get_descr(self): return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) @@ -248,7 +267,7 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: def from_proc_descr( cls, descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], - tensor_id: TensorId, + member_id: MemberId, ) -> Self: kwargs = descr.kwargs if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): @@ -273,13 +292,13 @@ def from_proc_descr( else kwargs.offset[0] ) - return cls(input=tensor_id, output=tensor_id, gain=gain, offset=offset) + return cls(input=member_id, output=member_id, gain=gain, offset=offset) @dataclass class ScaleMeanVariance(_SimpleOperator): axes: Optional[Sequence[AxisId]] = None - reference_tensor: Optional[TensorId] = None + reference_tensor: Optional[MemberId] = None eps: float = 1e-6 mean: Union[SampleMean, DatasetMean] = field(init=False) std: Union[SampleStd, DatasetStd] = field(init=False) @@ -300,10 +319,10 @@ def __post_init__(self): Mean = DatasetMean Std = DatasetStd - self.mean = Mean(tensor_id=self.input, axes=axes) - self.std = Std(tensor_id=self.input, axes=axes) - self.ref_mean = Mean(tensor_id=ref_tensor, axes=axes) - self.ref_std = Std(tensor_id=ref_tensor, axes=axes) + self.mean = Mean(member_id=self.input, axes=axes) + self.std = Std(member_id=self.input, axes=axes) + self.ref_mean = Mean(member_id=ref_tensor, axes=axes) + self.ref_std = Std(member_id=ref_tensor, axes=axes) def _apply(self, input: Tensor, stat: Stat) -> Tensor: mean = stat[self.mean] @@ -316,15 +335,15 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: def from_proc_descr( cls, descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], - tensor_id: TensorId, + member_id: MemberId, ) -> Self: kwargs = descr.kwargs axes = _get_axes(descr.kwargs) return cls( - input=tensor_id, - output=tensor_id, - reference_tensor=TensorId(str(kwargs.reference_tensor)), + input=member_id, + output=member_id, + reference_tensor=MemberId(str(kwargs.reference_tensor)), axes=axes, eps=kwargs.eps, ) @@ -371,17 +390,17 @@ def __post_init__( upper_percentile: Optional[Union[SamplePercentile, DatasetPercentile]], ): if lower_percentile is None: - tid = self.input if upper_percentile is None else upper_percentile.tensor_id - self.lower = DatasetPercentile(q=0.0, tensor_id=tid) + tid = self.input if upper_percentile is None else upper_percentile.member_id + self.lower = DatasetPercentile(q=0.0, member_id=tid) else: self.lower = lower_percentile if upper_percentile is None: - self.upper = DatasetPercentile(q=1.0, tensor_id=self.lower.tensor_id) + self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id) else: self.upper = upper_percentile - assert self.lower.tensor_id == self.upper.tensor_id + assert self.lower.member_id == self.upper.member_id assert self.lower.q < self.upper.q assert self.lower.axes == self.upper.axes @@ -393,13 +412,13 @@ def required_measures(self): def from_proc_descr( cls, descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], - tensor_id: TensorId, + member_id: MemberId, ): kwargs = descr.kwargs ref_tensor = ( - tensor_id + member_id if kwargs.reference_tensor is None - else TensorId(str(kwargs.reference_tensor)) + else MemberId(str(kwargs.reference_tensor)) ) axes = _get_axes(descr.kwargs) if axes is None or AxisId("batch") in axes: @@ -408,13 +427,13 @@ def from_proc_descr( Percentile = SamplePercentile return cls( - input=tensor_id, - output=tensor_id, + input=member_id, + output=member_id, lower_percentile=Percentile( - q=kwargs.min_percentile / 100, axes=axes, tensor_id=ref_tensor + q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor ), upper_percentile=Percentile( - q=kwargs.max_percentile / 100, axes=axes, tensor_id=ref_tensor + q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor ), ) @@ -425,7 +444,7 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: def get_descr(self): assert self.lower.axes == self.upper.axes - assert self.lower.tensor_id == self.upper.tensor_id + assert self.lower.member_id == self.upper.member_id return v0_5.ScaleRangeDescr( kwargs=v0_5.ScaleRangeKwargs( @@ -433,7 +452,7 @@ def get_descr(self): min_percentile=self.lower.q * 100, max_percentile=self.upper.q * 100, eps=self.eps, - reference_tensor=self.lower.tensor_id, + reference_tensor=self.lower.member_id, ) ) @@ -451,10 +470,10 @@ def required_measures(self) -> Collection[Measure]: @classmethod def from_proc_descr( - cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], tensor_id: TensorId + cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId ) -> Self: assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) - return cls(input=tensor_id, output=tensor_id) + return cls(input=member_id, output=member_id) def get_descr(self): return v0_5.SigmoidDescr() @@ -480,7 +499,7 @@ def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: def from_proc_descr( cls, descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], - tensor_id: TensorId, + member_id: MemberId, ): axes = _get_axes(descr.kwargs) @@ -492,10 +511,10 @@ def from_proc_descr( Std = SampleStd return cls( - input=tensor_id, - output=tensor_id, - mean=Mean(axes=axes, tensor_id=tensor_id), - std=Std(axes=axes, tensor_id=tensor_id), + input=member_id, + output=member_id, + mean=Mean(axes=axes, member_id=member_id), + std=Std(axes=axes, member_id=member_id), ) def _apply(self, input: Tensor, stat: Stat) -> Tensor: @@ -529,7 +548,7 @@ def __post_init__(self): def from_proc_descr( cls, descr: v0_5.FixedZeroMeanUnitVarianceDescr, - tensor_id: TensorId, + member_id: MemberId, ) -> Self: if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): dims = None @@ -539,8 +558,8 @@ def from_proc_descr( assert_never(descr.kwargs) return cls( - input=tensor_id, - output=tensor_id, + input=member_id, + output=member_id, mean=xr.DataArray(descr.kwargs.mean, dims=dims), std=xr.DataArray(descr.kwargs.std, dims=dims), ) diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index 0979773f..64168ce9 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -1,9 +1,9 @@ -from types import MappingProxyType from typing import ( Iterable, List, Mapping, NamedTuple, + Optional, Sequence, Set, Union, @@ -11,18 +11,14 @@ from typing_extensions import assert_never -from bioimageio.core.proc_ops import ( - AddKnownDatasetStats, - Processing, - UpdateStats, - get_proc_class, -) -from bioimageio.core.sample import UntiledSample -from bioimageio.core.stat_calculators import StatsCalculator -from bioimageio.core.stat_measures import DatasetMeasure, Measure, MeasureValue from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.model.v0_5 import TensorId +from .proc_ops import AddKnownDatasetStats, Processing, UpdateStats, get_proc_class +from .sample import Sample +from .stat_calculators import StatsCalculator +from .stat_measures import DatasetMeasure, Measure, MeasureValue + TensorDescr = Union[ v0_4.InputTensorDescr, v0_4.OutputTensorDescr, @@ -45,9 +41,9 @@ class _SetupProcessing(NamedTuple): def setup_pre_and_postprocessing( model: AnyModelDescr, - dataset_for_initial_statistics: Iterable[UntiledSample], + dataset_for_initial_statistics: Iterable[Sample], keep_updating_initial_dataset_stats: bool = False, - fixed_dataset_stats: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType({}), + fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None, ) -> PreAndPostprocessing: """ Get pre- and postprocessing operators for a `model` description. @@ -55,7 +51,9 @@ def setup_pre_and_postprocessing( prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model) missing_dataset_stats = { - m for m in prep_meas | post_meas if m not in fixed_dataset_stats + m + for m in prep_meas | post_meas + if fixed_dataset_stats is None or m not in fixed_dataset_stats } initial_stats_calc = StatsCalculator(missing_dataset_stats) for sample in dataset_for_initial_statistics: @@ -125,18 +123,18 @@ def prepare_procs(tensor_descrs: Sequence[TensorDescr]): for proc_d in proc_descrs: proc_class = get_proc_class(proc_d) - tensor_id = ( + member_id = ( TensorId(str(t_descr.name)) if isinstance(t_descr, v0_4.TensorDescrBase) else t_descr.id ) req = proc_class.from_proc_descr( - proc_d, tensor_id # pyright: ignore[reportArgumentType] + proc_d, member_id # pyright: ignore[reportArgumentType] ) for m in req.required_measures: - if m.tensor_id in input_ids: + if m.member_id in input_ids: pre_measures.add(m) - elif m.tensor_id in output_ids: + elif m.member_id in output_ids: post_measures.add(m) else: raise ValueError("When to raise ") diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index e3d003e2..c67ef769 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -1,223 +1,186 @@ +from __future__ import annotations + +from abc import abstractmethod from dataclasses import dataclass, field -from pprint import pformat -from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union, cast +from typing import Dict, Generic, Iterable, Optional, Tuple, TypeVar import numpy as np -import xarray as xr from typing_extensions import Self -from bioimageio.core.tensor_block import TensorBlock - -from .axis import AxisId, PerAxis -from .block import Block, BlockNumber, TotalNumberOfBlocks, split_shape_into_blocks -from .common import BlockNumber, Halo, HaloLike, PadMode, SliceInfo +from bioimageio.core.block import Block + +from .axis import PerAxis +from .block_meta import BlockMeta, split_multiple_shapes_into_blocks +from .common import ( + BlockNumber, + HaloLike, + MemberId, + PadMode, + PerMember, + SampleId, + TotalNumberOfBlocks, +) from .stat_measures import Stat -from .tensor import PerTensor, Tensor, TensorId - - -def split_multiple_shapes_into_blocks( - shapes: PerTensor[PerAxis[int]], - block_shapes: PerTensor[PerAxis[int]], - *, - strides: Optional[PerTensor[PerAxis[int]]] = None, - halo: PerTensor[PerAxis[HaloLike]], - pad_mode: PadMode, - broadcast: bool = False, -) -> Tuple[TotalNumberOfBlocks, Iterable[PerTensor[Block]]]: - assert not ( - missing := [t for t in block_shapes if t not in shapes] - ), f"block shape specified for unknown tensors: {missing}" - assert broadcast or not ( - missing := [t for t in shapes if t not in block_shapes] - ), f"no block shape specified for {missing} (set `broadcast` to True if these tensors should be repeated for each block)" - assert not ( - missing := [t for t in halo if t not in block_shapes] - ), f"`halo` specified for tensors without block shape: {missing}" - - if strides is None: - strides = {} - - assert not ( - missing := [t for t in strides if t not in block_shapes] - ), f"`stride` specified for tensors without block shape: {missing}" - - blocks: Dict[TensorId, Iterable[Block]] = {} - n_blocks: Dict[TensorId, TotalNumberOfBlocks] = {} - for t in block_shapes: - n_blocks[t], blocks[t] = split_shape_into_blocks( - shape=shapes[t], - block_shape=block_shapes[t], - halo=halo.get(t, {}), - stride=strides.get(t), - ) - assert n_blocks[t] > 0 - - unique_n_blocks = set(n_blocks.values()) - n = max(unique_n_blocks) - if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: - if not broadcast: - raise ValueError( - f"Mismatch for total number of blocks due to unsplit (single block) tensors: {n_blocks}." - + " Set `broadcast` to True if you want to repeat unsplit (single block) tensors." - ) - - blocks = { - t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen - for t, block_gen in blocks.items() - } - elif len(unique_n_blocks) != 1: - raise ValueError(f"Mismatch for total number of blocks: {n_blocks}") - - return n, _aligned_blocks_generator(n, blocks) - - -def _aligned_blocks_generator( - n: TotalNumberOfBlocks, blocks: Dict[TensorId, Iterable[Block]] -): - iterators = {t: iter(gen) for t, gen in blocks.items()} - for _ in range(n): - yield {t: next(it) for t, it in iterators.items()} - - -def _repeat_single_block(block_generator: Iterable[Block], n: TotalNumberOfBlocks): - round_two = False - for block in block_generator: - assert not round_two - for _ in range(n): - yield block +from .tensor import Tensor - round_two = True +# TODO: allow for lazy samples to read/write to disk @dataclass class Sample: """A dataset sample""" - data: Dict[TensorId, Tensor] + members: Dict[MemberId, Tensor] """the sample's tensors""" stat: Stat = field(default_factory=dict) """sample and dataset statistics""" + id: Optional[SampleId] = None + """identifier within the sample's dataset""" + @property - def sizes(self) -> PerTensor[PerAxis[int]]: - return {tid: t.sizes for tid, t in self.data.items()} + def shape(self) -> PerMember[PerAxis[int]]: + return {tid: t.sizes for tid, t in self.members.items()} def split_into_blocks( self, - tile_sizes: PerTensor[PerAxis[int]], - halo: PerTensor[PerAxis[HaloLike]], + block_shapes: PerMember[PerAxis[int]], + halo: PerMember[PerAxis[HaloLike]], pad_mode: PadMode, - ) -> TiledSample: + broadcast: bool = False, + ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlock]]: assert not ( - missing := [t for t in tile_sizes if t not in self.data] - ), f"`tile_sizes` specified for missing tensors: {missing}" + missing := [m for m in block_shapes if m not in self.members] + ), f"`block_shapes` specified for unknown members: {missing}" assert not ( - missing := [t for t in halo if t not in tile_sizes] - ), f"`halo` specified for tensors without `tile_sizes`: {missing}" - - # any axis not given in `tile_sizes` is treated - # as tile size equal to the tensor axis' size - explicit_tile_sizes = { - t: {a: tile_sizes.get(t, {}).get(a, s) for a, s in tdata.sizes.items()} - for t, tdata in self.data.items() - } - - tensor_ids = tuple(self.data) - broadcasted_tensors = { - t: Tensor.from_xarray(d) - for t, d in zip( - tensor_ids, xr.broadcast(*(self.data[tt].data for tt in tensor_ids)) - ) - } - - tile_iterators: Dict[ - TensorId, Iterator[Tuple[BlockNumber, Tensor, PerAxis[SliceInfo]]] - ] = {} - - n_tiles_common = 1 - last_non_trivial: Optional[TensorId] = None - for t in tensor_ids: - n_tiles, generator = broadcasted_tensors[t].block( - block_size=explicit_tile_sizes[t], - explicit_halo=halo.get(t, {}), - pad_mode=pad_mode, - ) - tile_iterators[t] = iter(generator) - if n_tiles in (1, n_tiles_common): - pass - elif n_tiles_common == 1: - last_non_trivial = t - n_tiles_common = n_tiles - else: - assert last_non_trivial is not None - mismatch = { - last_non_trivial: { - "original sizes": self.data[last_non_trivial].sizes, - "broadcasted sizes": broadcasted_tensors[ - last_non_trivial - ].sizes, - "n_tiles": n_tiles_common, - }, - t: { - "original sizes": self.data[t].sizes, - "broadcasted sizes": broadcasted_tensors[t].sizes, - "n_tiles": n_tiles, - }, - } - raise ValueError( - f"broadcasted tensors {last_non_trivial, t} do not tile to the same" - + f" number of tiles {n_tiles_common, n_tiles}. Details\n" - + pformat(mismatch) - ) - - for i in range(n_tiles_common): - data: Dict[TensorId, Tensor] = {} - inner_slice: Dict[TensorId, PerAxis[SliceInfo]] = {} - for t, iterator in tile_iterators.items(): - tn, tensor_tile, tensor_slice = next(iterator) - assert tn == i, f"expected tile number {i}, but got {tn}" - data[t] = tensor_tile - inner_slice[t] = tensor_slice - - yield Tile( - data=data, - inner_slice=inner_slice, - halo={ - t: {a: Halo.create(h) for a, h in th.items()} - for t, th in halo.items() - }, - sample_sizes=self.sizes, - tile_number=i, - tiles_in_sample=n_tiles_common, - stat=self.stat, - ) + missing := [m for m in halo if m not in block_shapes] + ), f"`halo` specified for members without `block_shape`: {missing}" + + n_blocks, blocks = split_multiple_shapes_into_blocks( + shapes=self.shape, + block_shapes=block_shapes, + halo=halo, + broadcast=broadcast, + ) + return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode) @classmethod - def from_tiles( - cls, tiles: Iterable[Tile], *, fill_value: float = float("nan") + def from_blocks( + cls, + sample_blocks: Iterable[SampleBlock], + *, + fill_value: float = float("nan"), ) -> Self: - # TODO: add `mode: Literal['in-memory', 'to-disk']` or similar to save out of mem samples - data: PerTensor[Tensor] = {} - stat: Stat = {} - for tile in tiles: - for t, tile_data in tile.inner_data.items(): - if t not in data: - axes = cast(Tuple[AxisId], tile_data.dims) - data[t] = Tensor( + members: PerMember[Tensor] = {} + for member_blocks in sample_blocks: + for m, block in member_blocks.blocks.items(): + if m not in members: + members[m] = Tensor( np.full( - tuple(tile.sample_sizes[t][a] for a in axes), + tuple(block.sample_shape[a] for a in block.data.dims), fill_value, - dtype=tile_data.dtype, + dtype=block.data.dtype, ), - dims=axes, + dims=block.data.dims, ) - data[t][tile.inner_slice[t]] = tile_data + members[m][block.inner_slice] = block.inner_data - stat = tile.stat + return cls(members=members) - return cls(data=data, stat=stat) +BlockT = TypeVar("BlockT", Block, BlockMeta) + + +@dataclass +class SampleBlockBase(Generic[BlockT]): + """base class for `SampleBlockMeta` and `SampleBlock`""" + + blocks: Dict[MemberId, BlockT] + + block_number: BlockNumber = field(init=False) + """the n-th block of the sample""" + + blocks_in_sample: TotalNumberOfBlocks = field(init=False) + """total number of blocks in the sample""" + + def __post_init__(self): + a_block = next(iter(self.blocks.values())) + self.block_number = a_block.block_number + self.blocks_in_sample = a_block.blocks_in_sample + + @property + def shape(self) -> PerMember[PerAxis[int]]: + return {mid: b.shape for mid, b in self.blocks.items()} + + @property + def inner_shape(self) -> PerMember[PerAxis[int]]: + return {mid: b.inner_shape for mid, b in self.blocks.items()} -Sample = Union[UntiledSample, TiledSample] + @property + @abstractmethod + def origin_shape(self) -> PerMember[PerAxis[int]]: ... + + +@dataclass +class SampleBlockMeta(SampleBlockBase[BlockMeta]): + """Meta data of a dataset sample block""" + + origin: PerMember[PerAxis[int]] + """the sampe shape the blocking for this block was based on""" + + @property + def origin_shape(self): + return self.origin + + +@dataclass +class SampleBlock(SampleBlockBase[Block]): + """A block of a dataset sample""" + + origin: Sample + """the sample this sample black was taken from""" + + @property + def origin_shape(self): + return self.origin.shape + + @property + def members(self) -> PerMember[Tensor]: + """the sample block's tensors""" + return {m: b.data for m, b in self.blocks.items()} + + @property + def stat(self): + return self.origin.stat + + +def sample_block_meta_generator( + blocks: Iterable[PerMember[BlockMeta]], + *, + origin: PerMember[PerAxis[int]], +): + for member_blocks in blocks: + yield SampleBlockMeta( + blocks=dict(member_blocks), + origin=origin, + ) + + +def sample_block_generator( + blocks: Iterable[PerMember[BlockMeta]], + *, + origin: Sample, + pad_mode: PadMode, +): + for member_blocks in blocks: + yield SampleBlock( + blocks={ + m: Block.from_sample_member( + origin.members[m], block=member_blocks[m], pad_mode=pad_mode + ) + for m in origin.members + }, + origin=origin, + ) diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 643c697a..9d6717e4 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -3,7 +3,6 @@ import collections.abc import warnings from itertools import product -from math import prod from typing import ( Any, Collection, @@ -27,7 +26,8 @@ from typing_extensions import assert_never from .axis import AxisId, PerAxis -from .sample import UntiledSample +from .common import MemberId +from .sample import Sample from .stat_measures import ( DatasetMean, DatasetMeasure, @@ -43,7 +43,7 @@ SampleStd, SampleVar, ) -from .tensor import Tensor, TensorId +from .tensor import Tensor try: import crick @@ -63,39 +63,37 @@ def quantile(self, q: Any) -> Any: class MeanCalculator: - """to calculate sample and dataset mean""" + """to calculate sample and dataset mean for in-memory samples""" - def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): + def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): super().__init__() self._n: int = 0 self._mean: Optional[Tensor] = None self._axes = None if axes is None else tuple(axes) - self._tensor_id = tensor_id - self._sample_mean = SampleMean(tensor_id=self._tensor_id, axes=self._axes) - self._dataset_mean = DatasetMean(tensor_id=self._tensor_id, axes=self._axes) + self._member_id = member_id + self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) + self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes) - def compute(self, sample: UntiledSample) -> Dict[SampleMean, MeasureValue]: + def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: return {self._sample_mean: self._compute_impl(sample)} - def _compute_impl(self, sample: UntiledSample) -> Tensor: - tensor = sample.data[self._tensor_id].astype("float64", copy=False) + def _compute_impl(self, sample: Sample) -> Tensor: + tensor = sample.members[self._member_id].astype("float64", copy=False) return tensor.mean(dim=self._axes) - def update(self, sample: UntiledSample) -> None: + def update(self, sample: Sample) -> None: mean = self._compute_impl(sample) - self._update_impl(sample.data[self._tensor_id], mean) + self._update_impl(sample.members[self._member_id], mean) - def compute_and_update( - self, sample: UntiledSample - ) -> Dict[SampleMean, MeasureValue]: + def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: mean = self._compute_impl(sample) - self._update_impl(sample.data[self._tensor_id], mean) + self._update_impl(sample.members[self._member_id], mean) return {self._sample_mean: mean} def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): assert tensor_mean.dtype == "float64" # reduced voxel count - n_b = int(np.prod(tensor.shape) / np.prod(tensor_mean.shape)) + n_b = int(tensor.size / tensor_mean.size) if self._mean is None: assert self._n == 0 @@ -119,18 +117,18 @@ def finalize(self) -> Dict[DatasetMean, MeasureValue]: class MeanVarStdCalculator: """to calculate sample and dataset mean, variance or standard deviation""" - def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): + def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): super().__init__() self._axes = None if axes is None else tuple(axes) - self._tensor_id = tensor_id + self._member_id = member_id self._n: int = 0 self._mean: Optional[Tensor] = None self._m2: Optional[Tensor] = None def compute( - self, sample: UntiledSample + self, sample: Sample ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: - tensor = sample.data[self._tensor_id] + tensor = sample.members[self._member_id] mean = tensor.mean(dim=self._axes) c = (tensor - mean).data if self._axes is None: @@ -143,21 +141,21 @@ def compute( std = np.sqrt(var) assert isinstance(std, xr.DataArray) return { - SampleMean(axes=self._axes, tensor_id=self._tensor_id): mean, - SampleVar(axes=self._axes, tensor_id=self._tensor_id): Tensor.from_xarray( + SampleMean(axes=self._axes, member_id=self._member_id): mean, + SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( var ), - SampleStd(axes=self._axes, tensor_id=self._tensor_id): Tensor.from_xarray( + SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( std ), } - def update(self, sample: UntiledSample): - tensor = sample.data[self._tensor_id].astype("float64", copy=False) + def update(self, sample: Sample): + tensor = sample.members[self._member_id].astype("float64", copy=False) mean_b = tensor.mean(dim=self._axes) assert mean_b.dtype == "float64" # reduced voxel count - n_b = int(prod(tensor.shape) / prod(mean_b.shape)) + n_b = int(tensor.size / mean_b.size) m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) assert m2_b.dtype == "float64" if self._mean is None: @@ -187,10 +185,10 @@ def finalize( sqrt = np.sqrt(var) assert isinstance(sqrt, xr.DataArray) return { - DatasetMean(tensor_id=self._tensor_id, axes=self._axes): self._mean, - DatasetVar(tensor_id=self._tensor_id, axes=self._axes): var, + DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, + DatasetVar(member_id=self._member_id, axes=self._axes): var, DatasetStd( - tensor_id=self._tensor_id, axes=self._axes + member_id=self._member_id, axes=self._axes ): Tensor.from_xarray(sqrt), } @@ -200,7 +198,7 @@ class SamplePercentilesCalculator: def __init__( self, - tensor_id: TensorId, + member_id: MemberId, axes: Optional[Sequence[AxisId]], qs: Collection[float], ): @@ -208,13 +206,13 @@ def __init__( assert all(0.0 <= q <= 1.0 for q in qs) self._qs = sorted(set(qs)) self._axes = None if axes is None else tuple(axes) - self._tensor_id = tensor_id + self._member_id = member_id - def compute(self, sample: UntiledSample) -> Dict[SamplePercentile, MeasureValue]: - tensor = sample.data[self._tensor_id] + def compute(self, sample: Sample) -> Dict[SamplePercentile, MeasureValue]: + tensor = sample.members[self._member_id] ps = tensor.quantile(self._qs, dim=self._axes) return { - SamplePercentile(q=q, axes=self._axes, tensor_id=self._tensor_id): p + SamplePercentile(q=q, axes=self._axes, member_id=self._member_id): p for q, p in zip(self._qs, ps) } @@ -226,7 +224,7 @@ class MeanPercentilesCalculator: def __init__( self, - tensor_id: TensorId, + member_id: MemberId, axes: Optional[Sequence[AxisId]], qs: Collection[float], ): @@ -234,18 +232,18 @@ def __init__( assert all(0.0 <= q <= 1.0 for q in qs) self._qs = sorted(set(qs)) self._axes = None if axes is None else tuple(axes) - self._tensor_id = tensor_id + self._member_id = member_id self._n: int = 0 self._estimates: Optional[Tensor] = None - def update(self, sample: UntiledSample): - tensor = sample.data[self._tensor_id] + def update(self, sample: Sample): + tensor = sample.members[self._member_id] sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( "float64", copy=False ) # reduced voxel count - n = int(np.prod(tensor.shape) / np.prod(sample_estimates.shape[1:])) + n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) if self._estimates is None: assert self._n == 0 @@ -266,7 +264,7 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: "Computed dataset percentiles naively by averaging percentiles of samples." ) return { - DatasetPercentile(q=q, axes=self._axes, tensor_id=self._tensor_id): e + DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e for q, e in zip(self._qs, self._estimates) } @@ -276,7 +274,7 @@ class CrickPercentilesCalculator: def __init__( self, - tensor_id: TensorId, + member_id: MemberId, axes: Optional[Sequence[AxisId]], qs: Collection[float], ): @@ -288,7 +286,7 @@ def __init__( assert axes is None or "_percentiles" not in axes self._qs = sorted(set(qs)) self._axes = None if axes is None else tuple(axes) - self._tensor_id = tensor_id + self._member_id = member_id self._digest: Optional[List[TDigest]] = None self._dims: Optional[Tuple[AxisId, ...]] = None self._indices: Optional[Iterator[Tuple[int, ...]]] = None @@ -309,11 +307,15 @@ def _initialize(self, tensor_sizes: PerAxis[int]): self._digest = [TDigest() for _ in range(d)] self._indices = product(*map(range, self._shape[1:])) - def update(self, sample: UntiledSample): - tensor = sample.data[self._tensor_id] + def update(self, part: Sample): + tensor = ( + part.members[self._member_id] + if isinstance(part, Sample) + else part.members[self._member_id].data + ) assert "_percentiles" not in tensor.dims if self._digest is None: - self._initialize(tensor.sizes) + self._initialize(tensor.tagged_shape) assert self._digest is not None assert self._indices is not None @@ -333,7 +335,7 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: ).reshape(self._shape) return { DatasetPercentile( - q=q, axes=self._axes, tensor_id=self._tensor_id + q=q, axes=self._axes, member_id=self._member_id ): Tensor(v, dims=self._dims[1:]) for q, v in zip(self._qs, vs) } @@ -350,12 +352,12 @@ def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: class NaiveSampleMeasureCalculator: """wrapper for measures to match interface of other sample measure calculators""" - def __init__(self, tensor_id: TensorId, measure: SampleMeasure): + def __init__(self, member_id: MemberId, measure: SampleMeasure): super().__init__() - self.tensor_name = tensor_id + self.tensor_name = member_id self.measure = measure - def compute(self, sample: UntiledSample) -> Dict[SampleMeasure, MeasureValue]: + def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: return {self.measure: self.measure.compute(sample)} @@ -408,7 +410,10 @@ def __init__( def has_dataset_measures(self): return self._current_dataset_measures is not None - def update(self, sample: Union[UntiledSample, Iterable[UntiledSample]]) -> None: + def update( + self, + sample: Union[Sample, Iterable[Sample]], + ) -> None: _ = self._update(sample) def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: @@ -422,7 +427,8 @@ def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: return self._current_dataset_measures def update_and_get_all( - self, sample: Union[UntiledSample, Iterable[UntiledSample]] + self, + sample: Union[Sample, Iterable[Sample]], ) -> Dict[Measure, MeasureValue]: """Returns sample as well as updated dataset statistics""" last_sample = self._update(sample) @@ -431,13 +437,11 @@ def update_and_get_all( return {**self._compute(last_sample), **self.finalize()} - def skip_update_and_get_all( - self, sample: UntiledSample - ) -> Dict[Measure, MeasureValue]: + def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: """Returns sample as well as previously computed dataset statistics""" return {**self._compute(sample), **self.finalize()} - def _compute(self, sample: UntiledSample) -> Dict[SampleMeasure, MeasureValue]: + def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: ret: Dict[SampleMeasure, MeasureValue] = {} for calc in self.sample_calculators: values = calc.compute(sample) @@ -445,16 +449,14 @@ def _compute(self, sample: UntiledSample) -> Dict[SampleMeasure, MeasureValue]: return ret - def _update( - self, sample: Union[UntiledSample, Iterable[UntiledSample]] - ) -> Optional[UntiledSample]: + def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: self.sample_count += 1 - samples = [sample] if isinstance(sample, UntiledSample) else sample + samples = [sample] if isinstance(sample, Sample) else sample last_sample = None - for s in samples: - last_sample = s + for el in samples: + last_sample = el for calc in self.dataset_calculators: - calc.update(s) + calc.update(el) self._current_dataset_measures = None return last_sample @@ -476,10 +478,10 @@ def get_measure_calculators( set() ) required_sample_percentiles: Dict[ - Tuple[TensorId, Optional[Tuple[AxisId, ...]]], Set[float] + Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] ] = {} required_dataset_percentiles: Dict[ - Tuple[TensorId, Optional[Tuple[AxisId, ...]]], Set[float] + Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] ] = {} for rm in required_measures: @@ -490,7 +492,7 @@ def get_measure_calculators( elif isinstance(rm, (SampleVar, SampleStd)): required_sample_mean_var_std.update( { - msv(axes=rm.axes, tensor_id=rm.tensor_id) + msv(axes=rm.axes, member_id=rm.member_id) for msv in (SampleMean, SampleStd, SampleVar) } ) @@ -498,17 +500,17 @@ def get_measure_calculators( elif isinstance(rm, (DatasetVar, DatasetStd)): required_dataset_mean_var_std.update( { - msv(axes=rm.axes, tensor_id=rm.tensor_id) + msv(axes=rm.axes, member_id=rm.member_id) for msv in (DatasetMean, DatasetStd, DatasetVar) } ) assert rm in required_dataset_mean_var_std elif isinstance(rm, SamplePercentile): - required_sample_percentiles.setdefault((rm.tensor_id, rm.axes), set()).add( + required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add( rm.q ) elif isinstance(rm, DatasetPercentile): - required_dataset_percentiles.setdefault((rm.tensor_id, rm.axes), set()).add( + required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add( rm.q ) else: @@ -519,11 +521,11 @@ def get_measure_calculators( # computed togehter with var and std continue - sample_calculators.append(MeanCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) + sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) for rm in required_sample_mean_var_std: sample_calculators.append( - MeanVarStdCalculator(tensor_id=rm.tensor_id, axes=rm.axes) + MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) ) for rm in required_dataset_means: @@ -531,28 +533,28 @@ def get_measure_calculators( # computed togehter with var and std continue - dataset_calculators.append(MeanCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) + dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) for rm in required_dataset_mean_var_std: dataset_calculators.append( - MeanVarStdCalculator(tensor_id=rm.tensor_id, axes=rm.axes) + MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) ) for (tid, axes), qs in required_sample_percentiles.items(): sample_calculators.append( - SamplePercentilesCalculator(tensor_id=tid, axes=axes, qs=qs) + SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs) ) for (tid, axes), qs in required_dataset_percentiles.items(): dataset_calculators.append( - DatasetPercentilesCalculator(tensor_id=tid, axes=axes, qs=qs) + DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs) ) return sample_calculators, dataset_calculators def compute_dataset_measures( - measures: Iterable[DatasetMeasure], dataset: Iterable[UntiledSample] + measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] ) -> Dict[DatasetMeasure, MeasureValue]: """compute all dataset `measures` for the given `dataset`""" sample_calculators, calculators = get_measure_calculators(measures) @@ -571,7 +573,7 @@ def compute_dataset_measures( def compute_sample_measures( - measures: Iterable[SampleMeasure], sample: UntiledSample + measures: Iterable[SampleMeasure], sample: Sample ) -> Dict[SampleMeasure, MeasureValue]: """compute all sample `measures` for the given `sample`""" calculators, dataset_calculators = get_measure_calculators(measures) @@ -585,7 +587,7 @@ def compute_sample_measures( def compute_measures( - measures: Iterable[Measure], dataset: Iterable[UntiledSample] + measures: Iterable[Measure], dataset: Iterable[Sample] ) -> Dict[Measure, MeasureValue]: """compute all `measures` for the given `dataset` sample measures are computed for the last sample in `dataset`""" diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index fa928eae..ec25b954 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -5,7 +5,8 @@ from typing import Dict, Optional, Protocol, Tuple, TypeVar, Union from .axis import AxisId -from .tensor import PerTensor, Tensor, TensorId +from .common import MemberId, PerMember +from .tensor import Tensor MeasureValue = Union[float, Tensor] @@ -13,12 +14,12 @@ # using Sample Protocol really only to avoid circular imports class SampleLike(Protocol): @property - def data(self) -> PerTensor[Tensor]: ... + def members(self) -> PerMember[Tensor]: ... @dataclass(frozen=True) class MeasureBase: - tensor_id: TensorId + member_id: MemberId @dataclass(frozen=True) @@ -45,7 +46,7 @@ class SampleMean(_Mean, SampleMeasureBase): """The mean value of a single tensor""" def compute(self, sample: SampleLike) -> MeasureValue: - tensor = sample.data[self.tensor_id] + tensor = sample.members[self.member_id] return tensor.mean(dim=self.axes) def __post_init__(self): @@ -71,7 +72,7 @@ class SampleStd(_Std, SampleMeasureBase): """The standard deviation of a single tensor""" def compute(self, sample: SampleLike) -> MeasureValue: - tensor = sample.data[self.tensor_id] + tensor = sample.members[self.member_id] return tensor.std(dim=self.axes) def __post_init__(self): @@ -97,7 +98,7 @@ class SampleVar(_Var, SampleMeasureBase): """The variance of a single tensor""" def compute(self, sample: SampleLike) -> MeasureValue: - tensor = sample.data[self.tensor_id] + tensor = sample.members[self.member_id] return tensor.var(dim=self.axes) def __post_init__(self): @@ -128,7 +129,7 @@ class SamplePercentile(_Percentile, SampleMeasureBase): """The `n`th percentile of a single tensor""" def compute(self, sample: SampleLike) -> MeasureValue: - tensor = sample.data[self.tensor_id] + tensor = sample.members[self.member_id] return tensor.quantile(self.q, dim=self.axes) def __post_init__(self): diff --git a/bioimageio/core/tensor.py b/bioimageio/core/tensor.py index 910aad01..2ed28424 100644 --- a/bioimageio/core/tensor.py +++ b/bioimageio/core/tensor.py @@ -10,7 +10,6 @@ Optional, Sequence, Tuple, - TypeVar, Union, cast, get_args, @@ -38,11 +37,6 @@ if TYPE_CHECKING: from numpy.typing import ArrayLike, NDArray -TensorId = v0_5.TensorId - -T = TypeVar("T") - -PerTensor = Mapping[TensorId, T] _ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" @@ -200,9 +194,14 @@ def dims(self): # TODO: rename to `axes`? @property def tagged_shape(self): - """alias for `sizes`""" + """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" return self.sizes + @property + def shape_tuple(self): + """Tuple of tensor axes lengths""" + return self._data.shape + @property def size(self): """Number of elements in the tensor. @@ -231,11 +230,6 @@ def sizes(self): """Ordered, immutable mapping from axis ids to axis lengths.""" return cast(Mapping[AxisId, int], self.data.sizes) - # @property - # def tagged_shape(self): - # """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" - # return cast(Mapping[AxisId, int], self.data.sizes) - def astype(self, dtype: DTypeStr, *, copy: bool = False): """Return tensor cast to `dtype` @@ -373,6 +367,9 @@ def quantile( or not isinstance(q, (float, int)) and all(qq <= 1.0 for qq in q) ) + assert dim is None or ( + (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) + ) return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) def resize_to( diff --git a/bioimageio/core/tensor_block.py b/bioimageio/core/tensor_block.py deleted file mode 100644 index 449dc596..00000000 --- a/bioimageio/core/tensor_block.py +++ /dev/null @@ -1,106 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Generator, Iterable, Optional, Tuple - -from typing_extensions import Self - -from bioimageio.core.axis import PerAxis -from bioimageio.core.common import ( - Halo, - HaloLike, - PadMode, - SliceInfo, - TotalNumberOfBlocks, -) - -from .block import Block, split_shape_into_blocks -from .stat_measures import Stat -from .tensor import Tensor - - -@dataclass(init=False) -class TensorBlock(Block): - """A block with data""" - - stat: Stat - """sample and dataset statistics""" - - data: Tensor - """the block's tensor""" - - def __init__( - self, - data: Tensor, - *, - inner_slice: PerAxis[SliceInfo], - halo: PerAxis[Halo], - block_number: int, - blocks_in_sample: int, - stat: Stat, - ): - super().__init__( - sample_shape=data.tagged_shape, - inner_slice=inner_slice, - halo=halo, - block_number=block_number, - blocks_in_sample=blocks_in_sample, - ) - self.data = data - self.stat = stat - - @property - def inner_data(self): - return {t: self.data[self.local_slice] for t in self.data} - - def __post_init__(self): - super().__post_init__() - for a, s in self.data.sizes.items(): - slice_ = self.inner_slice[a] - halo = self.halo[a] - assert s == slice_.stop - slice_.start + halo.left + halo.right, ( - s, - slice_, - halo, - ) - - @classmethod - def from_sample( - cls, - sample: Tensor, - block: Block, - *, - pad_mode: PadMode, - stat: Stat, - ) -> Self: - return cls( - data=sample[block.outer_slice].pad(block.padding, pad_mode), - inner_slice=block.inner_slice, - halo=block.halo, - block_number=block.block_number, - blocks_in_sample=block.blocks_in_sample, - stat=stat, - ) - - -def split_tensor_into_blocks( - sample: Tensor, - block_shape: PerAxis[int], - *, - halo: PerAxis[HaloLike], - stride: Optional[PerAxis[int]] = None, - pad_mode: PadMode, - stat: Stat, -) -> Tuple[TotalNumberOfBlocks, Generator[TensorBlock, Any, None]]: - """divide a sample tensor into tensor blocks.""" - n_blocks, block_gen = split_shape_into_blocks( - sample.tagged_shape, block_shape=block_shape, halo=halo - ) - return n_blocks, _tensor_block_generator( - sample, block_gen, pad_mode=pad_mode, stat=stat - ) - - -def _tensor_block_generator( - sample: Tensor, blocks: Iterable[Block], *, pad_mode: PadMode, stat: Stat -): - for block in blocks: - yield TensorBlock.from_sample(sample, block, pad_mode=pad_mode, stat=stat) diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py index ddc519f7..84e94d38 100644 --- a/bioimageio/core/utils/__init__.py +++ b/bioimageio/core/utils/__init__.py @@ -2,11 +2,6 @@ import sys from pathlib import Path -from ._digest_spec import get_axes_infos as get_axes_infos -from ._digest_spec import get_test_inputs as get_test_inputs -from ._digest_spec import get_test_outputs as get_test_outputs -from ._import_callable import import_callable as import_callable - if sys.version_info < (3, 9): def files(package_name: str): diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py deleted file mode 100644 index 09a290f8..00000000 --- a/bioimageio/core/utils/_digest_spec.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import Iterable, Union - -from bioimageio.core.tile import AbstractTile -from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 -from bioimageio.spec.utils import load_array - -from ..axis import AxisInfo -from ..sample import UntiledSample -from ..tensor import Tensor, TensorId - - -def get_axes_infos( - io_descr: Union[ - v0_4.InputTensorDescr, - v0_4.OutputTensorDescr, - v0_5.InputTensorDescr, - v0_5.OutputTensorDescr, - ] -): - return [ - ( - AxisInfo.create("i") - if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x") - else AxisInfo.create(a) - ) - for a in io_descr.axes - ] - - -def get_test_inputs(model: AnyModelDescr) -> UntiledSample: - if isinstance(model, v0_4.ModelDescr): - tensor_ids = [TensorId(t.name) for t in model.inputs] - else: - tensor_ids = [t.id for t in model.inputs] - - if isinstance(model, v0_4.ModelDescr): - arrays = [load_array(tt) for tt in model.test_inputs] - else: - arrays = [load_array(d.test_tensor) for d in model.inputs] - - axes = [get_axes_infos(t) for t in model.inputs] - return UntiledSample( - data={ - tid: Tensor.from_numpy(arr, dims=ax) - for tid, arr, ax in zip(tensor_ids, arrays, axes) - } - ) - - -def get_test_outputs(model: AnyModelDescr) -> UntiledSample: - if isinstance(model, v0_4.ModelDescr): - tensor_ids = [TensorId(t.name) for t in model.outputs] - else: - tensor_ids = [t.id for t in model.outputs] - - if isinstance(model, v0_4.ModelDescr): - arrays = [load_array(tt) for tt in model.test_outputs] - else: - arrays = [load_array(d.test_tensor) for d in model.outputs] - - axes = [get_axes_infos(t) for t in model.outputs] - - return UntiledSample( - data={ - tid: Tensor.from_numpy(arr, dims=ax) - for tid, arr, ax in zip(tensor_ids, arrays, axes) - } - ) - - -def get_abstract_output_tiles( - input_tiles: Iterable[AbstractTile], model: v0_5.ModelDescr -): - if not isinstance(model, v0_5.ModelDescr): - raise TypeError(f"get_abstract_output_tile() not implemented for {type(model)}") - - sample_sizes = model.get_output_tensor_sizes(input_tile.sample_sizes) - outer_sizes = model.get_output_tensor_sizes(input_tile.outer_sizes) - UntiledSample() - halo = { - t.id: {a.id: a.halo for a in t.axes if isinstance(a, v0_5.WithHalo)} - for t in model.outputs - if t.id in outer_sizes - } - inner_sizes = { - t: { - a: outer_sizes[t][a] - 2 * halo.get(t, {}).get(a, 0) for a in outer_sizes[t] - } - for t in outer_sizes - } - - return AbstractTile( - halo=halo, - tile_number=input_tile.tile_number, - tiles_in_sample=input_tile.tiles_in_sample, - stat={}, - ) diff --git a/bioimageio/core/utils/_import_callable.py b/bioimageio/core/utils/_import_callable.py deleted file mode 100644 index 3e1569b7..00000000 --- a/bioimageio/core/utils/_import_callable.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -import importlib.util -from functools import singledispatch -from typing import Any, Callable - -from typing_extensions import Unpack - -from bioimageio.spec._internal.io_utils import HashKwargs, download -from bioimageio.spec.common import FileSource -from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile -from bioimageio.spec.model.v0_5 import ( - ArchitectureFromFileDescr, - ArchitectureFromLibraryDescr, -) - - -@singledispatch -def import_callable(node: type, /) -> Callable[..., Any]: - raise TypeError(type(node)) - - -@import_callable.register -def import_from_dependency04(node: CallableFromDepencency) -> Callable[..., Any]: - module = importlib.import_module(node.module_name) - c = getattr(module, str(node.callable_name)) - if not callable(c): - raise ValueError(f"{node} (imported: {c}) is not callable") - - return c - - -@import_callable.register -def import_from_dependency05(node: ArchitectureFromLibraryDescr) -> Callable[..., Any]: - module = importlib.import_module(node.import_from) - c = getattr(module, str(node.callable)) - if not callable(c): - raise ValueError(f"{node} (imported: {c}) is not callable") - - return c - - -@import_callable.register -def import_from_file04(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): - return _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) - - -@import_callable.register -def import_from_file05(node: ArchitectureFromFileDescr, **kwargs: Unpack[HashKwargs]): - return _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) - - -def _import_from_file_impl( - source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] -): - local_file = download(source, **kwargs) - module_name = local_file.path.stem - importlib_spec = importlib.util.spec_from_file_location( - module_name, local_file.path - ) - if importlib_spec is None: - raise ImportError(f"Failed to import {module_name} from {source}.") - - dep = importlib.util.module_from_spec(importlib_spec) - importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? - return getattr(dep, callable_name) diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index 2fa69c2a..6986e6e5 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -29,7 +29,7 @@ def test_individual_normal_measure( ): data_id = TensorId("test_data") measure = getattr(stat_measures, "Sample" + name.title())( - axes=axes, tensor_id=data_id + axes=axes, member_id=data_id ) data = Tensor( np.random.random((5, 6, 3)), dims=(AxisId("x"), AxisId("y"), AxisId("c")) @@ -46,7 +46,7 @@ def test_individual_percentile_measure(axes: Optional[Tuple[AxisId, ...]]): qs = [0, 0.1, 0.5, 1.0] tid = TensorId("tensor") - measures = [SamplePercentile(tensor_id=tid, axes=axes, q=q) for q in qs] + measures = [SamplePercentile(member_id=tid, axes=axes, q=q) for q in qs] calcs, _ = get_measure_calculators(measures) assert len(calcs) == 1 calc = calcs[0] From d7e81a72f54131af77ff8c46948ad445da0dbcd6 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 8 Apr 2024 16:09:06 +0200 Subject: [PATCH 179/244] WIP PredictionPipeline done --- bioimageio/core/_op_base.py | 6 +- bioimageio/core/_prediction_pipeline.py | 276 +++++------- bioimageio/core/_resource_tests.py | 8 +- bioimageio/core/block.py | 12 +- bioimageio/core/block_meta.py | 153 +++++-- bioimageio/core/digest_spec.py | 216 ++++++++-- bioimageio/core/io.py | 61 ++- .../model_adapters/_pytorch_model_adapter.py | 4 +- bioimageio/core/prediction.py | 404 +----------------- bioimageio/core/proc_ops.py | 51 ++- bioimageio/core/sample.py | 128 +++++- .../core/weight_converter/torch/_onnx.py | 5 +- setup.py | 2 +- tests/test_prediction.py | 2 +- tests/test_prediction_pipeline.py | 2 +- ...t_prediction_pipeline_device_management.py | 8 +- tests/test_proc_ops.py | 70 +-- tests/test_stat_calculators.py | 12 +- tests/test_stat_measures.py | 13 +- tests/test_tensor.py | 8 +- 20 files changed, 658 insertions(+), 783 deletions(-) diff --git a/bioimageio/core/_op_base.py b/bioimageio/core/_op_base.py index 78d13b52..afc3226d 100644 --- a/bioimageio/core/_op_base.py +++ b/bioimageio/core/_op_base.py @@ -1,15 +1,15 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Collection +from typing import Collection, Union -from .sample import SampleBlock +from .sample import Sample, SampleBlockWithOrigin from .stat_measures import Measure @dataclass class Operator(ABC): @abstractmethod - def __call__(self, sample_block: SampleBlock) -> None: ... + def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: ... @property @abstractmethod diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index b9498095..b651befa 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -1,6 +1,5 @@ import collections.abc import warnings -from dataclasses import dataclass from types import MappingProxyType from typing import ( Any, @@ -20,38 +19,27 @@ from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.model.v0_5 import WeightsFormat -from .axis import AxisId, AxisInfo -from .block import Block -from .common import MemberId, PadMode, PerMember -from .digest_spec import get_axes_infos, get_block_meta +from .axis import AxisId, PerAxis +from .common import Halo, MemberId, PerMember +from .digest_spec import ( + get_block_transform, + get_input_halo, + get_member_ids, +) from .model_adapters import ModelAdapter, create_model_adapter from .model_adapters import get_weight_formats as get_weight_formats from .proc_ops import Processing from .proc_setup import setup_pre_and_postprocessing -from .sample import Sample, SampleBlock +from .sample import Sample, SampleBlock, SampleBlockWithOrigin from .stat_measures import DatasetMeasure, MeasureValue from .tensor import Tensor - -@dataclass -class MemberDescr: - id: MemberId - axes: Sequence[AxisInfo] - optional: bool - - Predict_IO = TypeVar( "Predict_IO", Sample, - SampleBlock, Iterable[Sample], - Iterable[SampleBlock], ) -# NDArray[Any], -# Sequence[Optional[NDArray[Any]]], -# Mapping[Union[MemberId, str], Optional[NDArray[Any]]], - class PredictionPipeline: """ @@ -67,10 +55,11 @@ def __init__( preprocessing: List[Processing], postprocessing: List[Processing], model_adapter: ModelAdapter, - ns: Union[ + default_ns: Union[ v0_5.ParameterizedSize.N, Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N], - ], + ] = 10, + default_batch_size: int = 1, ) -> None: super().__init__() if model_description.run_mode: @@ -83,56 +72,41 @@ def __init__( self._postprocessing = postprocessing self.model_description = model_description - if isinstance(ns, int): - if isinstance(model_description, v0_4.ModelDescr): - self.ns = None - else: - self.ns = { - (ipt.id, a.id): ns + if isinstance(model_description, v0_4.ModelDescr): + self._default_input_block_shape = {} + default_ns = {} + self._default_input_halo: PerMember[PerAxis[Halo]] = {} + self._block_transform = {} + else: + if isinstance(default_ns, int): + default_ns = { + (ipt.id, a.id): default_ns for ipt in model_description.inputs for a in ipt.axes if isinstance(a.size, v0_5.ParameterizedSize) } - else: - self.ns = ns - # if isinstance(model_description, v0_4.ModelDescr): - # self.default_sample_block_shape = None - # else: - - # self.default_sample_block_shape = model_description.get_tensor_sizes( - # ns, 1 - # ).inputs - - self.input_ids = tuple( - (MemberId(str(t.name)) if isinstance(t, v0_4.InputTensorDescr) else t.id) - for t in model_description.inputs - ) - self.inputs = collections.OrderedDict( - ( - tid, - MemberDescr( - id=tid, - axes=get_axes_infos(t), - optional=not isinstance(t, v0_4.InputTensorDescr) and t.optional, - ), - ) - for tid, t in zip(self.input_ids, model_description.inputs) - ) - self.output_ids = tuple( - (MemberId(str(t.name)) if isinstance(t, v0_4.OutputTensorDescr) else t.id) - for t in model_description.outputs - ) - self.outputs = collections.OrderedDict( - ( - tid, - MemberDescr( - id=tid, - axes=get_axes_infos(t), - optional=False, - ), + + self._default_input_block_shape = model_description.get_tensor_sizes( + default_ns, default_batch_size + ).inputs + + default_output_halo = { + t.id: { + a.id: Halo(a.halo, a.halo) + for a in t.axes + if isinstance(a, v0_5.WithHalo) + } + for t in model_description.outputs + } + self._default_input_halo = get_input_halo( + model_description, default_output_halo ) - for tid, t in zip(self.output_ids, model_description.outputs) - ) + self._block_transform = get_block_transform(model_description) + + self._default_ns = default_ns + + self._input_ids = get_member_ids(model_description.inputs) + self._output_ids = get_member_ids(model_description.outputs) self._adapter: ModelAdapter = model_adapter @@ -147,121 +121,66 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore self.unload() return False - # def predict_sample( - # self, - # sample: Sample, - # parameterized_size_n: Optional[int] = None, - # pad_mode: PadMode = "reflect", - # ) -> Sample: - # if parameterized_size_n is None: - # # TODO guess n - # parameterized_size_n = 10 - - # return Sample.from_blocks( - # map( - # self.predict_sample_block, - # sample.split_into_blocks( - # block_shapes={m: ipt.axes for m, ipt in self.inputs.items()}, - # halo={ - # m: ipt.axes.halo - # for m, ipt in self.inputs.items() - # if isinstance(ipt.axes, v0_5.WithHalo) - # }, - # pad_mode=pad_mode, - # ), - # ) - # ) - - # def predict_sample_block(self, inputs: SampleBlock) -> SampleBlock: - # self.apply_preprocessing(inputs) - # output = Block( - # data={ - # tid: out - # for tid, out in zip( - # self.output_ids, - # self._adapter.forward( - # *(inputs.data[t] for t in self.input_ids) - # ), - # ) - # if out is not None - # } - # ) - # self.apply_postprocessing(output) - # return output - - # else: - # assert_never(inputs) - - # return output - - def predict(self, inputs: Predict_IO) -> Predict_IO: - """Run model prediction **including** pre/postprocessing.""" - - if isinstance(inputs, Sample): - if isinstance(self.model_description, v0_4.ModelDescr): - raise NotImplementedError( - "predicting `Sample`s no implemented for model" - + f" {self.model_description.format_version}." - + " Please divide the sample into block. using `sample.split_into_blocks()`." + def _predict_sample_block_wo_procs( + self, sample_block: SampleBlockWithOrigin + ) -> SampleBlock: + output_meta = sample_block.get_transformed_meta(self._block_transform) + output = output_meta.with_data( + { + tid: out + for tid, out in zip( + self._output_ids, + self._adapter.forward( + *(sample_block.members[t] for t in self._input_ids) + ), ) + if out is not None + }, + stat=sample_block.stat, + ) + return output + + def predict_sample(self, sample: Sample) -> Sample: + self.apply_preprocessing(sample) + n_blocks, input_blocks = sample.split_into_blocks( + self._default_input_block_shape, + halo=self._default_input_halo, + pad_mode="reflect", + ) + input_blocks = tqdm( + input_blocks, + desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}", + unit="block", + total=n_blocks, + ) + predicted_blocks = map(self._predict_sample_block_wo_procs, input_blocks) + predicted_sample = Sample.from_blocks(predicted_blocks) + self.apply_postprocessing(predicted_sample) + return predicted_sample - assert self.ns is not None - n_blocks, block_metas = get_block_meta( - self.model_description, input_sample_shape=inputs.shape, ns=self.ns - ) + def predict( + self, + inputs: Predict_IO, + ) -> Predict_IO: + """Run model prediction **including** pre/postprocessing.""" - # for block_meta in tqdm(block_metas, desc=f"predict sample {inputs.id or ''} with {self.model_description.id or self.model_description.name}", unit="block", total=n_blocks): - input_halo = - Sample.from_blocks(inputs.split_into_blocks()) - # return Sample.from_blocks( - # map( - # self.predict, - n_blocks, blocks = inputs.split_into_blocks( - block_shapes=self.default_sample_block_shape, - halo={ - m: ipt.axes.halo - for m, ipt in self.inputs.items() - if isinstance(ipt.axes, v0_5.WithHalo) - }, - pad_mode="reflect", - ) - # ) - # ) + if isinstance(inputs, Sample): + return self.predict_sample(inputs) + elif isinstance(inputs, collections.abc.Iterable): + return (self.predict(ipt) for ipt in inputs) else: - return self.predict_sample_block(inputs) - - # if isinstance(inputs, collections.abc.Mapping): - # data = { - # tid: d - # for tid in self.input_ids - # if (d := inputs.get(tid, inputs.get(str(tid)))) is not None - # } - # else: - # if isinstance(inputs, (Tensor, np.ndarray)): - # inputs_seq = [inputs] - # else: - # inputs_seq = inputs - - # assert len(inputs_seq) == len(self.input_ids) - # data = { - # tid: d for tid, d in zip(self.input_ids, inputs_seq) if d is not None - # } - - # sample = Sample( - # data={ - # tid: Tensor.from_numpy(d, dims=self.inputs[tid].axes) - # for tid, d in data.items() - # } - # ) - # output = self.predict(sample) - # return {tid: out.data.data for } - - def apply_preprocessing(self, sample_block: SampleBlock) -> None: + assert_never(inputs) + + def apply_preprocessing( + self, sample_block: Union[Sample, SampleBlockWithOrigin] + ) -> None: """apply preprocessing in-place, also updates sample stats""" for op in self._preprocessing: op(sample_block) - def apply_postprocessing(self, sample_block: SampleBlock) -> None: + def apply_postprocessing( + self, sample_block: Union[Sample, SampleBlockWithOrigin] + ) -> None: """apply postprocessing in-place, also updates samples stats""" for op in self._postprocessing: op(sample_block) @@ -293,7 +212,7 @@ def create_prediction_pipeline( model_adapter: Optional[ModelAdapter] = None, ns: Union[ v0_5.ParameterizedSize.N, - Mapping[Tuple[TensorId, AxisId], v0_5.ParameterizedSize.N], + Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N], ] = 10, **deprecated_kwargs: Any, ) -> PredictionPipeline: @@ -318,10 +237,7 @@ def create_prediction_pipeline( weight_format_priority_order=weights_format and (weights_format,), ) - if isinstance(bioimageio_model, v0_4.ModelDescr): - input_ids = [MemberId(str(ipt.name)) for ipt in bioimageio_model.inputs] - else: - input_ids = [ipt.id for ipt in bioimageio_model.inputs] + input_ids = get_member_ids(bioimageio_model.inputs) def dataset(): for x in dataset_for_initial_statistics: @@ -343,5 +259,5 @@ def dataset(): model_adapter=model_adapter, preprocessing=preprocessing, postprocessing=postprocessing, - ns=ns, + default_ns=ns, ) diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index d14e836e..ffb3949a 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -4,9 +4,6 @@ import numpy as np -from bioimageio.core._prediction_pipeline import create_prediction_pipeline -from bioimageio.core.axis import AxisId, BatchSize -from bioimageio.core.utils import VERSION, get_test_inputs, get_test_outputs from bioimageio.spec import ( InvalidDescr, ResourceDescr, @@ -25,6 +22,11 @@ ValidationSummary, ) +from ._prediction_pipeline import create_prediction_pipeline +from .axis import AxisId, BatchSize +from .digest_spec import get_test_inputs, get_test_outputs +from .utils import VERSION + def test_model( source: Union[v0_5.ModelDescr, PermissiveFileSource], diff --git a/bioimageio/core/block.py b/bioimageio/core/block.py index 7d78e56c..c57d6955 100644 --- a/bioimageio/core/block.py +++ b/bioimageio/core/block.py @@ -5,12 +5,13 @@ Iterable, Optional, Tuple, + Union, ) from typing_extensions import Self from .axis import PerAxis -from .block_meta import BlockMeta, split_shape_into_blocks +from .block_meta import BlockMeta, LinearAxisTransform, split_shape_into_blocks from .common import ( Halo, HaloLike, @@ -21,7 +22,7 @@ from .tensor import Tensor -@dataclass(init=False) +@dataclass(init=False, frozen=True) class Block(BlockMeta): """A block/tile of a (larger) tensor""" @@ -44,7 +45,7 @@ def __init__( block_number=block_number, blocks_in_sample=blocks_in_sample, ) - self.data = data + object.__setattr__(self, "data", data) @property def inner_data(self): @@ -77,6 +78,11 @@ def from_sample_member( blocks_in_sample=block.blocks_in_sample, ) + def get_transformed( + self, new_axes: PerAxis[Union[LinearAxisTransform, int]] + ) -> Self: + raise NotImplementedError + def split_tensor_into_blocks( tensor: Tensor, diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py index 785abc6c..d2e39be0 100644 --- a/bioimageio/core/block_meta.py +++ b/bioimageio/core/block_meta.py @@ -10,8 +10,11 @@ List, Optional, Tuple, + Union, ) +from typing_extensions import Self + from .axis import AxisId, PerAxis from .common import ( BlockNumber, @@ -26,6 +29,16 @@ @dataclass +class LinearAxisTransform: + axis: AxisId + scale: float + offset: int + + def compute(self, s: int) -> int: + return int(s * self.scale) + self.offset + + +@dataclass(frozen=True) class BlockMeta: """Block meta data of a sample member (a tensor in a sample) @@ -104,7 +117,7 @@ def inner_slice_wo_overlap(self): stiched together trivially to form the original sample. This can also be used to calculate statistics - without overrepresenting edge regions.""" + without overrepresenting block edge regions.""" # TODO: update inner_slice_wo_overlap when adding block overlap return self.inner_slice @@ -116,52 +129,110 @@ def __post_init__(self): a in self.inner_slice for a in self.halo ), "halo has axes not present in block" - self.shape = { - a: s.stop - s.start + sum(self.halo[a]) for a, s in self.inner_slice.items() - } + object.__setattr__( #TODO: write as property + self, + "shape", + { + a: s.stop - s.start + sum(self.halo[a]) + for a, s in self.inner_slice.items() + }, + ) assert all( s <= self.sample_shape[a] for a, s in self.shape.items() ), "block larger than sample" - self.inner_shape = {a: s.stop - s.start for a, s in self.inner_slice.items()} - self.outer_slice = { - a: SliceInfo( - max( - 0, + object.__setattr__( #TODO: write as property + self, + "inner_shape", + {a: s.stop - s.start for a, s in self.inner_slice.items()}, + ) + object.__setattr__( #TODO: write as property + self, + "outer_slice", + { + a: SliceInfo( + max( + 0, + min( + self.inner_slice[a].start - self.halo[a].left, + self.sample_shape[a] + - self.inner_shape[a] + - self.halo[a].left, + ), + ), min( - self.inner_slice[a].start - self.halo[a].left, - self.sample_shape[a] - self.inner_shape[a] - self.halo[a].left, + self.sample_shape[a], + self.inner_slice[a].stop + self.halo[a].right, ), - ), - min( - self.sample_shape[a], - self.inner_slice[a].stop + self.halo[a].right, - ), - ) - for a in self.inner_slice - } - self.padding = { - a: PadWidth( - max( - 0, - self.halo[a].left - - (self.inner_slice[a].start + self.outer_slice[a].start), - ), - max( - 0, - self.halo[a].right - - (self.outer_slice[a].stop + self.inner_slice[a].stop), - ), - ) - for a in self.inner_slice - } - self.local_slice = { - a: SliceInfo( - self.padding[a].left, - self.padding[a].left + self.inner_shape[a], - ) - for a in self.inner_slice - } + ) + for a in self.inner_slice + }, + ) + object.__setattr__( #TODO: write as property + self, + "padding", + { + a: PadWidth( + max( + 0, + self.halo[a].left + - (self.inner_slice[a].start + self.outer_slice[a].start), + ), + max( + 0, + self.halo[a].right + - (self.outer_slice[a].stop + self.inner_slice[a].stop), + ), + ) + for a in self.inner_slice + }, + ) + object.__setattr__( #TODO: write as property + self, + "local_slice", + { + a: SliceInfo( + self.padding[a].left, + self.padding[a].left + self.inner_shape[a], + ) + for a in self.inner_slice + }, + ) + + def get_transformed( + self, new_axes: PerAxis[Union[LinearAxisTransform, int]] + ) -> Self: + return self.__class__( + sample_shape={ + a: ( + trf + if isinstance(trf, int) + else trf.compute(self.sample_shape[trf.axis]) + ) + for a, trf in new_axes.items() + }, + inner_slice={ + a: ( + SliceInfo(0, trf) + if isinstance(trf, int) + else SliceInfo( + trf.compute(self.inner_slice[trf.axis].start), + trf.compute(self.inner_slice[trf.axis].stop), + ) + ) + for a, trf in new_axes.items() + }, + halo={ + a: ( + Halo(0, 0) + if isinstance(trf, int) + else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) + ) + for a, trf in new_axes.items() + }, + block_number=self.block_number, + blocks_in_sample=self.blocks_in_sample, + ) def split_shape_into_blocks( diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py index 7594d0c1..46e2e4d2 100644 --- a/bioimageio/core/digest_spec.py +++ b/bioimageio/core/digest_spec.py @@ -2,9 +2,23 @@ import importlib.util from functools import singledispatch -from typing import Any, Callable, Dict, Iterable, Mapping, NamedTuple, Tuple, Union +from itertools import chain +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) -from typing_extensions import Unpack +from numpy.typing import NDArray +from typing_extensions import Unpack, assert_never from bioimageio.spec._internal.io_utils import HashKwargs, download from bioimageio.spec.common import FileSource @@ -14,19 +28,25 @@ ArchitectureFromFileDescr, ArchitectureFromLibraryDescr, ParameterizedSize, - TensorId, ) from bioimageio.spec.utils import load_array from .axis import AxisId, AxisInfo, PerAxis from .block_meta import split_multiple_shapes_into_blocks from .common import Halo, MemberId, PerMember, TotalNumberOfBlocks -from .sample import Sample, SampleBlockMeta, sample_block_meta_generator +from .sample import ( + LinearSampleAxisTransform, + Sample, + SampleBlockMeta, + sample_block_meta_generator, +) +from .stat_measures import Stat from .tensor import Tensor @singledispatch def import_callable(node: type, /) -> Callable[..., Any]: + """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" raise TypeError(type(node)) @@ -83,8 +103,8 @@ def get_axes_infos( v0_5.InputTensorDescr, v0_5.OutputTensorDescr, ] -): - """get a unified, simplified axis represenation from spec axes""" +) -> List[AxisInfo]: + """get a unified, simplified axis representation from spec axes""" return [ ( AxisInfo.create("i") @@ -95,13 +115,43 @@ def get_axes_infos( ] -def get_test_inputs(model: AnyModelDescr) -> Sample: - """returns a model's test input sample""" - if isinstance(model, v0_4.ModelDescr): - tensor_ids = [TensorId(t.name) for t in model.inputs] +def get_member_id( + tensor_description: Union[ + v0_4.InputTensorDescr, + v0_4.OutputTensorDescr, + v0_5.InputTensorDescr, + v0_5.OutputTensorDescr, + ] +) -> MemberId: + """get the normalized tensor ID, usable as a sample member ID""" + + if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): + return MemberId(tensor_description.name) + elif isinstance( + tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) + ): + return tensor_description.id else: - tensor_ids = [t.id for t in model.inputs] + assert_never(tensor_description) + + +def get_member_ids( + tensor_descriptions: Sequence[ + Union[ + v0_4.InputTensorDescr, + v0_4.OutputTensorDescr, + v0_5.InputTensorDescr, + v0_5.OutputTensorDescr, + ] + ] +) -> List[MemberId]: + """get normalized tensor IDs to be used as sample member IDs""" + return [get_member_id(descr) for descr in tensor_descriptions] + +def get_test_inputs(model: AnyModelDescr) -> Sample: + """returns a model's test input sample""" + member_ids = get_member_ids(model.inputs) if isinstance(model, v0_4.ModelDescr): arrays = [load_array(tt) for tt in model.test_inputs] else: @@ -110,18 +160,15 @@ def get_test_inputs(model: AnyModelDescr) -> Sample: axes = [get_axes_infos(t) for t in model.inputs] return Sample( members={ - tid: Tensor.from_numpy(arr, dims=ax) - for tid, arr, ax in zip(tensor_ids, arrays, axes) + m: Tensor.from_numpy(arr, dims=ax) + for m, arr, ax in zip(member_ids, arrays, axes) } ) def get_test_outputs(model: AnyModelDescr) -> Sample: """returns a model's test output sample""" - if isinstance(model, v0_4.ModelDescr): - tensor_ids = [TensorId(t.name) for t in model.outputs] - else: - tensor_ids = [t.id for t in model.outputs] + member_ids = get_member_ids(model.outputs) if isinstance(model, v0_4.ModelDescr): arrays = [load_array(tt) for tt in model.test_outputs] @@ -132,8 +179,8 @@ def get_test_outputs(model: AnyModelDescr) -> Sample: return Sample( members={ - tid: Tensor.from_numpy(arr, dims=ax) - for tid, arr, ax in zip(tensor_ids, arrays, axes) + m: Tensor.from_numpy(arr, dims=ax) + for m, arr, ax in zip(member_ids, arrays, axes) } ) @@ -142,8 +189,11 @@ class IO_SampleBlockMeta(NamedTuple): input: SampleBlockMeta output: SampleBlockMeta + def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): - halo: Dict[MemberId, Dict[AxisId, Halo]] = {} + """returns which halo input tensors need to be divided into blocks with such that + `output_halo` can be cropped from their outputs without intorducing gaps.""" + input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} outputs = {t.id: t for t in model.outputs} all_tensors = {**{t.id: t for t in model.inputs}, **outputs} @@ -153,30 +203,85 @@ def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]] for a, ah in th.items(): s = axes[a].size if not isinstance(s, v0_5.SizeReference): - raise ValueError(f"Unable to map output halo for {t}.{a} to an input axis") - + raise ValueError( + f"Unable to map output halo for {t}.{a} to an input axis" + ) axis = axes[a] ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] total_output_halo = sum(ah) total_input_halo = total_output_halo * axis.scale / ref_axis.scale - if total_input_halo != int(total_input_halo): - raise ValueError() - for lr in (ah.left, ah.right): - input_halo = - return halo - -def get_block_meta( + assert ( + total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 + ) + input_halo.setdefault(t, {})[a] = Halo( + int(total_input_halo // 2), int(total_input_halo // 2) + ) + + return input_halo + + +def get_block_transform(model: v0_5.ModelDescr): + """returns how a model's output tensor shapes relate to its input shapes""" + ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} + batch_axis_trf = None + for ipt in model.inputs: + for a in ipt.axes: + if a.type == "batch": + batch_axis_trf = LinearSampleAxisTransform( + axis=a.id, scale=1, offset=0, member=ipt.id + ) + break + if batch_axis_trf is not None: + break + axis_scales = { + t.id: {a.id: a.scale for a in t.axes} + for t in chain(model.inputs, model.outputs) + } + for out in model.outputs: + new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} + for a in out.axes: + if a.size is None: + assert a.type == "batch" + if batch_axis_trf is None: + raise ValueError( + "no batch axis found in any input tensor, but output tensor" + + f" '{out.id}' has one." + ) + s = batch_axis_trf + elif isinstance(a.size, int): + s = a.size + elif isinstance(a.size, v0_5.DataDependentSize): + s = -1 + elif isinstance(a.size, v0_5.SizeReference): + s = LinearSampleAxisTransform( + axis=a.size.axis_id, + scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, + offset=a.size.offset, + member=a.size.tensor_id, + ) + else: + assert_never(a.size) + + new_axes[a.id] = s + + ret[out.id] = new_axes + + return ret + + +def get_io_sample_block_metas( model: v0_5.ModelDescr, input_sample_shape: PerMember[PerAxis[int]], - ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize.N], + ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize.N], + batch_size: int = 1, ) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: """returns an iterable yielding meta data for corresponding input and output samples""" if not isinstance(model, v0_5.ModelDescr): raise TypeError(f"get_block_meta() not implemented for {type(model)}") - block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=1) + block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) input_block_shape = { t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} for t in {tt for tt, _ in block_axis_sizes.inputs} @@ -189,7 +294,12 @@ def get_block_meta( } for t in {tt for tt, _ in block_axis_sizes.outputs} } - output_halo = {t.id: {a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo)} for t in model.outputs} + output_halo = { + t.id: { + a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) + } + for t in model.outputs + } input_halo = get_input_halo(model, output_halo) output_sample_shape_data_dep = model.get_output_tensor_sizes(input_sample_shape) output_sample_shape = { @@ -209,7 +319,45 @@ def get_block_meta( return n_input_blocks, ( IO_SampleBlockMeta(ipt, out) for ipt, out in zip( - sample_block_meta_generator(input_blocks, origin=input_sample_shape), - sample_block_meta_generator(output_blocks, origin=output_sample_shape), + sample_block_meta_generator(input_blocks, sample_shape=input_sample_shape), + sample_block_meta_generator( + output_blocks, sample_shape=output_sample_shape + ), ) ) + + +def create_sample( + inputs: Sequence[NDArray[Any]], + model: AnyModelDescr, + stat: Optional[Stat] = None, +) -> Sample: + """Run prediction for a single set of input(s) with a bioimage.io model + + Args: + inputs: the input(s) for this model. + model: a bioimage.io model description + stat: dictionary with sample and dataset statistics (may be updated in-place!) + """ + if len(inputs) > len(model.inputs): + raise ValueError( + f"Got {len(inputs)} inputs, but expected at most {len(model.inputs)}" + ) + + missing_inputs = model.inputs[len(inputs) :] + for missing in missing_inputs: + if isinstance(missing, v0_4.InputTensorDescr): + raise ValueError(f"Missing input tensor '{missing.name}'") + elif isinstance(missing, v0_5.InputTensorDescr): + if not missing.optional: + raise ValueError(f"Missing non-optional input tensor '{missing.id}'") + else: + assert_never(missing) + + return Sample( + members={ + get_member_id(ipt): Tensor.from_numpy(array, dims=get_axes_infos(ipt)) + for ipt, array in zip(model.inputs, inputs) + }, + stat={} if stat is None else stat, + ) diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index 9bcab722..66ff371f 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -1,29 +1,60 @@ from pathlib import Path -from typing import Optional, Sequence +from typing import Any, Optional, Sequence import imageio +from loguru import logger +from numpy.typing import NDArray -from bioimageio.core.axis import Axis, AxisLike +from bioimageio.core.digest_spec import create_sample, get_axes_infos +from bioimageio.core.stat_measures import Stat +from bioimageio.spec.model import AnyModelDescr from bioimageio.spec.utils import load_array -from .tensor import Tensor, TensorId +from .axis import Axis, AxisLike +from .tensor import Tensor -def load_tensor( - path: Path, axes: Optional[Sequence[AxisLike]] = None, id: Optional[TensorId] = None -) -> Tensor: - +def load_image(path: Path, is_volume: bool) -> NDArray[Any]: + """load a single image as numpy array""" ext = path.suffix if ext == ".npy": - array = load_array(path) + return load_array(path) else: - is_volume = ( - True - if axes is None - else sum(Axis.create(a).type != "channel" for a in axes) > 2 + return imageio.volread(path) if is_volume else imageio.imread(path) + + +def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor: + array = load_image( + path, + is_volume=( + axes is None or sum(Axis.create(a).type != "channel" for a in axes) > 2 + ), + ) + + return Tensor.from_numpy(array, dims=axes) + + +def load_sample( + *paths: Path, + model: AnyModelDescr, + axes: Optional[Sequence[Sequence[AxisLike]]] = None, + stat: Optional[Stat] = None, +): + """load a single sample from `paths` that can be processed by `model`""" + + if axes is None: + axes = [get_axes_infos(ipt) for ipt in model.inputs[: len(paths)]] + logger.warning( + "loading paths with default input axes: {} (from {})", + axes, + model.id or model.name, ) - array = imageio.volread(path) if is_volume else imageio.imread(path) + elif len(axes) != len(paths): + raise ValueError(f"got {len(paths)} paths, but {len(axes)} axes hints!") - return Tensor.from_numpy( - array, dims=axes, id=TensorId(path.stem) if id is None else id + arrays = [load_image(p, is_volume=True) for p in paths] + return create_sample( + arrays, + model, + stat={} if stat is None else stat, ) diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index b3454582..8ab8c967 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -2,11 +2,11 @@ import warnings from typing import Any, List, Optional, Sequence, Tuple, Union -from bioimageio.core.tensor import Tensor -from bioimageio.core.utils import import_callable from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download +from ..digest_spec import import_callable +from ..tensor import Tensor from ._model_adapter import ModelAdapter try: diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index c3562d82..f984c324 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -27,168 +27,18 @@ from numpy.typing import NDArray from pydantic import HttpUrl from tqdm import tqdm +from typing_extensions import assert_never +from bioimageio.core.digest_spec import get_axes_infos, get_member_id, get_member_ids +from bioimageio.core.stat_measures import Stat from bioimageio.spec import ResourceDescr, load_description from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import AxisType from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline from .axis import AxisInfo -from .sample import UntiledSample -from .tensor import Tensor, TensorId - -# def _predict_with_tiling_impl( -# prediction_pipeline: PredictionPipeline, -# inputs: Sequence[xr.DataArray], -# outputs: Sequence[xr.DataArray], -# tile_shapes: Sequence[Dict[str, int]], -# halos: Sequence[Dict[str, int]], -# scales: Sequence[Dict[str, Tuple[int, int]]], -# verbose: bool = False, -# ): -# if len(inputs) > 1: -# raise NotImplementedError("Tiling with multiple inputs not implemented yet") - -# if len(outputs) > 1: -# raise NotImplementedError("Tiling with multiple outputs not implemented yet") - -# assert len(tile_shapes) == len(outputs) -# assert len(halos) == len(outputs) - -# input_ = inputs[0] -# output = outputs[0] -# tile_shape = tile_shapes[0] -# halo = halos[0] -# scaling = scales[0] - -# tiles = get_tiling(shape=input_.shape, tile_shape=tile_shape, halo=halo, input_axes=input_.dims, scaling=scaling) - -# def load_tile(tile): -# inp = input_[tile] -# # whether to pad on the right or left of the dim for the spatial dims -# # + placeholders for batch and axis dimension, where we don't pad -# pad_right = [tile[ax].start == 0 if ax in "xyz" else None for ax in input_.dims] -# return inp, pad_right - -# if verbose: -# shape = {ax: sh for ax, sh in zip(prediction_pipeline.input_specs[0].axes, input_.shape)} -# n_tiles = int(np.prod([np.ceil(float(shape[ax]) / (tsh - 2 * halo[ax])) for ax, tsh in tile_shape.items()])) -# tiles = tqdm(tiles, total=n_tiles, desc="prediction with tiling") - -# # we need to use padded prediction for the individual tiles in case the -# # border tiles don't match the requested tile shape -# padding = {ax: tile_shape[ax] for ax in input_axes if ax in "xyz"} -# padding["mode"] = "fixed" -# for outer_tile, inner_tile, local_tile in tiles: -# inp, pad_right = load_tile(outer_tile) -# out = predict_with_padding(prediction_pipeline, inp, padding, pad_right) -# assert len(out) == 1 -# out = out[0] -# output[inner_tile] = out[local_tile] - - -def predict_numpy( - prediction_pipeline: PredictionPipeline, - - """Run prediction for a single set of input(s) with a bioimage.io model - - Args: - prediction_pipeline: the prediction pipeline for the input model. - inputs: the input(s) for this model represented as xarray data or numpy nd array. - """ - return prediction_pipeline.forward(*tagged_data) - - -# def _parse_padding(padding, input_specs): -# if padding is None: # no padding -# return padding -# if len(input_specs) > 1: -# raise NotImplementedError("Padding for multiple inputs not yet implemented") - -# input_spec = input_specs[0] -# pad_keys = tuple(input_spec.axes) + ("mode",) - -# def check_padding(padding): -# assert all(k in pad_keys for k in padding.keys()) - -# if isinstance(padding, dict): # pre-defined padding -# check_padding(padding) -# elif isinstance(padding, bool): # determine padding from spec -# if padding: -# axes = input_spec.axes -# shape = input_spec.shape -# if isinstance(shape, list): # fixed padding -# padding = {ax: sh for ax, sh in zip(axes, shape) if ax in "xyz"} -# padding["mode"] = "fixed" -# else: # dynamic padding -# step = shape.step -# padding = {ax: st for ax, st in zip(axes, step) if ax in "xyz"} -# padding["mode"] = "dynamic" -# check_padding(padding) -# else: # no padding -# padding = None -# else: -# raise ValueError(f"Invalid argument for padding: {padding}") -# return padding - - -# def predict_with_padding( -# prediction_pipeline: PredictionPipeline, -# inputs: Union[xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray]], -# padding: Union[bool, Dict[str, int]] = True, -# pad_right: bool = True, -# ) -> List[xr.DataArray]: -# """Run prediction with padding for a single set of input(s) with a bioimage.io model. - -# Args: -# prediction_pipeline: the prediction pipeline for the input model. -# inputs: the input(s) for this model represented as xarray data. -# padding: the padding settings. Pass True to derive from the model spec. -# pad_right: whether to applying padding to the right or left of the input. -# """ -# if not padding: -# raise ValueError -# assert len(inputs) == len(prediction_pipeline.input_specs) - -# output_spec = prediction_pipeline.output_specs[0] -# if hasattr(output_spec.shape, "scale"): -# scale = dict(zip(output_spec.axes, output_spec.shape.scale)) -# offset = dict(zip(output_spec.axes, output_spec.shape.offset)) -# network_resizes = any(sc != 1 for ax, sc in scale.items() if ax in "xyz") or any( -# off != 0 for ax, off in offset.items() if ax in "xyz" -# ) -# else: -# network_resizes = False - -# padding = _parse_padding(padding, prediction_pipeline.input_specs) -# if not isinstance(inputs, (tuple, list)): -# inputs = [inputs] -# if not isinstance(padding, (tuple, list)): -# padding = [padding] -# assert len(padding) == len(prediction_pipeline.input_specs) -# inputs, crops = zip( -# *[ -# image_helper.pad(inp, spec.axes, p, pad_right=pad_right) -# for inp, spec, p in zip(inputs, prediction_pipeline.input_specs, padding) -# ] -# ) -# result = predict(prediction_pipeline, inputs) -# if network_resizes: -# crops = [ -# { -# ax: ( -# slice( -# crp.start if crp.start is None else int(crp.start * scale[ax] + 2 * offset[ax]), -# crp.stop if crp.stop is None else int(crp.stop * scale[ax] + 2 * offset[ax]), -# ) -# if ax in "xyz" -# else crp -# ) -# for ax, crp in crop.items() -# } -# for crop in crops -# ] -# return [res[crop] for res, crop in zip(result, crops)] +from .sample import Sample +from .tensor import Tensor # # simple heuristic to determine suitable shape from min and step @@ -205,247 +55,3 @@ def predict_numpy( # else: # shape.append(min_ax) # return shape - - -# def _parse_tiling(tiling, input_specs, output_specs): -# if tiling is None: # no tiling -# return tiling -# if len(input_specs) > 1: -# raise NotImplementedError("Tiling for multiple inputs not yet implemented") -# if len(output_specs) > 1: -# raise NotImplementedError("Tiling for multiple outputs not yet implemented") - -# input_spec = input_specs[0] -# output_spec = output_specs[0] -# if isinstance(output_spec.shape, list): -# assert isinstance(input_spec.shape, list) and input_spec.shape == output_spec.shape, ( -# "When predicting with tiling, output_shape and input_shape must either be specified " -# "explictly and must be identical, or output_shape must be" -# "implicitly defined by input_shape, otherwise relationship between " -# "input and output shapes per tile cannot be known." -# ) -# axes = input_spec.axes - -# def check_tiling(tiling): -# assert "halo" in tiling and "tile" in tiling -# spatial_axes = [ax for ax in axes if ax in "xyz"] -# halo = tiling["halo"] -# tile = tiling["tile"] -# scale = tiling.get("scale", dict()) -# assert all(halo.get(ax, 0) >= 0 for ax in spatial_axes) -# assert all(tile.get(ax, 0) > 0 for ax in spatial_axes) -# assert all(scale.get(ax, 1) > 0 for ax in spatial_axes) - -# if isinstance(tiling, dict) or (isinstance(tiling, bool) and tiling): -# # NOTE we assume here that shape in input and output are the same -# # for different input and output shapes, we should actually tile in the -# # output space and then request the corresponding input tiles -# # so we would need to apply the output scale and offset to the -# # input shape to compute the tile size and halo here -# shape = input_spec.shape -# if not isinstance(shape, list): -# shape = _determine_shape(shape.min, shape.step, axes) -# assert isinstance(shape, list) -# assert len(shape) == len(axes) - -# scale = None -# output_shape = output_spec.shape -# scale = [1.0] * len(output_spec.shape) if isinstance(output_shape, list) else output_shape.scale -# assert len(scale) == len(axes) - -# halo = output_spec.halo -# if not isinstance(halo, list): -# halo = [0] * len(axes) -# assert len(halo) == len(axes) - -# default_tiling = { -# "halo": {ax: ha for ax, ha in zip(axes, halo) if ax in "xyz"}, -# "tile": {ax: sh for ax, sh in zip(axes, shape) if ax in "xyz"}, -# "scale": {ax: sc for ax, sc in zip(axes, scale) if ax in "xyz"}, -# } - -# # override metadata defaults with provided dict -# if isinstance(tiling, dict): -# for key in ["halo", "tile", "scale"]: -# default_tiling[key].update(tiling.get(key, dict())) -# tiling = default_tiling -# check_tiling(tiling) - -# elif isinstance(tiling, bool) and not tiling: -# raise NotImplementedError("Should be unreachable") - -# else: -# raise ValueError(f"Invalid argument for tiling: {tiling}") - -# return tiling - - -# def predict_with_tiling( -# prediction_pipeline: PredictionPipeline, -# inputs: Union[xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray]], -# tiling: Union[bool, Dict[str, Dict[str, int]]] = True, -# verbose: bool = False, -# ) -> List[xr.DataArray]: -# """Run prediction with tiling for a single set of input(s) with a bioimage.io model. - -# Args: -# prediction_pipeline: the prediction pipeline for the input model. -# inputs: the input(s) for this model represented as xarray data. -# tiling: the tiling settings. Pass True to derive from the model spec. -# verbose: whether to print the prediction progress. -# """ -# if not tiling: -# raise ValueError("cannot call predict_with_tiling with tiling=False") -# assert len(inputs) == len(prediction_pipeline.input_specs) - -# tiling = _parse_tiling(tiling, prediction_pipeline.input_specs, prediction_pipeline.output_specs) -# if not isinstance(inputs, (list, tuple)): -# inputs = [inputs] -# named_inputs: OrderedDict[str, xr.DataArray] = collections.OrderedDict( -# **{ -# ipt_spec.name: xr.DataArray(ipt_data, dims=tuple(ipt_spec.axes)) -# for ipt_data, ipt_spec in zip(inputs, prediction_pipeline.input_specs) -# } -# ) - -# outputs = [] -# for output_spec in prediction_pipeline.output_specs: -# if isinstance(output_spec.shape, ImplicitOutputShape): -# scale = dict(zip(output_spec.axes, output_spec.shape.scale)) -# offset = dict(zip(output_spec.axes, output_spec.shape.offset)) - -# ref_input = named_inputs[output_spec.shape.reference_tensor] -# ref_input_shape = dict(zip(ref_input.dims, ref_input.shape)) -# output_shape = tuple(int(scale[ax] * ref_input_shape[ax] + 2 * offset[ax]) for ax in output_spec.axes) -# else: -# if len(inputs) > 1: -# raise NotImplementedError -# input_spec = prediction_pipeline.input_specs[0] -# if input_spec.axes != output_spec.axes: -# raise NotImplementedError("Tiling with a different output shape is not yet supported") -# out_axes = output_spec.axes -# fixed_shape = tuple(output_spec.shape) -# if not all(fixed_shape[out_axes.index(ax)] == tile_shape for ax, tile_shape in tiling["tile"].items()): -# raise NotImplementedError("Tiling with a different output shape is not yet supported") - -# output_shape = list(inputs[0].shape) -# chan_id = out_axes.index("c") -# if fixed_shape[chan_id] != output_shape[chan_id]: -# output_shape[chan_id] = fixed_shape[chan_id] -# output_shape = tuple(output_shape) - -# outputs.append(xr.DataArray(np.zeros(output_shape, dtype=output_spec.data_type), dims=tuple(output_spec.axes))) - -# _predict_with_tiling_impl( -# prediction_pipeline, -# list(named_inputs.values()), -# outputs, -# tile_shapes=[tiling["tile"]], # todo: update tiling for multiple inputs/outputs -# halos=[tiling["halo"]], -# scales=[tiling["scale"]], -# verbose=verbose, -# ) - -# return outputs - - -# def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling): -# if padding and tiling: -# raise ValueError("Only one of padding or tiling is supported") - -# input_data = image_helper.load_tensors(inputs, prediction_pipeline.input_specs) -# if padding is not None: -# result = predict_with_padding(prediction_pipeline, input_data, padding) -# elif tiling is not None: -# result = predict_with_tiling(prediction_pipeline, input_data, tiling) -# else: -# result = predict(prediction_pipeline, input_data) - -# assert isinstance(result, list) -# assert len(result) == len(outputs) -# for res, out in zip(result, outputs): -# image_helper.save_image(out, res) - - -# def predict_image( -# model_rdf: DescriptionSource, -# inputs: Union[Tuple[Path, ...], List[Path], Path], -# outputs: Union[Tuple[Path, ...], List[Path], Path], -# padding: Optional[Union[bool, Dict[str, int]]] = None, -# tiling: Optional[Union[bool, Dict[str, Dict[str, int]]]] = None, -# weight_format: Optional[str] = None, -# devices: Optional[List[str]] = None, -# verbose: bool = False, -# ): -# """Run prediction for a single set of input image(s) with a bioimage.io model. - -# Args: -# model_rdf: the bioimageio model. -# inputs: the filepaths for the input images. -# outputs: the filepaths for saving the input images. -# padding: the padding settings for prediction. By default no padding is used. -# tiling: the tiling settings for prediction. By default no tiling is used. -# weight_format: the weight format to use for predictions. -# devices: the devices to use for prediction. -# verbose: run prediction in verbose mode. -# """ -# if not isinstance(inputs, (tuple, list)): -# inputs = [inputs] - -# if not isinstance(outputs, (tuple, list)): -# outputs = [outputs] - -# model = load_description(model_rdf) -# assert isinstance(model, Model) -# if len(model.inputs) != len(inputs): -# raise ValueError -# if len(model.outputs) != len(outputs): -# raise ValueError - -# with create_prediction_pipeline( -# bioimageio_model=model, weight_format=weight_format, devices=devices -# ) as prediction_pipeline: -# _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling) - - -# def predict_images( -# model_rdf: DescriptionSource, -# inputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]], -# outputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]], -# padding: Optional[Union[bool, Dict[str, int]]] = None, -# tiling: Optional[Union[bool, Dict[str, Dict[str, int]]]] = None, -# weight_format: Optional[str] = None, -# devices: Optional[List[str]] = None, -# verbose: bool = False, -# ): -# """Predict multiple input images with a bioimage.io model. - -# Args: -# model_rdf: the bioimageio model. -# inputs: the filepaths for the input images. -# outputs: the filepaths for saving the input images. -# padding: the padding settings for prediction. By default no padding is used. -# tiling: the tiling settings for prediction. By default no tiling is used. -# weight_format: the weight format to use for predictions. -# devices: the devices to use for prediction. -# verbose: run prediction in verbose mode. -# """ - -# model = load_description(model_rdf) -# assert isinstance(model, Model) - -# with create_prediction_pipeline( -# bioimageio_model=model, weight_format=weight_format, devices=devices -# ) as prediction_pipeline: -# prog = zip(inputs, outputs) -# if verbose: -# prog = tqdm(prog, total=len(inputs)) - -# for inp, outp in prog: -# if not isinstance(inp, (tuple, list)): -# inp = [inp] - -# if not isinstance(outp, (tuple, list)): -# outp = [outp] - -# _predict_sample(prediction_pipeline, inp, outp, padding, tiling) diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 0676b87b..90da5942 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -16,7 +16,7 @@ import xarray as xr from typing_extensions import Self, assert_never -from bioimageio.core.sample import SampleBlock +from bioimageio.core.sample import Sample, SampleBlock, SampleBlockWithOrigin from bioimageio.spec.model import v0_4, v0_5 from ._op_base import Operator @@ -76,20 +76,23 @@ def required_measures(self) -> Collection[Measure]: # def produced_tensors(self) -> Set[MemberId]: # return {self.output} - def __call__(self, sample_block: SampleBlock) -> None: - input_tensor = sample_block.members[self.input] - output_tensor = self._apply(input_tensor, sample_block.stat) + def __call__(self, sample: Union[Sample, SampleBlock]) -> None: + input_tensor = sample.members[self.input] + output_tensor = self._apply(input_tensor, sample.stat) - if self.output in sample_block.blocks: + if self.output in sample.members: assert ( - sample_block.blocks[self.output].tagged_shape - == output_tensor.tagged_shape + sample.members[self.output].tagged_shape == output_tensor.tagged_shape ) - sample_block.blocks[self.output].data = output_tensor - else: - sample_block.blocks[self.output] = replace( - sample_block.blocks[self.input], data=output_tensor + + if isinstance(sample, Sample): + sample.members[self.output] = output_tensor + elif isinstance(sample, SampleBlock): + sample.blocks[self.output] = replace( + sample.blocks[self.input], data=output_tensor ) + else: + assert_never(sample) @abstractmethod def _apply(self, input: Tensor, stat: Stat) -> Tensor: ... @@ -103,8 +106,8 @@ class AddKnownDatasetStats(Operator): def required_measures(self) -> Set[Measure]: return set() - def __call__(self, sample_block: SampleBlock) -> None: - sample_block.stat.update(self.dataset_stats.items()) + def __call__(self, sample: Union[Sample, SampleBlock]) -> None: + sample.stat.update(self.dataset_stats.items()) # @dataclass @@ -136,7 +139,7 @@ def __call__(self, sample_block: SampleBlock) -> None: # else: # self._keep_updating_dataset_stats = self.keep_updating_dataset_stats -# def __call__(self, sample_block: SampleBlock> None: +# def __call__(self, sample_block: SampleBlockWithOrigin> None: # if self._keep_updating_dataset_stats: # sample.stat.update(self._stats_calculator.update_and_get_all(sample)) # else: @@ -166,18 +169,20 @@ def __post_init__(self): or not self.stats_calculator.has_dataset_measures ) - def __call__(self, sample_block: SampleBlock) -> None: - if sample_block.block_number != 0: - return # update stats with whole sample on first block + def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: + if isinstance(sample, SampleBlockWithOrigin): + # update stats with whole sample on first block + if sample.block_number != 0: + return + + origin = sample.origin + else: + origin = sample if self._keep_updating_dataset_stats: - sample_block.stat.update( - self.stats_calculator.update_and_get_all(sample_block.origin) - ) + sample.stat.update(self.stats_calculator.update_and_get_all(origin)) else: - sample_block.stat.update( - self.stats_calculator.skip_update_and_get_all(sample_block.origin) - ) + sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin)) @dataclass diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index c67ef769..8f92033f 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -2,7 +2,7 @@ from abc import abstractmethod from dataclasses import dataclass, field -from typing import Dict, Generic, Iterable, Optional, Tuple, TypeVar +from typing import Dict, Generic, Iterable, Optional, Tuple, TypeVar, Union import numpy as np from typing_extensions import Self @@ -10,14 +10,20 @@ from bioimageio.core.block import Block from .axis import PerAxis -from .block_meta import BlockMeta, split_multiple_shapes_into_blocks +from .block_meta import ( + BlockMeta, + LinearAxisTransform, + split_multiple_shapes_into_blocks, +) from .common import ( BlockNumber, + Halo, HaloLike, MemberId, PadMode, PerMember, SampleId, + SliceInfo, TotalNumberOfBlocks, ) from .stat_measures import Stat @@ -49,7 +55,7 @@ def split_into_blocks( halo: PerMember[PerAxis[HaloLike]], pad_mode: PadMode, broadcast: bool = False, - ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlock]]: + ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: assert not ( missing := [m for m in block_shapes if m not in self.members] ), f"`block_shapes` specified for unknown members: {missing}" @@ -76,6 +82,11 @@ def from_blocks( for member_blocks in sample_blocks: for m, block in member_blocks.blocks.items(): if m not in members: + if -1 in block.sample_shape.values(): + raise NotImplementedError( + "merging blocks with data dependent axis not yet implemented" + ) + members[m] = Tensor( np.full( tuple(block.sample_shape[a] for a in block.data.dims), @@ -97,7 +108,11 @@ def from_blocks( class SampleBlockBase(Generic[BlockT]): """base class for `SampleBlockMeta` and `SampleBlock`""" + sample_shape: PerMember[PerAxis[int]] + """the sample shape this block represents a part of""" + blocks: Dict[MemberId, BlockT] + """Individual tensor blocks comprising this sample block""" block_number: BlockNumber = field(init=False) """the n-th block of the sample""" @@ -123,48 +138,119 @@ def inner_shape(self) -> PerMember[PerAxis[int]]: def origin_shape(self) -> PerMember[PerAxis[int]]: ... +@dataclass +class LinearSampleAxisTransform(LinearAxisTransform): + member: MemberId + + @dataclass class SampleBlockMeta(SampleBlockBase[BlockMeta]): """Meta data of a dataset sample block""" - origin: PerMember[PerAxis[int]] - """the sampe shape the blocking for this block was based on""" + def get_transformed( + self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] + ) -> Self: + sample_shape = { + m: { + a: ( + trf + if isinstance(trf, int) + else trf.compute(self.origin_shape[trf.member][trf.axis]) + ) + for a, trf in new_axes[m].items() + } + for m in new_axes + } + return self.__class__( + blocks={ + m: BlockMeta( + sample_shape=sample_shape[m], + inner_slice={ + a: ( + SliceInfo(0, trf) + if isinstance(trf, int) + else SliceInfo( + trf.compute( + self.blocks[trf.member].inner_slice[trf.axis].start + ), + trf.compute( + self.blocks[trf.member].inner_slice[trf.axis].stop + ), + ) + ) + for a, trf in new_axes[m].items() + }, + halo={ + a: ( + Halo(0, 0) + if isinstance(trf, int) + else Halo( + self.blocks[trf.member].halo[trf.axis].left, + self.blocks[trf.member].halo[trf.axis].right, + ) + ) + for a, trf in new_axes[m].items() + }, + block_number=self.block_number, + blocks_in_sample=self.blocks_in_sample, + ) + for m in new_axes + }, + sample_shape=sample_shape, + ) - @property - def origin_shape(self): - return self.origin + def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock: + return SampleBlock( + sample_shape=self.sample_shape, + blocks={ + m: Block( + data[m], + inner_slice=b.inner_slice, + halo=b.halo, + block_number=b.block_number, + blocks_in_sample=b.blocks_in_sample, + ) + for m, b in self.blocks.items() + }, + stat=stat, + ) @dataclass class SampleBlock(SampleBlockBase[Block]): """A block of a dataset sample""" - origin: Sample - """the sample this sample black was taken from""" - - @property - def origin_shape(self): - return self.origin.shape + stat: Stat + """computed statistics""" @property def members(self) -> PerMember[Tensor]: """the sample block's tensors""" return {m: b.data for m, b in self.blocks.items()} - @property - def stat(self): - return self.origin.stat + def get_transformed_meta( + self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] + ) -> SampleBlockMeta: + return SampleBlockMeta( + blocks=dict(self.blocks), sample_shape=self.sample_shape + ).get_transformed(new_axes) + + +@dataclass +class SampleBlockWithOrigin(SampleBlock): + origin: Sample + """the sample this sample black was taken from""" def sample_block_meta_generator( blocks: Iterable[PerMember[BlockMeta]], *, - origin: PerMember[PerAxis[int]], + sample_shape: PerMember[PerAxis[int]], ): for member_blocks in blocks: yield SampleBlockMeta( blocks=dict(member_blocks), - origin=origin, + sample_shape=sample_shape, ) @@ -175,12 +261,14 @@ def sample_block_generator( pad_mode: PadMode, ): for member_blocks in blocks: - yield SampleBlock( + yield SampleBlockWithOrigin( blocks={ m: Block.from_sample_member( origin.members[m], block=member_blocks[m], pad_mode=pad_mode ) for m in origin.members }, + sample_shape=origin.shape, origin=origin, + stat=origin.stat, ) diff --git a/bioimageio/core/weight_converter/torch/_onnx.py b/bioimageio/core/weight_converter/torch/_onnx.py index 50c56fbd..12b31cca 100644 --- a/bioimageio/core/weight_converter/torch/_onnx.py +++ b/bioimageio/core/weight_converter/torch/_onnx.py @@ -7,12 +7,13 @@ import torch from numpy.testing import assert_array_almost_equal -from bioimageio.core.utils import get_test_inputs -from bioimageio.core.weight_converter.torch._utils import load_torch_model from bioimageio.spec import load_description from bioimageio.spec.common import InvalidDescr from bioimageio.spec.model import v0_4, v0_5 +from ...digest_spec import get_test_inputs +from ...weight_converter.torch._utils import load_torch_model + def add_onnx_weights( model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr", diff --git a/setup.py b/setup.py index 34944b0a..d4a5047a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ VERSION = json.loads(VERSION_FILE.read_text())["version"] -setup( +_ = setup( name="bioimageio.core", version=VERSION, description="Python functionality for the bioimage model zoo", diff --git a/tests/test_prediction.py b/tests/test_prediction.py index b2547171..822641d7 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -4,7 +4,6 @@ import numpy as np from numpy.testing import assert_array_almost_equal -from bioimageio.core.utils import get_test_inputs from bioimageio.spec import load_description from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr_v0_4 from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr_v0_4 @@ -52,6 +51,7 @@ def test_predict_image_with_weight_format( def _test_predict_with_padding(any_model: Path, tmp_path: Path): + from bioimageio.core.digest_spec import get_test_inputs from bioimageio.core.prediction import predict_image model = load_description(any_model) diff --git a/tests/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py index 61fde356..33de4bb4 100644 --- a/tests/test_prediction_pipeline.py +++ b/tests/test_prediction_pipeline.py @@ -2,7 +2,6 @@ from numpy.testing import assert_array_almost_equal -from bioimageio.core.utils import get_test_inputs, get_test_outputs from bioimageio.spec import load_description from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04 from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat @@ -10,6 +9,7 @@ def _test_prediction_pipeline(model_package: Path, weights_format: WeightsFormat): from bioimageio.core._prediction_pipeline import create_prediction_pipeline + from bioimageio.core.digest_spec import get_test_inputs, get_test_outputs bio_model = load_description(model_package) assert isinstance( diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py index b244ee26..064533d5 100644 --- a/tests/test_prediction_pipeline_device_management.py +++ b/tests/test_prediction_pipeline_device_management.py @@ -2,8 +2,6 @@ from numpy.testing import assert_array_almost_equal -from bioimageio.core import load_description -from bioimageio.core.utils import get_test_inputs, get_test_outputs from bioimageio.core.utils.testing import skip_on from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04 from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat @@ -16,11 +14,13 @@ class TooFewDevicesException(Exception): def _test_device_management(model_package: Path, weight_format: WeightsFormat): import torch + from bioimageio.core import load_description + from bioimageio.core._prediction_pipeline import create_prediction_pipeline + from bioimageio.core.digest_spec import get_test_inputs, get_test_outputs + if torch.cuda.device_count() == 0: raise TooFewDevicesException("Need at least one cuda device for this test") - from bioimageio.core._prediction_pipeline import create_prediction_pipeline - bio_model = load_description(model_package) assert isinstance(bio_model, (ModelDescr, ModelDescr04)) pred_pipe = create_prediction_pipeline( diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py index 9ec34ec6..431f2a79 100644 --- a/tests/test_proc_ops.py +++ b/tests/test_proc_ops.py @@ -6,24 +6,24 @@ from typing_extensions import TypeGuard from bioimageio.core.axis import AxisId -from bioimageio.core.sample import UntiledSample +from bioimageio.core.common import MemberId +from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import compute_measures from bioimageio.core.stat_measures import SampleMean, SamplePercentile, SampleStd -from bioimageio.core.tensor import TensorId @pytest.fixture(scope="module") def tid(): - return TensorId("data123") + return MemberId("data123") -def test_scale_linear(tid: TensorId): +def test_scale_linear(tid: MemberId): from bioimageio.core.proc_ops import ScaleLinear offset = xr.DataArray([1, 2, 42], dims=("c")) gain = xr.DataArray([1, 2, 3], dims=("c")) data = xr.DataArray(np.arange(6).reshape((1, 2, 3)), dims=("x", "y", "c")) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) op = ScaleLinear(input=tid, output=tid, offset=offset, gain=gain) op(sample) @@ -32,12 +32,12 @@ def test_scale_linear(tid: TensorId): xr.testing.assert_allclose(expected, sample.data[tid]) -def test_scale_linear_no_channel(tid: TensorId): +def test_scale_linear_no_channel(tid: MemberId): from bioimageio.core.proc_ops import ScaleLinear op = ScaleLinear(tid, tid, offset=1, gain=2) data = xr.DataArray(np.arange(6).reshape(2, 3), dims=("x", "y")) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) op(sample) expected = xr.DataArray(np.array([[1, 3, 5], [7, 9, 11]]), dims=("x", "y")) @@ -52,11 +52,11 @@ def is_iterable(val: Iterable[T], inner: Type[T]) -> TypeGuard[Iterable[T]]: return all(isinstance(x, inner) for x in val) -def test_zero_mean_unit_variance(tid: TensorId): +def test_zero_mean_unit_variance(tid: MemberId): from bioimageio.core.proc_ops import ZeroMeanUnitVariance data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) m = SampleMean(tid) std = SampleStd(tid) op = ZeroMeanUnitVariance(tid, tid, m, std) @@ -77,7 +77,7 @@ def test_zero_mean_unit_variance(tid: TensorId): xr.testing.assert_allclose(expected, sample.data[tid]) -def test_zero_mean_unit_variance_fixed(tid: TensorId): +def test_zero_mean_unit_variance_fixed(tid: MemberId): from bioimageio.core.proc_ops import FixedZeroMeanUnitVariance op = FixedZeroMeanUnitVariance( @@ -99,12 +99,12 @@ def test_zero_mean_unit_variance_fixed(tid: TensorId): ), dims=("b", "c", "x"), ) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) op(sample) xr.testing.assert_allclose(expected, sample.data[tid]) -def test_zero_mean_unit_across_axes(tid: TensorId): +def test_zero_mean_unit_across_axes(tid: MemberId): from bioimageio.core.proc_ops import ZeroMeanUnitVariance data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) @@ -115,7 +115,7 @@ def test_zero_mean_unit_across_axes(tid: TensorId): SampleMean(tid, (AxisId("x"), AxisId("y"))), SampleStd(tid, (AxisId("x"), AxisId("y"))), ) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) sample.stat = compute_measures(op.required_measures, [sample]) expected = xr.concat( @@ -125,7 +125,7 @@ def test_zero_mean_unit_across_axes(tid: TensorId): xr.testing.assert_allclose(expected, sample.data[tid]) -def test_zero_mean_unit_variance_fixed2(tid: TensorId): +def test_zero_mean_unit_variance_fixed2(tid: MemberId): from bioimageio.core.proc_ops import FixedZeroMeanUnitVariance np_data = np.arange(9).reshape(3, 3) @@ -135,25 +135,25 @@ def test_zero_mean_unit_variance_fixed2(tid: TensorId): op = FixedZeroMeanUnitVariance(tid, tid, mean=mean, std=std, eps=eps) data = xr.DataArray(np_data, dims=("x", "y")) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y")) op(sample) xr.testing.assert_allclose(expected, sample.data[tid]) -def test_binarize(tid: TensorId): +def test_binarize(tid: MemberId): from bioimageio.core.proc_ops import Binarize op = Binarize(tid, tid, threshold=14) data = xr.DataArray(np.arange(30).reshape((2, 3, 5)), dims=("x", "y", "c")) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) expected = xr.zeros_like(data) expected[{"x": slice(1, None)}] = 1 op(sample) xr.testing.assert_allclose(expected, sample.data[tid]) -def test_binarize2(tid: TensorId): +def test_binarize2(tid: MemberId): from bioimageio.core.proc_ops import Binarize shape = (3, 32, 32) @@ -164,18 +164,18 @@ def test_binarize2(tid: TensorId): threshold = 0.5 exp = xr.DataArray(np_data > threshold, dims=axes) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) binarize = Binarize(tid, tid, threshold=threshold) binarize(sample) xr.testing.assert_allclose(exp, sample.data[tid]) -def test_clip(tid: TensorId): +def test_clip(tid: MemberId): from bioimageio.core.proc_ops import Clip op = Clip(tid, tid, min=3, max=5) data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) expected = xr.DataArray( np.array([[3, 3, 3], [3, 4, 5], [5, 5, 5]]), dims=("x", "y") @@ -184,11 +184,11 @@ def test_clip(tid: TensorId): xr.testing.assert_equal(expected, sample.data[tid]) -def test_combination_of_op_steps_with_dims_specified(tid: TensorId): +def test_combination_of_op_steps_with_dims_specified(tid: MemberId): from bioimageio.core.proc_ops import ZeroMeanUnitVariance data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) op = ZeroMeanUnitVariance( tid, tid, @@ -234,7 +234,7 @@ def test_combination_of_op_steps_with_dims_specified(tid: TensorId): tuple(map(AxisId, "x")), ], ) -def test_scale_mean_variance(tid: TensorId, axes: Optional[Tuple[AxisId, ...]]): +def test_scale_mean_variance(tid: MemberId, axes: Optional[Tuple[AxisId, ...]]): from bioimageio.core.proc_ops import ScaleMeanVariance shape = (3, 32, 46) @@ -243,8 +243,8 @@ def test_scale_mean_variance(tid: TensorId, axes: Optional[Tuple[AxisId, ...]]): ipt_data = xr.DataArray(np_data, dims=ipt_axes) ref_data = xr.DataArray((np_data * 2) + 3, dims=ipt_axes) - op = ScaleMeanVariance(tid, tid, reference_tensor=TensorId("ref_name"), axes=axes) - sample = UntiledSample(data={tid: ipt_data, TensorId("ref_name"): ref_data}) + op = ScaleMeanVariance(tid, tid, reference_tensor=MemberId("ref_name"), axes=axes) + sample = Sample(data={tid: ipt_data, MemberId("ref_name"): ref_data}) sample.stat = compute_measures(op.required_measures, [sample]) op(sample) xr.testing.assert_allclose(ref_data, sample.data[tid]) @@ -254,7 +254,7 @@ def test_scale_mean_variance(tid: TensorId, axes: Optional[Tuple[AxisId, ...]]): "axes_str", [None, "cy", "y", "yx"], ) -def test_scale_mean_variance_per_channel(tid: TensorId, axes_str: Optional[str]): +def test_scale_mean_variance_per_channel(tid: MemberId, axes_str: Optional[str]): from bioimageio.core.proc_ops import ScaleMeanVariance axes = None if axes_str is None else tuple(map(AxisId, axes_str)) @@ -268,8 +268,8 @@ def test_scale_mean_variance_per_channel(tid: TensorId, axes_str: Optional[str]) np_ref_data = np.stack([d * i + i for i, d in enumerate(np_data, start=2)]) ref_data = xr.DataArray(np_ref_data, dims=ipt_axes) - op = ScaleMeanVariance(tid, tid, reference_tensor=TensorId("ref_name"), axes=axes) - sample = UntiledSample(data={tid: ipt_data, TensorId("ref_name"): ref_data}) + op = ScaleMeanVariance(tid, tid, reference_tensor=MemberId("ref_name"), axes=axes) + sample = Sample(data={tid: ipt_data, MemberId("ref_name"): ref_data}) sample.stat = compute_measures(op.required_measures, [sample]) op(sample) @@ -282,13 +282,13 @@ def test_scale_mean_variance_per_channel(tid: TensorId, axes_str: Optional[str]) xr.testing.assert_allclose(ref_data, sample.data[tid]) -def test_scale_range(tid: TensorId): +def test_scale_range(tid: MemberId): from bioimageio.core.proc_ops import ScaleRange op = ScaleRange(tid, tid) np_data = np.arange(9).reshape(3, 3).astype("float32") data = xr.DataArray(np_data, dims=("x", "y")) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) sample.stat = compute_measures(op.required_measures, [sample]) eps = 1.0e-6 @@ -301,7 +301,7 @@ def test_scale_range(tid: TensorId): np.testing.assert_allclose(expected, sample.data[tid]) -def test_scale_range_axes(tid: TensorId): +def test_scale_range_axes(tid: MemberId): from bioimageio.core.proc_ops import ScaleRange lower_percentile = SamplePercentile(tid, 1, axes=(AxisId("x"), AxisId("y"))) @@ -310,7 +310,7 @@ def test_scale_range_axes(tid: TensorId): np_data = np.arange(18).reshape((2, 3, 3)).astype("float32") data = xr.DataArray(np_data, dims=("c", "x", "y")) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) sample.stat = compute_measures(op.required_measures, [sample]) eps = 1.0e-6 @@ -324,14 +324,14 @@ def test_scale_range_axes(tid: TensorId): np.testing.assert_allclose(expected, sample.data[tid]) -def test_sigmoid(tid: TensorId): +def test_sigmoid(tid: MemberId): from bioimageio.core.proc_ops import Sigmoid shape = (3, 32, 32) axes = ("c", "y", "x") np_data = np.random.rand(*shape) data = xr.DataArray(np_data, dims=axes) - sample = UntiledSample(data={tid: data}) + sample = Sample(data={tid: data}) sigmoid = Sigmoid(tid, tid) sigmoid(sample) diff --git a/tests/test_stat_calculators.py b/tests/test_stat_calculators.py index b1468609..bd93b282 100644 --- a/tests/test_stat_calculators.py +++ b/tests/test_stat_calculators.py @@ -5,23 +5,23 @@ from xarray.testing import assert_allclose # pyright: ignore[reportUnknownVariableType] from bioimageio.core.axis import AxisId -from bioimageio.core.sample import UntiledSample +from bioimageio.core.common import MemberId +from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import MeanVarStdCalculator from bioimageio.core.stat_measures import ( DatasetMean, DatasetStd, DatasetVar, ) -from bioimageio.core.tensor import Tensor, TensorId +from bioimageio.core.tensor import Tensor -def create_random_dataset(tid: TensorId, axes: Tuple[AxisId, ...]): +def create_random_dataset(tid: MemberId, axes: Tuple[AxisId, ...]): n = 3 sizes = list(range(n, len(axes) + 1)) data = np.asarray(np.random.rand(*sizes)) ds = [ - UntiledSample(data={tid: Tensor(data[i : i + 1], dims=axes, id=tid)}) - for i in range(n) + Sample(data={tid: Tensor(data[i : i + 1], dims=axes, id=tid)}) for i in range(n) ] return Tensor(data, dims=axes), ds @@ -35,7 +35,7 @@ def create_random_dataset(tid: TensorId, axes: Tuple[AxisId, ...]): ], ) def test_mean_var_std_calculator(axes: Union[None, str, Tuple[str, ...]]): - tid = TensorId("tensor") + tid = MemberId("tensor") axes = tuple(map(AxisId, ("batch", "channel", "x", "y"))) data, ds = create_random_dataset(tid, axes) expected_mean = data.mean(axes) diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index 6986e6e5..53de1017 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -7,13 +7,14 @@ from bioimageio.core import stat_measures from bioimageio.core.axis import AxisId -from bioimageio.core.sample import UntiledSample +from bioimageio.core.common import MemberId +from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import ( SamplePercentilesCalculator, get_measure_calculators, ) from bioimageio.core.stat_measures import SamplePercentile -from bioimageio.core.tensor import Tensor, TensorId +from bioimageio.core.tensor import Tensor @pytest.mark.parametrize( @@ -27,7 +28,7 @@ def test_individual_normal_measure( name: str, axes: Optional[Tuple[AxisId, AxisId]], ): - data_id = TensorId("test_data") + data_id = MemberId("test_data") measure = getattr(stat_measures, "Sample" + name.title())( axes=axes, member_id=data_id ) @@ -36,7 +37,7 @@ def test_individual_normal_measure( ) expected = getattr(data, name)(dim=axes) - sample = UntiledSample(data={data_id: data}) + sample = Sample(data={data_id: data}) actual = measure.compute(sample) xr.testing.assert_allclose(expected, actual) @@ -44,7 +45,7 @@ def test_individual_normal_measure( @pytest.mark.parametrize("axes", [None, (AxisId("x"), AxisId("y"))]) def test_individual_percentile_measure(axes: Optional[Tuple[AxisId, ...]]): qs = [0, 0.1, 0.5, 1.0] - tid = TensorId("tensor") + tid = MemberId("tensor") measures = [SamplePercentile(member_id=tid, axes=axes, q=q) for q in qs] calcs, _ = get_measure_calculators(measures) @@ -55,7 +56,7 @@ def test_individual_percentile_measure(axes: Optional[Tuple[AxisId, ...]]): data = Tensor( np.random.random((5, 6, 3)), dims=(AxisId("x"), AxisId("y"), AxisId("c")) ) - actual = calc.compute(UntiledSample(data={tid: data})) + actual = calc.compute(Sample(data={tid: data})) for m in measures: expected = data.quantile(q=m.q, dim=m.axes) xr.testing.assert_allclose(expected, actual[m]) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 076e0961..7f3d67cc 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -3,7 +3,7 @@ import xarray as xr from xarray.testing import assert_equal # pyright: ignore[reportUnknownVariableType] -from bioimageio.core import AxisId, Tensor, TensorId +from bioimageio.core import AxisId, Tensor @pytest.mark.parametrize( @@ -12,7 +12,7 @@ ) def test_transpose_tensor_2d(axes: str): - tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None, id=TensorId("id")) + tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None) transposed = tensor.transpose([AxisId(a) for a in axes]) assert transposed.ndim == len(axes) @@ -22,7 +22,7 @@ def test_transpose_tensor_2d(axes: str): ["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"], ) def test_transpose_tensor_3d(axes: str): - tensor = Tensor.from_numpy(np.random.rand(64, 64, 64), dims=None, id=TensorId("id")) + tensor = Tensor.from_numpy(np.random.rand(64, 64, 64), dims=None) transposed = tensor.transpose([AxisId(a) for a in axes]) assert transposed.ndim == len(axes) @@ -37,5 +37,5 @@ def test_crop_and_pad(): def test_some_magic_ops(): - tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None, id=TensorId("id")) + tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None) assert tensor + 2 == 2 + tensor From f42141d6bcfe698fd81420cf51805736d02121cf Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 8 Apr 2024 16:09:25 +0200 Subject: [PATCH 180/244] add test_digest_spec --- tests/test_digest_spec.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/test_digest_spec.py diff --git a/tests/test_digest_spec.py b/tests/test_digest_spec.py new file mode 100644 index 00000000..eb810b4a --- /dev/null +++ b/tests/test_digest_spec.py @@ -0,0 +1,39 @@ +from bioimageio.spec import load_description +from bioimageio.spec.model import v0_5 + + +# TODO: don't just test with unet2d_nuclei_broad_model +def test_get_block_transform(unet2d_nuclei_broad_model: str): + from bioimageio.core.axis import AxisId + from bioimageio.core.common import MemberId + from bioimageio.core.digest_spec import ( + get_block_transform, + get_io_sample_block_metas, + ) + + model = load_description(unet2d_nuclei_broad_model) + assert isinstance(model, v0_5.ModelDescr) + block_transform = get_block_transform(model) + + ns = { + (ipt.id, a.id): 1 + for ipt in model.inputs + for a in ipt.axes + if isinstance(a.size, v0_5.ParameterizedSize) + } + + _, blocks = get_io_sample_block_metas( + model, + input_sample_shape={ + MemberId("raw"): { + AxisId("batch"): 3, + AxisId("channel"): 1, + AxisId("x"): 4000, + AxisId("y"): 3000, + } + }, + ns=ns, + ) + for ipt_block, out_block in blocks: + trf_block = ipt_block.get_transformed(block_transform) + assert out_block == trf_block From f0197688e78c7cfeb8310bbc508cc172f13f5566 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 11:41:53 +0200 Subject: [PATCH 181/244] rename some util funcs --- bioimageio/core/digest_spec.py | 6 ++-- bioimageio/core/io.py | 8 ++--- bioimageio/core/prediction.py | 57 ---------------------------------- 3 files changed, 7 insertions(+), 64 deletions(-) diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py index 46e2e4d2..0dc92670 100644 --- a/bioimageio/core/digest_spec.py +++ b/bioimageio/core/digest_spec.py @@ -327,15 +327,15 @@ def get_io_sample_block_metas( ) -def create_sample( +def create_sample_for_model( inputs: Sequence[NDArray[Any]], model: AnyModelDescr, stat: Optional[Stat] = None, ) -> Sample: - """Run prediction for a single set of input(s) with a bioimage.io model + """Create a sample from a single set of input(s) for a specific bioimage.io model Args: - inputs: the input(s) for this model. + inputs: the input(s) constituting a single sample. model: a bioimage.io model description stat: dictionary with sample and dataset statistics (may be updated in-place!) """ diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index 66ff371f..8ca8b02f 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -5,12 +5,12 @@ from loguru import logger from numpy.typing import NDArray -from bioimageio.core.digest_spec import create_sample, get_axes_infos -from bioimageio.core.stat_measures import Stat from bioimageio.spec.model import AnyModelDescr from bioimageio.spec.utils import load_array from .axis import Axis, AxisLike +from .digest_spec import create_sample_for_model, get_axes_infos +from .stat_measures import Stat from .tensor import Tensor @@ -34,7 +34,7 @@ def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor return Tensor.from_numpy(array, dims=axes) -def load_sample( +def load_sample_for_model( *paths: Path, model: AnyModelDescr, axes: Optional[Sequence[Sequence[AxisLike]]] = None, @@ -53,7 +53,7 @@ def load_sample( raise ValueError(f"got {len(paths)} paths, but {len(axes)} axes hints!") arrays = [load_image(p, is_volume=True) for p in paths] - return create_sample( + return create_sample_for_model( arrays, model, stat={} if stat is None else stat, diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index f984c324..e69de29b 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -1,57 +0,0 @@ -"""coming soon""" - -# TODO: update -import collections.abc -import os -from fractions import Fraction -from itertools import product -from pathlib import Path -from typing import ( - Any, - Collection, - Dict, - Hashable, - Iterator, - List, - Mapping, - NamedTuple, - Optional, - OrderedDict, - Sequence, - Tuple, - Union, -) - -import numpy as np -import xarray as xr -from numpy.typing import NDArray -from pydantic import HttpUrl -from tqdm import tqdm -from typing_extensions import assert_never - -from bioimageio.core.digest_spec import get_axes_infos, get_member_id, get_member_ids -from bioimageio.core.stat_measures import Stat -from bioimageio.spec import ResourceDescr, load_description -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import AxisType - -from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline -from .axis import AxisInfo -from .sample import Sample -from .tensor import Tensor - - -# # simple heuristic to determine suitable shape from min and step -# def _determine_shape(min_shape, step, axes): -# is3d = "z" in axes -# min_len = 64 if is3d else 256 -# shape = [] -# for ax, min_ax, step_ax in zip(axes, min_shape, step): -# if ax in "zyx" and step_ax > 0: -# len_ax = min_ax -# while len_ax < min_len: -# len_ax += step_ax -# shape.append(len_ax) -# else: -# shape.append(min_ax) -# return shape From acfab41063e86fd4459f062442cfc0f3b9731f05 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 11:42:31 +0200 Subject: [PATCH 182/244] add BlockedOperator --- bioimageio/core/_op_base.py | 14 ++++- bioimageio/core/_prediction_pipeline.py | 70 ++++++++++++++++++++----- bioimageio/core/prediction.py | 7 +++ bioimageio/core/proc_ops.py | 6 +-- bioimageio/core/proc_setup.py | 16 +++--- 5 files changed, 88 insertions(+), 25 deletions(-) diff --git a/bioimageio/core/_op_base.py b/bioimageio/core/_op_base.py index afc3226d..55c961bc 100644 --- a/bioimageio/core/_op_base.py +++ b/bioimageio/core/_op_base.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Collection, Union -from .sample import Sample, SampleBlockWithOrigin +from .sample import Sample, SampleBlock, SampleBlockWithOrigin from .stat_measures import Measure @@ -14,3 +14,15 @@ def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: ... @property @abstractmethod def required_measures(self) -> Collection[Measure]: ... + + +@dataclass +class BlockedOperator(Operator, ABC): + @abstractmethod + def __call__( + self, sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] + ) -> None: ... + + @property + @abstractmethod + def required_measures(self) -> Collection[Measure]: ... diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index b651befa..2307b2ad 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -19,6 +19,7 @@ from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.model.v0_5 import WeightsFormat +from ._op_base import BlockedOperator from .axis import AxisId, PerAxis from .common import Halo, MemberId, PerMember from .digest_spec import ( @@ -121,9 +122,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore self.unload() return False - def _predict_sample_block_wo_procs( - self, sample_block: SampleBlockWithOrigin + def predict_sample_block( + self, + sample_block: SampleBlockWithOrigin, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, ) -> SampleBlock: + if not skip_preprocessing: + self.apply_preprocessing(sample_block) + output_meta = sample_block.get_transformed_meta(self._block_transform) output = output_meta.with_data( { @@ -138,10 +145,20 @@ def _predict_sample_block_wo_procs( }, stat=sample_block.stat, ) + if not skip_postprocessing: + self.apply_postprocessing(output) + return output - def predict_sample(self, sample: Sample) -> Sample: - self.apply_preprocessing(sample) + def predict_sample( + self, + sample: Sample, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + ) -> Sample: + if not skip_preprocessing: + self.apply_preprocessing(sample) + n_blocks, input_blocks = sample.split_into_blocks( self._default_input_block_shape, halo=self._default_input_halo, @@ -153,37 +170,62 @@ def predict_sample(self, sample: Sample) -> Sample: unit="block", total=n_blocks, ) - predicted_blocks = map(self._predict_sample_block_wo_procs, input_blocks) + predicted_blocks = ( + self.predict_sample_block( + b, skip_preprocessing=True, skip_postprocessing=True + ) + for b in input_blocks + ) predicted_sample = Sample.from_blocks(predicted_blocks) - self.apply_postprocessing(predicted_sample) + if not skip_postprocessing: + self.apply_postprocessing(predicted_sample) + return predicted_sample def predict( self, inputs: Predict_IO, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, ) -> Predict_IO: """Run model prediction **including** pre/postprocessing.""" if isinstance(inputs, Sample): - return self.predict_sample(inputs) + return self.predict_sample( + inputs, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + ) elif isinstance(inputs, collections.abc.Iterable): - return (self.predict(ipt) for ipt in inputs) + return ( + self.predict( + ipt, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + ) + for ipt in inputs + ) else: assert_never(inputs) - def apply_preprocessing( - self, sample_block: Union[Sample, SampleBlockWithOrigin] - ) -> None: + def apply_preprocessing(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: """apply preprocessing in-place, also updates sample stats""" for op in self._preprocessing: - op(sample_block) + op(sample) def apply_postprocessing( - self, sample_block: Union[Sample, SampleBlockWithOrigin] + self, sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] ) -> None: """apply postprocessing in-place, also updates samples stats""" for op in self._postprocessing: - op(sample_block) + if isinstance(sample, (Sample, SampleBlockWithOrigin)): + op(sample) + elif not isinstance(op, BlockedOperator): + raise NotImplementedError( + "block wise update of output statistics not yet implemented" + ) + else: + op(sample) def load(self): """ diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index e69de29b..3d10d31d 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -0,0 +1,7 @@ +"""convenience functions for prediction coming soon. +For now, please use `create_prediction_pipeline` to get a `PredictionPipeline` +and then `PredictionPipeline.predict(sample)` +e..g load samples with core.io.load_sample_for_model() +""" + +# TODO: add convenience functions for predictions diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 90da5942..9880a818 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -19,7 +19,7 @@ from bioimageio.core.sample import Sample, SampleBlock, SampleBlockWithOrigin from bioimageio.spec.model import v0_4, v0_5 -from ._op_base import Operator +from ._op_base import BlockedOperator, Operator from .axis import AxisId from .common import DTypeStr, MemberId from .stat_calculators import StatsCalculator @@ -60,7 +60,7 @@ def convert_axis_ids( @dataclass -class _SimpleOperator(Operator, ABC): +class _SimpleOperator(BlockedOperator, ABC): input: MemberId output: MemberId @@ -99,7 +99,7 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: ... @dataclass -class AddKnownDatasetStats(Operator): +class AddKnownDatasetStats(BlockedOperator): dataset_stats: Mapping[DatasetMeasure, MeasureValue] @property diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index 64168ce9..947ea0c2 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -67,13 +67,15 @@ def setup_pre_and_postprocessing( keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, ), ) - post.insert( - 0, - UpdateStats( - StatsCalculator(post_meas, initial_stats), - keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, - ), - ) + if post_meas: + post.insert( + 0, + UpdateStats( + StatsCalculator(post_meas, initial_stats), + keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, + ), + ) + if fixed_dataset_stats: prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats)) post.insert(0, AddKnownDatasetStats(fixed_dataset_stats)) From fac789fa321f4e1c149ca871ee8331ef074cf669 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 11:42:48 +0200 Subject: [PATCH 183/244] fix some tests --- bioimageio/core/stat_calculators.py | 8 ++++---- bioimageio/core/tensor.py | 25 +++++++++++++++++-------- tests/test_stat_calculators.py | 6 ++---- tests/test_stat_measures.py | 12 ++++++++---- tests/test_tensor.py | 2 +- 5 files changed, 32 insertions(+), 21 deletions(-) diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 9d6717e4..9845138e 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -105,7 +105,7 @@ def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): mean_old = self._mean self._n = n_a + n_b self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n - assert self._mean.dtype == np.float64 + assert self._mean.dtype == "float64" def finalize(self) -> Dict[DatasetMean, MeasureValue]: if self._mean is None: @@ -169,10 +169,10 @@ def update(self, sample: Sample): m2_a = self._m2 self._n = n = n_a + n_b self._mean = (n_a * mean_a + n_b * mean_b) / n - assert self._mean.dtype == np.float64 + assert self._mean.dtype == "float64" d = mean_b - mean_a self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n - assert self._m2.dtype == np.float64 + assert self._m2.dtype == "float64" def finalize( self, @@ -252,7 +252,7 @@ def update(self, sample: Sample): self._estimates = (self._n * self._estimates + n * sample_estimates) / ( self._n + n ) - assert self._estimates.dtype == np.float64 + assert self._estimates.dtype == "float64" self._n += n diff --git a/bioimageio/core/tensor.py b/bioimageio/core/tensor.py index 2ed28424..c93bd31a 100644 --- a/bioimageio/core/tensor.py +++ b/bioimageio/core/tensor.py @@ -1,5 +1,6 @@ from __future__ import annotations +import collections.abc from typing import ( TYPE_CHECKING, Any, @@ -65,17 +66,25 @@ def __init__( def __array__(self, dtype: DTypeLike = None): return np.asarray(self._data, dtype=dtype) - def __getitem__(self, key: PerAxis[Union[SliceInfo, slice, int]]) -> Self: - key = { - a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) - for a, s in key.items() - } + def __getitem__( + self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]] + ) -> Self: + if isinstance(key, SliceInfo): + key = slice(*key) + elif isinstance(key, collections.abc.Mapping): + key = { + a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) + for a, s in key.items() + } return self.__class__.from_xarray(self._data[key]) def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None: key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} self._data[key] = value._data + def __len__(self) -> int: + return len(self.data) + def _iter(self: Any) -> Iterator[Any]: for n in range(len(self)): yield self[n] @@ -290,13 +299,13 @@ def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: - return self.__class__.from_xarray(self._data.mean(dims=dim)) + return self.__class__.from_xarray(self._data.mean(dim=dim)) def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: - return self.__class__.from_xarray(self._data.std(dims=dim)) + return self.__class__.from_xarray(self._data.std(dim=dim)) def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: - return self.__class__.from_xarray(self._data.var(dims=dim)) + return self.__class__.from_xarray(self._data.var(dim=dim)) def pad( self, diff --git a/tests/test_stat_calculators.py b/tests/test_stat_calculators.py index bd93b282..b4513e5b 100644 --- a/tests/test_stat_calculators.py +++ b/tests/test_stat_calculators.py @@ -18,11 +18,9 @@ def create_random_dataset(tid: MemberId, axes: Tuple[AxisId, ...]): n = 3 - sizes = list(range(n, len(axes) + 1)) + sizes = list(range(n, len(axes) + n)) data = np.asarray(np.random.rand(*sizes)) - ds = [ - Sample(data={tid: Tensor(data[i : i + 1], dims=axes, id=tid)}) for i in range(n) - ] + ds = [Sample(members={tid: Tensor(data[i : i + 1], dims=axes)}) for i in range(n)] return Tensor(data, dims=axes), ds diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index 53de1017..2c3bc266 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -37,9 +37,9 @@ def test_individual_normal_measure( ) expected = getattr(data, name)(dim=axes) - sample = Sample(data={data_id: data}) + sample = Sample(members={data_id: data}) actual = measure.compute(sample) - xr.testing.assert_allclose(expected, actual) + xr.testing.assert_allclose(expected.data, actual.data) @pytest.mark.parametrize("axes", [None, (AxisId("x"), AxisId("y"))]) @@ -56,7 +56,11 @@ def test_individual_percentile_measure(axes: Optional[Tuple[AxisId, ...]]): data = Tensor( np.random.random((5, 6, 3)), dims=(AxisId("x"), AxisId("y"), AxisId("c")) ) - actual = calc.compute(Sample(data={tid: data})) + actual = calc.compute(Sample(members={tid: data})) for m in measures: expected = data.quantile(q=m.q, dim=m.axes) - xr.testing.assert_allclose(expected, actual[m]) + actual_data = actual[m] + if isinstance(actual_data, Tensor): + actual_data = actual_data.data + + xr.testing.assert_allclose(expected.data, actual_data) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 7f3d67cc..33163077 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -33,7 +33,7 @@ def test_crop_and_pad(): ) padded = tensor.pad({AxisId("x"): 7, AxisId("y"): (3, 3)}) cropped = padded.crop_to(tensor.sizes) - assert_equal(tensor, cropped) + assert_equal(tensor.data, cropped.data) def test_some_magic_ops(): From 7419eeba87670cf9030f3518f8e7b679cc23ecd2 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 13:09:24 +0200 Subject: [PATCH 184/244] fix some more tests --- bioimageio/core/stat_calculators.py | 10 +- tests/test_prediction.py | 437 ++++++++++++++-------------- tests/test_proc_ops.py | 68 +++-- tests/test_stat_calculators.py | 15 +- 4 files changed, 276 insertions(+), 254 deletions(-) diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 9845138e..4615b974 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -183,13 +183,15 @@ def finalize( assert self._m2 is not None var = self._m2 / self._n sqrt = np.sqrt(var) - assert isinstance(sqrt, xr.DataArray) + if isinstance(sqrt, (int, float)): + # var and mean are scalar tensors, let's keep it consistent + sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) + + assert isinstance(sqrt, Tensor), type(sqrt) return { DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, DatasetVar(member_id=self._member_id, axes=self._axes): var, - DatasetStd( - member_id=self._member_id, axes=self._axes - ): Tensor.from_xarray(sqrt), + DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, } diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 822641d7..de8b8062 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,227 +1,228 @@ -from pathlib import Path - -import imageio -import numpy as np -from numpy.testing import assert_array_almost_equal - -from bioimageio.spec import load_description -from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr_v0_4 -from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr_v0_4 -from bioimageio.spec.model.v0_5 import ModelDescr - - -def test_predict_image(any_model: Path, tmpdir: Path): - from bioimageio.core.prediction import predict_image - - spec = load_description(any_model) - assert isinstance(spec, ModelDescr) - inputs = spec.test_inputs - - outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))] - predict_image(any_model, inputs, outputs) - for out_path in outputs: - assert out_path.exists() - - result = [np.load(str(p)) for p in outputs] - expected = [np.load(str(p)) for p in spec.test_outputs] - for res, exp in zip(result, expected): - assert_array_almost_equal(res, exp, decimal=4) - - -def test_predict_image_with_weight_format( - unet2d_fixed_shape_or_not: Path, tmpdir: Path -): - from bioimageio.core.prediction import predict_image - - spec = load_description(unet2d_fixed_shape_or_not) - assert isinstance(spec, Model) - inputs = spec.test_inputs - - outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))] - predict_image( - unet2d_fixed_shape_or_not, inputs, outputs, weight_format="pytorch_state_dict" - ) - for out_path in outputs: - assert out_path.exists() - - result = [np.load(str(p)) for p in outputs] - expected = [np.load(str(p)) for p in spec.test_outputs] - for res, exp in zip(result, expected): - assert_array_almost_equal(res, exp, decimal=4) - - -def _test_predict_with_padding(any_model: Path, tmp_path: Path): - from bioimageio.core.digest_spec import get_test_inputs - from bioimageio.core.prediction import predict_image - - model = load_description(any_model) - assert isinstance(model, (ModelDescr_v0_4, ModelDescr)) - - input_spec, output_spec = model.inputs[0], model.outputs[0] - channel_axis = ( - "c" - if isinstance(input_spec, InputTensorDescr_v0_4) - else [a.id for a in input_spec.axes][0] - ) - channel_first = channel_axis == 1 - - # TODO: check more tensors - image = get_test_inputs(model)[0] - - if isinstance(output_spec.shape, list): - n_channels = output_spec.shape[channel_axis] - else: - scale = output_spec.shape.scale[channel_axis] - offset = output_spec.shape.offset[channel_axis] - in_channels = 1 - n_channels = int(2 * offset + scale * in_channels) - - # write the padded image - image = image[3:-2, 1:-12] - in_path = tmp_path / "in.tif" - out_path = tmp_path / "out.tif" - imageio.imwrite(in_path, image) - - if hasattr(output_spec.shape, "scale"): - scale = dict(zip(output_spec.axes, output_spec.shape.scale)) - offset = dict(zip(output_spec.axes, output_spec.shape.offset)) - spatial_axes = [ax for ax in output_spec.axes if ax in "xyz"] - network_resizes = any( - sc != 1 for ax, sc in scale.items() if ax in spatial_axes - ) or any(off != 0 for ax, off in offset.items() if ax in spatial_axes) - else: - network_resizes = False - - if network_resizes: - exp_shape = tuple( - int(sh * scale[ax] + 2 * offset[ax]) - for sh, ax in zip(image.shape, spatial_axes) - ) - else: - exp_shape = image.shape - - def check_result(): - if n_channels == 1: - assert out_path.exists() - res = imageio.imread(out_path) - assert res.shape == exp_shape - else: - path = str(out_path) - for c in range(n_channels): - channel_out_path = Path(path.replace(".tif", f"-c{c}.tif")) - assert channel_out_path.exists() - res = imageio.imread(channel_out_path) - assert res.shape == exp_shape - - # test with dynamic padding - predict_image( - any_model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"} - ) - check_result() - - # test with fixed padding - predict_image( - any_model, - in_path, - out_path, - padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"}, - ) - check_result() - - # test with automated padding - predict_image(any_model, in_path, out_path, padding=True) - check_result() - - -# prediction with padding with the parameters above may not be suited for any model -# so we only run it for the pytorch unet2d here -def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path): - _test_predict_with_padding(unet2d_fixed_shape_or_not, tmp_path) - - -# and with different output shape -def test_predict_image_with_padding_diff_output_shape( - unet2d_diff_output_shape, tmp_path -): - _test_predict_with_padding(unet2d_diff_output_shape, tmp_path) - - -def test_predict_image_with_padding_channel_last(stardist, tmp_path): - _test_predict_with_padding(stardist, tmp_path) - - -def _test_predict_image_with_tiling(model: Path, tmp_path: Path, exp_mean_deviation): - from bioimageio.core.prediction import predict_image - - spec = load_description(model) - assert isinstance(spec, Model) - inputs = spec.test_inputs - assert len(inputs) == 1 - exp = np.load(str(spec.test_outputs[0])) - - out_path = tmp_path.with_suffix(".npy") - - def check_result(): - assert out_path.exists() - res = np.load(out_path) - assert res.shape == exp.shape - # check that the mean deviation is smaller than the expected value - # note that we can't use array_almost_equal here, because the numerical differences - # between tiled and normal prediction are too large - mean_deviation = np.abs(res - exp).mean() - assert mean_deviation <= exp_mean_deviation - - # with tiling config - tiling = {"halo": {"x": 32, "y": 32}, "tile": {"x": 256, "y": 256}} - predict_image(model, inputs, [out_path], tiling=tiling) - check_result() - - # with tiling determined from spec - predict_image(model, inputs, [out_path], tiling=True) - check_result() - - -# prediction with tiling with the parameters above may not be suited for any model -# so we only run it for the pytorch unet2d here -def test_predict_image_with_tiling_1(unet2d_nuclei_broad_model: Path, tmp_path: Path): - _test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path, 0.012) - - -def test_predict_image_with_tiling_2(unet2d_diff_output_shape: Path, tmp_path: Path): - _test_predict_image_with_tiling(unet2d_diff_output_shape, tmp_path, 0.06) - - -def test_predict_image_with_tiling_3(shape_change_model: Path, tmp_path: Path): - _test_predict_image_with_tiling(shape_change_model, tmp_path, 0.012) +# TODO: update +# from pathlib import Path + +# import imageio +# import numpy as np +# from numpy.testing import assert_array_almost_equal + +# from bioimageio.spec import load_description +# from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr_v0_4 +# from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr_v0_4 +# from bioimageio.spec.model.v0_5 import ModelDescr + + +# def test_predict_image(any_model: Path, tmpdir: Path): +# from bioimageio.core.prediction import predict_image + +# spec = load_description(any_model) +# assert isinstance(spec, ModelDescr) +# inputs = spec.test_inputs + +# outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))] +# predict_image(any_model, inputs, outputs) +# for out_path in outputs: +# assert out_path.exists() + +# result = [np.load(str(p)) for p in outputs] +# expected = [np.load(str(p)) for p in spec.test_outputs] +# for res, exp in zip(result, expected): +# assert_array_almost_equal(res, exp, decimal=4) + + +# def test_predict_image_with_weight_format( +# unet2d_fixed_shape_or_not: Path, tmpdir: Path +# ): +# from bioimageio.core.prediction import predict_image + +# spec = load_description(unet2d_fixed_shape_or_not) +# assert isinstance(spec, Model) +# inputs = spec.test_inputs + +# outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))] +# predict_image( +# unet2d_fixed_shape_or_not, inputs, outputs, weight_format="pytorch_state_dict" +# ) +# for out_path in outputs: +# assert out_path.exists() + +# result = [np.load(str(p)) for p in outputs] +# expected = [np.load(str(p)) for p in spec.test_outputs] +# for res, exp in zip(result, expected): +# assert_array_almost_equal(res, exp, decimal=4) + + +# def _test_predict_with_padding(any_model: Path, tmp_path: Path): +# from bioimageio.core.digest_spec import get_test_inputs +# from bioimageio.core.prediction import predict_image + +# model = load_description(any_model) +# assert isinstance(model, (ModelDescr_v0_4, ModelDescr)) + +# input_spec, output_spec = model.inputs[0], model.outputs[0] +# channel_axis = ( +# "c" +# if isinstance(input_spec, InputTensorDescr_v0_4) +# else [a.id for a in input_spec.axes][0] +# ) +# channel_first = channel_axis == 1 + +# # TODO: check more tensors +# image = get_test_inputs(model)[0] + +# if isinstance(output_spec.shape, list): +# n_channels = output_spec.shape[channel_axis] +# else: +# scale = output_spec.shape.scale[channel_axis] +# offset = output_spec.shape.offset[channel_axis] +# in_channels = 1 +# n_channels = int(2 * offset + scale * in_channels) + +# # write the padded image +# image = image[3:-2, 1:-12] +# in_path = tmp_path / "in.tif" +# out_path = tmp_path / "out.tif" +# imageio.imwrite(in_path, image) + +# if hasattr(output_spec.shape, "scale"): +# scale = dict(zip(output_spec.axes, output_spec.shape.scale)) +# offset = dict(zip(output_spec.axes, output_spec.shape.offset)) +# spatial_axes = [ax for ax in output_spec.axes if ax in "xyz"] +# network_resizes = any( +# sc != 1 for ax, sc in scale.items() if ax in spatial_axes +# ) or any(off != 0 for ax, off in offset.items() if ax in spatial_axes) +# else: +# network_resizes = False + +# if network_resizes: +# exp_shape = tuple( +# int(sh * scale[ax] + 2 * offset[ax]) +# for sh, ax in zip(image.shape, spatial_axes) +# ) +# else: +# exp_shape = image.shape + +# def check_result(): +# if n_channels == 1: +# assert out_path.exists() +# res = imageio.imread(out_path) +# assert res.shape == exp_shape +# else: +# path = str(out_path) +# for c in range(n_channels): +# channel_out_path = Path(path.replace(".tif", f"-c{c}.tif")) +# assert channel_out_path.exists() +# res = imageio.imread(channel_out_path) +# assert res.shape == exp_shape + +# # test with dynamic padding +# predict_image( +# any_model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"} +# ) +# check_result() + +# # test with fixed padding +# predict_image( +# any_model, +# in_path, +# out_path, +# padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"}, +# ) +# check_result() + +# # test with automated padding +# predict_image(any_model, in_path, out_path, padding=True) +# check_result() + + +# # prediction with padding with the parameters above may not be suited for any model +# # so we only run it for the pytorch unet2d here +# def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path): +# _test_predict_with_padding(unet2d_fixed_shape_or_not, tmp_path) + + +# # and with different output shape +# def test_predict_image_with_padding_diff_output_shape( +# unet2d_diff_output_shape, tmp_path +# ): +# _test_predict_with_padding(unet2d_diff_output_shape, tmp_path) + + +# def test_predict_image_with_padding_channel_last(stardist, tmp_path): +# _test_predict_with_padding(stardist, tmp_path) + + +# def _test_predict_image_with_tiling(model: Path, tmp_path: Path, exp_mean_deviation): +# from bioimageio.core.prediction import predict_image + +# spec = load_description(model) +# assert isinstance(spec, Model) +# inputs = spec.test_inputs +# assert len(inputs) == 1 +# exp = np.load(str(spec.test_outputs[0])) + +# out_path = tmp_path.with_suffix(".npy") + +# def check_result(): +# assert out_path.exists() +# res = np.load(out_path) +# assert res.shape == exp.shape +# # check that the mean deviation is smaller than the expected value +# # note that we can't use array_almost_equal here, because the numerical differences +# # between tiled and normal prediction are too large +# mean_deviation = np.abs(res - exp).mean() +# assert mean_deviation <= exp_mean_deviation + +# # with tiling config +# tiling = {"halo": {"x": 32, "y": 32}, "tile": {"x": 256, "y": 256}} +# predict_image(model, inputs, [out_path], tiling=tiling) +# check_result() + +# # with tiling determined from spec +# predict_image(model, inputs, [out_path], tiling=True) +# check_result() + + +# # prediction with tiling with the parameters above may not be suited for any model +# # so we only run it for the pytorch unet2d here +# def test_predict_image_with_tiling_1(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# _test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path, 0.012) + + +# def test_predict_image_with_tiling_2(unet2d_diff_output_shape: Path, tmp_path: Path): +# _test_predict_image_with_tiling(unet2d_diff_output_shape, tmp_path, 0.06) + + +# def test_predict_image_with_tiling_3(shape_change_model: Path, tmp_path: Path): +# _test_predict_image_with_tiling(shape_change_model, tmp_path, 0.012) -def test_predict_image_with_tiling_channel_last(stardist: Path, tmp_path: Path): - _test_predict_image_with_tiling(stardist, tmp_path, 0.13) +# def test_predict_image_with_tiling_channel_last(stardist: Path, tmp_path: Path): +# _test_predict_image_with_tiling(stardist, tmp_path, 0.13) -def test_predict_image_with_tiling_fixed_output_shape( - unet2d_fixed_shape: Path, tmp_path: Path -): - _test_predict_image_with_tiling(unet2d_fixed_shape, tmp_path, 0.025) +# def test_predict_image_with_tiling_fixed_output_shape( +# unet2d_fixed_shape: Path, tmp_path: Path +# ): +# _test_predict_image_with_tiling(unet2d_fixed_shape, tmp_path, 0.025) -def test_predict_images(unet2d_nuclei_broad_model: Path, tmp_path: Path): - from bioimageio.core.prediction import predict_images +# def test_predict_images(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# from bioimageio.core.prediction import predict_images - n_images = 5 - shape = (256, 256) +# n_images = 5 +# shape = (256, 256) - in_paths = [] - out_paths = [] - for i in range(n_images): - in_path = tmp_path / f"in{i}.tif" - im = np.random.randint(0, 255, size=shape).astype("uint8") - imageio.imwrite(in_path, im) - in_paths.append(in_path) - out_paths.append(tmp_path / f"out{i}.tif") - predict_images(unet2d_nuclei_broad_model, in_paths, out_paths) +# in_paths = [] +# out_paths = [] +# for i in range(n_images): +# in_path = tmp_path / f"in{i}.tif" +# im = np.random.randint(0, 255, size=shape).astype("uint8") +# imageio.imwrite(in_path, im) +# in_paths.append(in_path) +# out_paths.append(tmp_path / f"out{i}.tif") +# predict_images(unet2d_nuclei_broad_model, in_paths, out_paths) - for outp in out_paths: - assert outp.exists() - out = imageio.imread(outp) - assert out.shape == shape +# for outp in out_paths: +# assert outp.exists() +# out = imageio.imread(outp) +# assert out.shape == shape diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py index 431f2a79..e56fa554 100644 --- a/tests/test_proc_ops.py +++ b/tests/test_proc_ops.py @@ -23,13 +23,13 @@ def test_scale_linear(tid: MemberId): offset = xr.DataArray([1, 2, 42], dims=("c")) gain = xr.DataArray([1, 2, 3], dims=("c")) data = xr.DataArray(np.arange(6).reshape((1, 2, 3)), dims=("x", "y", "c")) - sample = Sample(data={tid: data}) + sample = Sample(members={tid: Tensor.from_xarray(data)}) op = ScaleLinear(input=tid, output=tid, offset=offset, gain=gain) op(sample) expected = xr.DataArray(np.array([[[1, 4, 48], [4, 10, 57]]]), dims=("x", "y", "c")) - xr.testing.assert_allclose(expected, sample.data[tid]) + xr.testing.assert_allclose(expected, sample.members[tid].data) def test_scale_linear_no_channel(tid: MemberId): @@ -37,11 +37,11 @@ def test_scale_linear_no_channel(tid: MemberId): op = ScaleLinear(tid, tid, offset=1, gain=2) data = xr.DataArray(np.arange(6).reshape(2, 3), dims=("x", "y")) - sample = Sample(data={tid: data}) + sample = Sample(members={tid: Tensor.from_xarray(data)}) op(sample) expected = xr.DataArray(np.array([[1, 3, 5], [7, 9, 11]]), dims=("x", "y")) - xr.testing.assert_allclose(expected, sample.data[tid]) + xr.testing.assert_allclose(expected, sample.members[tid].data) T = TypeVar("T") @@ -56,7 +56,7 @@ def test_zero_mean_unit_variance(tid: MemberId): from bioimageio.core.proc_ops import ZeroMeanUnitVariance data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) - sample = Sample(data={tid: data}) + sample = Sample(members={tid: Tensor.from_xarray(data)}) m = SampleMean(tid) std = SampleStd(tid) op = ZeroMeanUnitVariance(tid, tid, m, std) @@ -74,7 +74,7 @@ def test_zero_mean_unit_variance(tid: MemberId): ), dims=("x", "y"), ) - xr.testing.assert_allclose(expected, sample.data[tid]) + xr.testing.assert_allclose(expected, sample.members[tid].data) def test_zero_mean_unit_variance_fixed(tid: MemberId): @@ -99,9 +99,9 @@ def test_zero_mean_unit_variance_fixed(tid: MemberId): ), dims=("b", "c", "x"), ) - sample = Sample(data={tid: data}) + sample = Sample(members={tid: Tensor.from_xarray(data)}) op(sample) - xr.testing.assert_allclose(expected, sample.data[tid]) + xr.testing.assert_allclose(expected, sample.members[tid].data) def test_zero_mean_unit_across_axes(tid: MemberId): @@ -115,14 +115,14 @@ def test_zero_mean_unit_across_axes(tid: MemberId): SampleMean(tid, (AxisId("x"), AxisId("y"))), SampleStd(tid, (AxisId("x"), AxisId("y"))), ) - sample = Sample(data={tid: data}) + sample = Sample(members={tid: Tensor.from_xarray(data)}) sample.stat = compute_measures(op.required_measures, [sample]) expected = xr.concat( [(data[i : i + 1] - data[i].mean()) / data[i].std() for i in range(2)], dim="c" ) op(sample) - xr.testing.assert_allclose(expected, sample.data[tid]) + xr.testing.assert_allclose(expected, sample.members[tid].data) def test_zero_mean_unit_variance_fixed2(tid: MemberId): @@ -135,10 +135,10 @@ def test_zero_mean_unit_variance_fixed2(tid: MemberId): op = FixedZeroMeanUnitVariance(tid, tid, mean=mean, std=std, eps=eps) data = xr.DataArray(np_data, dims=("x", "y")) - sample = Sample(data={tid: data}) + sample = Sample(members={tid: Tensor.from_xarray(data)}) expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y")) op(sample) - xr.testing.assert_allclose(expected, sample.data[tid]) + xr.testing.assert_allclose(expected, sample.members[tid].data) def test_binarize(tid: MemberId): @@ -146,11 +146,11 @@ def test_binarize(tid: MemberId): op = Binarize(tid, tid, threshold=14) data = xr.DataArray(np.arange(30).reshape((2, 3, 5)), dims=("x", "y", "c")) - sample = Sample(data={tid: data}) + sample = Sample(members={tid: Tensor.from_xarray(data)}) expected = xr.zeros_like(data) expected[{"x": slice(1, None)}] = 1 op(sample) - xr.testing.assert_allclose(expected, sample.data[tid]) + xr.testing.assert_allclose(expected, sample.members[tid].data) def test_binarize2(tid: MemberId): @@ -164,10 +164,10 @@ def test_binarize2(tid: MemberId): threshold = 0.5 exp = xr.DataArray(np_data > threshold, dims=axes) - sample = Sample(data={tid: data}) + sample = Sample(members={tid: Tensor.from_xarray(data)}) binarize = Binarize(tid, tid, threshold=threshold) binarize(sample) - xr.testing.assert_allclose(exp, sample.data[tid]) + xr.testing.assert_allclose(exp, sample.members[tid].data) def test_clip(tid: MemberId): @@ -175,20 +175,20 @@ def test_clip(tid: MemberId): op = Clip(tid, tid, min=3, max=5) data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) - sample = Sample(data={tid: data}) + sample = Sample(members={tid: Tensor.from_xarray(data)}) expected = xr.DataArray( np.array([[3, 3, 3], [3, 4, 5], [5, 5, 5]]), dims=("x", "y") ) op(sample) - xr.testing.assert_equal(expected, sample.data[tid]) + xr.testing.assert_equal(expected, sample.members[tid].data) def test_combination_of_op_steps_with_dims_specified(tid: MemberId): from bioimageio.core.proc_ops import ZeroMeanUnitVariance data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) - sample = Sample(data={tid: data}) + sample = Sample(members={tid: Tensor.from_xarray(data)}) op = ZeroMeanUnitVariance( tid, tid, @@ -222,7 +222,7 @@ def test_combination_of_op_steps_with_dims_specified(tid: MemberId): ) op(sample) - xr.testing.assert_allclose(expected, sample.data[tid]) + xr.testing.assert_allclose(expected, sample.members[tid].data) @pytest.mark.parametrize( @@ -244,10 +244,15 @@ def test_scale_mean_variance(tid: MemberId, axes: Optional[Tuple[AxisId, ...]]): ref_data = xr.DataArray((np_data * 2) + 3, dims=ipt_axes) op = ScaleMeanVariance(tid, tid, reference_tensor=MemberId("ref_name"), axes=axes) - sample = Sample(data={tid: ipt_data, MemberId("ref_name"): ref_data}) + sample = Sample( + members={ + tid: Tensor.from_xarray(ipt_data), + MemberId("ref_name"): Tensor.from_xarray(ref_data), + } + ) sample.stat = compute_measures(op.required_measures, [sample]) op(sample) - xr.testing.assert_allclose(ref_data, sample.data[tid]) + xr.testing.assert_allclose(ref_data, sample.members[tid].data) @pytest.mark.parametrize( @@ -269,17 +274,22 @@ def test_scale_mean_variance_per_channel(tid: MemberId, axes_str: Optional[str]) ref_data = xr.DataArray(np_ref_data, dims=ipt_axes) op = ScaleMeanVariance(tid, tid, reference_tensor=MemberId("ref_name"), axes=axes) - sample = Sample(data={tid: ipt_data, MemberId("ref_name"): ref_data}) + sample = Sample( + members={ + tid: Tensor.from_xarray(ipt_data), + MemberId("ref_name"): Tensor.from_xarray(ref_data), + } + ) sample.stat = compute_measures(op.required_measures, [sample]) op(sample) if axes is not None and AxisId("c") not in axes: # mean,std per channel should match exactly - xr.testing.assert_allclose(ref_data, sample.data[tid]) + xr.testing.assert_allclose(ref_data, sample.members[tid].data) else: # mean,std across channels should not match with pytest.raises(AssertionError): - xr.testing.assert_allclose(ref_data, sample.data[tid]) + xr.testing.assert_allclose(ref_data, sample.members[tid].data) def test_scale_range(tid: MemberId): @@ -288,7 +298,7 @@ def test_scale_range(tid: MemberId): op = ScaleRange(tid, tid) np_data = np.arange(9).reshape(3, 3).astype("float32") data = xr.DataArray(np_data, dims=("x", "y")) - sample = Sample(data={tid: data}) + sample = Sample(members={tid: Tensor.from_xarray(data)}) sample.stat = compute_measures(op.required_measures, [sample]) eps = 1.0e-6 @@ -298,7 +308,7 @@ def test_scale_range(tid: MemberId): op(sample) # NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct - np.testing.assert_allclose(expected, sample.data[tid]) + np.testing.assert_allclose(expected, sample.members[tid].data) def test_scale_range_axes(tid: MemberId): @@ -331,9 +341,9 @@ def test_sigmoid(tid: MemberId): axes = ("c", "y", "x") np_data = np.random.rand(*shape) data = xr.DataArray(np_data, dims=axes) - sample = Sample(data={tid: data}) + sample = Sample(members={tid: Tensor.from_xarray(data)}) sigmoid = Sigmoid(tid, tid) sigmoid(sample) exp = xr.DataArray(1.0 / (1 + np.exp(-np_data)), dims=axes) - xr.testing.assert_allclose(exp, sample.data[tid]) + xr.testing.assert_allclose(exp, sample.members[tid].data) diff --git a/tests/test_stat_calculators.py b/tests/test_stat_calculators.py index b4513e5b..0a642168 100644 --- a/tests/test_stat_calculators.py +++ b/tests/test_stat_calculators.py @@ -49,6 +49,15 @@ def test_mean_var_std_calculator(axes: Union[None, str, Tuple[str, ...]]): actual_var = actual[DatasetVar(tid, axes=axes)] actual_std = actual[DatasetStd(tid, axes=axes)] - assert_allclose(actual_mean, expected_mean) - assert_allclose(actual_var, expected_var) - assert_allclose(actual_std, expected_std) + assert_allclose( + actual_mean if isinstance(actual_mean, (int, float)) else actual_mean.data, + expected_mean.data, + ) + assert_allclose( + actual_var if isinstance(actual_var, (int, float)) else actual_var.data, + expected_var.data, + ) + assert_allclose( + actual_std if isinstance(actual_std, (int, float)) else actual_std.data, + expected_std.data, + ) From 25d889172f8f427dbb10aa19460759023cdb64a5 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 13:10:11 +0200 Subject: [PATCH 185/244] use quantile in names (not percentile) --- bioimageio/core/proc_ops.py | 22 ++++++++---------- bioimageio/core/stat_calculators.py | 8 +++---- bioimageio/core/stat_measures.py | 10 ++++----- tests/test_proc_ops.py | 35 ++++++++++++++++++++--------- tests/test_stat_measures.py | 4 ++-- 5 files changed, 44 insertions(+), 35 deletions(-) diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 9880a818..ee99b80e 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -32,7 +32,7 @@ Measure, MeasureValue, SampleMean, - SamplePercentile, + SampleQuantile, SampleStd, Stat, StdMeasure, @@ -378,21 +378,17 @@ def _get_axes( @dataclass class ScaleRange(_SimpleOperator): - lower_percentile: InitVar[Optional[Union[SamplePercentile, DatasetPercentile]]] = ( - None - ) - upper_percentile: InitVar[Optional[Union[SamplePercentile, DatasetPercentile]]] = ( - None - ) - lower: Union[SamplePercentile, DatasetPercentile] = field(init=False) - upper: Union[SamplePercentile, DatasetPercentile] = field(init=False) + lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None + upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None + lower: Union[SampleQuantile, DatasetPercentile] = field(init=False) + upper: Union[SampleQuantile, DatasetPercentile] = field(init=False) eps: float = 1e-6 def __post_init__( self, - lower_percentile: Optional[Union[SamplePercentile, DatasetPercentile]], - upper_percentile: Optional[Union[SamplePercentile, DatasetPercentile]], + lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], + upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], ): if lower_percentile is None: tid = self.input if upper_percentile is None else upper_percentile.member_id @@ -429,7 +425,7 @@ def from_proc_descr( if axes is None or AxisId("batch") in axes: Percentile = DatasetPercentile else: - Percentile = SamplePercentile + Percentile = SampleQuantile return cls( input=member_id, @@ -467,7 +463,7 @@ class Sigmoid(_SimpleOperator): """1 / (1 + e^(-input)).""" def _apply(self, input: Tensor, stat: Stat) -> Tensor: - return 1.0 / (1.0 + np.exp(-input)) # type: ignore + return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims) @property def required_measures(self) -> Collection[Measure]: diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index 4615b974..afd0ce24 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -39,7 +39,7 @@ MeasureValue, SampleMean, SampleMeasure, - SamplePercentile, + SampleQuantile, SampleStd, SampleVar, ) @@ -210,11 +210,11 @@ def __init__( self._axes = None if axes is None else tuple(axes) self._member_id = member_id - def compute(self, sample: Sample) -> Dict[SamplePercentile, MeasureValue]: + def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]: tensor = sample.members[self._member_id] ps = tensor.quantile(self._qs, dim=self._axes) return { - SamplePercentile(q=q, axes=self._axes, member_id=self._member_id): p + SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p for q, p in zip(self._qs, ps) } @@ -507,7 +507,7 @@ def get_measure_calculators( } ) assert rm in required_dataset_mean_var_std - elif isinstance(rm, SamplePercentile): + elif isinstance(rm, SampleQuantile): required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add( rm.q ) diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index ec25b954..e581916f 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -114,7 +114,7 @@ def __post_init__(self): @dataclass(frozen=True) -class _Percentile: +class _Quantile: q: float axes: Optional[Tuple[AxisId, ...]] = None """`axes` to reduce""" @@ -125,7 +125,7 @@ def __post_init__(self): @dataclass(frozen=True) -class SamplePercentile(_Percentile, SampleMeasureBase): +class SampleQuantile(_Quantile, SampleMeasureBase): """The `n`th percentile of a single tensor""" def compute(self, sample: SampleLike) -> MeasureValue: @@ -138,7 +138,7 @@ def __post_init__(self): @dataclass(frozen=True) -class DatasetPercentile(_Percentile, DatasetMeasureBase): +class DatasetPercentile(_Quantile, DatasetMeasureBase): """The `n`th percentile across multiple samples""" def __post_init__(self): @@ -146,7 +146,7 @@ def __post_init__(self): assert self.axes is None or AxisId("batch") in self.axes -SampleMeasure = Union[SampleMean, SampleStd, SampleVar, SamplePercentile] +SampleMeasure = Union[SampleMean, SampleStd, SampleVar, SampleQuantile] DatasetMeasure = Union[DatasetMean, DatasetStd, DatasetVar, DatasetPercentile] Measure = Union[SampleMeasure, DatasetMeasure] Stat = Dict[Measure, MeasureValue] @@ -154,7 +154,7 @@ def __post_init__(self): MeanMeasure = Union[SampleMean, DatasetMean] StdMeasure = Union[SampleStd, DatasetStd] VarMeasure = Union[SampleVar, DatasetVar] -PercentileMeasure = Union[SamplePercentile, DatasetPercentile] +PercentileMeasure = Union[SampleQuantile, DatasetPercentile] MeanMeasureT = TypeVar("MeanMeasureT", bound=MeanMeasure) StdMeasureT = TypeVar("StdMeasureT", bound=StdMeasure) VarMeasureT = TypeVar("VarMeasureT", bound=VarMeasure) diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py index e56fa554..ce8d04e4 100644 --- a/tests/test_proc_ops.py +++ b/tests/test_proc_ops.py @@ -9,7 +9,8 @@ from bioimageio.core.common import MemberId from bioimageio.core.sample import Sample from bioimageio.core.stat_calculators import compute_measures -from bioimageio.core.stat_measures import SampleMean, SamplePercentile, SampleStd +from bioimageio.core.stat_measures import SampleMean, SampleQuantile, SampleStd +from bioimageio.core.tensor import Tensor @pytest.fixture(scope="module") @@ -314,24 +315,36 @@ def test_scale_range(tid: MemberId): def test_scale_range_axes(tid: MemberId): from bioimageio.core.proc_ops import ScaleRange - lower_percentile = SamplePercentile(tid, 1, axes=(AxisId("x"), AxisId("y"))) - upper_percentile = SamplePercentile(tid, 100, axes=(AxisId("x"), AxisId("y"))) - op = ScaleRange(tid, tid, lower_percentile, upper_percentile) + eps = 1.0e-6 + + lower_quantile = SampleQuantile(tid, 0.1, axes=(AxisId("x"), AxisId("y"))) + upper_quantile = SampleQuantile(tid, 0.9, axes=(AxisId("x"), AxisId("y"))) + op = ScaleRange(tid, tid, lower_quantile, upper_quantile, eps=eps) np_data = np.arange(18).reshape((2, 3, 3)).astype("float32") - data = xr.DataArray(np_data, dims=("c", "x", "y")) - sample = Sample(data={tid: data}) + data = Tensor.from_xarray(xr.DataArray(np_data, dims=("c", "x", "y"))) + sample = Sample(members={tid: data}) + + p_low_direct = lower_quantile.compute(sample) + p_up_direct = upper_quantile.compute(sample) + + p_low_expected = np.quantile(np_data, lower_quantile.q, axis=(1, 2), keepdims=True) + p_up_expected = np.quantile(np_data, upper_quantile.q, axis=(1, 2), keepdims=True) + + np.testing.assert_allclose(p_low_expected.squeeze(), p_low_direct) + np.testing.assert_allclose(p_up_expected.squeeze(), p_up_direct) + sample.stat = compute_measures(op.required_measures, [sample]) - eps = 1.0e-6 - p_low = np.percentile(np_data, lower_percentile.n, axis=(1, 2), keepdims=True) - p_up = np.percentile(np_data, upper_percentile.n, axis=(1, 2), keepdims=True) - exp_data = (np_data - p_low) / (p_up - p_low + eps) + np.testing.assert_allclose(p_low_expected.squeeze(), sample.stat[lower_quantile]) + np.testing.assert_allclose(p_up_expected.squeeze(), sample.stat[upper_quantile]) + + exp_data = (np_data - p_low_expected) / (p_up_expected - p_low_expected + eps) expected = xr.DataArray(exp_data, dims=("c", "x", "y")) op(sample) # NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct - np.testing.assert_allclose(expected, sample.data[tid]) + np.testing.assert_allclose(expected, sample.members[tid].data) def test_sigmoid(tid: MemberId): diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index 2c3bc266..54cca0de 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -13,7 +13,7 @@ SamplePercentilesCalculator, get_measure_calculators, ) -from bioimageio.core.stat_measures import SamplePercentile +from bioimageio.core.stat_measures import SampleQuantile from bioimageio.core.tensor import Tensor @@ -47,7 +47,7 @@ def test_individual_percentile_measure(axes: Optional[Tuple[AxisId, ...]]): qs = [0, 0.1, 0.5, 1.0] tid = MemberId("tensor") - measures = [SamplePercentile(member_id=tid, axes=axes, q=q) for q in qs] + measures = [SampleQuantile(member_id=tid, axes=axes, q=q) for q in qs] calcs, _ = get_measure_calculators(measures) assert len(calcs) == 1 calc = calcs[0] From 48df1c604a36bf91651f84ebad96849c1be39cf0 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 14:13:54 +0200 Subject: [PATCH 186/244] WIP almost all tests fixed --- bioimageio/core/_resource_tests.py | 51 ++++++++++++++++-------------- bioimageio/core/block.py | 2 +- bioimageio/core/block_meta.py | 19 +++++++---- bioimageio/core/digest_spec.py | 5 +++ bioimageio/core/sample.py | 14 +++++++- tests/test_digest_spec.py | 3 ++ 6 files changed, 62 insertions(+), 32 deletions(-) diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index ffb3949a..9c721c63 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -4,6 +4,7 @@ import numpy as np +from bioimageio.core.sample import Sample from bioimageio.spec import ( InvalidDescr, ResourceDescr, @@ -125,13 +126,14 @@ def _test_model_inference( with create_prediction_pipeline( bioimageio_model=model, devices=devices, weight_format=weight_format ) as prediction_pipeline: - results = prediction_pipeline.forward(*inputs) + results = prediction_pipeline.predict(inputs) - if len(results) != len(expected): - error = f"Expected {len(expected)} outputs, but got {len(results)}" + if len(results.members) != len(expected.members): + error = f"Expected {len(expected.members)} outputs, but got {len(results.members)}" else: - for res, exp in zip(results, expected): + for m, exp in expected.members.items(): + res = results.members.get(m) if res is None: error = "Output tensors for test case may not be None" break @@ -219,24 +221,26 @@ def get_ns(n: int): else: tested.add(hashable_target_size) - resized_test_inputs = [ - t.resize_to( - { - aid: s - for (tid, aid), s in input_target_sizes.items() - if tid == t_descr.id - }, - ) - for t, t_descr in zip(test_inputs, model.inputs) - ] - expected_output_shapes = [ - { + resized_test_inputs = Sample( + members={ + t.id: test_inputs.members[t.id].resize_to( + { + aid: s + for (tid, aid), s in input_target_sizes.items() + if tid == t.id + }, + ) + for t in model.inputs + } + ) + expected_output_shapes = { + t.id: { aid: s for (tid, aid), s in expected_output_sizes.items() - if tid == t_descr.id + if tid == t.id } - for t_descr in model.outputs - ] + for t in model.outputs + } yield n, batch_size, resized_test_inputs, expected_output_shapes try: @@ -247,15 +251,16 @@ def get_ns(n: int): ) as prediction_pipeline: for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): error: Optional[str] = None - results = prediction_pipeline.forward(*inputs) - if len(results) != len(exptected_output_shape): + result = prediction_pipeline.predict(inputs) + if len(result.members) != len(exptected_output_shape): error = ( f"Expected {len(exptected_output_shape)} outputs," - + f" but got {len(results)}" + + f" but got {len(result.members)}" ) else: - for res, exp in zip(results, exptected_output_shape): + for m, exp in exptected_output_shape.items(): + res = result.members.get(m) if res is None: error = "Output tensors may not be None for test case" break diff --git a/bioimageio/core/block.py b/bioimageio/core/block.py index c57d6955..232be870 100644 --- a/bioimageio/core/block.py +++ b/bioimageio/core/block.py @@ -38,6 +38,7 @@ def __init__( block_number: int, blocks_in_sample: int, ): + object.__setattr__(self, "data", data) super().__init__( sample_shape=data.tagged_shape, inner_slice=inner_slice, @@ -45,7 +46,6 @@ def __init__( block_number=block_number, blocks_in_sample=blocks_in_sample, ) - object.__setattr__(self, "data", data) @property def inner_data(self): diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py index d2e39be0..487c1b75 100644 --- a/bioimageio/core/block_meta.py +++ b/bioimageio/core/block_meta.py @@ -129,7 +129,7 @@ def __post_init__(self): a in self.inner_slice for a in self.halo ), "halo has axes not present in block" - object.__setattr__( #TODO: write as property + object.__setattr__( # TODO: write as property self, "shape", { @@ -141,12 +141,12 @@ def __post_init__(self): s <= self.sample_shape[a] for a, s in self.shape.items() ), "block larger than sample" - object.__setattr__( #TODO: write as property + object.__setattr__( # TODO: write as property self, "inner_shape", {a: s.stop - s.start for a, s in self.inner_slice.items()}, ) - object.__setattr__( #TODO: write as property + object.__setattr__( # TODO: write as property self, "outer_slice", { @@ -168,7 +168,7 @@ def __post_init__(self): for a in self.inner_slice }, ) - object.__setattr__( #TODO: write as property + object.__setattr__( # TODO: write as property self, "padding", { @@ -187,7 +187,7 @@ def __post_init__(self): for a in self.inner_slice }, ) - object.__setattr__( #TODO: write as property + object.__setattr__( # TODO: write as property self, "local_slice", { @@ -247,8 +247,8 @@ def split_shape_into_blocks( ) assert all(a in shape for a in halo), (tuple(shape), set(halo)) - # fill in default halo (0) and tile_size (tensor size) - halo = {a: Halo.create(h) for a, h in halo.items()} + # fill in default halo (0) and block axis length (from tensor shape) + halo = {a: Halo.create(halo.get(a, 0)) for a in shape} block_shape = {a: block_shape.get(a, s) for a, s in shape.items()} if stride is None: stride = {} @@ -305,6 +305,9 @@ def split_multiple_shapes_into_blocks( assert not ( missing := [t for t in block_shapes if t not in shapes] ), f"block shape specified for unknown tensors: {missing}" + if not block_shapes: + block_shapes = shapes + assert broadcast or not ( missing := [t for t in shapes if t not in block_shapes] ), f"no block shape specified for {missing} (set `broadcast` to True if these tensors should be repeated for each block)" @@ -330,6 +333,8 @@ def split_multiple_shapes_into_blocks( ) assert n_blocks[t] > 0 + assert len(blocks) > 0, blocks + assert len(n_blocks) > 0, n_blocks unique_n_blocks = set(n_blocks.values()) n = max(unique_n_blocks) if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py index 0dc92670..c21bc31a 100644 --- a/bioimageio/core/digest_spec.py +++ b/bioimageio/core/digest_spec.py @@ -301,7 +301,12 @@ def get_io_sample_block_metas( for t in model.outputs } input_halo = get_input_halo(model, output_halo) + + # TODO: fix output_sample_shape_data_dep + # (below only valid if input_sample_shape is a valid model input, + # which is not a valid assumption) output_sample_shape_data_dep = model.get_output_tensor_sizes(input_sample_shape) + output_sample_shape = { t: { a: -1 if isinstance(s, tuple) else s diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 8f92033f..0e66186e 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -2,7 +2,18 @@ from abc import abstractmethod from dataclasses import dataclass, field -from typing import Dict, Generic, Iterable, Optional, Tuple, TypeVar, Union +from itertools import islice +from typing import ( + Any, + Dict, + Generator, + Generic, + Iterable, + Optional, + Tuple, + TypeVar, + Union, +) import numpy as np from typing_extensions import Self @@ -79,6 +90,7 @@ def from_blocks( fill_value: float = float("nan"), ) -> Self: members: PerMember[Tensor] = {} + sample_blocks = list(iter(sample_blocks)) for member_blocks in sample_blocks: for m, block in member_blocks.blocks.items(): if m not in members: diff --git a/tests/test_digest_spec.py b/tests/test_digest_spec.py index eb810b4a..08022ab2 100644 --- a/tests/test_digest_spec.py +++ b/tests/test_digest_spec.py @@ -1,8 +1,11 @@ +import pytest + from bioimageio.spec import load_description from bioimageio.spec.model import v0_5 # TODO: don't just test with unet2d_nuclei_broad_model +@pytest.mark.skip("get_io_sample_block_metas needs improvements") def test_get_block_transform(unet2d_nuclei_broad_model: str): from bioimageio.core.axis import AxisId from bioimageio.core.common import MemberId From 2296d5e405e15e5dea863e2b8aed826094dfe0b1 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 14:35:06 +0200 Subject: [PATCH 187/244] add as_single_block --- bioimageio/core/sample.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 0e66186e..c6632733 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -82,6 +82,25 @@ def split_into_blocks( ) return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode) + def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): + if halo is None: + halo = {} + return SampleBlockWithOrigin( + sample_shape=self.shape, + blocks={ + m: Block( + data, + inner_slice={}, + halo=halo.get(m, {}), + block_number=1, + blocks_in_sample=1 + ) + for m, data in self.members.items() + }, + stat=self.stat, + origin=self, + ) + @classmethod def from_blocks( cls, From 5d94e5206b9c61b5328f75433c700ff1a0bb3f81 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 14:41:37 +0200 Subject: [PATCH 188/244] udpate create_sample_for_model --- bioimageio/core/digest_spec.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py index c21bc31a..23441b37 100644 --- a/bioimageio/core/digest_spec.py +++ b/bioimageio/core/digest_spec.py @@ -333,9 +333,9 @@ def get_io_sample_block_metas( def create_sample_for_model( - inputs: Sequence[NDArray[Any]], model: AnyModelDescr, stat: Optional[Stat] = None, + **inputs: NDArray[Any], ) -> Sample: """Create a sample from a single set of input(s) for a specific bioimage.io model @@ -349,20 +349,20 @@ def create_sample_for_model( f"Got {len(inputs)} inputs, but expected at most {len(model.inputs)}" ) - missing_inputs = model.inputs[len(inputs) :] - for missing in missing_inputs: - if isinstance(missing, v0_4.InputTensorDescr): - raise ValueError(f"Missing input tensor '{missing.name}'") - elif isinstance(missing, v0_5.InputTensorDescr): - if not missing.optional: - raise ValueError(f"Missing non-optional input tensor '{missing.id}'") - else: - assert_never(missing) + missing_inputs = { + get_member_id(ipt) + for ipt in model.inputs + if str(get_member_id(ipt) not in inputs) + and not (isinstance(ipt, v0_5.InputTensorDescr) and ipt.optional) + } + if missing_inputs: + raise ValueError(f"Missing non-optional input tensors {missing_inputs}") return Sample( members={ - get_member_id(ipt): Tensor.from_numpy(array, dims=get_axes_infos(ipt)) - for ipt, array in zip(model.inputs, inputs) + m: Tensor.from_numpy(inputs[str(m)], dims=get_axes_infos(ipt)) + for ipt in model.inputs + if str((m := get_member_id(ipt))) in inputs }, stat={} if stat is None else stat, ) From 768a83528cc7da572138b203363a990d8d78627e Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 14:42:16 +0200 Subject: [PATCH 189/244] update docstring --- bioimageio/core/digest_spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py index 23441b37..c79f09d5 100644 --- a/bioimageio/core/digest_spec.py +++ b/bioimageio/core/digest_spec.py @@ -340,9 +340,9 @@ def create_sample_for_model( """Create a sample from a single set of input(s) for a specific bioimage.io model Args: - inputs: the input(s) constituting a single sample. model: a bioimage.io model description stat: dictionary with sample and dataset statistics (may be updated in-place!) + inputs: the input(s) constituting a single sample. """ if len(inputs) > len(model.inputs): raise ValueError( From 89cfeda33c3a40a511376630471e7b8d6f061ead Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 14:44:56 +0200 Subject: [PATCH 190/244] fix create_sample_for_model --- bioimageio/core/digest_spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py index c79f09d5..12901fab 100644 --- a/bioimageio/core/digest_spec.py +++ b/bioimageio/core/digest_spec.py @@ -352,7 +352,7 @@ def create_sample_for_model( missing_inputs = { get_member_id(ipt) for ipt in model.inputs - if str(get_member_id(ipt) not in inputs) + if str(get_member_id(ipt)) not in inputs and not (isinstance(ipt, v0_5.InputTensorDescr) and ipt.optional) } if missing_inputs: From a3f1f9a999e2a0e18a188c15156fc274ef112420 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 14:47:55 +0200 Subject: [PATCH 191/244] fix as_single_block --- bioimageio/core/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index c6632733..bd5b74fd 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -90,7 +90,7 @@ def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): blocks={ m: Block( data, - inner_slice={}, + inner_slice={a: SliceInfo(0, s) for a, s in data.tagged_shape.items()}, halo=halo.get(m, {}), block_number=1, blocks_in_sample=1 From ed1477374053e155aa0948eb4a407469af33b5f3 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 14:52:24 +0200 Subject: [PATCH 192/244] fix BlockMeta --- bioimageio/core/block_meta.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py index 487c1b75..4940bb6a 100644 --- a/bioimageio/core/block_meta.py +++ b/bioimageio/core/block_meta.py @@ -133,7 +133,7 @@ def __post_init__(self): self, "shape", { - a: s.stop - s.start + sum(self.halo[a]) + a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) for a, s in self.inner_slice.items() }, ) @@ -154,15 +154,17 @@ def __post_init__(self): max( 0, min( - self.inner_slice[a].start - self.halo[a].left, + self.inner_slice[a].start + - (self.halo[a].left if a in self.halo else 0), self.sample_shape[a] - self.inner_shape[a] - - self.halo[a].left, + - (self.halo[a].left if a in self.halo else 0), ), ), min( self.sample_shape[a], - self.inner_slice[a].stop + self.halo[a].right, + self.inner_slice[a].stop + + (self.halo[a].right if a in self.halo else 0), ), ) for a in self.inner_slice From 8234d311beca4f555453e90c03490ec40623a03d Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 14:54:13 +0200 Subject: [PATCH 193/244] fix BlockMeta again --- bioimageio/core/block_meta.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py index 4940bb6a..ff7d38f3 100644 --- a/bioimageio/core/block_meta.py +++ b/bioimageio/core/block_meta.py @@ -177,12 +177,12 @@ def __post_init__(self): a: PadWidth( max( 0, - self.halo[a].left + (self.halo[a].left if a in self.halo else 0) - (self.inner_slice[a].start + self.outer_slice[a].start), ), max( 0, - self.halo[a].right + (self.halo[a].right if a in self.halo else 0) - (self.outer_slice[a].stop + self.inner_slice[a].stop), ), ) From 20b8bce06a57fe60f70bdd7cb7fce9e7a34ade68 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 9 Apr 2024 15:00:22 +0200 Subject: [PATCH 194/244] Treat halo correctly for single block --- bioimageio/core/block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioimageio/core/block.py b/bioimageio/core/block.py index 232be870..b6b6919b 100644 --- a/bioimageio/core/block.py +++ b/bioimageio/core/block.py @@ -55,7 +55,7 @@ def __post_init__(self): super().__post_init__() for a, s in self.data.sizes.items(): slice_ = self.inner_slice[a] - halo = self.halo[a] + halo = self.halo.get(a, Halo(0, 0)) assert s == slice_.stop - slice_.start + halo.left + halo.right, ( s, slice_, From eb13f801e23a6b2f243dff48f597fd1f02ac4c50 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 15:00:17 +0200 Subject: [PATCH 195/244] fix _SimpleOperator --- bioimageio/core/proc_ops.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index ee99b80e..1fe016b3 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -1,6 +1,6 @@ import collections.abc from abc import ABC, abstractmethod -from dataclasses import InitVar, dataclass, field, replace +from dataclasses import InitVar, dataclass, field from typing import ( Collection, Literal, @@ -16,6 +16,7 @@ import xarray as xr from typing_extensions import Self, assert_never +from bioimageio.core.block import Block from bioimageio.core.sample import Sample, SampleBlock, SampleBlockWithOrigin from bioimageio.spec.model import v0_4, v0_5 @@ -88,8 +89,13 @@ def __call__(self, sample: Union[Sample, SampleBlock]) -> None: if isinstance(sample, Sample): sample.members[self.output] = output_tensor elif isinstance(sample, SampleBlock): - sample.blocks[self.output] = replace( - sample.blocks[self.input], data=output_tensor + b = sample.blocks[self.input] + sample.blocks[self.output] = Block( + data=output_tensor, + inner_slice=b.inner_slice, + halo=b.halo, + block_number=b.block_number, + blocks_in_sample=b.blocks_in_sample, ) else: assert_never(sample) From 8696ffce54ba623bfdb61f2e6b7686b7eeacc132 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 15:38:30 +0200 Subject: [PATCH 196/244] add predict_sample_without_blocking --- bioimageio/core/_prediction_pipeline.py | 118 +++++++++++++++--------- 1 file changed, 72 insertions(+), 46 deletions(-) diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 2307b2ad..2eb53379 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -1,4 +1,3 @@ -import collections.abc import warnings from types import MappingProxyType from typing import ( @@ -14,7 +13,6 @@ ) from tqdm import tqdm -from typing_extensions import assert_never from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.model.v0_5 import WeightsFormat @@ -74,23 +72,9 @@ def __init__( self.model_description = model_description if isinstance(model_description, v0_4.ModelDescr): - self._default_input_block_shape = {} - default_ns = {} self._default_input_halo: PerMember[PerAxis[Halo]] = {} self._block_transform = {} else: - if isinstance(default_ns, int): - default_ns = { - (ipt.id, a.id): default_ns - for ipt in model_description.inputs - for a in ipt.axes - if isinstance(a.size, v0_5.ParameterizedSize) - } - - self._default_input_block_shape = model_description.get_tensor_sizes( - default_ns, default_batch_size - ).inputs - default_output_halo = { t.id: { a.id: Halo(a.halo, a.halo) @@ -105,15 +89,13 @@ def __init__( self._block_transform = get_block_transform(model_description) self._default_ns = default_ns + self._default_batch_size = default_batch_size self._input_ids = get_member_ids(model_description.inputs) self._output_ids = get_member_ids(model_description.outputs) self._adapter: ModelAdapter = model_adapter - def __call__(self, data: Predict_IO) -> Predict_IO: - return self.predict(data) - def __enter__(self): self.load() return self @@ -150,17 +132,61 @@ def predict_sample_block( return output - def predict_sample( + def predict_sample_without_blocking( + self, + sample: Sample, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + ) -> Sample: + """predict a sample. + The sample's tensor shapes have to match the model's input tensor description. + If that is not the case, consider `predict_sample_with_blocking`""" + + block = sample.as_single_block() + predicted_block = self.predict_sample_block( + block, + skip_preprocessing=skip_preprocessing, + skip_postprocessing=skip_postprocessing, + ) + predicted_sample = Sample.from_blocks([predicted_block]) + return predicted_sample + + def predict_sample_with_blocking( self, sample: Sample, skip_preprocessing: bool = False, skip_postprocessing: bool = False, + ns: Optional[ + Union[ + v0_5.ParameterizedSize.N, + Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N], + ] + ] = None, + batch_size: Optional[int] = None, ) -> Sample: + """predict a sample by splitting it into blocks according to the model and the `ns` parameter""" if not skip_preprocessing: self.apply_preprocessing(sample) + if isinstance(self.model_description, v0_4.ModelDescr): + raise NotImplementedError( + "predict with blocking not implemented for v0_4.ModelDescr {self.model_description.name}" + ) + + ns = ns or self._default_ns + if isinstance(ns, int): + ns = { + (ipt.id, a.id): ns + for ipt in self.model_description.inputs + for a in ipt.axes + if isinstance(a.size, v0_5.ParameterizedSize) + } + input_block_shape = self.model_description.get_tensor_sizes( + ns, batch_size or self._default_batch_size + ).inputs + n_blocks, input_blocks = sample.split_into_blocks( - self._default_input_block_shape, + input_block_shape, halo=self._default_input_halo, pad_mode="reflect", ) @@ -182,31 +208,31 @@ def predict_sample( return predicted_sample - def predict( - self, - inputs: Predict_IO, - skip_preprocessing: bool = False, - skip_postprocessing: bool = False, - ) -> Predict_IO: - """Run model prediction **including** pre/postprocessing.""" - - if isinstance(inputs, Sample): - return self.predict_sample( - inputs, - skip_preprocessing=skip_preprocessing, - skip_postprocessing=skip_postprocessing, - ) - elif isinstance(inputs, collections.abc.Iterable): - return ( - self.predict( - ipt, - skip_preprocessing=skip_preprocessing, - skip_postprocessing=skip_postprocessing, - ) - for ipt in inputs - ) - else: - assert_never(inputs) + # def predict( + # self, + # inputs: Predict_IO, + # skip_preprocessing: bool = False, + # skip_postprocessing: bool = False, + # ) -> Predict_IO: + # """Run model prediction **including** pre/postprocessing.""" + + # if isinstance(inputs, Sample): + # return self.predict_sample_with_blocking( + # inputs, + # skip_preprocessing=skip_preprocessing, + # skip_postprocessing=skip_postprocessing, + # ) + # elif isinstance(inputs, collections.abc.Iterable): + # return ( + # self.predict( + # ipt, + # skip_preprocessing=skip_preprocessing, + # skip_postprocessing=skip_postprocessing, + # ) + # for ipt in inputs + # ) + # else: + # assert_never(inputs) def apply_preprocessing(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: """apply preprocessing in-place, also updates sample stats""" From 374bb528eb2895d0985d8c94e44588ce64eb3860 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 9 Apr 2024 17:20:30 +0200 Subject: [PATCH 197/244] rename block_number -> block_index --- bioimageio/core/block.py | 6 +++--- bioimageio/core/block_meta.py | 10 +++++----- bioimageio/core/common.py | 2 +- bioimageio/core/proc_ops.py | 4 ++-- bioimageio/core/sample.py | 18 ++++++++++-------- 5 files changed, 21 insertions(+), 19 deletions(-) diff --git a/bioimageio/core/block.py b/bioimageio/core/block.py index b6b6919b..36a94bb2 100644 --- a/bioimageio/core/block.py +++ b/bioimageio/core/block.py @@ -35,7 +35,7 @@ def __init__( *, inner_slice: PerAxis[SliceInfo], halo: PerAxis[Halo], - block_number: int, + block_index: int, blocks_in_sample: int, ): object.__setattr__(self, "data", data) @@ -43,7 +43,7 @@ def __init__( sample_shape=data.tagged_shape, inner_slice=inner_slice, halo=halo, - block_number=block_number, + block_index=block_index, blocks_in_sample=blocks_in_sample, ) @@ -74,7 +74,7 @@ def from_sample_member( data=sample_member[block.outer_slice].pad(block.padding, pad_mode), inner_slice=block.inner_slice, halo=block.halo, - block_number=block.block_number, + block_index=block.block_index, blocks_in_sample=block.blocks_in_sample, ) diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py index ff7d38f3..d250e978 100644 --- a/bioimageio/core/block_meta.py +++ b/bioimageio/core/block_meta.py @@ -17,7 +17,7 @@ from .axis import AxisId, PerAxis from .common import ( - BlockNumber, + BlockIndex, Halo, HaloLike, MemberId, @@ -80,8 +80,8 @@ class BlockMeta: halo: PerAxis[Halo] """halo enlarging the inner region to the block's sizes""" - block_number: BlockNumber - """the n-th block of the sample""" + block_index: BlockIndex + """the i-th block of the sample""" blocks_in_sample: TotalNumberOfBlocks """total number of blocks in the sample""" @@ -232,7 +232,7 @@ def get_transformed( ) for a, trf in new_axes.items() }, - block_number=self.block_number, + block_index=self.block_index, blocks_in_sample=self.blocks_in_sample, ) @@ -291,7 +291,7 @@ def _block_meta_generator( sample_shape=sample_shape, inner_slice=inner_slice, halo=halo, - block_number=i, + block_index=i, blocks_in_sample=blocks_in_sample, ) diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index d5c825c1..268ac616 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -82,5 +82,5 @@ class SliceInfo(NamedTuple): T = TypeVar("T") PerMember = Mapping[MemberId, T] -BlockNumber = int +BlockIndex = int TotalNumberOfBlocks = int diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 1fe016b3..1e5a07ba 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -94,7 +94,7 @@ def __call__(self, sample: Union[Sample, SampleBlock]) -> None: data=output_tensor, inner_slice=b.inner_slice, halo=b.halo, - block_number=b.block_number, + block_index=b.block_index, blocks_in_sample=b.blocks_in_sample, ) else: @@ -178,7 +178,7 @@ def __post_init__(self): def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: if isinstance(sample, SampleBlockWithOrigin): # update stats with whole sample on first block - if sample.block_number != 0: + if sample.block_index != 0: return origin = sample.origin diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index bd5b74fd..73759997 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -27,7 +27,7 @@ split_multiple_shapes_into_blocks, ) from .common import ( - BlockNumber, + BlockIndex, Halo, HaloLike, MemberId, @@ -90,10 +90,12 @@ def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): blocks={ m: Block( data, - inner_slice={a: SliceInfo(0, s) for a, s in data.tagged_shape.items()}, + inner_slice={ + a: SliceInfo(0, s) for a, s in data.tagged_shape.items() + }, halo=halo.get(m, {}), - block_number=1, - blocks_in_sample=1 + block_index=0, + blocks_in_sample=1, ) for m, data in self.members.items() }, @@ -145,7 +147,7 @@ class SampleBlockBase(Generic[BlockT]): blocks: Dict[MemberId, BlockT] """Individual tensor blocks comprising this sample block""" - block_number: BlockNumber = field(init=False) + block_index: BlockIndex = field(init=False) """the n-th block of the sample""" blocks_in_sample: TotalNumberOfBlocks = field(init=False) @@ -153,7 +155,7 @@ class SampleBlockBase(Generic[BlockT]): def __post_init__(self): a_block = next(iter(self.blocks.values())) - self.block_number = a_block.block_number + self.block_index = a_block.block_index self.blocks_in_sample = a_block.blocks_in_sample @property @@ -222,7 +224,7 @@ def get_transformed( ) for a, trf in new_axes[m].items() }, - block_number=self.block_number, + block_index=self.block_index, blocks_in_sample=self.blocks_in_sample, ) for m in new_axes @@ -238,7 +240,7 @@ def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock: data[m], inner_slice=b.inner_slice, halo=b.halo, - block_number=b.block_number, + block_index=b.block_index, blocks_in_sample=b.blocks_in_sample, ) for m, b in self.blocks.items() From 4c833868a38067d440fc36454780275c0bd70407 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 10 Apr 2024 13:32:36 +0200 Subject: [PATCH 198/244] some small fixes --- bioimageio/core/_resource_tests.py | 4 ++-- bioimageio/core/digest_spec.py | 2 +- .../core/model_adapters/_pytorch_model_adapter.py | 5 +++-- .../core/model_adapters/_torchscript_model_adapter.py | 11 ++++------- bioimageio/core/sample.py | 6 +----- bioimageio/core/weight_converter/torch/_onnx.py | 7 ++++--- 6 files changed, 15 insertions(+), 20 deletions(-) diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index 9c721c63..54588435 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -126,7 +126,7 @@ def _test_model_inference( with create_prediction_pipeline( bioimageio_model=model, devices=devices, weight_format=weight_format ) as prediction_pipeline: - results = prediction_pipeline.predict(inputs) + results = prediction_pipeline.predict_sample_without_blocking(inputs) if len(results.members) != len(expected.members): error = f"Expected {len(expected.members)} outputs, but got {len(results.members)}" @@ -251,7 +251,7 @@ def get_ns(n: int): ) as prediction_pipeline: for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): error: Optional[str] = None - result = prediction_pipeline.predict(inputs) + result = prediction_pipeline.predict_sample_with_blocking(inputs) if len(result.members) != len(exptected_output_shape): error = ( f"Expected {len(exptected_output_shape)} outputs," diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py index 12901fab..6ffab907 100644 --- a/bioimageio/core/digest_spec.py +++ b/bioimageio/core/digest_spec.py @@ -215,7 +215,7 @@ def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]] assert ( total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 ) - input_halo.setdefault(t, {})[a] = Halo( + input_halo.setdefault(s.tensor_id, {})[a] = Halo( int(total_input_halo // 2), int(total_input_halo // 2) ) diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 8ab8c967..eaf03fcc 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -5,6 +5,7 @@ from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download +from ..axis import AxisId from ..digest_spec import import_callable from ..tensor import Tensor from ._model_adapter import ModelAdapter @@ -31,7 +32,7 @@ def __init__( raise ImportError("torch") super().__init__() self.output_dims = [ - tuple(a if isinstance(a, str) else a.id for a in out.axes) + tuple(AxisId(a) if isinstance(a, str) else a.id for a in out.axes) for out in outputs ] self._network = self.get_network(weights) @@ -52,7 +53,7 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: raise ImportError("torch") with torch.no_grad(): tensors = [ - None if ipt is None else torch.from_numpy(ipt.data) + None if ipt is None else torch.from_numpy(ipt.data.data) for ipt in input_tensors ] tensors = [ diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index d9454854..d7cee1a3 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -5,10 +5,11 @@ import numpy as np from numpy.typing import NDArray -from bioimageio.core.tensor import Tensor from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download +from ..axis import AxisId +from ..tensor import Tensor from ._model_adapter import ModelAdapter try: @@ -49,11 +50,7 @@ def __init__( ) self._model.to(self.devices[0]) self._internal_output_axes = [ - ( - tuple(out.axes) - if isinstance(out.axes, str) - else tuple(a.id for a in out.axes) - ) + tuple(AxisId(a) if isinstance(a, str) else a.id for a in out.axes) for out in model_description.outputs ] @@ -61,7 +58,7 @@ def forward(self, *batch: Optional[Tensor]) -> List[Optional[Tensor]]: assert torch is not None with torch.no_grad(): torch_tensor = [ - None if b is None else torch.from_numpy(b.data).to(self.devices[0]) + None if b is None else torch.from_numpy(b.data.data).to(self.devices[0]) for b in batch ] _result: Union[ # pyright: ignore[reportUnknownVariableType] diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 73759997..071525e6 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -166,10 +166,6 @@ def shape(self) -> PerMember[PerAxis[int]]: def inner_shape(self) -> PerMember[PerAxis[int]]: return {mid: b.inner_shape for mid, b in self.blocks.items()} - @property - @abstractmethod - def origin_shape(self) -> PerMember[PerAxis[int]]: ... - @dataclass class LinearSampleAxisTransform(LinearAxisTransform): @@ -188,7 +184,7 @@ def get_transformed( a: ( trf if isinstance(trf, int) - else trf.compute(self.origin_shape[trf.member][trf.axis]) + else trf.compute(self.sample_shape[trf.member][trf.axis]) ) for a, trf in new_axes[m].items() } diff --git a/bioimageio/core/weight_converter/torch/_onnx.py b/bioimageio/core/weight_converter/torch/_onnx.py index 12b31cca..1e1e68ae 100644 --- a/bioimageio/core/weight_converter/torch/_onnx.py +++ b/bioimageio/core/weight_converter/torch/_onnx.py @@ -11,7 +11,7 @@ from bioimageio.spec.common import InvalidDescr from bioimageio.spec.model import v0_4, v0_5 -from ...digest_spec import get_test_inputs +from ...digest_spec import get_member_id, get_test_inputs from ...weight_converter.torch._utils import load_torch_model @@ -50,8 +50,9 @@ def add_onnx_weights( with torch.no_grad(): - input_data = [t.data for t in get_test_inputs(model_spec)] - input_tensors = [torch.from_numpy(d) for d in input_data] + sample = get_test_inputs(model_spec) + input_data = [sample[get_member_id(ipt)].data.data for ipt in model_spec.inputs] + input_tensors = [torch.from_numpy(ipt) for ipt in input_data] model = load_torch_model(state_dict_weights_descr) expected_tensors = model(*input_tensors) From 292e38d2bc40d9305e7a4b50185888b7ae217c52 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 10 Apr 2024 13:33:10 +0200 Subject: [PATCH 199/244] use cached_property in BlockMeta --- bioimageio/core/block_meta.py | 166 +++++++++++++++++----------------- bioimageio/core/common.py | 44 ++++++++- 2 files changed, 127 insertions(+), 83 deletions(-) diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py index d250e978..999406d1 100644 --- a/bioimageio/core/block_meta.py +++ b/bioimageio/core/block_meta.py @@ -1,5 +1,6 @@ import itertools -from dataclasses import dataclass, field +from dataclasses import dataclass +from functools import cached_property from math import prod from typing import ( Any, @@ -13,11 +14,13 @@ Union, ) +from loguru import logger from typing_extensions import Self from .axis import AxisId, PerAxis from .common import ( BlockIndex, + Frozen, Halo, HaloLike, MemberId, @@ -86,69 +89,42 @@ class BlockMeta: blocks_in_sample: TotalNumberOfBlocks """total number of blocks in the sample""" - shape: PerAxis[int] = field(init=False) - """axis lengths of the block""" - - padding: PerAxis[PadWidth] = field(init=False) - """padding to realize the halo at the sample edge - where we cannot simply enlarge the inner slice""" - - outer_slice: PerAxis[SliceInfo] = field(init=False) - """slice of the outer block (without padding) wrt the sample""" - - inner_shape: PerAxis[int] = field(init=False) - """axis lengths of the inner region (without halo)""" - - local_slice: PerAxis[SliceInfo] = field(init=False) - """inner slice wrt the block, **not** the sample""" - - @property - def dims(self) -> Collection[AxisId]: - return set(self.inner_shape) - - @property - def tagged_shape(self) -> PerAxis[int]: - """alias for shape""" - return self.shape - - @property - def inner_slice_wo_overlap(self): - """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be - stiched together trivially to form the original sample. - - This can also be used to calculate statistics - without overrepresenting block edge regions.""" - # TODO: update inner_slice_wo_overlap when adding block overlap - return self.inner_slice - - def __post_init__(self): - assert all( - a in self.sample_shape for a in self.inner_slice - ), "block has axes not present in sample" - assert all( - a in self.inner_slice for a in self.halo - ), "halo has axes not present in block" - - object.__setattr__( # TODO: write as property - self, - "shape", + @cached_property + def shape(self) -> PerAxis[int]: + """axis lengths of the block""" + return Frozen( { a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) for a, s in self.inner_slice.items() - }, + } ) - assert all( - s <= self.sample_shape[a] for a, s in self.shape.items() - ), "block larger than sample" - object.__setattr__( # TODO: write as property - self, - "inner_shape", - {a: s.stop - s.start for a, s in self.inner_slice.items()}, + @cached_property + def padding(self) -> PerAxis[PadWidth]: + """padding to realize the halo at the sample edge + where we cannot simply enlarge the inner slice""" + return Frozen( + { + a: PadWidth( + max( + 0, + (self.halo[a].left if a in self.halo else 0) + - (self.inner_slice[a].start + self.outer_slice[a].start), + ), + max( + 0, + (self.halo[a].right if a in self.halo else 0) + - (self.outer_slice[a].stop + self.inner_slice[a].stop), + ), + ) + for a in self.inner_slice + } ) - object.__setattr__( # TODO: write as property - self, - "outer_slice", + + @cached_property + def outer_slice(self) -> PerAxis[SliceInfo]: + """slice of the outer block (without padding) wrt the sample""" + return Frozen( { a: SliceInfo( max( @@ -168,39 +144,65 @@ def __post_init__(self): ), ) for a in self.inner_slice - }, - ) - object.__setattr__( # TODO: write as property - self, - "padding", - { - a: PadWidth( - max( - 0, - (self.halo[a].left if a in self.halo else 0) - - (self.inner_slice[a].start + self.outer_slice[a].start), - ), - max( - 0, - (self.halo[a].right if a in self.halo else 0) - - (self.outer_slice[a].stop + self.inner_slice[a].stop), - ), - ) - for a in self.inner_slice - }, + } ) - object.__setattr__( # TODO: write as property - self, - "local_slice", + + @cached_property + def inner_shape(self) -> PerAxis[int]: + """axis lengths of the inner region (without halo)""" + return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()}) + + @cached_property + def local_slice(self) -> PerAxis[SliceInfo]: + """inner slice wrt the block, **not** the sample""" + return Frozen( { a: SliceInfo( self.padding[a].left, self.padding[a].left + self.inner_shape[a], ) for a in self.inner_slice - }, + } ) + @property + def dims(self) -> Collection[AxisId]: + return set(self.inner_shape) + + @property + def tagged_shape(self) -> PerAxis[int]: + """alias for shape""" + return self.shape + + @property + def inner_slice_wo_overlap(self): + """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be + stiched together trivially to form the original sample. + + This can also be used to calculate statistics + without overrepresenting block edge regions.""" + # TODO: update inner_slice_wo_overlap when adding block overlap + return self.inner_slice + + def __post_init__(self): + # freeze mutable inputs + object.__setattr__(self, "sample_shape", Frozen(self.sample_shape)) + object.__setattr__(self, "inner_slice", Frozen(self.inner_slice)) + object.__setattr__(self, "halo", Frozen(self.halo)) + + assert all( + a in self.sample_shape for a in self.inner_slice + ), "block has axes not present in sample" + + assert all( + a in self.inner_slice for a in self.halo + ), "halo has axes not present in block" + + if any(s > self.sample_shape[a] for a, s in self.shape.items()): + logger.warning( + "block {} larger than sample {}", self.shape, self.sample_shape + ) + def get_transformed( self, new_axes: PerAxis[Union[LinearAxisTransform, int]] ) -> Self: diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 268ac616..4e4c493a 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,6 +1,16 @@ from __future__ import annotations -from typing import Hashable, Literal, Mapping, NamedTuple, Tuple, TypeVar, Union +from copy import deepcopy +from typing import ( + Hashable, + Iterator, + Literal, + Mapping, + NamedTuple, + Tuple, + TypeVar, + Union, +) from typing_extensions import Self, assert_never @@ -84,3 +94,35 @@ class SliceInfo(NamedTuple): BlockIndex = int TotalNumberOfBlocks = int + + +K = TypeVar("K", bound=Hashable) +V = TypeVar("V") + + +class Frozen(Mapping[K, V]): # adapted from xarray.core.utils.Frozen + """Wrapper around an object implementing the mapping interface to make it + immutable.""" + + __slots__ = ("mapping",) + + def __init__(self, mapping: Mapping[K, V]): + super().__init__() + self.mapping = deepcopy( + mapping + ) # added deepcopy (compared to xarray.core.utils.Frozen) + + def __getitem__(self, key: K) -> V: + return self.mapping[key] + + def __iter__(self) -> Iterator[K]: + return iter(self.mapping) + + def __len__(self) -> int: + return len(self.mapping) + + def __contains__(self, key: object) -> bool: + return key in self.mapping + + def __repr__(self) -> str: + return f"{type(self).__name__}({self.mapping!r})" From 250bee697b1c09d2a800b9abc3de34277722427b Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 10 Apr 2024 13:47:11 +0200 Subject: [PATCH 200/244] add ruff to dev envs (now available for py 3.12) --- dev/env-py38.yaml | 2 +- dev/env-tf.yaml | 2 +- dev/env-wo-python.yaml | 2 +- dev/env.yaml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml index 0fe96b73..726ce341 100644 --- a/dev/env-py38.yaml +++ b/dev/env-py38.yaml @@ -27,7 +27,7 @@ dependencies: - python-dotenv - python=3.8 # changed - pytorch>=2.1 - - ruff # uncommented + - ruff - ruyaml - torchvision - tqdm diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml index 53a18ae0..304c0193 100644 --- a/dev/env-tf.yaml +++ b/dev/env-tf.yaml @@ -27,7 +27,7 @@ dependencies: - python-dotenv # - python=3.9 # removed # - pytorch>=2.1 # removed - # - ruff # removed + - ruff - ruyaml - tensorflow>=2.15 # added # - torchvision # removed diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml index 40fe27b6..8816ea48 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-wo-python.yaml @@ -27,7 +27,7 @@ dependencies: - python-dotenv # - python=3.9 # removed - pytorch>=2.1 - # - ruff # requires python < 3.11 + - ruff - ruyaml - torchvision - tqdm diff --git a/dev/env.yaml b/dev/env.yaml index 6f6b8059..0aa1660e 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -26,7 +26,7 @@ dependencies: - python-dotenv - python=3.9 - pytorch>=2.1 - # - ruff # requires python < 3.11 + - ruff - ruyaml - torchvision - tqdm From 9217276c6dd93344e4d166bbe3c9ebb818971b4f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 11 Apr 2024 15:41:43 +0200 Subject: [PATCH 201/244] make prediction pipeline easier to debug --- bioimageio/core/_prediction_pipeline.py | 17 ++++++++++------- bioimageio/core/block.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 2eb53379..17e6bf25 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -190,18 +190,21 @@ def predict_sample_with_blocking( halo=self._default_input_halo, pad_mode="reflect", ) - input_blocks = tqdm( + input_blocks = list(input_blocks) + predicted_blocks: List[SampleBlock] = [] + for b in tqdm( input_blocks, desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}", unit="block", + unit_divisor=1, total=n_blocks, - ) - predicted_blocks = ( - self.predict_sample_block( - b, skip_preprocessing=True, skip_postprocessing=True + ): + predicted_blocks.append( + self.predict_sample_block( + b, skip_preprocessing=True, skip_postprocessing=True + ) ) - for b in input_blocks - ) + predicted_sample = Sample.from_blocks(predicted_blocks) if not skip_postprocessing: self.apply_postprocessing(predicted_sample) diff --git a/bioimageio/core/block.py b/bioimageio/core/block.py index 36a94bb2..05adca9e 100644 --- a/bioimageio/core/block.py +++ b/bioimageio/core/block.py @@ -56,7 +56,7 @@ def __post_init__(self): for a, s in self.data.sizes.items(): slice_ = self.inner_slice[a] halo = self.halo.get(a, Halo(0, 0)) - assert s == slice_.stop - slice_.start + halo.left + halo.right, ( + assert s == halo.left + (slice_.stop - slice_.start) + halo.right, ( s, slice_, halo, From 38b6cb0d3ca6b991db1f2593bb3f5d527aa07db0 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 11 Apr 2024 15:42:20 +0200 Subject: [PATCH 202/244] fix BlockMeta.padding --- bioimageio/core/block_meta.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py index 999406d1..417cd002 100644 --- a/bioimageio/core/block_meta.py +++ b/bioimageio/core/block_meta.py @@ -106,15 +106,17 @@ def padding(self) -> PerAxis[PadWidth]: return Frozen( { a: PadWidth( - max( - 0, - (self.halo[a].left if a in self.halo else 0) - - (self.inner_slice[a].start + self.outer_slice[a].start), + ( + self.halo[a].left + - (self.inner_slice[a].start - self.outer_slice[a].start) + if a in self.halo + else 0 ), - max( - 0, - (self.halo[a].right if a in self.halo else 0) - - (self.outer_slice[a].stop + self.inner_slice[a].stop), + ( + self.halo[a].right + - (self.outer_slice[a].stop - self.inner_slice[a].stop) + if a in self.halo + else 0 ), ) for a in self.inner_slice @@ -249,6 +251,9 @@ def split_shape_into_blocks( tuple(shape), set(block_shape), ) + if any(shape[a] < block_shape[a] for a in block_shape): + raise ValueError(f"shape {shape} is smaller than block shape {block_shape}") + assert all(a in shape for a in halo), (tuple(shape), set(halo)) # fill in default halo (0) and block axis length (from tensor shape) From aa60cd6e19a6a688d744f06f7d035e46f7dcd3bb Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 11 Apr 2024 15:43:34 +0200 Subject: [PATCH 203/244] scale halo in SampleBlockMeta.get_transformed --- bioimageio/core/sample.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 071525e6..488353cc 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -1,12 +1,9 @@ from __future__ import annotations -from abc import abstractmethod from dataclasses import dataclass, field -from itertools import islice +from math import ceil from typing import ( - Any, Dict, - Generator, Generic, Iterable, Optional, @@ -111,7 +108,6 @@ def from_blocks( fill_value: float = float("nan"), ) -> Self: members: PerMember[Tensor] = {} - sample_blocks = list(iter(sample_blocks)) for member_blocks in sample_blocks: for m, block in member_blocks.blocks.items(): if m not in members: @@ -213,9 +209,16 @@ def get_transformed( a: ( Halo(0, 0) if isinstance(trf, int) + or trf.axis not in self.blocks[trf.member].halo else Halo( - self.blocks[trf.member].halo[trf.axis].left, - self.blocks[trf.member].halo[trf.axis].right, + ceil( + self.blocks[trf.member].halo[trf.axis].left + * trf.scale + ), + ceil( + self.blocks[trf.member].halo[trf.axis].right + * trf.scale + ), ) ) for a, trf in new_axes[m].items() From ed3453eab8ed42489b27e5085180ded8aa8a2685 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 11 Apr 2024 17:38:35 +0200 Subject: [PATCH 204/244] clean up nested Frozen instances --- bioimageio/core/block.py | 1 - bioimageio/core/block_meta.py | 11 ++++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/bioimageio/core/block.py b/bioimageio/core/block.py index 05adca9e..f22c89fc 100644 --- a/bioimageio/core/block.py +++ b/bioimageio/core/block.py @@ -16,7 +16,6 @@ Halo, HaloLike, PadMode, - SliceInfo, TotalNumberOfBlocks, ) from .tensor import Tensor diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py index 417cd002..4de01baf 100644 --- a/bioimageio/core/block_meta.py +++ b/bioimageio/core/block_meta.py @@ -188,9 +188,14 @@ def inner_slice_wo_overlap(self): def __post_init__(self): # freeze mutable inputs - object.__setattr__(self, "sample_shape", Frozen(self.sample_shape)) - object.__setattr__(self, "inner_slice", Frozen(self.inner_slice)) - object.__setattr__(self, "halo", Frozen(self.halo)) + if not isinstance(self.sample_shape, Frozen): + object.__setattr__(self, "sample_shape", Frozen(self.sample_shape)) + + if not isinstance(self.inner_slice, Frozen): + object.__setattr__(self, "inner_slice", Frozen(self.inner_slice)) + + if not isinstance(self.halo, Frozen): + object.__setattr__(self, "halo", Frozen(self.halo)) assert all( a in self.sample_shape for a in self.inner_slice From d45040e228ab738e299c5d794b755310255176d2 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 11 Apr 2024 17:39:00 +0200 Subject: [PATCH 205/244] fix Block.sample_shape --- bioimageio/core/block.py | 21 ++------------------- bioimageio/core/sample.py | 10 ++++++---- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/bioimageio/core/block.py b/bioimageio/core/block.py index f22c89fc..3355f9b1 100644 --- a/bioimageio/core/block.py +++ b/bioimageio/core/block.py @@ -21,31 +21,13 @@ from .tensor import Tensor -@dataclass(init=False, frozen=True) +@dataclass(frozen=True) class Block(BlockMeta): """A block/tile of a (larger) tensor""" data: Tensor """the block's tensor, e.g. a (padded) slice of some larger, original tensor""" - def __init__( - self, - data: Tensor, - *, - inner_slice: PerAxis[SliceInfo], - halo: PerAxis[Halo], - block_index: int, - blocks_in_sample: int, - ): - object.__setattr__(self, "data", data) - super().__init__( - sample_shape=data.tagged_shape, - inner_slice=inner_slice, - halo=halo, - block_index=block_index, - blocks_in_sample=blocks_in_sample, - ) - @property def inner_data(self): return self.data[self.local_slice] @@ -71,6 +53,7 @@ def from_sample_member( ) -> Self: return cls( data=sample_member[block.outer_slice].pad(block.padding, pad_mode), + sample_shape=sample_member.tagged_shape, inner_slice=block.inner_slice, halo=block.halo, block_index=block.block_index, diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 488353cc..046e9e04 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -86,7 +86,8 @@ def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): sample_shape=self.shape, blocks={ m: Block( - data, + sample_shape=self.shape[m], + data=data, inner_slice={ a: SliceInfo(0, s) for a, s in data.tagged_shape.items() }, @@ -108,8 +109,8 @@ def from_blocks( fill_value: float = float("nan"), ) -> Self: members: PerMember[Tensor] = {} - for member_blocks in sample_blocks: - for m, block in member_blocks.blocks.items(): + for sample_block in sample_blocks: + for m, block in sample_block.blocks.items(): if m not in members: if -1 in block.sample_shape.values(): raise NotImplementedError( @@ -236,11 +237,12 @@ def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock: sample_shape=self.sample_shape, blocks={ m: Block( - data[m], + sample_shape=self.sample_shape[m], inner_slice=b.inner_slice, halo=b.halo, block_index=b.block_index, blocks_in_sample=b.blocks_in_sample, + data=data[m], ) for m, b in self.blocks.items() }, From fdf4bed2e276442cba02f90ed931b65e8028a44c Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 04:19:19 +0200 Subject: [PATCH 206/244] be strict about halo scaling --- bioimageio/core/block_meta.py | 7 ++-- bioimageio/core/sample.py | 78 +++++++++++++++++++---------------- 2 files changed, 47 insertions(+), 38 deletions(-) diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py index 4de01baf..0fe5b6c5 100644 --- a/bioimageio/core/block_meta.py +++ b/bioimageio/core/block_meta.py @@ -1,9 +1,10 @@ import itertools from dataclasses import dataclass from functools import cached_property -from math import prod +from math import floor, prod from typing import ( Any, + Callable, Collection, Dict, Generator, @@ -37,8 +38,8 @@ class LinearAxisTransform: scale: float offset: int - def compute(self, s: int) -> int: - return int(s * self.scale) + self.offset + def compute(self, s: int, round: Callable[[float], int] = floor) -> int: + return round(s * self.scale) + self.offset @dataclass(frozen=True) diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 046e9e04..b4b2464a 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -1,8 +1,9 @@ from __future__ import annotations from dataclasses import dataclass, field -from math import ceil +from math import ceil, floor from typing import ( + Callable, Dict, Generic, Iterable, @@ -17,7 +18,7 @@ from bioimageio.core.block import Block -from .axis import PerAxis +from .axis import AxisId, PerAxis from .block_meta import ( BlockMeta, LinearAxisTransform, @@ -187,43 +188,50 @@ def get_transformed( } for m in new_axes } + + def get_member_halo(m: MemberId, round: Callable[[float], int]): + return { + a: ( + Halo(0, 0) + if isinstance(trf, int) + or trf.axis not in self.blocks[trf.member].halo + else Halo( + ceil(self.blocks[trf.member].halo[trf.axis].left * trf.scale), + ceil(self.blocks[trf.member].halo[trf.axis].right * trf.scale), + ) + ) + for a, trf in new_axes[m].items() + } + + halo: Dict[MemberId, Dict[AxisId, Halo]] = {} + for m in new_axes: + halo[m] = get_member_halo(m, floor) + assert halo[m] == get_member_halo( + m, ceil + ), f"failed to unambiguously scale halo {halo[m]} with {new_axes[m]}" + + inner_slice = { + m: { + a: ( + SliceInfo(0, trf) + if isinstance(trf, int) + else SliceInfo( + trf.compute( + self.blocks[trf.member].inner_slice[trf.axis].start + ), + trf.compute(self.blocks[trf.member].inner_slice[trf.axis].stop), + ) + ) + for a, trf in new_axes[m].items() + } + for m in new_axes + } return self.__class__( blocks={ m: BlockMeta( sample_shape=sample_shape[m], - inner_slice={ - a: ( - SliceInfo(0, trf) - if isinstance(trf, int) - else SliceInfo( - trf.compute( - self.blocks[trf.member].inner_slice[trf.axis].start - ), - trf.compute( - self.blocks[trf.member].inner_slice[trf.axis].stop - ), - ) - ) - for a, trf in new_axes[m].items() - }, - halo={ - a: ( - Halo(0, 0) - if isinstance(trf, int) - or trf.axis not in self.blocks[trf.member].halo - else Halo( - ceil( - self.blocks[trf.member].halo[trf.axis].left - * trf.scale - ), - ceil( - self.blocks[trf.member].halo[trf.axis].right - * trf.scale - ), - ) - ) - for a, trf in new_axes[m].items() - }, + inner_slice=inner_slice[m], + halo=halo[m], block_index=self.block_index, blocks_in_sample=self.blocks_in_sample, ) From 2413067f66fdfbd61eabbe7fc0b2b78ad9bc8da2 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 04:21:07 +0200 Subject: [PATCH 207/244] set sample_shape for _SimpleOperator output block --- bioimageio/core/proc_ops.py | 54 +++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 1e5a07ba..481fa302 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -21,7 +21,7 @@ from bioimageio.spec.model import v0_4, v0_5 from ._op_base import BlockedOperator, Operator -from .axis import AxisId +from .axis import AxisId, PerAxis from .common import DTypeStr, MemberId from .stat_calculators import StatsCalculator from .stat_measures import ( @@ -69,13 +69,10 @@ class _SimpleOperator(BlockedOperator, ABC): def required_measures(self) -> Collection[Measure]: return set() - # @property - # def required_tensors(self) -> Set[MemberId]: - # return {self.input} - - # @property - # def produced_tensors(self) -> Set[MemberId]: - # return {self.output} + @abstractmethod + def get_sample_output_shape( + self, sample_input_shape: PerAxis[int] + ) -> PerAxis[int]: ... def __call__(self, sample: Union[Sample, SampleBlock]) -> None: input_tensor = sample.members[self.input] @@ -91,6 +88,7 @@ def __call__(self, sample: Union[Sample, SampleBlock]) -> None: elif isinstance(sample, SampleBlock): b = sample.blocks[self.input] sample.blocks[self.output] = Block( + sample_shape=self.get_sample_output_shape(sample.shape[self.input]), data=output_tensor, inner_slice=b.inner_slice, halo=b.halo, @@ -201,6 +199,11 @@ class Binarize(_SimpleOperator): def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input > self.threshold + def get_sample_output_shape( + self, sample_input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return sample_input_shape + @classmethod def from_proc_descr( cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId @@ -236,6 +239,11 @@ def __post_init__(self): def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input.clip(self.min, self.max) + def get_sample_output_shape( + self, sample_input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return sample_input_shape + @classmethod def from_proc_descr( cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId @@ -274,6 +282,11 @@ class ScaleLinear(_SimpleOperator): def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input * self.gain + self.offset + def get_sample_output_shape( + self, sample_input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return sample_input_shape + @classmethod def from_proc_descr( cls, @@ -342,6 +355,11 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: ref_std = stat[self.ref_std] + self.eps return (input - mean) / std * ref_std + ref_mean + def get_sample_output_shape( + self, sample_input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return sample_input_shape + @classmethod def from_proc_descr( cls, @@ -415,6 +433,11 @@ def __post_init__( def required_measures(self): return {self.lower, self.upper} + def get_sample_output_shape( + self, sample_input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return sample_input_shape + @classmethod def from_proc_descr( cls, @@ -475,6 +498,11 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: def required_measures(self) -> Collection[Measure]: return {} + def get_sample_output_shape( + self, sample_input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return sample_input_shape + @classmethod def from_proc_descr( cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId @@ -502,6 +530,11 @@ def __post_init__(self): def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: return {self.mean, self.std} + def get_sample_output_shape( + self, sample_input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return sample_input_shape + @classmethod def from_proc_descr( cls, @@ -551,6 +584,11 @@ def __post_init__(self): or self.mean.dims == self.std.dims ) + def get_sample_output_shape( + self, sample_input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return sample_input_shape + @classmethod def from_proc_descr( cls, From a8d65914f3cfa9f205302d878a5fa1ac3d515c57 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 04:21:27 +0200 Subject: [PATCH 208/244] update some tests --- tests/test_bioimageio_spec_version.py | 2 +- tests/test_prediction_pipeline.py | 11 +++++----- ...t_prediction_pipeline_device_management.py | 20 ++++++++++--------- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/test_bioimageio_spec_version.py b/tests/test_bioimageio_spec_version.py index 2a0ae2c2..719796ef 100644 --- a/tests/test_bioimageio_spec_version.py +++ b/tests/test_bioimageio_spec_version.py @@ -14,7 +14,7 @@ def test_bioimageio_spec_version(mamba_cmd: Optional[str]): # get latest released bioimageio.spec version mamba_repoquery = subprocess.run( - f"{pytest.mamba_cmd} repoquery search -c conda-forge --json bioimageio.spec".split( + f"{mamba_cmd} repoquery search -c conda-forge --json bioimageio.spec".split( " " ), encoding="utf-8", diff --git a/tests/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py index 33de4bb4..a0a85f5d 100644 --- a/tests/test_prediction_pipeline.py +++ b/tests/test_prediction_pipeline.py @@ -20,13 +20,14 @@ def _test_prediction_pipeline(model_package: Path, weights_format: WeightsFormat ) inputs = get_test_inputs(bio_model) - outputs = pp.forward(*inputs) - assert isinstance(outputs, list) + outputs = pp.predict_sample_without_blocking(inputs) expected_outputs = get_test_outputs(bio_model) - assert len(outputs) == len(expected_outputs) - - for out, exp in zip(outputs, expected_outputs): + assert len(outputs.shape) == len(expected_outputs.shape) + for m in expected_outputs.members: + out = outputs.members[m].data + assert out is not None + exp = expected_outputs.members[m].data assert_array_almost_equal(out, exp, decimal=4) diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py index 064533d5..bada06ae 100644 --- a/tests/test_prediction_pipeline_device_management.py +++ b/tests/test_prediction_pipeline_device_management.py @@ -29,24 +29,26 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): inputs = get_test_inputs(bio_model) with pred_pipe as pp: - outputs = pp.forward(*inputs) - - assert isinstance(outputs, list) + outputs = pp.predict_sample_without_blocking(inputs) expected_outputs = get_test_outputs(bio_model) - assert len(outputs) == len(expected_outputs) - for out, exp in zip(outputs, expected_outputs): + assert len(outputs.shape) == len(expected_outputs.shape) + for m in expected_outputs.members: + out = outputs.members[m].data assert out is not None + exp = expected_outputs.members[m].data assert_array_almost_equal(out, exp, decimal=4) - # repeat inference with context manager to test load/unload/load/forward + # repeat inference with context manager to test load/predict/unload/load/predict with pred_pipe as pp: - outputs = pp.forward(*inputs) + outputs = pp.predict_sample_without_blocking(inputs) - assert len(outputs) == len(expected_outputs) - for out, exp in zip(outputs, expected_outputs): + assert len(outputs.shape) == len(expected_outputs.shape) + for m in expected_outputs.members: + out = outputs.members[m].data assert out is not None + exp = expected_outputs.members[m].data assert_array_almost_equal(out, exp, decimal=4) From 124660c646be31eedb33010139767beeaf7e1688 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 04:23:38 +0200 Subject: [PATCH 209/244] rename get_sample_output_shape -> get_output_shape --- bioimageio/core/proc_ops.py | 59 +++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 481fa302..3d9afe3a 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -70,9 +70,7 @@ def required_measures(self) -> Collection[Measure]: return set() @abstractmethod - def get_sample_output_shape( - self, sample_input_shape: PerAxis[int] - ) -> PerAxis[int]: ... + def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ... def __call__(self, sample: Union[Sample, SampleBlock]) -> None: input_tensor = sample.members[self.input] @@ -88,7 +86,7 @@ def __call__(self, sample: Union[Sample, SampleBlock]) -> None: elif isinstance(sample, SampleBlock): b = sample.blocks[self.input] sample.blocks[self.output] = Block( - sample_shape=self.get_sample_output_shape(sample.shape[self.input]), + sample_shape=self.get_output_shape(sample.shape[self.input]), data=output_tensor, inner_slice=b.inner_slice, halo=b.halo, @@ -199,10 +197,10 @@ class Binarize(_SimpleOperator): def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input > self.threshold - def get_sample_output_shape( - self, sample_input_shape: Mapping[AxisId, int] + def get_output_shape( + self, input_shape: Mapping[AxisId, int] ) -> Mapping[AxisId, int]: - return sample_input_shape + return input_shape @classmethod def from_proc_descr( @@ -239,10 +237,10 @@ def __post_init__(self): def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input.clip(self.min, self.max) - def get_sample_output_shape( - self, sample_input_shape: Mapping[AxisId, int] + def get_output_shape( + self, input_shape: Mapping[AxisId, int] ) -> Mapping[AxisId, int]: - return sample_input_shape + return input_shape @classmethod def from_proc_descr( @@ -267,6 +265,11 @@ def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): def get_descr(self): return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) + def get_output_shape( + self, input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return input_shape + def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input.astype(self.dtype) @@ -282,10 +285,10 @@ class ScaleLinear(_SimpleOperator): def _apply(self, input: Tensor, stat: Stat) -> Tensor: return input * self.gain + self.offset - def get_sample_output_shape( - self, sample_input_shape: Mapping[AxisId, int] + def get_output_shape( + self, input_shape: Mapping[AxisId, int] ) -> Mapping[AxisId, int]: - return sample_input_shape + return input_shape @classmethod def from_proc_descr( @@ -355,10 +358,10 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: ref_std = stat[self.ref_std] + self.eps return (input - mean) / std * ref_std + ref_mean - def get_sample_output_shape( - self, sample_input_shape: Mapping[AxisId, int] + def get_output_shape( + self, input_shape: Mapping[AxisId, int] ) -> Mapping[AxisId, int]: - return sample_input_shape + return input_shape @classmethod def from_proc_descr( @@ -433,10 +436,10 @@ def __post_init__( def required_measures(self): return {self.lower, self.upper} - def get_sample_output_shape( - self, sample_input_shape: Mapping[AxisId, int] + def get_output_shape( + self, input_shape: Mapping[AxisId, int] ) -> Mapping[AxisId, int]: - return sample_input_shape + return input_shape @classmethod def from_proc_descr( @@ -498,10 +501,10 @@ def _apply(self, input: Tensor, stat: Stat) -> Tensor: def required_measures(self) -> Collection[Measure]: return {} - def get_sample_output_shape( - self, sample_input_shape: Mapping[AxisId, int] + def get_output_shape( + self, input_shape: Mapping[AxisId, int] ) -> Mapping[AxisId, int]: - return sample_input_shape + return input_shape @classmethod def from_proc_descr( @@ -530,10 +533,10 @@ def __post_init__(self): def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: return {self.mean, self.std} - def get_sample_output_shape( - self, sample_input_shape: Mapping[AxisId, int] + def get_output_shape( + self, input_shape: Mapping[AxisId, int] ) -> Mapping[AxisId, int]: - return sample_input_shape + return input_shape @classmethod def from_proc_descr( @@ -584,10 +587,10 @@ def __post_init__(self): or self.mean.dims == self.std.dims ) - def get_sample_output_shape( - self, sample_input_shape: Mapping[AxisId, int] + def get_output_shape( + self, input_shape: Mapping[AxisId, int] ) -> Mapping[AxisId, int]: - return sample_input_shape + return input_shape @classmethod def from_proc_descr( From 939e0ba35792220a9ccc3fc0a687bfd5d4327eaa Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 04:48:00 +0200 Subject: [PATCH 210/244] fix get_member_halo --- bioimageio/core/sample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index b4b2464a..70a11877 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -196,8 +196,8 @@ def get_member_halo(m: MemberId, round: Callable[[float], int]): if isinstance(trf, int) or trf.axis not in self.blocks[trf.member].halo else Halo( - ceil(self.blocks[trf.member].halo[trf.axis].left * trf.scale), - ceil(self.blocks[trf.member].halo[trf.axis].right * trf.scale), + round(self.blocks[trf.member].halo[trf.axis].left * trf.scale), + round(self.blocks[trf.member].halo[trf.axis].right * trf.scale), ) ) for a, trf in new_axes[m].items() From 417fe4265879089a913e5732bb42cb5883955ae9 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 04:49:42 +0200 Subject: [PATCH 211/244] remove Sample.__post_init__ --- bioimageio/core/sample.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 70a11877..1b314622 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -100,6 +100,8 @@ def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): }, stat=self.stat, origin=self, + block_index=0, + blocks_in_sample=1, ) @classmethod @@ -145,17 +147,12 @@ class SampleBlockBase(Generic[BlockT]): blocks: Dict[MemberId, BlockT] """Individual tensor blocks comprising this sample block""" - block_index: BlockIndex = field(init=False) + block_index: BlockIndex """the n-th block of the sample""" - blocks_in_sample: TotalNumberOfBlocks = field(init=False) + blocks_in_sample: TotalNumberOfBlocks """total number of blocks in the sample""" - def __post_init__(self): - a_block = next(iter(self.blocks.values())) - self.block_index = a_block.block_index - self.blocks_in_sample = a_block.blocks_in_sample - @property def shape(self) -> PerMember[PerAxis[int]]: return {mid: b.shape for mid, b in self.blocks.items()} @@ -238,6 +235,8 @@ def get_member_halo(m: MemberId, round: Callable[[float], int]): for m in new_axes }, sample_shape=sample_shape, + block_index=self.block_index, + blocks_in_sample=self.blocks_in_sample, ) def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock: @@ -255,6 +254,8 @@ def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock: for m, b in self.blocks.items() }, stat=stat, + block_index=self.block_index, + blocks_in_sample=self.blocks_in_sample, ) @@ -274,7 +275,10 @@ def get_transformed_meta( self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] ) -> SampleBlockMeta: return SampleBlockMeta( - blocks=dict(self.blocks), sample_shape=self.sample_shape + blocks=dict(self.blocks), + sample_shape=self.sample_shape, + block_index=self.block_index, + blocks_in_sample=self.blocks_in_sample, ).get_transformed(new_axes) @@ -290,9 +294,16 @@ def sample_block_meta_generator( sample_shape: PerMember[PerAxis[int]], ): for member_blocks in blocks: + block_indices = {block.block_index for block in member_blocks.values()} + assert len(block_indices) == 1 + blocks_in_samples = {block.blocks_in_sample for block in member_blocks.values()} + assert len(blocks_in_samples) == 1 + yield SampleBlockMeta( blocks=dict(member_blocks), sample_shape=sample_shape, + block_index=block_indices.pop(), + blocks_in_sample=blocks_in_samples.pop(), ) @@ -303,6 +314,10 @@ def sample_block_generator( pad_mode: PadMode, ): for member_blocks in blocks: + block_indices = {block.block_index for block in member_blocks.values()} + assert len(block_indices) == 1 + blocks_in_samples = {block.blocks_in_sample for block in member_blocks.values()} + assert len(blocks_in_samples) == 1 yield SampleBlockWithOrigin( blocks={ m: Block.from_sample_member( @@ -313,4 +328,6 @@ def sample_block_generator( sample_shape=origin.shape, origin=origin, stat=origin.stat, + block_index=block_indices.pop(), + blocks_in_sample=blocks_in_samples.pop(), ) From 866a868a89a3865c7dfbd082eb855f4d88ad6e3a Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 05:33:52 +0200 Subject: [PATCH 212/244] add sample_id to SampleBlockBase --- bioimageio/core/digest_spec.py | 8 ++++-- bioimageio/core/sample.py | 48 ++++++++++++++++++++++++---------- 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py index 6ffab907..75f1c5c1 100644 --- a/bioimageio/core/digest_spec.py +++ b/bioimageio/core/digest_spec.py @@ -324,9 +324,13 @@ def get_io_sample_block_metas( return n_input_blocks, ( IO_SampleBlockMeta(ipt, out) for ipt, out in zip( - sample_block_meta_generator(input_blocks, sample_shape=input_sample_shape), sample_block_meta_generator( - output_blocks, sample_shape=output_sample_shape + input_blocks, sample_shape=input_sample_shape, sample_id=None + ), + sample_block_meta_generator( + output_blocks, + sample_shape=output_sample_shape, + sample_id=None, ), ) ) diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 1b314622..4da3f6b6 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -6,6 +6,7 @@ Callable, Dict, Generic, + Hashable, Iterable, Optional, Tuple, @@ -85,6 +86,7 @@ def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): halo = {} return SampleBlockWithOrigin( sample_shape=self.shape, + sample_id=self.id, blocks={ m: Block( sample_shape=self.shape[m], @@ -112,7 +114,12 @@ def from_blocks( fill_value: float = float("nan"), ) -> Self: members: PerMember[Tensor] = {} + stat: Stat = {} + sample_id = None for sample_block in sample_blocks: + assert sample_id is None or sample_id == sample_block.sample_id + sample_id = sample_block.sample_id + stat = sample_block.stat for m, block in sample_block.blocks.items(): if m not in members: if -1 in block.sample_shape.values(): @@ -131,7 +138,7 @@ def from_blocks( members[m][block.inner_slice] = block.inner_data - return cls(members=members) + return cls(members=members, stat=stat, id=sample_id) BlockT = TypeVar("BlockT", Block, BlockMeta) @@ -144,6 +151,9 @@ class SampleBlockBase(Generic[BlockT]): sample_shape: PerMember[PerAxis[int]] """the sample shape this block represents a part of""" + sample_id: Optional[Hashable] + """identifier for the sample within its dataset""" + blocks: Dict[MemberId, BlockT] """Individual tensor blocks comprising this sample block""" @@ -235,6 +245,7 @@ def get_member_halo(m: MemberId, round: Callable[[float], int]): for m in new_axes }, sample_shape=sample_shape, + sample_id=self.sample_id, block_index=self.block_index, blocks_in_sample=self.blocks_in_sample, ) @@ -242,6 +253,7 @@ def get_member_halo(m: MemberId, round: Callable[[float], int]): def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock: return SampleBlock( sample_shape=self.sample_shape, + sample_id=self.sample_id, blocks={ m: Block( sample_shape=self.sample_shape[m], @@ -275,6 +287,7 @@ def get_transformed_meta( self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] ) -> SampleBlockMeta: return SampleBlockMeta( + sample_id=self.sample_id, blocks=dict(self.blocks), sample_shape=self.sample_shape, block_index=self.block_index, @@ -288,22 +301,31 @@ class SampleBlockWithOrigin(SampleBlock): """the sample this sample black was taken from""" +class _ConsolidatedMemberBlocks: + def __init__(self, blocks: PerMember[BlockMeta]): + super().__init__() + block_indices = {b.block_index for b in blocks.values()} + assert len(block_indices) == 1 + self.block_index = block_indices.pop() + blocks_in_samples = {b.blocks_in_sample for b in blocks.values()} + assert len(blocks_in_samples) == 1 + self.blocks_in_sample = blocks_in_samples.pop() + + def sample_block_meta_generator( blocks: Iterable[PerMember[BlockMeta]], *, sample_shape: PerMember[PerAxis[int]], + sample_id: Optional[Hashable], ): for member_blocks in blocks: - block_indices = {block.block_index for block in member_blocks.values()} - assert len(block_indices) == 1 - blocks_in_samples = {block.blocks_in_sample for block in member_blocks.values()} - assert len(blocks_in_samples) == 1 - + cons = _ConsolidatedMemberBlocks(member_blocks) yield SampleBlockMeta( blocks=dict(member_blocks), sample_shape=sample_shape, - block_index=block_indices.pop(), - blocks_in_sample=blocks_in_samples.pop(), + sample_id=sample_id, + block_index=cons.block_index, + blocks_in_sample=cons.blocks_in_sample, ) @@ -314,10 +336,7 @@ def sample_block_generator( pad_mode: PadMode, ): for member_blocks in blocks: - block_indices = {block.block_index for block in member_blocks.values()} - assert len(block_indices) == 1 - blocks_in_samples = {block.blocks_in_sample for block in member_blocks.values()} - assert len(blocks_in_samples) == 1 + cons = _ConsolidatedMemberBlocks(member_blocks) yield SampleBlockWithOrigin( blocks={ m: Block.from_sample_member( @@ -328,6 +347,7 @@ def sample_block_generator( sample_shape=origin.shape, origin=origin, stat=origin.stat, - block_index=block_indices.pop(), - blocks_in_sample=blocks_in_samples.pop(), + sample_id=origin.id, + block_index=cons.block_index, + blocks_in_sample=cons.blocks_in_sample, ) From 31d59289de4c489c8d7b4c5623669b4ec53d0b69 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 06:38:23 +0200 Subject: [PATCH 213/244] reimpl predict_sample_without_blocking independently of predict_sample_block --- bioimageio/core/_prediction_pipeline.py | 55 +++++++++++++++++++------ 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 17e6bf25..b9034d05 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -19,7 +19,7 @@ from ._op_base import BlockedOperator from .axis import AxisId, PerAxis -from .common import Halo, MemberId, PerMember +from .common import Halo, MemberId, PerMember, SampleId from .digest_spec import ( get_block_transform, get_input_halo, @@ -30,7 +30,7 @@ from .proc_ops import Processing from .proc_setup import setup_pre_and_postprocessing from .sample import Sample, SampleBlock, SampleBlockWithOrigin -from .stat_measures import DatasetMeasure, MeasureValue +from .stat_measures import DatasetMeasure, MeasureValue, Stat from .tensor import Tensor Predict_IO = TypeVar( @@ -73,7 +73,7 @@ def __init__( self.model_description = model_description if isinstance(model_description, v0_4.ModelDescr): self._default_input_halo: PerMember[PerAxis[Halo]] = {} - self._block_transform = {} + self._block_transform = None else: default_output_halo = { t.id: { @@ -110,6 +110,13 @@ def predict_sample_block( skip_preprocessing: bool = False, skip_postprocessing: bool = False, ) -> SampleBlock: + if isinstance(self.model_description, v0_4.ModelDescr): + raise NotImplementedError( + f"predict_sample_block not implemented for model {self.model_description.format_version}" + ) + else: + assert self._block_transform is not None + if not skip_preprocessing: self.apply_preprocessing(sample_block) @@ -120,7 +127,7 @@ def predict_sample_block( for tid, out in zip( self._output_ids, self._adapter.forward( - *(sample_block.members[t] for t in self._input_ids) + *(sample_block.members.get(t) for t in self._input_ids) ), ) if out is not None @@ -142,14 +149,35 @@ def predict_sample_without_blocking( The sample's tensor shapes have to match the model's input tensor description. If that is not the case, consider `predict_sample_with_blocking`""" - block = sample.as_single_block() - predicted_block = self.predict_sample_block( - block, - skip_preprocessing=skip_preprocessing, - skip_postprocessing=skip_postprocessing, + if not skip_preprocessing: + self.apply_preprocessing(sample) + + output = Sample( + members={ + out_id: out + for out_id, out in zip( + self._output_ids, + self._adapter.forward( + *(sample.members.get(in_id) for in_id in self._input_ids) + ), + ) + if out is not None + }, + stat=sample.stat, + id=self.get_output_sample_id(sample.id), ) - predicted_sample = Sample.from_blocks([predicted_block]) - return predicted_sample + if not skip_postprocessing: + self.apply_postprocessing(output) + + return output + + def get_output_sample_id(self, input_sample_id: SampleId): + if input_sample_id is None: + return None + else: + return f"{input_sample_id}_" + ( + self.model_description.id or self.model_description.name + ) def predict_sample_with_blocking( self, @@ -311,11 +339,12 @@ def create_prediction_pipeline( input_ids = get_member_ids(bioimageio_model.inputs) def dataset(): - for x in dataset_for_initial_statistics: + common_stat: Stat = {} + for i, x in enumerate(dataset_for_initial_statistics): if isinstance(x, Sample): yield x else: - yield Sample(members=dict(zip(input_ids, x))) + yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i) preprocessing, postprocessing = setup_pre_and_postprocessing( bioimageio_model, From ca61ca1fd2b8c32701023c82d02f8a2f0dc0e94e Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 06:39:16 +0200 Subject: [PATCH 214/244] update create_sample_for_model --- bioimageio/core/digest_spec.py | 35 ++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py index 75f1c5c1..9ee5873a 100644 --- a/bioimageio/core/digest_spec.py +++ b/bioimageio/core/digest_spec.py @@ -338,8 +338,11 @@ def get_io_sample_block_metas( def create_sample_for_model( model: AnyModelDescr, + *, stat: Optional[Stat] = None, - **inputs: NDArray[Any], + sample_id: SampleId = None, + inputs: Optional[PerMember[NDArray[Any]]] = None, # TODO: make non-optional + **kwargs: NDArray[Any], # TODO: deprecate in favor of `inputs` ) -> Sample: """Create a sample from a single set of input(s) for a specific bioimage.io model @@ -348,25 +351,25 @@ def create_sample_for_model( stat: dictionary with sample and dataset statistics (may be updated in-place!) inputs: the input(s) constituting a single sample. """ - if len(inputs) > len(model.inputs): - raise ValueError( - f"Got {len(inputs)} inputs, but expected at most {len(model.inputs)}" - ) + inputs = {MemberId(k): v for k, v in {**kwargs, **(inputs or {})}.items()} - missing_inputs = { - get_member_id(ipt) - for ipt in model.inputs - if str(get_member_id(ipt)) not in inputs - and not (isinstance(ipt, v0_5.InputTensorDescr) and ipt.optional) - } - if missing_inputs: - raise ValueError(f"Missing non-optional input tensors {missing_inputs}") + model_inputs = {get_member_id(d): d for d in model.inputs} + if unknown := {k for k in inputs if k not in model_inputs}: + raise ValueError(f"Got unexpected inputs: {unknown}") + + if missing := { + k + for k, v in model_inputs.items() + if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) + }: + raise ValueError(f"Missing non-optional model inputs: {missing}") return Sample( members={ - m: Tensor.from_numpy(inputs[str(m)], dims=get_axes_infos(ipt)) - for ipt in model.inputs - if str((m := get_member_id(ipt))) in inputs + m: Tensor.from_numpy(inputs[m], dims=get_axes_infos(ipt)) + for m, ipt in model_inputs.items() + if m in inputs }, stat={} if stat is None else stat, + id=sample_id, ) From 216910a1b68f90482ec0fcf8ae89c870cbb58fab Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 06:39:47 +0200 Subject: [PATCH 215/244] update load_sample_for_model --- bioimageio/core/io.py | 54 +++++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index 8ca8b02f..f053077e 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Optional, Sequence +from typing import Any, Dict, Optional, Sequence import imageio from loguru import logger @@ -9,7 +9,9 @@ from bioimageio.spec.utils import load_array from .axis import Axis, AxisLike -from .digest_spec import create_sample_for_model, get_axes_infos +from .common import MemberId, PerMember, SampleId +from .digest_spec import get_axes_infos, get_member_id +from .sample import Sample from .stat_measures import Stat from .tensor import Tensor @@ -35,26 +37,44 @@ def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor def load_sample_for_model( - *paths: Path, + *, model: AnyModelDescr, - axes: Optional[Sequence[Sequence[AxisLike]]] = None, + paths: PerMember[Path], + axes: Optional[PerMember[Sequence[AxisLike]]] = None, stat: Optional[Stat] = None, + sample_id: Optional[SampleId] = None, ): """load a single sample from `paths` that can be processed by `model`""" if axes is None: - axes = [get_axes_infos(ipt) for ipt in model.inputs[: len(paths)]] - logger.warning( - "loading paths with default input axes: {} (from {})", - axes, - model.id or model.name, - ) - elif len(axes) != len(paths): - raise ValueError(f"got {len(paths)} paths, but {len(axes)} axes hints!") - - arrays = [load_image(p, is_volume=True) for p in paths] - return create_sample_for_model( - arrays, - model, + axes = {} + + # make sure members are keyed by MemberId, not string + paths = {MemberId(k): v for k, v in paths.items()} + axes = {MemberId(k): v for k, v in axes.items()} + + model_inputs = {get_member_id(d): d for d in model.inputs} + + if unknown := {k for k in paths if k not in model_inputs}: + raise ValueError(f"Got unexpected paths for {unknown}") + + if unknown := {k for k in axes if k not in model_inputs}: + raise ValueError(f"Got unexpected axes hints for: {unknown}") + + members: Dict[MemberId, Tensor] = {} + for m, p in paths.items(): + if m not in axes: + axes[m] = get_axes_infos(model_inputs[m]) + logger.warning( + "loading paths with {}'s default input axes {} for input '{}'", + axes[m], + model.id or model.name, + m, + ) + members[m] = load_tensor(p, axes[m]) + + return Sample( + members=members, stat={} if stat is None else stat, + id=sample_id or tuple(sorted(paths.values())), ) From c637c8e883604668b0c6b4074cf6b8379a7a5c74 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 06:40:30 +0200 Subject: [PATCH 216/244] make stat and id non-optional --- bioimageio/core/digest_spec.py | 10 +++++++--- bioimageio/core/sample.py | 11 +++++------ tests/test_proc_ops.py | 34 +++++++++++++++++++--------------- tests/test_stat_calculators.py | 5 ++++- tests/test_stat_measures.py | 4 ++-- 5 files changed, 37 insertions(+), 27 deletions(-) diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py index 9ee5873a..b3a693f8 100644 --- a/bioimageio/core/digest_spec.py +++ b/bioimageio/core/digest_spec.py @@ -33,7 +33,7 @@ from .axis import AxisId, AxisInfo, PerAxis from .block_meta import split_multiple_shapes_into_blocks -from .common import Halo, MemberId, PerMember, TotalNumberOfBlocks +from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks from .sample import ( LinearSampleAxisTransform, Sample, @@ -162,7 +162,9 @@ def get_test_inputs(model: AnyModelDescr) -> Sample: members={ m: Tensor.from_numpy(arr, dims=ax) for m, arr, ax in zip(member_ids, arrays, axes) - } + }, + stat={}, + id="test-input", ) @@ -181,7 +183,9 @@ def get_test_outputs(model: AnyModelDescr) -> Sample: members={ m: Tensor.from_numpy(arr, dims=ax) for m, arr, ax in zip(member_ids, arrays, axes) - } + }, + stat={}, + id="test-output", ) diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index 4da3f6b6..8cec90fa 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -1,12 +1,11 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from math import ceil, floor from typing import ( Callable, Dict, Generic, - Hashable, Iterable, Optional, Tuple, @@ -49,10 +48,10 @@ class Sample: members: Dict[MemberId, Tensor] """the sample's tensors""" - stat: Stat = field(default_factory=dict) + stat: Stat """sample and dataset statistics""" - id: Optional[SampleId] = None + id: SampleId """identifier within the sample's dataset""" @property @@ -151,7 +150,7 @@ class SampleBlockBase(Generic[BlockT]): sample_shape: PerMember[PerAxis[int]] """the sample shape this block represents a part of""" - sample_id: Optional[Hashable] + sample_id: SampleId """identifier for the sample within its dataset""" blocks: Dict[MemberId, BlockT] @@ -316,7 +315,7 @@ def sample_block_meta_generator( blocks: Iterable[PerMember[BlockMeta]], *, sample_shape: PerMember[PerAxis[int]], - sample_id: Optional[Hashable], + sample_id: SampleId, ): for member_blocks in blocks: cons = _ConsolidatedMemberBlocks(member_blocks) diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py index ce8d04e4..7ef1a8fe 100644 --- a/tests/test_proc_ops.py +++ b/tests/test_proc_ops.py @@ -24,7 +24,7 @@ def test_scale_linear(tid: MemberId): offset = xr.DataArray([1, 2, 42], dims=("c")) gain = xr.DataArray([1, 2, 3], dims=("c")) data = xr.DataArray(np.arange(6).reshape((1, 2, 3)), dims=("x", "y", "c")) - sample = Sample(members={tid: Tensor.from_xarray(data)}) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) op = ScaleLinear(input=tid, output=tid, offset=offset, gain=gain) op(sample) @@ -38,7 +38,7 @@ def test_scale_linear_no_channel(tid: MemberId): op = ScaleLinear(tid, tid, offset=1, gain=2) data = xr.DataArray(np.arange(6).reshape(2, 3), dims=("x", "y")) - sample = Sample(members={tid: Tensor.from_xarray(data)}) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) op(sample) expected = xr.DataArray(np.array([[1, 3, 5], [7, 9, 11]]), dims=("x", "y")) @@ -57,7 +57,7 @@ def test_zero_mean_unit_variance(tid: MemberId): from bioimageio.core.proc_ops import ZeroMeanUnitVariance data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) - sample = Sample(members={tid: Tensor.from_xarray(data)}) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) m = SampleMean(tid) std = SampleStd(tid) op = ZeroMeanUnitVariance(tid, tid, m, std) @@ -100,7 +100,7 @@ def test_zero_mean_unit_variance_fixed(tid: MemberId): ), dims=("b", "c", "x"), ) - sample = Sample(members={tid: Tensor.from_xarray(data)}) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) op(sample) xr.testing.assert_allclose(expected, sample.members[tid].data) @@ -116,7 +116,7 @@ def test_zero_mean_unit_across_axes(tid: MemberId): SampleMean(tid, (AxisId("x"), AxisId("y"))), SampleStd(tid, (AxisId("x"), AxisId("y"))), ) - sample = Sample(members={tid: Tensor.from_xarray(data)}) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) sample.stat = compute_measures(op.required_measures, [sample]) expected = xr.concat( @@ -136,7 +136,7 @@ def test_zero_mean_unit_variance_fixed2(tid: MemberId): op = FixedZeroMeanUnitVariance(tid, tid, mean=mean, std=std, eps=eps) data = xr.DataArray(np_data, dims=("x", "y")) - sample = Sample(members={tid: Tensor.from_xarray(data)}) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y")) op(sample) xr.testing.assert_allclose(expected, sample.members[tid].data) @@ -147,7 +147,7 @@ def test_binarize(tid: MemberId): op = Binarize(tid, tid, threshold=14) data = xr.DataArray(np.arange(30).reshape((2, 3, 5)), dims=("x", "y", "c")) - sample = Sample(members={tid: Tensor.from_xarray(data)}) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) expected = xr.zeros_like(data) expected[{"x": slice(1, None)}] = 1 op(sample) @@ -165,7 +165,7 @@ def test_binarize2(tid: MemberId): threshold = 0.5 exp = xr.DataArray(np_data > threshold, dims=axes) - sample = Sample(members={tid: Tensor.from_xarray(data)}) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) binarize = Binarize(tid, tid, threshold=threshold) binarize(sample) xr.testing.assert_allclose(exp, sample.members[tid].data) @@ -176,7 +176,7 @@ def test_clip(tid: MemberId): op = Clip(tid, tid, min=3, max=5) data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) - sample = Sample(members={tid: Tensor.from_xarray(data)}) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) expected = xr.DataArray( np.array([[3, 3, 3], [3, 4, 5], [5, 5, 5]]), dims=("x", "y") @@ -189,7 +189,7 @@ def test_combination_of_op_steps_with_dims_specified(tid: MemberId): from bioimageio.core.proc_ops import ZeroMeanUnitVariance data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) - sample = Sample(members={tid: Tensor.from_xarray(data)}) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) op = ZeroMeanUnitVariance( tid, tid, @@ -249,7 +249,9 @@ def test_scale_mean_variance(tid: MemberId, axes: Optional[Tuple[AxisId, ...]]): members={ tid: Tensor.from_xarray(ipt_data), MemberId("ref_name"): Tensor.from_xarray(ref_data), - } + }, + stat={}, + id=None, ) sample.stat = compute_measures(op.required_measures, [sample]) op(sample) @@ -279,7 +281,9 @@ def test_scale_mean_variance_per_channel(tid: MemberId, axes_str: Optional[str]) members={ tid: Tensor.from_xarray(ipt_data), MemberId("ref_name"): Tensor.from_xarray(ref_data), - } + }, + stat={}, + id=None, ) sample.stat = compute_measures(op.required_measures, [sample]) op(sample) @@ -299,7 +303,7 @@ def test_scale_range(tid: MemberId): op = ScaleRange(tid, tid) np_data = np.arange(9).reshape(3, 3).astype("float32") data = xr.DataArray(np_data, dims=("x", "y")) - sample = Sample(members={tid: Tensor.from_xarray(data)}) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) sample.stat = compute_measures(op.required_measures, [sample]) eps = 1.0e-6 @@ -323,7 +327,7 @@ def test_scale_range_axes(tid: MemberId): np_data = np.arange(18).reshape((2, 3, 3)).astype("float32") data = Tensor.from_xarray(xr.DataArray(np_data, dims=("c", "x", "y"))) - sample = Sample(members={tid: data}) + sample = Sample(members={tid: data}, stat={}, id=None) p_low_direct = lower_quantile.compute(sample) p_up_direct = upper_quantile.compute(sample) @@ -354,7 +358,7 @@ def test_sigmoid(tid: MemberId): axes = ("c", "y", "x") np_data = np.random.rand(*shape) data = xr.DataArray(np_data, dims=axes) - sample = Sample(members={tid: Tensor.from_xarray(data)}) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) sigmoid = Sigmoid(tid, tid) sigmoid(sample) diff --git a/tests/test_stat_calculators.py b/tests/test_stat_calculators.py index 0a642168..115b8556 100644 --- a/tests/test_stat_calculators.py +++ b/tests/test_stat_calculators.py @@ -20,7 +20,10 @@ def create_random_dataset(tid: MemberId, axes: Tuple[AxisId, ...]): n = 3 sizes = list(range(n, len(axes) + n)) data = np.asarray(np.random.rand(*sizes)) - ds = [Sample(members={tid: Tensor(data[i : i + 1], dims=axes)}) for i in range(n)] + ds = [ + Sample(members={tid: Tensor(data[i : i + 1], dims=axes)}, stat={}, id=None) + for i in range(n) + ] return Tensor(data, dims=axes), ds diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py index 54cca0de..49c87609 100644 --- a/tests/test_stat_measures.py +++ b/tests/test_stat_measures.py @@ -37,7 +37,7 @@ def test_individual_normal_measure( ) expected = getattr(data, name)(dim=axes) - sample = Sample(members={data_id: data}) + sample = Sample(members={data_id: data}, stat={}, id=None) actual = measure.compute(sample) xr.testing.assert_allclose(expected.data, actual.data) @@ -56,7 +56,7 @@ def test_individual_percentile_measure(axes: Optional[Tuple[AxisId, ...]]): data = Tensor( np.random.random((5, 6, 3)), dims=(AxisId("x"), AxisId("y"), AxisId("c")) ) - actual = calc.compute(Sample(members={tid: data})) + actual = calc.compute(Sample(members={tid: data}, stat={}, id=None)) for m in measures: expected = data.quantile(q=m.q, dim=m.axes) actual_data = actual[m] From f88f2938ca468c3f27f970369d26cb482996911e Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 06:40:52 +0200 Subject: [PATCH 217/244] add TODO --- bioimageio/core/model_adapters/_model_adapter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py index 633ee342..1d3c2b95 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -143,6 +143,7 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: """ Run forward pass of model to get model predictions """ + # TODO: handle tensor.transpose in here and make _forward_impl the abstract impl @abstractmethod def unload(self): From 9cd039e09d3dc41f0948d60b7863aa3b25f37148 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 09:41:03 +0200 Subject: [PATCH 218/244] clean up conftest.py --- tests/conftest.py | 236 +++++++++++++++++++++++++--------------------- 1 file changed, 127 insertions(+), 109 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5e54b17e..c4fa5ff7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import subprocess import warnings +from itertools import chain from typing import Dict, List from loguru import logger @@ -39,90 +40,124 @@ keras = None skip_tensorflow = tensorflow is None -skip_tensorflow_js = True # TODO: add a tensorflow_js example model warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}") +# TODO: use models from new collection on S3 +MODEL_SOURCES: Dict[str, str] = { + "hpa_densenet": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/hpa-densenet/rdf.yaml" + ), + "stardist": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models" + "/stardist_example_model/v0_4.bioimageio.yaml" + ), + "shape_change": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "upsample_test_model/v0_4.bioimageio.yaml" + ), + "stardist_wrong_shape": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "stardist_example_model/rdf_wrong_shape.yaml" + ), + "stardist_wrong_shape2": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "stardist_example_model/rdf_wrong_shape2_v0_4.yaml" + ), + "unet2d_diff_output_shape": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_diff_output_shape/v0_4.bioimageio.yaml" + ), + "unet2d_expand_output_shape": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_nuclei_broad/expand_output_shape_v0_4.bioimageio.yaml" + ), + "unet2d_fixed_shape": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_fixed_shape/v0_4.bioimageio.yaml" + ), + "unet2d_keras_tf2": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_keras_tf2/v0_4.bioimageio.yaml" + ), + "unet2d_keras": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_keras_tf/v0_4.bioimageio.yaml" + ), + "unet2d_multi_tensor": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_multi_tensor/v0_4.bioimageio.yaml" + ), + "unet2d_nuclei_broad_model": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_nuclei_broad/bioimageio.yaml" + ), +} + # test models for various frameworks -TORCH_MODELS = [ - "unet2d_fixed_shape", - "unet2d_multi_tensor", - "unet2d_nuclei_broad_model", - "unet2d_diff_output_shape", - "shape_change", -] -TORCHSCRIPT_MODELS = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"] -ONNX_MODELS = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"] -TENSORFLOW1_MODELS = ["stardist"] -TENSORFLOW2_MODELS = ["unet2d_keras_tf2"] -KERAS_TF1_MODELS = ["unet2d_keras"] -KERAS_TF2_MODELS = ["unet2d_keras_tf2"] -TENSORFLOW_JS_MODELS: List[str] = [] - - -MODEL_SOURCES: Dict[str, str] = {} -if keras is not None: - MODEL_SOURCES.update( - { - "unet2d_keras": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_keras_tf/v0_4.bioimageio.yaml" - ), - "unet2d_keras_tf2": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_keras_tf2/v0_4.bioimageio.yaml" - ), - } - ) -if torch is not None: - MODEL_SOURCES.update( - { - "unet2d_nuclei_broad_model": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_nuclei_broad/bioimageio.yaml" - ), - "unet2d_expand_output_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_nuclei_broad/expand_output_shape_v0_4.bioimageio.yaml" - ), - "unet2d_fixed_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_fixed_shape/v0_4.bioimageio.yaml" - ), - "unet2d_multi_tensor": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_multi_tensor/v0_4.bioimageio.yaml" - ), - "unet2d_diff_output_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "unet2d_diff_output_shape/v0_4.bioimageio.yaml" - ), - "shape_change": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "upsample_test_model/v0_4.bioimageio.yaml" - ), - } - ) -if tensorflow is not None: - MODEL_SOURCES.update( - { - "hpa_densenet": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/hpa-densenet/rdf.yaml" - ), - "stardist": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models" - "/stardist_example_model/v0_4.bioimageio.yaml" - ), - "stardist_wrong_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "stardist_example_model/rdf_wrong_shape.yaml" - ), - "stardist_wrong_shape2": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" - "stardist_example_model/rdf_wrong_shape2_v0_4.yaml" - ), - } +TORCH_MODELS = ( + [] + if torch is None + else [ + "shape_change", + "unet2d_diff_output_shape", + "unet2d_expand_output_shape", + "unet2d_fixed_shape", + "unet2d_multi_tensor", + "unet2d_nuclei_broad_model", + ] +) +TORCHSCRIPT_MODELS = ( + [] + if torch is None + else [ + "unet2d_multi_tensor", + "unet2d_nuclei_broad_model", + ] +) +ONNX_MODELS = ( + [] + if onnxruntime is None + else [ + "hpa_densenet", + "unet2d_multi_tensor", + "unet2d_nuclei_broad_model", + ] +) +TENSORFLOW_MODELS = ( + [] + if tensorflow is None + else ( + [ + "hpa_densenet", + "stardist", + ] + if tf_major_version == 1 + else [ + "unet2d_keras_tf2", + ] ) +) +KERAS_MODELS = ( + [] + if keras is None + else ["unet2d_keras"] if tf_major_version == 1 else ["unet2d_keras_tf2"] +) +TENSORFLOW_JS_MODELS: List[str] = [] # TODO: add a tensorflow_js example model + +ALL_MODELS = sorted( + { + m + for m in chain( + TORCH_MODELS, + TORCHSCRIPT_MODELS, + ONNX_MODELS, + TENSORFLOW_MODELS, + KERAS_MODELS, + TENSORFLOW_JS_MODELS, + ) + } +) @fixture(scope="session") @@ -145,52 +180,39 @@ def mamba_cmd(): # -@fixture(params=[] if skip_torch else TORCH_MODELS) +@fixture(params=TORCH_MODELS) def any_torch_model(request: FixtureRequest): return MODEL_SOURCES[request.param] -@fixture(params=[] if skip_torch else TORCHSCRIPT_MODELS) +@fixture(params=TORCHSCRIPT_MODELS) def any_torchscript_model(request: FixtureRequest): return MODEL_SOURCES[request.param] -@fixture(params=[] if skip_onnx else ONNX_MODELS) +@fixture(params=ONNX_MODELS) def any_onnx_model(request: FixtureRequest): return MODEL_SOURCES[request.param] -@fixture( - params=( - [] - if skip_tensorflow - else TENSORFLOW1_MODELS if tf_major_version == 1 else TENSORFLOW2_MODELS - ) -) +@fixture(params=TENSORFLOW_MODELS) def any_tensorflow_model(request: FixtureRequest): return MODEL_SOURCES[request.param] -@fixture( - params=( - [] - if skip_tensorflow - else KERAS_TF1_MODELS if tf_major_version == 1 else KERAS_TF2_MODELS - ) -) +@fixture(params=KERAS_MODELS) def any_keras_model(request: FixtureRequest): return MODEL_SOURCES[request.param] -@fixture(params=[] if skip_tensorflow_js else TENSORFLOW_JS_MODELS) +@fixture(params=TENSORFLOW_JS_MODELS) def any_tensorflow_js_model(request: FixtureRequest): return MODEL_SOURCES[request.param] # fixture to test with all models that should run in the current environment -# we exclude stardist_wrong_shape here because it is not a valid model -# and included only to test that validation for this model fails -@fixture(params=set(MODEL_SOURCES) - {"stardist_wrong_shape", "stardist_wrong_shape2"}) +# we exclude any 'wrong' model here +@fixture(params=sorted({m for m in ALL_MODELS if "wrong" not in m})) def any_model(request: FixtureRequest): return MODEL_SOURCES[request.param] @@ -223,8 +245,8 @@ def convert_to_onnx(request: FixtureRequest): @fixture( params=( [] - if skip_tensorflow - else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"] + if tf_major_version is None + else ["unet2d_keras"] if tf_major_version == 1 else ["unet2d_keras_tf2"] ) ) def unet2d_keras(request: FixtureRequest): @@ -262,22 +284,18 @@ def shape_change_model(request: FixtureRequest): # written as model group to automatically skip on missing tensorflow 1 -@fixture( - params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"] -) +@fixture(params=["stardist_wrong_shape"] if tf_major_version == 1 else []) def stardist_wrong_shape(request: FixtureRequest): return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing tensorflow 1 -@fixture( - params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape2"] -) +@fixture(params=["stardist_wrong_shape2"] if tf_major_version == 1 else []) def stardist_wrong_shape2(request: FixtureRequest): return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing tensorflow 1 -@fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"]) +@fixture(params=["stardist"] if tf_major_version == 1 else []) def stardist(request: FixtureRequest): return MODEL_SOURCES[request.param] From 60a1c1b7d552849193f860bd0b05111506bac538 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 09:41:22 +0200 Subject: [PATCH 219/244] fix _test_model_inference_parametrized --- bioimageio/core/_resource_tests.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index 54588435..c84f0027 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -231,7 +231,9 @@ def get_ns(n: int): }, ) for t in model.inputs - } + }, + stat=test_inputs.stat, + id=test_inputs.id, ) expected_output_shapes = { t.id: { From 55519c9d66e071ed394d5958c8a057e039f4e1b5 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 10:56:34 +0200 Subject: [PATCH 220/244] fix ONNXModelAdapter --- bioimageio/core/model_adapters/_onnx_model_adapter.py | 5 ++++- bioimageio/core/proc_ops.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py index e42b8912..de1b3b76 100644 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -5,6 +5,7 @@ from bioimageio.core.tensor import Tensor from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download from ._model_adapter import ModelAdapter @@ -36,7 +37,9 @@ def __init__( if model_description.weights.onnx is None: raise ValueError("No ONNX weights specified for {model_description.name}") - self._session = rt.InferenceSession(str(model_description.weights.onnx.source)) + self._session = rt.InferenceSession( + str(download(model_description.weights.onnx.source).path) + ) onnx_inputs = self._session.get_inputs() # type: ignore self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] # type: ignore diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 3d9afe3a..b05d7c8b 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -308,7 +308,9 @@ def from_proc_descr( gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) else: - assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 + assert ( + isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 + ), kwargs.gain gain = ( kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] ) From 1b65ca8bafdb86eb5425aa79911bb2aebae5036f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 12 Apr 2024 22:46:14 +0200 Subject: [PATCH 221/244] update onnx model adapter --- bioimageio/core/model_adapters/_onnx_model_adapter.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py index de1b3b76..e7bdfc05 100644 --- a/bioimageio/core/model_adapters/_onnx_model_adapter.py +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -3,10 +3,11 @@ from numpy.typing import NDArray -from bioimageio.core.tensor import Tensor from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download +from ..digest_spec import get_axes_infos +from ..tensor import Tensor from ._model_adapter import ModelAdapter try: @@ -27,11 +28,7 @@ def __init__( super().__init__() self._internal_output_axes = [ - ( - tuple(out.axes) - if isinstance(out.axes, str) - else tuple(a.id for a in out.axes) - ) + tuple(a.id for a in get_axes_infos(out)) for out in model_description.outputs ] if model_description.weights.onnx is None: @@ -50,7 +47,7 @@ def __init__( def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: assert len(input_tensors) == len(self._input_names) - input_arrays = [None if ipt is None else ipt.data for ipt in input_tensors] + input_arrays = [None if ipt is None else ipt.data.data for ipt in input_tensors] result: Union[Sequence[Optional[NDArray[Any]]], Optional[NDArray[Any]]] result = self._session.run( # pyright: ignore[reportUnknownVariableType] None, dict(zip(self._input_names, input_arrays)) From 8bc43f77616f96d0504e85c0e99425453d3dc2da Mon Sep 17 00:00:00 2001 From: fynnbe Date: Sat, 13 Apr 2024 22:32:50 +0200 Subject: [PATCH 222/244] Frozen = MappingProxyType --- bioimageio/core/common.py | 43 +++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 4e4c493a..78a85886 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,9 +1,8 @@ from __future__ import annotations -from copy import deepcopy +from types import MappingProxyType from typing import ( Hashable, - Iterator, Literal, Mapping, NamedTuple, @@ -99,30 +98,30 @@ class SliceInfo(NamedTuple): K = TypeVar("K", bound=Hashable) V = TypeVar("V") +Frozen = MappingProxyType +# class Frozen(Mapping[K, V]): # adapted from xarray.core.utils.Frozen +# """Wrapper around an object implementing the mapping interface to make it +# immutable.""" -class Frozen(Mapping[K, V]): # adapted from xarray.core.utils.Frozen - """Wrapper around an object implementing the mapping interface to make it - immutable.""" +# __slots__ = ("mapping",) - __slots__ = ("mapping",) +# def __init__(self, mapping: Mapping[K, V]): +# super().__init__() +# self.mapping = deepcopy( +# mapping +# ) # added deepcopy (compared to xarray.core.utils.Frozen) - def __init__(self, mapping: Mapping[K, V]): - super().__init__() - self.mapping = deepcopy( - mapping - ) # added deepcopy (compared to xarray.core.utils.Frozen) +# def __getitem__(self, key: K) -> V: +# return self.mapping[key] - def __getitem__(self, key: K) -> V: - return self.mapping[key] +# def __iter__(self) -> Iterator[K]: +# return iter(self.mapping) - def __iter__(self) -> Iterator[K]: - return iter(self.mapping) +# def __len__(self) -> int: +# return len(self.mapping) - def __len__(self) -> int: - return len(self.mapping) +# def __contains__(self, key: object) -> bool: +# return key in self.mapping - def __contains__(self, key: object) -> bool: - return key in self.mapping - - def __repr__(self) -> str: - return f"{type(self).__name__}({self.mapping!r})" +# def __repr__(self) -> str: +# return f"{type(self).__name__}({self.mapping!r})" From f0f0a94022ea4ad570c9124c647e3027d9c94b43 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 15 Apr 2024 21:57:36 +0200 Subject: [PATCH 223/244] use pytest-xdist --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e4356c86..083aaf84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ typeCheckingMode = "strict" useLibraryCodeForTypes = true [tool.pytest.ini_options] -addopts = " -n 0 --capture=no --doctest-modules --failed-first" +addopts = " -n auto --capture=no --doctest-modules --failed-first" [tool.ruff] line-length = 88 From bc95dfe3d34bbe683c2ae7ba0e8e33a9210a1ce0 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 15 Apr 2024 23:32:55 +0200 Subject: [PATCH 224/244] fix CLI tests --- tests/test_cli.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 5844b115..6b9622f2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,4 @@ import subprocess -from pathlib import Path from typing import Any, List, Sequence import pytest @@ -38,7 +37,6 @@ def run_subprocess( ], ) def test_cli(args: List[str], unet2d_nuclei_broad_model: str): - assert Path(unet2d_nuclei_broad_model).exists() resolved_args = [ str(unet2d_nuclei_broad_model) if arg == "unet2d_nuclei_broad_model" else arg for arg in args From a1ce637048549e07102f7effad3c02e22f9c19cf Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 16 Apr 2024 17:27:03 +0200 Subject: [PATCH 225/244] add 'weight-format' option alias --- bioimageio/core/__main__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index 4d769cd4..022f9576 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -65,7 +65,8 @@ def package( Optional[List[WeightsFormatEnum]], typer.Option( "--weights-priority-order", - "-wpo", + "--weight-format", + "-w", help="For model packages only. " + "If given, only the first matching weights entry is included. " + "Defaults to including all weights present in source.", @@ -73,7 +74,7 @@ def package( ), ] = None, ): - # typer bug: typer returns empty tuple instead of None if weights_order_priority is not given + # typer bug: typer returns empty tuple instead of None if weights_priority_order is not given weights_priority_order = ( weights_priority_order or None ) # TODO: check if this is still the case From 1f27b5065a50f1911be6b944e9c648b87e52e601 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 16 Apr 2024 21:22:47 +0200 Subject: [PATCH 226/244] avoid ruff and py3.12 issues --- dev/env-tf.yaml | 2 +- dev/env-wo-python.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml index 304c0193..53a18ae0 100644 --- a/dev/env-tf.yaml +++ b/dev/env-tf.yaml @@ -27,7 +27,7 @@ dependencies: - python-dotenv # - python=3.9 # removed # - pytorch>=2.1 # removed - - ruff + # - ruff # removed - ruyaml - tensorflow>=2.15 # added # - torchvision # removed diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml index 8816ea48..9717bd14 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-wo-python.yaml @@ -27,7 +27,7 @@ dependencies: - python-dotenv # - python=3.9 # removed - pytorch>=2.1 - - ruff + # - ruff # removed - ruyaml - torchvision - tqdm From 63ba4fea6a64b51051e8209f464236890932049d Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 16 Apr 2024 23:19:23 +0200 Subject: [PATCH 227/244] typer -> fire and simplify cli a bit --- .github/workflows/build.yaml | 2 +- bioimageio/core/__main__.py | 219 +++++++++++------------------------ dev/env-py38.yaml | 2 +- dev/env-tf.yaml | 2 +- dev/env-wo-python.yaml | 2 +- dev/env.yaml | 2 +- setup.py | 2 +- tests/test_cli.py | 6 +- 8 files changed, 76 insertions(+), 161 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index d5dea6f1..e797e1e3 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -89,7 +89,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.8', '3.12'] + python-version: ['3.9', '3.11'] steps: - uses: actions/checkout@v4 - name: Install Conda environment with Micromamba diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index 022f9576..d1b41975 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -1,164 +1,80 @@ -import enum import sys from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Union -import typer -from typing_extensions import Annotated +import fire -from bioimageio.core import __version__ -from bioimageio.core import test_description as _test_description -from bioimageio.core import test_model as _test_model +from bioimageio.core import __version__, test_description from bioimageio.spec import save_bioimageio_package from bioimageio.spec.collection import CollectionDescr from bioimageio.spec.dataset import DatasetDescr from bioimageio.spec.model import ModelDescr +from bioimageio.spec.model.v0_5 import WeightsFormat from bioimageio.spec.notebook import NotebookDescr -help_version = f"""bioimageio.core {__version__} -bioimageio.spec {__version__} -implementing: -\tcollection RDF {CollectionDescr.implemented_format_version} -\tdataset RDF {DatasetDescr.implemented_format_version} -\tmodel RDF {ModelDescr.implemented_format_version} -\tnotebook RDF {NotebookDescr.implemented_format_version}""" - - -# prevent rewrapping with \b\n: https://click.palletsprojects.com/en/7.x/documentation/#preventing-rewrapping -app = typer.Typer( - help="\b\n" + help_version, - context_settings={ - "help_option_names": ["-h", "--help", "--version"] - }, # make --version display help with version -) # https://typer.tiangolo.com/ - - -@app.callback() -def callback(): - typer.echo(help_version) - - -# if we want to use something like "choice" for the weight formats, we need to use an enum, see: -# https://github.com/tiangolo/typer/issues/182 - - -class WeightsFormatEnum(enum.Enum): - keras_hdf5 = "keras_hdf5" - onnx = "onnx" - pytorch_state_dict = "pytorch_state_dict" - tensorflow_js = "tensorflow_js" - tensorflow_saved_model_bundle = "tensorflow_saved_model_bundle" - torchscript = "torchscript" - - -# Enum with int values does not work with click.Choice: https://github.com/pallets/click/issues/784 -# so a simple Enum with auto int values is not an option. - - -@app.command() -def package( - source: Annotated[str, typer.Argument(help="path or url to a bioimageio RDF")], - path: Annotated[Path, typer.Argument(help="Save package as")] = Path( - "bioimageio-package.zip" - ), - weights_priority_order: Annotated[ - Optional[List[WeightsFormatEnum]], - typer.Option( - "--weights-priority-order", - "--weight-format", - "-w", - help="For model packages only. " - + "If given, only the first matching weights entry is included. " - + "Defaults to including all weights present in source.", - show_default=False, - ), - ] = None, -): - # typer bug: typer returns empty tuple instead of None if weights_priority_order is not given - weights_priority_order = ( - weights_priority_order or None - ) # TODO: check if this is still the case - - _ = save_bioimageio_package( - source, - output_path=path, - weights_priority_order=( - None - if weights_priority_order is None - else [wpo.name for wpo in weights_priority_order] - ), - ) - - -@app.command() -def test_model( - model_rdf: Annotated[ - str, - typer.Argument( - help="Path or URL to the model resource description file (rdf.yaml) or zipped model." - ), - ], - weight_format: Annotated[ - Optional[WeightsFormatEnum], typer.Option(help="The weight format to use.") - ] = None, - devices: Annotated[ - Optional[List[str]], typer.Option(help="Devices for running the model.") - ] = None, - decimal: Annotated[int, typer.Option(help="The test precision.")] = 4, -): - # this is a weird typer bug: default devices are empty tuple although they should be None - devices = devices or None - - summary = _test_model( - model_rdf, - weight_format=None if weight_format is None else weight_format.value, - devices=devices, - decimal=decimal, - ) - print(f"\ntesting model {model_rdf}...") - print(summary.format()) - sys.exit(0 if summary.status == "passed" else 1) - - -test_model.__doc__ = _test_model.__doc__ - - -@app.command() -def test_resource( - rdf: Annotated[ - str, - typer.Argument( - help="Path or URL to the resource description file (rdf.yaml) or zipped resource package." - ), - ], - weight_format: Annotated[ - Optional[WeightsFormatEnum], - typer.Option(help="(for model only) The weight format to use."), - ] = None, - devices: Annotated[ - Optional[List[str]], - typer.Option(help="(for model only) Devices for running the model."), - ] = None, - decimal: Annotated[ - int, typer.Option(help="(for model only) The test precision.") - ] = 4, -): - # this is a weird typer bug: default devices are empty tuple although they should be None - if devices is None or len(devices) == 0: - devices = None - - summary = _test_description( - rdf, - weight_format=None if weight_format is None else weight_format.value, - devices=devices, - decimal=decimal, - ) - print(summary.format()) - sys.exit(0 if summary.status == "passed" else 1) - - -test_resource.__doc__ = _test_description.__doc__ +class Bioimageio: + def package( + self, + source: str, + path: Path = Path("bioimageio-package.zip"), + weight_format: Optional[WeightsFormat] = None, + ): + """Package a bioimageio resource as a zip file + + Args: + source: RDF source e.g. `bioimageio.yaml` or `http://example.com/rdf.yaml` + path: output path + weight-format: include only this single weight-format + """ + _ = save_bioimageio_package( + source, + output_path=path, + weights_priority_order=None if weight_format is None else (weight_format,), + ) + + def test( + self, + source: str, + weight_format: Optional[WeightsFormat] = None, + *, + devices: Optional[Union[str, List[str]]] = None, + decimal: int = 4, + ): + """test a bioimageio resource + + Args: + source: Path or URL to the bioimageio resource description file + (bioimageio.yaml or rdf.yaml) or to a zipped resource + weight_format: (model only) The weight format to use + devices: Device(s) to use for testing + decimal: Precision for numerical comparisons + """ + summary = test_description( + source, + weight_format=None if weight_format is None else weight_format, + devices=[devices] if isinstance(devices, str) else devices, + decimal=decimal, + ) + print(f"\ntesting model {source}...") + print(summary.format()) + sys.exit(0 if summary.status == "passed" else 1) + + +Bioimageio.__doc__ = f""" +work with resources shared on bioimage.io + +library versions: + bioimageio.core {__version__} + bioimageio.spec {__version__} + +spec format versions: + model RDF {ModelDescr.implemented_format_version} + dataset RDF {DatasetDescr.implemented_format_version} + notebook RDF {NotebookDescr.implemented_format_version} + collection RDF {CollectionDescr.implemented_format_version} + +""" # TODO: add predict commands # @app.command() @@ -302,6 +218,5 @@ def test_resource( # keras_converter.convert_weights_to_tensorflow_saved_model_bundle.__doc__ # ) - if __name__ == "__main__": - app() + fire.Fire(Bioimageio, name="bioimageio") diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml index 726ce341..c14ef880 100644 --- a/dev/env-py38.yaml +++ b/dev/env-py38.yaml @@ -8,6 +8,7 @@ dependencies: - black - crick # uncommented - filelock + - fire - imageio>=2.5 - jupyter - jupyter-black @@ -31,7 +32,6 @@ dependencies: - ruyaml - torchvision - tqdm - - typer - typing-extensions - xarray - pip: diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml index 53a18ae0..e371b558 100644 --- a/dev/env-tf.yaml +++ b/dev/env-tf.yaml @@ -8,6 +8,7 @@ dependencies: - black # - crick # currently requires python<=3.9 - filelock + - fire - imageio>=2.5 - jupyter - jupyter-black @@ -32,7 +33,6 @@ dependencies: - tensorflow>=2.15 # added # - torchvision # removed - tqdm - - typer - typing-extensions - xarray - pip: diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml index 9717bd14..21cb85ca 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-wo-python.yaml @@ -8,6 +8,7 @@ dependencies: - black # - crick # currently requires python<=3.9 - filelock + - fire - imageio>=2.5 - jupyter - jupyter-black @@ -31,7 +32,6 @@ dependencies: - ruyaml - torchvision - tqdm - - typer - typing-extensions - xarray - pip: diff --git a/dev/env.yaml b/dev/env.yaml index 0aa1660e..0624979f 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -7,6 +7,7 @@ dependencies: - black # - crick # currently requires python<=3.9 - filelock + - fire - imageio>=2.5 - jupyter - jupyter-black @@ -30,7 +31,6 @@ dependencies: - ruyaml - torchvision - tqdm - - typer - typing-extensions - xarray - pip: diff --git a/setup.py b/setup.py index d4a5047a..c0110480 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ packages=find_namespace_packages(exclude=["tests"]), install_requires=[ "bioimageio.spec==0.5.1.*", + "fire", "imageio>=2.5", "loguru", "numpy", @@ -38,7 +39,6 @@ "python-dotenv", "ruyaml", "tqdm", - "typer", "typing-extensions", "xarray", ], diff --git a/tests/test_cli.py b/tests/test_cli.py index 6b9622f2..b9a8246f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -28,12 +28,12 @@ def run_subprocess( ], ["package", "unet2d_nuclei_broad_model"], [ - "test-model", + "test", "unet2d_nuclei_broad_model", "--weight-format", "pytorch_state_dict", ], - ["test-model", "unet2d_nuclei_broad_model"], + ["test", "unet2d_nuclei_broad_model"], ], ) def test_cli(args: List[str], unet2d_nuclei_broad_model: str): @@ -45,7 +45,7 @@ def test_cli(args: List[str], unet2d_nuclei_broad_model: str): assert ret.returncode == 0, ret.stdout -@pytest.mark.parametrize("args", [["test-model", "stardist_wrong_shape"]]) +@pytest.mark.parametrize("args", [["test", "stardist_wrong_shape"]]) def test_cli_fails(args: List[str], stardist_wrong_shape: FilePath): resolved_args = [ str(stardist_wrong_shape) if arg == "stardist_wrong_shape" else arg From 5e9d37a7075a214a818d3ec26981ffb5babd9f72 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 16 Apr 2024 23:19:54 +0200 Subject: [PATCH 228/244] add test_save_bioimageio_package --- tests/test_package.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 tests/test_package.py diff --git a/tests/test_package.py b/tests/test_package.py new file mode 100644 index 00000000..7f4b2093 --- /dev/null +++ b/tests/test_package.py @@ -0,0 +1,11 @@ +def test_save_bioimageio_package(unet2d_nuclei_broad_model: str): + from bioimageio.spec._package import save_bioimageio_package + + _ = save_bioimageio_package( + unet2d_nuclei_broad_model, + weights_priority_order=( + None + if weights_priority_order is None + else [wpo.name for wpo in weights_priority_order] + ), + ) From c09cc00a208051163a4fc3784ba980622f46b1ce Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 16 Apr 2024 23:28:38 +0200 Subject: [PATCH 229/244] add rich to conda envs --- dev/env-py38.yaml | 1 + dev/env-tf.yaml | 1 + dev/env-wo-python.yaml | 1 + dev/env.yaml | 1 + 4 files changed, 4 insertions(+) diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml index c14ef880..decc164b 100644 --- a/dev/env-py38.yaml +++ b/dev/env-py38.yaml @@ -28,6 +28,7 @@ dependencies: - python-dotenv - python=3.8 # changed - pytorch>=2.1 + - rich - ruff - ruyaml - torchvision diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml index e371b558..4f954acd 100644 --- a/dev/env-tf.yaml +++ b/dev/env-tf.yaml @@ -28,6 +28,7 @@ dependencies: - python-dotenv # - python=3.9 # removed # - pytorch>=2.1 # removed + - rich # - ruff # removed - ruyaml - tensorflow>=2.15 # added diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml index 21cb85ca..185cd85c 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-wo-python.yaml @@ -28,6 +28,7 @@ dependencies: - python-dotenv # - python=3.9 # removed - pytorch>=2.1 + - rich # - ruff # removed - ruyaml - torchvision diff --git a/dev/env.yaml b/dev/env.yaml index 0624979f..d2061c98 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -27,6 +27,7 @@ dependencies: - python-dotenv - python=3.9 - pytorch>=2.1 + - rich - ruff - ruyaml - torchvision From 8139d27dfbad4de371b9dfd22bfb023393b7c405 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 16 Apr 2024 23:36:58 +0200 Subject: [PATCH 230/244] lower keras pins for tf envs --- dev/env-tf.yaml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml index 4f954acd..1cdd82a3 100644 --- a/dev/env-tf.yaml +++ b/dev/env-tf.yaml @@ -12,7 +12,7 @@ dependencies: - imageio>=2.5 - jupyter - jupyter-black - - keras>=3.0 + - keras>=2.15 - loguru - numpy - onnxruntime diff --git a/setup.py b/setup.py index c0110480..8678689b 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ include_package_data=True, extras_require={ "pytorch": ["torch>=1.6", "torchvision", "keras>=3.0"], - "tensorflow": ["tensorflow", "keras>=3.0"], + "tensorflow": ["tensorflow", "keras>=2.15"], "onnx": ["onnxruntime"], "dev": [ "black", From e0a73a4da6c94d79ecb7860840540578a01d29a6 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 16 Apr 2024 23:43:42 +0200 Subject: [PATCH 231/244] update entry_points --- bioimageio/core/__main__.py | 7 ++++++- setup.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index d1b41975..6c944725 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -218,5 +218,10 @@ def test( # keras_converter.convert_weights_to_tensorflow_saved_model_bundle.__doc__ # ) -if __name__ == "__main__": + +def main(): fire.Fire(Bioimageio, name="bioimageio") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 8678689b..d720bb05 100644 --- a/setup.py +++ b/setup.py @@ -69,5 +69,5 @@ "Bug Reports": "https://github.com/bioimage-io/core-bioimage-io-python/issues", "Source": "https://github.com/bioimage-io/core-bioimage-io-python", }, - entry_points={"console_scripts": ["bioimageio = bioimageio.core.__main__:app"]}, + entry_points={"console_scripts": ["bioimageio = bioimageio.core.__main__:main"]}, ) From 7dfde5c4e71327295f8550f308a687e27abd7571 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 22 Apr 2024 22:10:48 +0200 Subject: [PATCH 232/244] update test_save_bioimageio_package --- tests/test_package.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_package.py b/tests/test_package.py index 7f4b2093..bb087375 100644 --- a/tests/test_package.py +++ b/tests/test_package.py @@ -3,9 +3,5 @@ def test_save_bioimageio_package(unet2d_nuclei_broad_model: str): _ = save_bioimageio_package( unet2d_nuclei_broad_model, - weights_priority_order=( - None - if weights_priority_order is None - else [wpo.name for wpo in weights_priority_order] - ), + weights_priority_order=("pytorch_state_dict",), ) From e9ac975b285a7deb6ed6d85dcfce41bdc1a54ed3 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 22 Apr 2024 22:11:07 +0200 Subject: [PATCH 233/244] update env-tf.yaml --- dev/env-tf.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml index 1cdd82a3..8dc523d5 100644 --- a/dev/env-tf.yaml +++ b/dev/env-tf.yaml @@ -12,7 +12,7 @@ dependencies: - imageio>=2.5 - jupyter - jupyter-black - - keras>=2.15 + - keras>=2.15 # changed - loguru - numpy - onnxruntime From 748c518b9d8da13899d39b8df711f85a369c8467 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 22 Apr 2024 22:11:24 +0200 Subject: [PATCH 234/244] bump spec --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d720bb05..72a4bce4 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ ], packages=find_namespace_packages(exclude=["tests"]), install_requires=[ - "bioimageio.spec==0.5.1.*", + "bioimageio.spec==0.5.2.*", "fire", "imageio>=2.5", "loguru", From 1601a05c34e2ac86598283239120c9968a999b3b Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 23 Apr 2024 10:06:01 +0200 Subject: [PATCH 235/244] try micromamba shell --- .github/workflows/build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index e797e1e3..b59b4e2f 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -8,7 +8,7 @@ on: defaults: run: - shell: bash -el {0} + shell: micromamba-shell {0} jobs: black: From 44538d9dc7999bff10bdb6c90c680786df7f5b30 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 23 Apr 2024 10:21:42 +0200 Subject: [PATCH 236/244] use bioimageio.spec form conda for tf tests --- .github/workflows/build.yaml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index b59b4e2f..8a9b5bdc 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -104,10 +104,7 @@ jobs: python=${{ matrix.python-version }} post-cleanup: 'all' - name: additional setup - run: | - conda remove --yes --force bioimageio.spec || true # allow failure for cached env - pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io - pip install --no-deps -e . + run: pip install --no-deps -e . - name: pytest-spec-tf run: pytest --disable-pytest-warnings From edb99b96ec234cc151ee8345bf20007106242712 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 23 Apr 2024 11:21:11 +0200 Subject: [PATCH 237/244] fix _internal_output_axes --- .../core/model_adapters/_tensorflow_model_adapter.py | 9 +++------ .../core/model_adapters/_torchscript_model_adapter.py | 8 +++----- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index 390b1b05..8a44ba5c 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -4,11 +4,12 @@ import numpy as np -from bioimageio.core.tensor import Tensor from bioimageio.spec.common import FileSource from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download +from ..digest_spec import get_axes_infos +from ..tensor import Tensor from ._model_adapter import ModelAdapter try: @@ -73,11 +74,7 @@ def __init__( weight_file = self.require_unzipped(weights.source) self._network = self._get_network(weight_file) self._internal_output_axes = [ - ( - tuple(out.axes) - if isinstance(out.axes, str) - else tuple(a.id for a in out.axes) - ) + tuple(a.id for a in get_axes_infos(out)) for out in model_description.outputs ] diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py index d7cee1a3..0d28f019 100644 --- a/bioimageio/core/model_adapters/_torchscript_model_adapter.py +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -8,7 +8,7 @@ from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download -from ..axis import AxisId +from ..digest_spec import get_axes_infos from ..tensor import Tensor from ._model_adapter import ModelAdapter @@ -45,12 +45,10 @@ def __init__( "Multiple devices for single torchscript model not yet implemented" ) - self._model = torch.jit.load( # pyright: ignore[reportPrivateImportUsage] - weight_path - ) + self._model = torch.jit.load(weight_path) self._model.to(self.devices[0]) self._internal_output_axes = [ - tuple(AxisId(a) if isinstance(a, str) else a.id for a in out.axes) + tuple(a.id for a in get_axes_infos(out)) for out in model_description.outputs ] From e41e87fb505ae5032cc26a8952ab2233e1beda88 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 24 Apr 2024 00:17:16 +0200 Subject: [PATCH 238/244] fix _test_device_management --- tests/test_prediction_pipeline_device_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py index bada06ae..447eb698 100644 --- a/tests/test_prediction_pipeline_device_management.py +++ b/tests/test_prediction_pipeline_device_management.py @@ -18,7 +18,7 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): from bioimageio.core._prediction_pipeline import create_prediction_pipeline from bioimageio.core.digest_spec import get_test_inputs, get_test_outputs - if torch.cuda.device_count() == 0: + if not hasattr(torch, "cuda") or torch.cuda.device_count() == 0: raise TooFewDevicesException("Need at least one cuda device for this test") bio_model = load_description(model_package) From 17e9aa67ec4dc7f65d0e6cac3b296974c457ff6f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 24 Apr 2024 00:22:53 +0200 Subject: [PATCH 239/244] bump spec in dev envs --- dev/env-py38.yaml | 2 +- dev/env-tf.yaml | 2 +- dev/env-wo-python.yaml | 2 +- dev/env.yaml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml index decc164b..ade6b2b4 100644 --- a/dev/env-py38.yaml +++ b/dev/env-py38.yaml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: - - bioimageio.spec>=0.5.1 + - bioimageio.spec>=0.5.2post1 - black - crick # uncommented - filelock diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml index 8dc523d5..06451185 100644 --- a/dev/env-tf.yaml +++ b/dev/env-tf.yaml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: - - bioimageio.spec>=0.5.1 + - bioimageio.spec>=0.5.2post1 - black # - crick # currently requires python<=3.9 - filelock diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml index 185cd85c..86dad127 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-wo-python.yaml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: - - bioimageio.spec>=0.5.1 + - bioimageio.spec>=0.5.2 - black # - crick # currently requires python<=3.9 - filelock diff --git a/dev/env.yaml b/dev/env.yaml index d2061c98..e0634abe 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -3,7 +3,7 @@ channels: - conda-forge - defaults dependencies: - - bioimageio.spec>=0.5.1 + - bioimageio.spec>=0.5.2post1 - black # - crick # currently requires python<=3.9 - filelock From fe45db8e38ba3a932ae94a4384f63fadaffc78dc Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 24 Apr 2024 00:49:05 +0200 Subject: [PATCH 240/244] add dot before post1 --- dev/env-py38.yaml | 2 +- dev/env-tf.yaml | 2 +- dev/env-wo-python.yaml | 2 +- dev/env.yaml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml index ade6b2b4..760d2f97 100644 --- a/dev/env-py38.yaml +++ b/dev/env-py38.yaml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: - - bioimageio.spec>=0.5.2post1 + - bioimageio.spec>=0.5.2.post1 - black - crick # uncommented - filelock diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml index 06451185..d51c5ad3 100644 --- a/dev/env-tf.yaml +++ b/dev/env-tf.yaml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: - - bioimageio.spec>=0.5.2post1 + - bioimageio.spec>=0.5.2.post1 - black # - crick # currently requires python<=3.9 - filelock diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml index 86dad127..fedd86d3 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-wo-python.yaml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: - - bioimageio.spec>=0.5.2 + - bioimageio.spec>=0.5.2.post1 - black # - crick # currently requires python<=3.9 - filelock diff --git a/dev/env.yaml b/dev/env.yaml index e0634abe..4ac24ecf 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -3,7 +3,7 @@ channels: - conda-forge - defaults dependencies: - - bioimageio.spec>=0.5.2post1 + - bioimageio.spec>=0.5.2.post1 - black # - crick # currently requires python<=3.9 - filelock From b966b340fb91a2f1e3824dbc38026e659cbf8732 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 24 Apr 2024 01:01:09 +0200 Subject: [PATCH 241/244] fix keras network call --- bioimageio/core/model_adapters/_keras_model_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py index 5e74b084..ef77c117 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -79,7 +79,7 @@ def __init__( def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: _result: Union[Sequence[NDArray[Any]], NDArray[Any]] _result = self._network.predict( # pyright: ignore[reportUnknownVariableType] - *input_tensors + *[None if t is None else t.data.data for t in input_tensors] ) if isinstance(_result, (tuple, list)): result: Sequence[NDArray[Any]] = _result From 73063bdac14938b5dd93f6bbfef32dac76c98fc5 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 24 Apr 2024 01:04:50 +0200 Subject: [PATCH 242/244] update conda recipe template --- conda-recipe/meta.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index a4c7aca6..7654f5ab 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -49,7 +49,11 @@ test: - tests requires: {% for dep in setup_py_data['extras_require']['dev'] %} + {% if dep.startswith('torch>=') %} # pip: torch -> conda: pytorch + - py{{ dep.lower() }} + {% else %} - {{ dep.lower() }} + {% endif %} {% endfor %} commands: - pytest From 2be9913ba4e196e6dd8457d5a1cc2011cc69b6d2 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 24 Apr 2024 01:09:02 +0200 Subject: [PATCH 243/244] fix _output_axes --- .../core/model_adapters/_keras_model_adapter.py | 14 +++++++++++--- .../core/model_adapters/_pytorch_model_adapter.py | 7 ++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py index ef77c117..c5d74132 100644 --- a/bioimageio/core/model_adapters/_keras_model_adapter.py +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -4,12 +4,13 @@ from loguru import logger from numpy.typing import NDArray -from bioimageio.core.tensor import Tensor from bioimageio.spec._internal.io_utils import download from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import Version from .._settings import settings +from ..digest_spec import get_axes_infos +from ..tensor import Tensor from ._model_adapter import ModelAdapter os.environ["KERAS_BACKEND"] = settings.keras_backend @@ -74,7 +75,10 @@ def __init__( weight_path = download(model_description.weights.keras_hdf5.source).path self._network = keras.models.load_model(weight_path) - self._output_axes = [tuple(out.axes) for out in model_description.outputs] + self._output_axes = [ + tuple(a.id for a in get_axes_infos(out)) + for out in model_description.outputs + ] def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: _result: Union[Sequence[NDArray[Any]], NDArray[Any]] @@ -87,7 +91,11 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: result = [_result] # type: ignore assert len(result) == len(self._output_axes) - return [Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)] + ret: List[Optional[Tensor]] = [] + ret.extend( + [Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)] + ) + return ret def unload(self) -> None: logger.warning( diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index eaf03fcc..b647aeff 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -6,7 +6,7 @@ from bioimageio.spec.utils import download from ..axis import AxisId -from ..digest_spec import import_callable +from ..digest_spec import get_axes_infos, import_callable from ..tensor import Tensor from ._model_adapter import ModelAdapter @@ -31,10 +31,7 @@ def __init__( if torch is None: raise ImportError("torch") super().__init__() - self.output_dims = [ - tuple(AxisId(a) if isinstance(a, str) else a.id for a in out.axes) - for out in outputs - ] + self.output_dims = [tuple(a.id for a in get_axes_infos(out)) for out in outputs] self._network = self.get_network(weights) self._devices = self.get_devices(devices) self._network = self._network.to(self._devices[0]) From b84d33b321b1f4e755aebe900f0bd74ed74eb012 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 24 Apr 2024 01:16:22 +0200 Subject: [PATCH 244/244] torch is optional dep --- bioimageio/core/weight_converter/torch/_onnx.py | 7 ++++++- .../core/weight_converter/torch/_torchscript.py | 10 ++++++++-- bioimageio/core/weight_converter/torch/_utils.py | 12 +++++++----- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/bioimageio/core/weight_converter/torch/_onnx.py b/bioimageio/core/weight_converter/torch/_onnx.py index 1e1e68ae..3935e1d1 100644 --- a/bioimageio/core/weight_converter/torch/_onnx.py +++ b/bioimageio/core/weight_converter/torch/_onnx.py @@ -4,7 +4,6 @@ from typing import Any, List, Sequence, cast import numpy as np -import torch from numpy.testing import assert_array_almost_equal from bioimageio.spec import load_description @@ -14,6 +13,11 @@ from ...digest_spec import get_member_id, get_test_inputs from ...weight_converter.torch._utils import load_torch_model +try: + import torch +except ImportError: + torch = None + def add_onnx_weights( model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr", @@ -48,6 +52,7 @@ def add_onnx_weights( "The provided model does not have weights in the pytorch state dict format" ) + assert torch is not None with torch.no_grad(): sample = get_test_inputs(model_spec) diff --git a/bioimageio/core/weight_converter/torch/_torchscript.py b/bioimageio/core/weight_converter/torch/_torchscript.py index 0d226563..5ca16069 100644 --- a/bioimageio/core/weight_converter/torch/_torchscript.py +++ b/bioimageio/core/weight_converter/torch/_torchscript.py @@ -3,7 +3,6 @@ from typing import List, Sequence, Union import numpy as np -import torch from numpy.testing import assert_array_almost_equal from typing_extensions import Any, assert_never @@ -12,14 +11,21 @@ from ._utils import load_torch_model +try: + import torch +except ImportError: + torch = None + # FIXME: remove Any def _check_predictions( model: Any, scripted_model: Any, model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", - input_data: Sequence[torch.Tensor], + input_data: Sequence["torch.Tensor"], ): + assert torch is not None + def _check(input_: Sequence[torch.Tensor]) -> None: expected_tensors = model(*input_) if isinstance(expected_tensors, torch.Tensor): diff --git a/bioimageio/core/weight_converter/torch/_utils.py b/bioimageio/core/weight_converter/torch/_utils.py index d3908f61..01df0747 100644 --- a/bioimageio/core/weight_converter/torch/_utils.py +++ b/bioimageio/core/weight_converter/torch/_utils.py @@ -1,22 +1,24 @@ from typing import Union -import torch - from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download +try: + import torch +except ImportError: + torch = None + # additional convenience for pytorch state dict, eventually we want this in python-bioimageio too # and for each weight format def load_torch_model( # pyright: ignore[reportUnknownParameterType] node: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr], ): + assert torch is not None model = ( # pyright: ignore[reportUnknownVariableType] PytorchModelAdapter.get_network(node) ) - state = torch.load( # pyright: ignore[reportUnknownVariableType] - download(node.source).path, map_location="cpu" - ) + state = torch.load(download(node.source).path, map_location="cpu") model.load_state_dict(state) # FIXME: check incompatible keys? return model.eval() # pyright: ignore[reportUnknownVariableType]