diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 844cbb5e..2ba5afac 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -49,14 +49,29 @@ jobs: post-cleanup: 'all' - name: additional setup run: pip install --no-deps -e . + - name: Get Date + id: get-date + run: | + echo "date=$(date +'%Y-%b')" + echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT + shell: bash + - uses: actions/cache@v4 + with: + path: bioimageio_cache + key: "test-spec-conda-${{ steps.get-date.outputs.date }}" - name: pytest-spec-conda run: pytest --disable-pytest-warnings + env: + BIOIMAGEIO_CACHE_PATH: bioimageio_cache test-spec-main: runs-on: ubuntu-latest strategy: matrix: python-version: ['3.8', '3.12'] + include: + - python-version: '3.12' + is-dev-version: true steps: - uses: actions/checkout@v4 - name: Install Conda environment with Micromamba @@ -83,19 +98,31 @@ jobs: pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io - name: additional setup core run: pip install --no-deps -e . + - name: Get Date + id: get-date + run: | + echo "date=$(date +'%Y-%b')" + echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT + shell: bash + - uses: actions/cache@v4 + with: + path: bioimageio_cache + key: "test-spec-main-${{ steps.get-date.outputs.date }}" - name: pytest-spec-main run: pytest --disable-pytest-warnings - - if: matrix.python-version == '3.12' && github.event_name == 'pull_request' + env: + BIOIMAGEIO_CACHE_PATH: bioimageio_cache + - if: matrix.is-dev-version && github.event_name == 'pull_request' uses: orgoro/coverage@v3.2 with: coverageFile: coverage.xml token: ${{ secrets.GITHUB_TOKEN }} - - if: matrix.python-version == '3.12' && github.ref == 'refs/heads/main' + - if: matrix.is-dev-version && github.ref == 'refs/heads/main' run: | pip install genbadge[coverage] genbadge coverage --input-file coverage.xml --output-file ./dist/coverage/coverage-badge.svg coverage html -d dist/coverage - - if: matrix.python-version == '3.12' && github.ref == 'refs/heads/main' + - if: matrix.is-dev-version && github.ref == 'refs/heads/main' uses: actions/upload-artifact@v4 with: name: coverage @@ -103,15 +130,51 @@ jobs: path: dist - test-tf: + test-spec-main-tf: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.11'] + python-version: ['3.9', '3.12'] steps: - uses: actions/checkout@v4 - - name: Install Conda environment with Micromamba - uses: mamba-org/setup-micromamba@v1 + - uses: mamba-org/setup-micromamba@v1 + with: + cache-downloads: true + cache-environment: true + environment-file: dev/env-tf.yaml + condarc: | + channel-priority: flexible + create-args: >- + python=${{ matrix.python-version }} + post-cleanup: 'all' + - name: additional setup spec + run: | + conda remove --yes --force bioimageio.spec || true # allow failure for cached env + pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io + - name: additional setup core + run: pip install --no-deps -e . + - name: Get Date + id: get-date + run: | + echo "date=$(date +'%Y-%b')" + echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT + shell: bash + - uses: actions/cache@v4 + with: + path: bioimageio_cache + key: "test-spec-main-tf-${{ steps.get-date.outputs.date }}" + - run: pytest --disable-pytest-warnings + env: + BIOIMAGEIO_CACHE_PATH: bioimageio_cache + + test-spec-conda-tf: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.12'] + steps: + - uses: actions/checkout@v4 + - uses: mamba-org/setup-micromamba@v1 with: cache-downloads: true cache-environment: true @@ -123,8 +186,20 @@ jobs: post-cleanup: 'all' - name: additional setup run: pip install --no-deps -e . + - name: Get Date + id: get-date + run: | + echo "date=$(date +'%Y-%b')" + echo "date=$(date +'%Y-%b')" >> $GITHUB_OUTPUT + shell: bash + - uses: actions/cache@v4 + with: + path: bioimageio_cache + key: "test-spec-conda-tf-${{ steps.get-date.outputs.date }}" - name: pytest-spec-tf run: pytest --disable-pytest-warnings + env: + BIOIMAGEIO_CACHE_PATH: bioimageio_cache conda-build: runs-on: ubuntu-latest @@ -164,20 +239,11 @@ jobs: path: dist - uses: actions/setup-python@v5 with: - python-version: '3.12' + python-version: '3.13' cache: 'pip' - run: pip install -e .[dev] - - id: get_version - run: python -c 'import bioimageio.core;print(f"version={bioimageio.core.__version__}")' >> $GITHUB_OUTPUT - name: Generate developer docs - run: | - pdoc \ - --docformat google \ - --logo https://bioimage.io/static/img/bioimage-io-logo.svg \ - --logo-link https://bioimage.io/ \ - --favicon https://bioimage.io/static/img/bioimage-io-icon-small.svg \ - --footer-text 'bioimageio.core ${{steps.get_version.outputs.version}}' \ - -o ./dist bioimageio.core + run: ./scripts/pdoc/run.sh - run: cp README.md ./dist/README.md - name: copy rendered presentations run: | diff --git a/README.md b/README.md index 333b5fab..71349572 100644 --- a/README.md +++ b/README.md @@ -375,6 +375,24 @@ The model specification and its validation tools can be found at ValidationSummary: """Test model inference""" - # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference` return test_description( source, weight_format=weight_format, @@ -48,6 +111,7 @@ def test_model( absolute_tolerance=absolute_tolerance, relative_tolerance=relative_tolerance, decimal=decimal, + determinism=determinism, expected_type="model", ) @@ -61,10 +125,10 @@ def test_description( absolute_tolerance: float = 1.5e-4, relative_tolerance: float = 1e-4, decimal: Optional[int] = None, + determinism: Literal["seed_only", "full"] = "seed_only", expected_type: Optional[str] = None, ) -> ValidationSummary: """Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models""" - # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference` rd = load_description_and_test( source, format_version=format_version, @@ -73,6 +137,7 @@ def test_description( absolute_tolerance=absolute_tolerance, relative_tolerance=relative_tolerance, decimal=decimal, + determinism=determinism, expected_type=expected_type, ) return rd.validation_summary @@ -87,10 +152,10 @@ def load_description_and_test( absolute_tolerance: float = 1.5e-4, relative_tolerance: float = 1e-4, decimal: Optional[int] = None, + determinism: Literal["seed_only", "full"] = "seed_only", expected_type: Optional[str] = None, ) -> Union[ResourceDescr, InvalidDescr]: """Test RDF dynamically, e.g. model inference of test inputs""" - # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference` if ( isinstance(source, ResourceDescrBase) and format_version != "discover" @@ -108,7 +173,7 @@ def load_description_and_test( else: rd = load_description(source, format_version=format_version) - rd.validation_summary.env.append( + rd.validation_summary.env.add( InstalledPackage(name="bioimageio.core", version=VERSION) ) @@ -122,10 +187,25 @@ def load_description_and_test( ] # pyright: ignore[reportAssignmentType] else: weight_formats = [weight_format] - for w in weight_formats: - _test_model_inference( - rd, w, devices, absolute_tolerance, relative_tolerance, decimal + + if decimal is None: + atol = absolute_tolerance + rtol = relative_tolerance + else: + warnings.warn( + "The argument `decimal` has been deprecated in favour of" + + " `relative_tolerance` and `absolute_tolerance`, with different" + + " validation logic, using `numpy.testing.assert_allclose, see" + + " 'https://numpy.org/doc/stable/reference/generated/" + + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`" + + " will cause validation to revert to the old behaviour." ) + atol = 1.5 * 10 ** (-decimal) + rtol = 0 + + enable_determinism(determinism) + for w in weight_formats: + _test_model_inference(rd, w, devices, atol, rtol) if not isinstance(rd, v0_4.ModelDescr): _test_model_inference_parametrized(rd, w, devices) @@ -139,21 +219,14 @@ def _test_model_inference( model: Union[v0_4.ModelDescr, v0_5.ModelDescr], weight_format: WeightsFormat, devices: Optional[Sequence[str]], - absolute_tolerance: float, - relative_tolerance: float, - decimal: Optional[int], + atol: float, + rtol: float, ) -> None: test_name = f"Reproduce test outputs from test inputs ({weight_format})" logger.info("starting '{}'", test_name) error: Optional[str] = None tb: List[str] = [] - precision_args = _handle_legacy_precision_args( - absolute_tolerance=absolute_tolerance, - relative_tolerance=relative_tolerance, - decimal=decimal, - ) - try: inputs = get_test_inputs(model) expected = get_test_outputs(model) @@ -176,8 +249,8 @@ def _test_model_inference( np.testing.assert_allclose( res.data, exp.data, - rtol=precision_args["relative_tolerance"], - atol=precision_args["absolute_tolerance"], + rtol=rtol, + atol=atol, ) except AssertionError as e: error = f"Output and expected output disagree:\n {e}" @@ -189,7 +262,9 @@ def _test_model_inference( model.validation_summary.add_detail( ValidationDetail( name=test_name, + loc=("weights", weight_format), status="passed" if error is None else "failed", + recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]), errors=( [] if error is None @@ -332,6 +407,7 @@ def get_ns(n: int): ValidationDetail( name=f"Run {weight_format} inference for inputs with" + f" batch_size: {batch_size} and size parameter n: {n}", + loc=("weights", weight_format), status="passed" if error is None else "failed", errors=( [] @@ -353,6 +429,7 @@ def get_ns(n: int): ValidationDetail( name=f"Run {weight_format} inference for parametrized inputs", status="failed", + loc=("weights", weight_format), errors=[ ErrorEntry( loc=("weights", weight_format), @@ -373,6 +450,7 @@ def _test_expected_resource_type( ValidationDetail( name="Has expected resource type", status="passed" if has_expected_type else "failed", + loc=("type",), errors=( [] if has_expected_type @@ -388,38 +466,7 @@ def _test_expected_resource_type( ) -def _handle_legacy_precision_args( - absolute_tolerance: float, relative_tolerance: float, decimal: Optional[int] -) -> Dict[str, float]: - """ - Transform the precision arguments to conform with the current implementation. - - If the deprecated `decimal` argument is used it overrides the new behaviour with - the old behaviour. - """ - # Already conforms with current implementation - if decimal is None: - return { - "absolute_tolerance": absolute_tolerance, - "relative_tolerance": relative_tolerance, - } - - warnings.warn( - "The argument `decimal` has been depricated in favour of " - + "`relative_tolerance` and `absolute_tolerance`, with different validation " - + "logic, using `numpy.testing.assert_allclose, see " - + "'https://numpy.org/doc/stable/reference/generated/" - + "numpy.testing.assert_allclose.html'. Passing a value for `decimal` will " - + "cause validation to revert to the old behaviour." - ) - # decimal overrides new behaviour, - # have to convert the params to emulate old behaviour - return { - "absolute_tolerance": 1.5 * 10 ** (-decimal), - "relative_tolerance": 0, - } - - +# TODO: Implement `debug_model()` # def debug_model( # model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], # *, diff --git a/bioimageio/core/_settings.py b/bioimageio/core/_settings.py index d09f3b8b..c95cf55d 100644 --- a/bioimageio/core/_settings.py +++ b/bioimageio/core/_settings.py @@ -10,7 +10,7 @@ class Settings(SpecSettings): - """environment variables""" + """environment variables for bioimageio.spec and bioimageio.core""" keras_backend: Annotated[ Literal["torch", "tensorflow", "jax"], Field(alias="KERAS_BACKEND") @@ -18,3 +18,4 @@ class Settings(SpecSettings): settings = Settings() +"""parsed environment variables for bioimageio.spec and bioimageio.core""" diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py index 22f29ded..f7740092 100644 --- a/bioimageio/core/block_meta.py +++ b/bioimageio/core/block_meta.py @@ -317,25 +317,34 @@ def split_multiple_shapes_into_blocks( strides: Optional[PerMember[PerAxis[int]]] = None, broadcast: bool = False, ) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]: - assert not ( - missing := [t for t in block_shapes if t not in shapes] - ), f"block shape specified for unknown tensors: {missing}" + if unknown_blocks := [t for t in block_shapes if t not in shapes]: + raise ValueError( + f"block shape specified for unknown tensors: {unknown_blocks}." + ) + if not block_shapes: block_shapes = shapes - assert broadcast or not ( - missing := [t for t in shapes if t not in block_shapes] - ), f"no block shape specified for {missing} (set `broadcast` to True if these tensors should be repeated for each block)" - assert not ( - missing := [t for t in halo if t not in block_shapes] - ), f"`halo` specified for tensors without block shape: {missing}" + if not broadcast and ( + missing_blocks := [t for t in shapes if t not in block_shapes] + ): + raise ValueError( + f"no block shape specified for {missing_blocks}." + + " Set `broadcast` to True if these tensors should be repeated" + + " as a whole for each block." + ) + + if extra_halo := [t for t in halo if t not in block_shapes]: + raise ValueError( + f"`halo` specified for tensors without block shape: {extra_halo}." + ) if strides is None: strides = {} assert not ( - missing := [t for t in strides if t not in block_shapes] - ), f"`stride` specified for tensors without block shape: {missing}" + unknown_block := [t for t in strides if t not in block_shapes] + ), f"`stride` specified for tensors without block shape: {unknown_block}" blocks: Dict[MemberId, Iterable[BlockMeta]] = {} n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {} @@ -346,7 +355,7 @@ def split_multiple_shapes_into_blocks( halo=halo.get(t, {}), stride=strides.get(t), ) - assert n_blocks[t] > 0 + assert n_blocks[t] > 0, n_blocks assert len(blocks) > 0, blocks assert len(n_blocks) > 0, n_blocks @@ -355,8 +364,9 @@ def split_multiple_shapes_into_blocks( if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: if not broadcast: raise ValueError( - f"Mismatch for total number of blocks due to unsplit (single block) tensors: {n_blocks}." - + " Set `broadcast` to True if you want to repeat unsplit (single block) tensors." + "Mismatch for total number of blocks due to unsplit (single block)" + + f" tensors: {n_blocks}. Set `broadcast` to True if you want to" + + " repeat unsplit (single block) tensors." ) blocks = { diff --git a/bioimageio/core/cli.py b/bioimageio/core/cli.py index 4127dd17..fad44ab3 100644 --- a/bioimageio/core/cli.py +++ b/bioimageio/core/cli.py @@ -28,7 +28,7 @@ ) from loguru import logger -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, Field, model_validator from pydantic_settings import ( BaseSettings, CliPositionalArg, @@ -41,50 +41,46 @@ ) from ruyaml import YAML from tqdm import tqdm +from typing_extensions import assert_never -from bioimageio.core import ( - MemberId, - Sample, - __version__, - create_prediction_pipeline, -) -from bioimageio.core.commands import ( +from bioimageio.spec import AnyModelDescr, InvalidDescr, load_description +from bioimageio.spec._internal.io_basics import ZipPath +from bioimageio.spec._internal.types import NotEmpty +from bioimageio.spec.dataset import DatasetDescr +from bioimageio.spec.model import ModelDescr, v0_4, v0_5 +from bioimageio.spec.notebook import NotebookDescr +from bioimageio.spec.utils import download, ensure_description_is_model + +from .commands import ( WeightFormatArgAll, WeightFormatArgAny, package, test, validate_format, ) -from bioimageio.core.common import SampleId -from bioimageio.core.digest_spec import get_member_ids, load_sample_for_model -from bioimageio.core.io import load_dataset_stat, save_dataset_stat, save_sample -from bioimageio.core.proc_setup import ( +from .common import MemberId, SampleId +from .digest_spec import get_member_ids, load_sample_for_model +from .io import load_dataset_stat, save_dataset_stat, save_sample +from .prediction import create_prediction_pipeline +from .proc_setup import ( DatasetMeasure, Measure, MeasureValue, StatsCalculator, get_required_dataset_measures, ) -from bioimageio.core.stat_measures import Stat -from bioimageio.spec import ( - AnyModelDescr, - InvalidDescr, - load_description, -) -from bioimageio.spec._internal.types import NotEmpty -from bioimageio.spec.dataset import DatasetDescr -from bioimageio.spec.model import ModelDescr, v0_4, v0_5 -from bioimageio.spec.notebook import NotebookDescr -from bioimageio.spec.utils import download, ensure_description_is_model +from .sample import Sample +from .stat_measures import Stat +from .utils import VERSION yaml = YAML(typ="safe") -class CmdBase(BaseModel, use_attribute_docstrings=True): +class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): pass -class ArgMixin(BaseModel, use_attribute_docstrings=True): +class ArgMixin(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): pass @@ -117,14 +113,14 @@ def descr_id(self) -> str: class ValidateFormatCmd(CmdBase, WithSource): - """bioimageio-validate-format - validate the meta data format of a bioimageio resource.""" + """validate the meta data format of a bioimageio resource.""" def run(self): - validate_format(self.descr) + sys.exit(validate_format(self.descr)) class TestCmd(CmdBase, WithSource): - """bioimageio-test - Test a bioimageio resource (beyond meta data formatting)""" + """Test a bioimageio resource (beyond meta data formatting)""" weight_format: WeightFormatArgAll = "all" """The weight format to limit testing to. @@ -138,16 +134,18 @@ class TestCmd(CmdBase, WithSource): """Precision for numerical comparisons""" def run(self): - test( - self.descr, - weight_format=self.weight_format, - devices=self.devices, - decimal=self.decimal, + sys.exit( + test( + self.descr, + weight_format=self.weight_format, + devices=self.devices, + decimal=self.decimal, + ) ) class PackageCmd(CmdBase, WithSource): - """bioimageio-package - save a resource's metadata with its associated files.""" + """save a resource's metadata with its associated files.""" path: CliPositionalArg[Path] """The path to write the (zipped) package to. @@ -162,10 +160,12 @@ def run(self): self.descr.validation_summary.display() raise ValueError("resource description is invalid") - package( - self.descr, - self.path, - weight_format=self.weight_format, + sys.exit( + package( + self.descr, + self.path, + weight_format=self.weight_format, + ) ) @@ -204,7 +204,7 @@ def _get_stat( class PredictCmd(CmdBase, WithSource): - """bioimageio-predict - Run inference on your data with a bioimage.io model.""" + """Run inference on your data with a bioimage.io model.""" inputs: NotEmpty[Sequence[Union[str, NotEmpty[Tuple[str, ...]]]]] = ( "{input_id}/001.tif", @@ -305,7 +305,12 @@ def _example(self): dst = Path(f"{example_path}/{t}/001{''.join(local.suffixes)}") dst.parent.mkdir(parents=True, exist_ok=True) inputs001.append(dst.as_posix()) - shutil.copy(local, dst) + if isinstance(local, Path): + shutil.copy(local, dst) + elif isinstance(local, ZipPath): + _ = local.root.extract(local.at, path=dst) + else: + assert_never(local) inputs = [tuple(inputs001)] output_pattern = f"{example_path}/outputs/{{output_id}}/{{sample_id}}.tif" @@ -336,9 +341,10 @@ def get_example_command(preview: bool, escape: bool = False): return [ "bioimageio", "predict", - f"--preview={preview}", # update once we use implicit flags, see `class Bioimageio` below - "--overwrite=True", - f"--blockwise={self.blockwise}", + # --no-preview not supported for py=3.8 + *(["--preview"] if preview else []), + "--overwrite", + *(["--blockwise"] if self.blockwise else []), f"--stats={q}{stats}{q}", f"--inputs={q}{inputs_escaped if escape else inputs_json}{q}", f"--outputs={q}{output_pattern}{q}", @@ -545,22 +551,20 @@ def input_dataset(stat: Stat): class Bioimageio( BaseSettings, - # alias_generator=AliasGenerator( - # validation_alias=lambda s: AliasChoices(s, to_snake(s).replace("_", "-")) - # ), - # TODO: investigate how to allow a validation alias for subcommands - # ('validate-format' vs 'validate_format') cli_parse_args=True, cli_prog_name="bioimageio", cli_use_class_docs_for_groups=True, - # cli_implicit_flags=True, # TODO: make flags implicit, see https://github.com/pydantic/pydantic-settings/issues/361 + cli_implicit_flags=True, use_attribute_docstrings=True, ): """bioimageio - CLI for bioimage.io resources 🦒""" - model_config = SettingsConfigDict(json_file=JSON_FILE, yaml_file=YAML_FILE) + model_config = SettingsConfigDict( + json_file=JSON_FILE, + yaml_file=YAML_FILE, + ) - validate_format: CliSubCommand[ValidateFormatCmd] + validate_format: CliSubCommand[ValidateFormatCmd] = Field(alias="validate-format") "Check a resource's metadata format" test: CliSubCommand[TestCmd] @@ -618,8 +622,8 @@ def run(self): Bioimageio.__doc__ += f""" library versions: - bioimageio.core {__version__} - bioimageio.spec {__version__} + bioimageio.core {VERSION} + bioimageio.spec {VERSION} spec format versions: model RDF {ModelDescr.implemented_format_version} @@ -630,7 +634,7 @@ def run(self): def _get_sample_ids( - input_paths: Sequence[Mapping[MemberId, Path]] + input_paths: Sequence[Mapping[MemberId, Path]], ) -> Sequence[SampleId]: """Get sample ids for given input paths, based on the common path per sample. diff --git a/bioimageio/core/commands.py b/bioimageio/core/commands.py index a7cfc97c..c71d495f 100644 --- a/bioimageio/core/commands.py +++ b/bioimageio/core/commands.py @@ -1,13 +1,11 @@ """These functions implement the logic of the bioimageio command line interface -defined in the `cli` module.""" +defined in `bioimageio.core.cli`.""" -import sys from pathlib import Path -from typing import List, Optional, Sequence, Union +from typing import Optional, Sequence, Union from typing_extensions import Literal -from bioimageio.core import test_description from bioimageio.spec import ( InvalidDescr, ResourceDescr, @@ -16,6 +14,8 @@ ) from bioimageio.spec.model.v0_5 import WeightsFormat +from ._resource_tests import test_description + WeightFormatArgAll = Literal[WeightsFormat, "all"] WeightFormatArgAny = Literal[WeightsFormat, "any"] @@ -26,7 +26,7 @@ def test( weight_format: WeightFormatArgAll = "all", devices: Optional[Union[str, Sequence[str]]] = None, decimal: int = 4, -): +) -> int: """test a bioimageio resource Args: @@ -38,7 +38,7 @@ def test( """ if isinstance(descr, InvalidDescr): descr.validation_summary.display() - sys.exit(1) + return 1 summary = test_description( descr, @@ -47,7 +47,7 @@ def test( decimal=decimal, ) summary.display() - sys.exit(0 if summary.status == "passed" else 1) + return 0 if summary.status == "passed" else 1 def validate_format( @@ -59,7 +59,7 @@ def validate_format( descr: a bioimageio resource description """ descr.validation_summary.display() - sys.exit(0 if descr.validation_summary.status == "passed" else 1) + return 0 if descr.validation_summary.status == "passed" else 1 def package( @@ -96,53 +96,4 @@ def package( output_path=path, weights_priority_order=weights_priority_order, ) - - -# TODO: add convert command(s) -# if torch_converter is not None: - -# @app.command() -# def convert_torch_weights_to_onnx( -# model_rdf: Path = typer.Argument( -# ..., help="Path to the model resource description file (rdf.yaml) or zipped model." -# ), -# output_path: Path = typer.Argument(..., help="Where to save the onnx weights."), -# opset_version: Optional[int] = typer.Argument(12, help="Onnx opset version."), -# use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."), -# verbose: bool = typer.Option(True, help="Verbosity"), -# ): -# ret_code = torch_converter.convert_weights_to_onnx(model_rdf, output_path, opset_version, use_tracing, verbose) -# sys.exit(ret_code) - -# convert_torch_weights_to_onnx.__doc__ = torch_converter.convert_weights_to_onnx.__doc__ - -# @app.command() -# def convert_torch_weights_to_torchscript( -# model_rdf: Path = typer.Argument( -# ..., help="Path to the model resource description file (rdf.yaml) or zipped model." -# ), -# output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."), -# use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."), -# ): -# torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing) -# sys.exit(0) - -# convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_torchscript.__doc__ - - -# if keras_converter is not None: - -# @app.command() -# def convert_keras_weights_to_tensorflow( -# model_rdf: Annotated[ -# Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") -# ], -# output_path: Annotated[Path, typer.Argument(help="Where to save the tensorflow weights.")], -# ): -# rd = load_description(model_rdf) -# ret_code = keras_converter.convert_weights_to_tensorflow_saved_model_bundle(rd, output_path) -# sys.exit(ret_code) - -# convert_keras_weights_to_tensorflow.__doc__ = ( -# keras_converter.convert_weights_to_tensorflow_saved_model_bundle.__doc__ -# ) + return 0 diff --git a/bioimageio/core/dataset.py b/bioimageio/core/dataset.py index 59361b2d..e7740504 100644 --- a/bioimageio/core/dataset.py +++ b/bioimageio/core/dataset.py @@ -1,5 +1,5 @@ from typing import Iterable -from bioimageio.core.sample import Sample +from .sample import Sample Dataset = Iterable[Sample] diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py index 1e229e53..edb5a45d 100644 --- a/bioimageio/core/digest_spec.py +++ b/bioimageio/core/digest_spec.py @@ -23,10 +23,8 @@ from numpy.typing import NDArray from typing_extensions import Unpack, assert_never -from bioimageio.core.common import MemberId, PerMember, SampleId -from bioimageio.core.io import load_tensor -from bioimageio.core.sample import Sample -from bioimageio.spec._internal.io_utils import HashKwargs, download +from bioimageio.spec._internal.io import resolve_and_extract +from bioimageio.spec._internal.io_utils import HashKwargs from bioimageio.spec.common import FileSource from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile @@ -40,6 +38,7 @@ from .axis import AxisId, AxisInfo, AxisLike, PerAxis from .block_meta import split_multiple_shapes_into_blocks from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks +from .io import load_tensor from .sample import ( LinearSampleAxisTransform, Sample, @@ -79,7 +78,7 @@ def import_callable( def _import_from_file_impl( source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] ): - local_file = download(source, **kwargs) + local_file = resolve_and_extract(source, **kwargs) module_name = local_file.path.stem importlib_spec = importlib.util.spec_from_file_location( module_name, local_file.path @@ -98,7 +97,7 @@ def get_axes_infos( v0_4.OutputTensorDescr, v0_5.InputTensorDescr, v0_5.OutputTensorDescr, - ] + ], ) -> List[AxisInfo]: """get a unified, simplified axis representation from spec axes""" return [ @@ -117,7 +116,7 @@ def get_member_id( v0_4.OutputTensorDescr, v0_5.InputTensorDescr, v0_5.OutputTensorDescr, - ] + ], ) -> MemberId: """get the normalized tensor ID, usable as a sample member ID""" @@ -139,7 +138,7 @@ def get_member_ids( v0_5.InputTensorDescr, v0_5.OutputTensorDescr, ] - ] + ], ) -> List[MemberId]: """get normalized tensor IDs to be used as sample member IDs""" return [get_member_id(descr) for descr in tensor_descriptions] @@ -160,7 +159,7 @@ def get_test_inputs(model: AnyModelDescr) -> Sample: for m, arr, ax in zip(member_ids, arrays, axes) }, stat={}, - id="test-input", + id="test-sample", ) @@ -181,7 +180,7 @@ def get_test_outputs(model: AnyModelDescr) -> Sample: for m, arr, ax in zip(member_ids, arrays, axes) }, stat={}, - id="test-output", + id="test-sample", ) @@ -191,8 +190,8 @@ class IO_SampleBlockMeta(NamedTuple): def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): - """returns which halo input tensors need to be divided into blocks with such that - `output_halo` can be cropped from their outputs without intorducing gaps.""" + """returns which halo input tensors need to be divided into blocks with, such that + `output_halo` can be cropped from their outputs without introducing gaps.""" input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} outputs = {t.id: t for t in model.outputs} all_tensors = {**{t.id: t for t in model.inputs}, **outputs} @@ -222,8 +221,10 @@ def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]] return input_halo -def get_block_transform(model: v0_5.ModelDescr): - """returns how a model's output tensor shapes relate to its input shapes""" +def get_block_transform( + model: v0_5.ModelDescr, +) -> PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]]: + """returns how a model's output tensor shapes relates to its input shapes""" ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} batch_axis_trf = None for ipt in model.inputs: @@ -286,14 +287,6 @@ def get_io_sample_block_metas( t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} for t in {tt for tt, _ in block_axis_sizes.inputs} } - output_block_shape = { - t: { - aa: s - for (tt, aa), s in block_axis_sizes.outputs.items() - if tt == t and not isinstance(s, tuple) - } - for t in {tt for tt, _ in block_axis_sizes.outputs} - } output_halo = { t.id: { a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) @@ -302,36 +295,14 @@ def get_io_sample_block_metas( } input_halo = get_input_halo(model, output_halo) - # TODO: fix output_sample_shape_data_dep - # (below only valid if input_sample_shape is a valid model input, - # which is not a valid assumption) - output_sample_shape_data_dep = model.get_output_tensor_sizes(input_sample_shape) - - output_sample_shape = { - t: { - a: -1 if isinstance(s, tuple) else s - for a, s in output_sample_shape_data_dep[t].items() - } - for t in output_sample_shape_data_dep - } n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( input_sample_shape, input_block_shape, halo=input_halo ) - n_output_blocks, output_blocks = split_multiple_shapes_into_blocks( - output_sample_shape, output_block_shape, halo=output_halo - ) - assert n_input_blocks == n_output_blocks + block_transform = get_block_transform(model) return n_input_blocks, ( - IO_SampleBlockMeta(ipt, out) - for ipt, out in zip( - sample_block_meta_generator( - input_blocks, sample_shape=input_sample_shape, sample_id=None - ), - sample_block_meta_generator( - output_blocks, - sample_shape=output_sample_shape, - sample_id=None, - ), + IO_SampleBlockMeta(ipt, ipt.get_transformed(block_transform)) + for ipt in sample_block_meta_generator( + input_blocks, sample_shape=input_sample_shape, sample_id=None ) ) diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py index a1dec452..ee60a67a 100644 --- a/bioimageio/core/io.py +++ b/bioimageio/core/io.py @@ -1,22 +1,25 @@ import collections.abc -from os import PathLike -from pathlib import Path -from typing import Any, Mapping, Optional, Sequence, Union +import warnings +from pathlib import Path, PurePosixPath +from typing import Any, Mapping, Optional, Sequence, Tuple, Union -import imageio +import h5py +import numpy as np from imageio.v3 import imread, imwrite from loguru import logger from numpy.typing import NDArray from pydantic import BaseModel, ConfigDict, TypeAdapter -from bioimageio.core.common import PerMember -from bioimageio.core.stat_measures import DatasetMeasure, MeasureValue from bioimageio.spec.utils import load_array, save_array -from .axis import Axis, AxisLike +from .axis import AxisLike +from .common import PerMember from .sample import Sample +from .stat_measures import DatasetMeasure, MeasureValue from .tensor import Tensor +DEFAULT_H5_DATASET_PATH = "data" + def load_image(path: Path, is_volume: Optional[bool] = None) -> NDArray[Any]: """load a single image as numpy array @@ -25,9 +28,38 @@ def load_image(path: Path, is_volume: Optional[bool] = None) -> NDArray[Any]: path: image path is_volume: deprecated """ - ext = path.suffix - if ext == ".npy": + if is_volume is not None: + warnings.warn("**is_volume** is deprecated and will be removed soon.") + + file_path, subpath = _split_dataset_path(Path(path)) + + if file_path.suffix == ".npy": + if subpath is not None: + raise ValueError(f"Unexpected subpath {subpath} for .npy path {path}") return load_array(path) + elif file_path.suffix in (".h5", ".hdf", ".hdf5"): + if subpath is None: + dataset_path = DEFAULT_H5_DATASET_PATH + else: + dataset_path = str(subpath) + + with h5py.File(file_path, "r") as f: + h5_dataset = f.get( # pyright: ignore[reportUnknownVariableType] + dataset_path + ) + if not isinstance(h5_dataset, h5py.Dataset): + raise ValueError( + f"{path} is not of type {h5py.Dataset}, but has type " + + str( + type(h5_dataset) # pyright: ignore[reportUnknownArgumentType] + ) + ) + image: NDArray[Any] + image = h5_dataset[:] # pyright: ignore[reportUnknownVariableType] + assert isinstance(image, np.ndarray), type( + image # pyright: ignore[reportUnknownArgumentType] + ) + return image # pyright: ignore[reportUnknownVariableType] else: return imread(path) # pyright: ignore[reportUnknownVariableType] @@ -39,14 +71,53 @@ def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor return Tensor.from_numpy(array, dims=axes) +def _split_dataset_path(path: Path) -> Tuple[Path, Optional[PurePosixPath]]: + """Split off subpath (e.g. internal h5 dataset path) + from a file path following a file extension. + + Examples: + >>> _split_dataset_path(Path("my_file.h5/dataset")) + (PosixPath('my_file.h5'), PurePosixPath('dataset')) + + If no suffix is detected the path is returned with + >>> _split_dataset_path(Path("my_plain_file")) + (PosixPath('my_plain_file'), None) + + """ + if path.suffix: + return path, None + + for p in path.parents: + if p.suffix: + return p, PurePosixPath(path.relative_to(p)) + + return path, None + + def save_tensor(path: Path, tensor: Tensor) -> None: # TODO: save axis meta data data: NDArray[Any] = tensor.data.to_numpy() - path = Path(path) - path.parent.mkdir(exist_ok=True, parents=True) - if path.suffix == ".npy": - save_array(path, data) + file_path, subpath = _split_dataset_path(Path(path)) + if not file_path.suffix: + raise ValueError(f"No suffix (needed to decide file format) found in {path}") + + file_path.parent.mkdir(exist_ok=True, parents=True) + if file_path.suffix == ".npy": + if subpath is not None: + raise ValueError(f"Unexpected subpath {subpath} found in .npy path {path}") + save_array(file_path, data) + elif file_path.suffix in (".h5", ".hdf", ".hdf5"): + if subpath is None: + dataset_path = DEFAULT_H5_DATASET_PATH + else: + dataset_path = str(subpath) + + with h5py.File(file_path, "a") as f: + if dataset_path in f: + del f[dataset_path] + + _ = f.create_dataset(dataset_path, data=data, chunks=True) else: # if singleton_axes := [a for a, s in tensor.tagged_shape.items() if s == 1]: # tensor = tensor[{a: 0 for a in singleton_axes}] diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py index 082514d6..a5178d74 100644 --- a/bioimageio/core/model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -102,7 +102,7 @@ def unload(self) -> None: def get_network( # pyright: ignore[reportUnknownParameterType] weight_spec: Union[ v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr - ] + ], ) -> "torch.nn.Module": # pyright: ignore[reportInvalidTypeForm] if torch is None: raise ImportError("torch") diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index 6e01d8f4..27a4129c 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -16,15 +16,15 @@ from numpy.typing import NDArray from tqdm import tqdm -from bioimageio.core.axis import AxisId -from bioimageio.core.io import save_sample from bioimageio.spec import load_description from bioimageio.spec.common import PermissiveFileSource from bioimageio.spec.model import v0_4, v0_5 from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline +from .axis import AxisId from .common import MemberId, PerMember from .digest_spec import create_sample_for_model +from .io import save_sample from .sample import Sample from .tensor import Tensor diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 35a160f5..eecf47b1 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -16,13 +16,13 @@ import xarray as xr from typing_extensions import Self, assert_never -from bioimageio.core.block import Block -from bioimageio.core.sample import Sample, SampleBlock, SampleBlockWithOrigin from bioimageio.spec.model import v0_4, v0_5 from ._op_base import BlockedOperator, Operator from .axis import AxisId, PerAxis +from .block import Block from .common import DTypeStr, MemberId +from .sample import Sample, SampleBlock, SampleBlockWithOrigin from .stat_calculators import StatsCalculator from .stat_measures import ( DatasetMean, @@ -393,7 +393,7 @@ def _get_axes( v0_5.ScaleRangeKwargs, v0_4.ScaleMeanVarianceKwargs, v0_5.ScaleMeanVarianceKwargs, - ] + ], ) -> Tuple[bool, Optional[Tuple[AxisId, ...]]]: if kwargs.axes is None: return True, None diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index 6a9bcbf6..b9afb711 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -11,14 +11,12 @@ from typing_extensions import assert_never -from bioimageio.core.common import MemberId -from bioimageio.core.digest_spec import get_member_ids from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 from bioimageio.spec.model.v0_5 import TensorId +from .digest_spec import get_member_ids from .proc_ops import ( AddKnownDatasetStats, - EnsureDtype, Processing, UpdateStats, get_proc_class, diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py index b4297a5b..0620282d 100644 --- a/bioimageio/core/sample.py +++ b/bioimageio/core/sample.py @@ -16,9 +16,8 @@ import numpy as np from typing_extensions import Self -from bioimageio.core.block import Block - from .axis import AxisId, PerAxis +from .block import Block from .block_meta import ( BlockMeta, LinearAxisTransform, @@ -212,9 +211,11 @@ def get_member_halo(m: MemberId, round: Callable[[float], int]): halo: Dict[MemberId, Dict[AxisId, Halo]] = {} for m in new_axes: halo[m] = get_member_halo(m, floor) - assert halo[m] == get_member_halo( - m, ceil - ), f"failed to unambiguously scale halo {halo[m]} with {new_axes[m]}" + if halo[m] != get_member_halo(m, ceil): + raise ValueError( + f"failed to unambiguously scale halo {halo[m]} with {new_axes[m]}" + + f" for {m}." + ) inner_slice = { m: { @@ -294,8 +295,10 @@ def get_transformed_meta( @dataclass class SampleBlockWithOrigin(SampleBlock): + """A `SampleBlock` with a reference (`origin`) to the whole `Sample`""" + origin: Sample - """the sample this sample black was taken from""" + """the sample this sample block was taken from""" class _ConsolidatedMemberBlocks: @@ -331,7 +334,7 @@ def sample_block_generator( *, origin: Sample, pad_mode: PadMode, -): +) -> Iterable[SampleBlockWithOrigin]: for member_blocks in blocks: cons = _ConsolidatedMemberBlocks(member_blocks) yield SampleBlockWithOrigin( diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml index fd4aa65a..22353103 100644 --- a/dev/env-py38.yaml +++ b/dev/env-py38.yaml @@ -2,13 +2,13 @@ name: core38 channels: - conda-forge - - defaults + - nodefaults dependencies: - - bioimageio.spec>=0.5.3.3 + - bioimageio.spec>=0.5.3.5 - black - crick # uncommented - filelock - - fire + - h5py - imageio>=2.5 - jupyter - jupyter-black @@ -27,7 +27,6 @@ dependencies: - pytest - pytest-cov - pytest-xdist - - python-dotenv - python=3.8 # changed - pytorch>=2.1 - requests diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml index 7f8ae3a7..0df6fd07 100644 --- a/dev/env-tf.yaml +++ b/dev/env-tf.yaml @@ -2,13 +2,13 @@ name: core-tf # changed channels: - conda-forge - - defaults + - nodefaults dependencies: - - bioimageio.spec>=0.5.3.3 + - bioimageio.spec>=0.5.3.5 - black # - crick # currently requires python<=3.9 - filelock - - fire + - h5py - imageio>=2.5 - jupyter - jupyter-black @@ -27,7 +27,6 @@ dependencies: - pytest - pytest-cov - pytest-xdist - - python-dotenv # - python=3.9 # removed # - pytorch>=2.1 # removed - requests diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml index 6cb0a18e..d8cba289 100644 --- a/dev/env-wo-python.yaml +++ b/dev/env-wo-python.yaml @@ -2,13 +2,14 @@ name: core channels: - conda-forge - - defaults + - nodefaults + - pytorch # added dependencies: - - bioimageio.spec>=0.5.3.3 + - bioimageio.spec>=0.5.3.5 - black # - crick # currently requires python<=3.9 - filelock - - fire + - h5py - imageio>=2.5 - jupyter - jupyter-black @@ -27,12 +28,11 @@ dependencies: - pytest - pytest-cov - pytest-xdist - - python-dotenv # - python=3.9 # removed - pytorch>=2.1 - requests - rich - # - ruff # removed + - ruff - ruyaml - torchvision - tqdm diff --git a/dev/env.yaml b/dev/env.yaml index 7fa2ab7b..20d60a18 100644 --- a/dev/env.yaml +++ b/dev/env.yaml @@ -2,11 +2,11 @@ name: core channels: - conda-forge dependencies: - - bioimageio.spec>=0.5.3.3 + - bioimageio.spec>=0.5.3.5 - black # - crick # currently requires python<=3.9 - filelock - - fire + - h5py - imageio>=2.5 - jupyter - jupyter-black @@ -27,7 +27,6 @@ dependencies: - pytest - pytest-cov - pytest-xdist - - python-dotenv - python=3.9 - pytorch>=2.1 - requests diff --git a/pyproject.toml b/pyproject.toml index 59421385..91cd2cbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,21 @@ [tool.black] line-length = 88 -extend_exclude = "^/presentations/" +extend_exclude = "/presentations/" target-version = ["py38", "py39", "py310", "py311", "py312"] +preview = true [tool.pyright] -exclude = ["**/node_modules", "**/__pycache__", "tests/old_*", "presentations"] +exclude = [ + "**/__pycache__", + "**/node_modules", + "presentations", + "scripts/pdoc/original.py", + "scripts/pdoc/patched.py", + "tests/old_*", +] include = ["bioimageio", "scripts", "tests"] pythonPlatform = "All" -pythonVersion = "3.8" +pythonVersion = "3.12" reportDuplicateImport = "error" reportImplicitStringConcatenation = "error" reportIncompatibleMethodOverride = true @@ -35,6 +43,13 @@ addopts = "--cov=bioimageio --cov-report=xml -n auto --capture=no --doctest-modu [tool.ruff] line-length = 88 +target-version = "py312" include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"] -exclude = ["presentations"] -target-version = "py38" +exclude = [ + "presentations", + "scripts/pdoc/original.py", + "scripts/pdoc/patched.py", +] + +[tool.coverage.report] +exclude_also = ["if TYPE_CHECKING:", "assert_never\\("] diff --git a/scripts/pdoc/create_pydantic_patch.sh b/scripts/pdoc/create_pydantic_patch.sh new file mode 100644 index 00000000..05b6da6b --- /dev/null +++ b/scripts/pdoc/create_pydantic_patch.sh @@ -0,0 +1,25 @@ +pydantic_root=$(python -c "import pydantic;from pathlib import Path;print(Path(pydantic.__file__).parent)") +main=$pydantic_root'/main.py' +original="$(dirname "$0")/original.py" +patched="$(dirname "$0")/patched.py" + +if [ -e $original ] +then + echo "found existing $original" +else + cp --verbose $main $original +fi + +if [ -e $patched ] +then + echo "found existing $patched" +else + cp --verbose $main $patched + echo "Please update $patched, then press enter to continue" + read +fi + +patch_file="$(dirname "$0")/mark_pydantic_attrs_private.patch" +diff -au $original $patched > $patch_file +echo "content of $patch_file:" +cat $patch_file diff --git a/scripts/pdoc/mark_pydantic_attrs_private.patch b/scripts/pdoc/mark_pydantic_attrs_private.patch new file mode 100644 index 00000000..722d4fbb --- /dev/null +++ b/scripts/pdoc/mark_pydantic_attrs_private.patch @@ -0,0 +1,28 @@ +--- ./original.py 2024-11-08 15:18:37.493768700 +0100 ++++ ./patched.py 2024-11-08 15:13:54.288887700 +0100 +@@ -121,14 +121,14 @@ + # `GenerateSchema.model_schema` to work for a plain `BaseModel` annotation. + + model_config: ClassVar[ConfigDict] = ConfigDict() +- """ ++ """@private + Configuration for the model, should be a dictionary conforming to [`ConfigDict`][pydantic.config.ConfigDict]. + """ + + # Because `dict` is in the local namespace of the `BaseModel` class, we use `Dict` for annotations. + # TODO v3 fallback to `dict` when the deprecated `dict` method gets removed. + model_fields: ClassVar[Dict[str, FieldInfo]] = {} # noqa: UP006 +- """ ++ """@private + Metadata about the fields defined on the model, + mapping of field names to [`FieldInfo`][pydantic.fields.FieldInfo] objects. + +@@ -136,7 +136,7 @@ + """ + + model_computed_fields: ClassVar[Dict[str, ComputedFieldInfo]] = {} # noqa: UP006 +- """A dictionary of computed field names and their corresponding `ComputedFieldInfo` objects.""" ++ """@private A dictionary of computed field names and their corresponding `ComputedFieldInfo` objects.""" + + __class_vars__: ClassVar[set[str]] + """The names of the class variables defined on the model.""" diff --git a/scripts/pdoc/run.sh b/scripts/pdoc/run.sh new file mode 100644 index 00000000..74981aa5 --- /dev/null +++ b/scripts/pdoc/run.sh @@ -0,0 +1,16 @@ +cd "$(dirname "$0")" # cd to folder this script is in + +# patch pydantic to hide pydantic attributes that somehow show up in the docs +# (not even as inherited, but as if the documented class itself would define them) +pydantic_main=$(python -c "import pydantic;from pathlib import Path;print(Path(pydantic.__file__).parent / 'main.py')") + +patch --verbose --forward -p1 $pydantic_main < mark_pydantic_attrs_private.patch + +cd ../.. # cd to repo root +pdoc \ + --docformat google \ + --logo "https://bioimage.io/static/img/bioimage-io-logo.svg" \ + --logo-link "https://bioimage.io/" \ + --favicon "https://bioimage.io/static/img/bioimage-io-icon-small.svg" \ + --footer-text "bioimageio.core $(python -c 'import bioimageio.core;print(bioimageio.core.__version__)')" \ + -o ./dist bioimageio.core bioimageio.spec # generate bioimageio.spec as well for references diff --git a/setup.py b/setup.py index 98650cf1..065cc1b5 100644 --- a/setup.py +++ b/setup.py @@ -26,16 +26,17 @@ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], packages=find_namespace_packages(exclude=["tests"]), install_requires=[ - "bioimageio.spec ==0.5.3.3", + "bioimageio.spec ==0.5.3.5", + "h5py", "imageio>=2.10", "loguru", "numpy", - "pydantic-settings >=2.3", + "pydantic-settings >=2.5", "pydantic", - "python-dotenv", "requests", "ruyaml", "tqdm", diff --git a/tests/conftest.py b/tests/conftest.py index c1da1ab7..253ade2f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -150,18 +150,14 @@ @fixture(scope="session") -def mamba_cmd(): - mamba_cmd = "micromamba" +def conda_cmd(): + conda_cmd = "conda" try: - _ = subprocess.run(["which", mamba_cmd], check=True) + _ = subprocess.run(["which", conda_cmd], check=True) except (subprocess.CalledProcessError, FileNotFoundError): - mamba_cmd = "mamba" - try: - _ = subprocess.run(["which", mamba_cmd], check=True) - except (subprocess.CalledProcessError, FileNotFoundError): - mamba_cmd = None + conda_cmd = None - return mamba_cmd + return conda_cmd # @@ -169,39 +165,39 @@ def mamba_cmd(): # -@fixture(params=TORCH_MODELS) +@fixture(scope="session", params=TORCH_MODELS) def any_torch_model(request: FixtureRequest): return MODEL_SOURCES[request.param] -@fixture(params=TORCHSCRIPT_MODELS) +@fixture(scope="session", params=TORCHSCRIPT_MODELS) def any_torchscript_model(request: FixtureRequest): return MODEL_SOURCES[request.param] -@fixture(params=ONNX_MODELS) +@fixture(scope="session", params=ONNX_MODELS) def any_onnx_model(request: FixtureRequest): return MODEL_SOURCES[request.param] -@fixture(params=TENSORFLOW_MODELS) +@fixture(scope="session", params=TENSORFLOW_MODELS) def any_tensorflow_model(request: FixtureRequest): return MODEL_SOURCES[request.param] -@fixture(params=KERAS_MODELS) +@fixture(scope="session", params=KERAS_MODELS) def any_keras_model(request: FixtureRequest): return MODEL_SOURCES[request.param] -@fixture(params=TENSORFLOW_JS_MODELS) +@fixture(scope="session", params=TENSORFLOW_JS_MODELS) def any_tensorflow_js_model(request: FixtureRequest): return MODEL_SOURCES[request.param] # fixture to test with all models that should run in the current environment # we exclude any 'wrong' model here -@fixture(params=sorted({m for m in ALL_MODELS if "wrong" not in m})) +@fixture(scope="session", params=sorted({m for m in ALL_MODELS if "wrong" not in m})) def any_model(request: FixtureRequest): return MODEL_SOURCES[request.param] @@ -243,48 +239,52 @@ def unet2d_keras(request: FixtureRequest): # written as model group to automatically skip on missing torch -@fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"]) +@fixture(scope="session", params=[] if skip_torch else ["unet2d_nuclei_broad_model"]) def unet2d_nuclei_broad_model(request: FixtureRequest): return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch -@fixture(params=[] if skip_torch else ["unet2d_diff_output_shape"]) +@fixture(scope="session", params=[] if skip_torch else ["unet2d_diff_output_shape"]) def unet2d_diff_output_shape(request: FixtureRequest): return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch -@fixture(params=[] if skip_torch else ["unet2d_expand_output_shape"]) +@fixture(scope="session", params=[] if skip_torch else ["unet2d_expand_output_shape"]) def unet2d_expand_output_shape(request: FixtureRequest): return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch -@fixture(params=[] if skip_torch else ["unet2d_fixed_shape"]) +@fixture(scope="session", params=[] if skip_torch else ["unet2d_fixed_shape"]) def unet2d_fixed_shape(request: FixtureRequest): return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch -@fixture(params=[] if skip_torch else ["shape_change"]) +@fixture(scope="session", params=[] if skip_torch else ["shape_change"]) def shape_change_model(request: FixtureRequest): return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing tensorflow 1 -@fixture(params=["stardist_wrong_shape"] if tf_major_version == 1 else []) +@fixture( + scope="session", params=["stardist_wrong_shape"] if tf_major_version == 1 else [] +) def stardist_wrong_shape(request: FixtureRequest): return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing tensorflow 1 -@fixture(params=["stardist_wrong_shape2"] if tf_major_version == 1 else []) +@fixture( + scope="session", params=["stardist_wrong_shape2"] if tf_major_version == 1 else [] +) def stardist_wrong_shape2(request: FixtureRequest): return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing tensorflow 1 -@fixture(params=["stardist"] if tf_major_version == 1 else []) +@fixture(scope="session", params=["stardist"] if tf_major_version == 1 else []) def stardist(request: FixtureRequest): return MODEL_SOURCES[request.param] diff --git a/tests/test_bioimageio_spec_version.py b/tests/test_bioimageio_spec_version.py index 75c1303d..921ecd9c 100644 --- a/tests/test_bioimageio_spec_version.py +++ b/tests/test_bioimageio_spec_version.py @@ -6,26 +6,22 @@ from packaging.version import Version -def test_bioimageio_spec_version(mamba_cmd: Optional[str]): - if mamba_cmd is None: +def test_bioimageio_spec_version(conda_cmd: Optional[str]): + if conda_cmd is None: pytest.skip("requires mamba") from importlib.metadata import metadata # get latest released bioimageio.spec version - mamba_repoquery = subprocess.run( - f"{mamba_cmd} repoquery search -c conda-forge --json bioimageio.spec".split( - " " - ), + conda_search = subprocess.run( + f"{conda_cmd} search --json -f conda-forge::bioimageio.spec>=0.5.3.2".split(), encoding="utf-8", capture_output=True, check=True, ) - full_out = mamba_repoquery.stdout # full output includes mamba banner - search = json.loads(full_out[full_out.find("{") :]) # json output starts at '{' - latest_spec = max(search["result"]["pkgs"], key=lambda entry: entry["timestamp"]) - rmaj, rmin, rpatch, *_ = latest_spec["version"].split(".") - released = Version(f"{rmaj}.{rmin}.{rpatch}") + result = json.loads(conda_search.stdout) + latest_spec = max(result["bioimageio.spec"], key=lambda entry: entry["timestamp"]) + released = Version(latest_spec["version"]) # get currently pinned bioimageio.spec version meta = metadata("bioimageio.core") @@ -41,10 +37,5 @@ def test_bioimageio_spec_version(mamba_cmd: Optional[str]): ) assert spec_ver.count(".") == 3 - pmaj, pmin, ppatch, _ = spec_ver.split(".") - assert ( - pmaj.isdigit() and pmin.isdigit() and ppatch.isdigit() - ), "bioimageio.spec version should be pinned down to patch, e.g. '0.4.9.*'" - - pinned = Version(f"{pmaj}.{pmin}.{ppatch}") + pinned = Version(spec_ver) assert pinned == released, "bioimageio.spec not pinned to the latest version" diff --git a/tests/test_cli.py b/tests/test_cli.py index 0ecd7528..e0828ac6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -35,6 +35,7 @@ def run_subprocess( "pytorch_state_dict", ], ["test", "unet2d_nuclei_broad_model"], + ["predict", "--example", "unet2d_nuclei_broad_model"], ], ) def test_cli(args: List[str], unet2d_nuclei_broad_model: str): diff --git a/tests/test_commands.py b/tests/test_commands.py new file mode 100644 index 00000000..54981f00 --- /dev/null +++ b/tests/test_commands.py @@ -0,0 +1,44 @@ +from pathlib import Path +from typing import Literal, Optional + +import pytest + +from bioimageio.core import load_model +from bioimageio.core.commands import package, validate_format +from bioimageio.core.commands import test as command_tst +from bioimageio.spec import AnyModelDescr + + +@pytest.fixture(scope="module") +def model(unet2d_nuclei_broad_model: str): + return load_model(unet2d_nuclei_broad_model, perform_io_checks=False) + + +@pytest.mark.parametrize( + "weight_format", + [ + "all", + "pytorch_state_dict", + ], +) +def test_package( + weight_format: Literal["all", "pytorch_state_dict"], + model: AnyModelDescr, + tmp_path: Path, +): + assert package(model, weight_format=weight_format, path=tmp_path / "out.zip") == 0 + + +def test_validate_format(model: AnyModelDescr): + assert validate_format(model) == 0 + + +@pytest.mark.parametrize( + "weight_format,devices", [("all", None), ("pytorch_state_dict", "cpu")] +) +def test_test( + weight_format: Literal["all", "pytorch_state_dict"], + devices: Optional[str], + model: AnyModelDescr, +): + assert command_tst(model, weight_format=weight_format, devices=devices) == 0 diff --git a/tests/test_digest_spec.py b/tests/test_digest_spec.py index 08022ab2..8c7b8bb5 100644 --- a/tests/test_digest_spec.py +++ b/tests/test_digest_spec.py @@ -1,11 +1,8 @@ -import pytest - from bioimageio.spec import load_description from bioimageio.spec.model import v0_5 # TODO: don't just test with unet2d_nuclei_broad_model -@pytest.mark.skip("get_io_sample_block_metas needs improvements") def test_get_block_transform(unet2d_nuclei_broad_model: str): from bioimageio.core.axis import AxisId from bioimageio.core.common import MemberId @@ -25,18 +22,21 @@ def test_get_block_transform(unet2d_nuclei_broad_model: str): if isinstance(a.size, v0_5.ParameterizedSize) } + input_sample_shape = { + MemberId("raw"): { + AxisId("batch"): 3, + AxisId("channel"): 1, + AxisId("x"): 4000, + AxisId("y"): 3000, + } + } + _, blocks = get_io_sample_block_metas( model, - input_sample_shape={ - MemberId("raw"): { - AxisId("batch"): 3, - AxisId("channel"): 1, - AxisId("x"): 4000, - AxisId("y"): 3000, - } - }, + input_sample_shape=input_sample_shape, ns=ns, ) + for ipt_block, out_block in blocks: trf_block = ipt_block.get_transformed(block_transform) assert out_block == trf_block diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 00000000..a45dfe51 --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,38 @@ +from pathlib import Path +from typing import Tuple + +import numpy as np +import pytest + + +@pytest.mark.parametrize( + "name", + [ + "img.png", + "img.tiff", + "img.h5", + "img.h5/img", + "img.npy", + ], +) +@pytest.mark.parametrize( + "shape", + [ + (4, 5), + (3, 4, 5), + (1, 4, 5), + (5, 4, 3), + (5, 3, 4), + ], +) +def test_image_io(name: str, shape: Tuple[int, ...], tmp_path: Path): + from bioimageio.core import Tensor + from bioimageio.core.io import load_tensor, save_tensor + + path = tmp_path / name + data = Tensor.from_numpy( + np.arange(np.prod(shape), dtype=np.uint8).reshape(shape), dims=None + ) + save_tensor(path, data) + actual = load_tensor(path) + assert actual == data diff --git a/tests/test_prediction.py b/tests/test_prediction.py index de8b8062..bd30f064 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,228 +1,127 @@ -# TODO: update -# from pathlib import Path - -# import imageio -# import numpy as np -# from numpy.testing import assert_array_almost_equal - -# from bioimageio.spec import load_description -# from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr_v0_4 -# from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr_v0_4 -# from bioimageio.spec.model.v0_5 import ModelDescr - - -# def test_predict_image(any_model: Path, tmpdir: Path): -# from bioimageio.core.prediction import predict_image - -# spec = load_description(any_model) -# assert isinstance(spec, ModelDescr) -# inputs = spec.test_inputs - -# outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))] -# predict_image(any_model, inputs, outputs) -# for out_path in outputs: -# assert out_path.exists() - -# result = [np.load(str(p)) for p in outputs] -# expected = [np.load(str(p)) for p in spec.test_outputs] -# for res, exp in zip(result, expected): -# assert_array_almost_equal(res, exp, decimal=4) - - -# def test_predict_image_with_weight_format( -# unet2d_fixed_shape_or_not: Path, tmpdir: Path -# ): -# from bioimageio.core.prediction import predict_image - -# spec = load_description(unet2d_fixed_shape_or_not) -# assert isinstance(spec, Model) -# inputs = spec.test_inputs - -# outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))] -# predict_image( -# unet2d_fixed_shape_or_not, inputs, outputs, weight_format="pytorch_state_dict" -# ) -# for out_path in outputs: -# assert out_path.exists() - -# result = [np.load(str(p)) for p in outputs] -# expected = [np.load(str(p)) for p in spec.test_outputs] -# for res, exp in zip(result, expected): -# assert_array_almost_equal(res, exp, decimal=4) - - -# def _test_predict_with_padding(any_model: Path, tmp_path: Path): -# from bioimageio.core.digest_spec import get_test_inputs -# from bioimageio.core.prediction import predict_image - -# model = load_description(any_model) -# assert isinstance(model, (ModelDescr_v0_4, ModelDescr)) - -# input_spec, output_spec = model.inputs[0], model.outputs[0] -# channel_axis = ( -# "c" -# if isinstance(input_spec, InputTensorDescr_v0_4) -# else [a.id for a in input_spec.axes][0] -# ) -# channel_first = channel_axis == 1 - -# # TODO: check more tensors -# image = get_test_inputs(model)[0] - -# if isinstance(output_spec.shape, list): -# n_channels = output_spec.shape[channel_axis] -# else: -# scale = output_spec.shape.scale[channel_axis] -# offset = output_spec.shape.offset[channel_axis] -# in_channels = 1 -# n_channels = int(2 * offset + scale * in_channels) - -# # write the padded image -# image = image[3:-2, 1:-12] -# in_path = tmp_path / "in.tif" -# out_path = tmp_path / "out.tif" -# imageio.imwrite(in_path, image) - -# if hasattr(output_spec.shape, "scale"): -# scale = dict(zip(output_spec.axes, output_spec.shape.scale)) -# offset = dict(zip(output_spec.axes, output_spec.shape.offset)) -# spatial_axes = [ax for ax in output_spec.axes if ax in "xyz"] -# network_resizes = any( -# sc != 1 for ax, sc in scale.items() if ax in spatial_axes -# ) or any(off != 0 for ax, off in offset.items() if ax in spatial_axes) -# else: -# network_resizes = False - -# if network_resizes: -# exp_shape = tuple( -# int(sh * scale[ax] + 2 * offset[ax]) -# for sh, ax in zip(image.shape, spatial_axes) -# ) -# else: -# exp_shape = image.shape - -# def check_result(): -# if n_channels == 1: -# assert out_path.exists() -# res = imageio.imread(out_path) -# assert res.shape == exp_shape -# else: -# path = str(out_path) -# for c in range(n_channels): -# channel_out_path = Path(path.replace(".tif", f"-c{c}.tif")) -# assert channel_out_path.exists() -# res = imageio.imread(channel_out_path) -# assert res.shape == exp_shape - -# # test with dynamic padding -# predict_image( -# any_model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"} -# ) -# check_result() - -# # test with fixed padding -# predict_image( -# any_model, -# in_path, -# out_path, -# padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"}, -# ) -# check_result() - -# # test with automated padding -# predict_image(any_model, in_path, out_path, padding=True) -# check_result() - - -# # prediction with padding with the parameters above may not be suited for any model -# # so we only run it for the pytorch unet2d here -# def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path): -# _test_predict_with_padding(unet2d_fixed_shape_or_not, tmp_path) - - -# # and with different output shape -# def test_predict_image_with_padding_diff_output_shape( -# unet2d_diff_output_shape, tmp_path -# ): -# _test_predict_with_padding(unet2d_diff_output_shape, tmp_path) - - -# def test_predict_image_with_padding_channel_last(stardist, tmp_path): -# _test_predict_with_padding(stardist, tmp_path) - - -# def _test_predict_image_with_tiling(model: Path, tmp_path: Path, exp_mean_deviation): -# from bioimageio.core.prediction import predict_image - -# spec = load_description(model) -# assert isinstance(spec, Model) -# inputs = spec.test_inputs -# assert len(inputs) == 1 -# exp = np.load(str(spec.test_outputs[0])) - -# out_path = tmp_path.with_suffix(".npy") - -# def check_result(): -# assert out_path.exists() -# res = np.load(out_path) -# assert res.shape == exp.shape -# # check that the mean deviation is smaller than the expected value -# # note that we can't use array_almost_equal here, because the numerical differences -# # between tiled and normal prediction are too large -# mean_deviation = np.abs(res - exp).mean() -# assert mean_deviation <= exp_mean_deviation - -# # with tiling config -# tiling = {"halo": {"x": 32, "y": 32}, "tile": {"x": 256, "y": 256}} -# predict_image(model, inputs, [out_path], tiling=tiling) -# check_result() - -# # with tiling determined from spec -# predict_image(model, inputs, [out_path], tiling=True) -# check_result() - - -# # prediction with tiling with the parameters above may not be suited for any model -# # so we only run it for the pytorch unet2d here -# def test_predict_image_with_tiling_1(unet2d_nuclei_broad_model: Path, tmp_path: Path): -# _test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path, 0.012) - - -# def test_predict_image_with_tiling_2(unet2d_diff_output_shape: Path, tmp_path: Path): -# _test_predict_image_with_tiling(unet2d_diff_output_shape, tmp_path, 0.06) - - -# def test_predict_image_with_tiling_3(shape_change_model: Path, tmp_path: Path): -# _test_predict_image_with_tiling(shape_change_model, tmp_path, 0.012) - - -# def test_predict_image_with_tiling_channel_last(stardist: Path, tmp_path: Path): -# _test_predict_image_with_tiling(stardist, tmp_path, 0.13) - - -# def test_predict_image_with_tiling_fixed_output_shape( -# unet2d_fixed_shape: Path, tmp_path: Path -# ): -# _test_predict_image_with_tiling(unet2d_fixed_shape, tmp_path, 0.025) - - -# def test_predict_images(unet2d_nuclei_broad_model: Path, tmp_path: Path): -# from bioimageio.core.prediction import predict_images - -# n_images = 5 -# shape = (256, 256) - -# in_paths = [] -# out_paths = [] -# for i in range(n_images): -# in_path = tmp_path / f"in{i}.tif" -# im = np.random.randint(0, 255, size=shape).astype("uint8") -# imageio.imwrite(in_path, im) -# in_paths.append(in_path) -# out_paths.append(tmp_path / f"out{i}.tif") -# predict_images(unet2d_nuclei_broad_model, in_paths, out_paths) - -# for outp in out_paths: -# assert outp.exists() -# out = imageio.imread(outp) -# assert out.shape == shape +from pathlib import Path +from typing import Literal, Mapping, NamedTuple + +import numpy as np +import pytest +import xarray as xr +from typing_extensions import assert_never + +from bioimageio.core import ( + AxisId, + MemberId, + PredictionPipeline, + Sample, + create_prediction_pipeline, + load_model, + predict, +) +from bioimageio.core.digest_spec import get_test_inputs, get_test_outputs +from bioimageio.spec import AnyModelDescr + + +def _assert_equal_samples(expected: Sample, actual: Sample): + assert expected.id == actual.id + assert expected.members == actual.members + + +class Prep(NamedTuple): + model: AnyModelDescr + prediction_pipeline: PredictionPipeline + input_sample: Sample + output_sample: Sample + + +@pytest.fixture(scope="module") +def prep(any_model: str): + model = load_model(any_model, perform_io_checks=False) + input_sample = get_test_inputs(model) + output_sample = get_test_outputs(model) + return Prep(model, create_prediction_pipeline(model), input_sample, output_sample) + + +def test_predict_with_pipeline(prep: Prep): + out = predict( + model=prep.prediction_pipeline, + inputs=prep.input_sample, + ) + _assert_equal_samples(out, prep.output_sample) + + +@pytest.mark.parametrize("tensor_input", ["numpy", "xarray"]) +def test_predict_with_model_description( + tensor_input: Literal["numpy", "xarray"], prep: Prep +): + if tensor_input == "xarray": + ipt = {m: t.data for m, t in prep.input_sample.members.items()} + assert all(isinstance(v, xr.DataArray) for v in ipt.values()) + elif tensor_input == "numpy": + ipt = {m: t.data.data for m, t in prep.input_sample.members.items()} + assert all(isinstance(v, np.ndarray) for v in ipt.values()) + else: + assert_never(tensor_input) + + out = predict( + model=prep.model, + inputs=ipt, + sample_id=prep.input_sample.id, + skip_preprocessing=False, + skip_postprocessing=False, + ) + _assert_equal_samples(out, prep.output_sample) + + +@pytest.mark.parametrize("with_procs", [True, False]) +def test_predict_with_blocking(with_procs: bool, prep: Prep): + try: + out = predict( + model=prep.prediction_pipeline, + inputs=prep.input_sample, + blocksize_parameter=3, + sample_id=prep.input_sample.id, + skip_preprocessing=with_procs, + skip_postprocessing=with_procs, + ) + except NotImplementedError as e: + pytest.skip(str(e)) + + if with_procs: + _assert_equal_samples(out, prep.output_sample) + else: + assert isinstance(out, Sample) + + +def test_predict_with_fixed_blocking(prep: Prep): + block_along = list(prep.input_sample.members) + input_block_shape: Mapping[MemberId, Mapping[AxisId, int]] = { + ba: { + "x": min( # pyright: ignore[reportAssignmentType] + 128, prep.input_sample.members[ba].tagged_shape[AxisId("x")] + ), + AxisId("y"): min( + 128, prep.input_sample.members[ba].tagged_shape[AxisId("y")] + ), + } + for ba in block_along + } + try: + out = predict( + model=prep.prediction_pipeline, + inputs=prep.input_sample, + input_block_shape=input_block_shape, + sample_id=prep.input_sample.id, + ) + except NotImplementedError as e: + pytest.skip(str(e)) + + _assert_equal_samples(out, prep.output_sample) + + +def test_predict_save_output(prep: Prep, tmp_path: Path): + save_path = tmp_path / "{member_id}_{sample_id}.h5" + out = predict( + model=prep.prediction_pipeline, + inputs=prep.input_sample, + save_output_path=save_path, + ) + _assert_equal_samples(out, prep.output_sample) + assert save_path.parent.exists() diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py index 447eb698..0e241df1 100644 --- a/tests/test_prediction_pipeline_device_management.py +++ b/tests/test_prediction_pipeline_device_management.py @@ -52,36 +52,26 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): assert_array_almost_equal(out, exp, decimal=4) -@skip_on( - TooFewDevicesException, reason="Too few devices" -) # pyright: ignore[reportArgumentType] +@skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_torch(any_torch_model: Path): _test_device_management(any_torch_model, "pytorch_state_dict") -@skip_on( - TooFewDevicesException, reason="Too few devices" -) # pyright: ignore[reportArgumentType] +@skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_torchscript(any_torchscript_model: Path): _test_device_management(any_torchscript_model, "torchscript") -@skip_on( - TooFewDevicesException, reason="Too few devices" -) # pyright: ignore[reportArgumentType] +@skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_onnx(any_onnx_model: Path): _test_device_management(any_onnx_model, "onnx") -@skip_on( - TooFewDevicesException, reason="Too few devices" -) # pyright: ignore[reportArgumentType] +@skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_tensorflow(any_tensorflow_model: Path): _test_device_management(any_tensorflow_model, "tensorflow_saved_model_bundle") -@skip_on( - TooFewDevicesException, reason="Too few devices" -) # pyright: ignore[reportArgumentType] +@skip_on(TooFewDevicesException, reason="Too few devices") def test_device_management_keras(any_keras_model: Path): _test_device_management(any_keras_model, "keras_hdf5") diff --git a/tests/test_resource_tests.py b/tests/test_resource_tests.py index b9d3cf66..203ca64b 100644 --- a/tests/test_resource_tests.py +++ b/tests/test_resource_tests.py @@ -1,6 +1,17 @@ +from typing import Literal + +import pytest + from bioimageio.spec import InvalidDescr +@pytest.mark.parametrize("mode", ["seed_only", "full"]) +def test_enable_determinism(mode: Literal["seed_only", "full"]): + from bioimageio.core import enable_determinism + + enable_determinism(mode) + + def test_error_for_wrong_shape(stardist_wrong_shape: str): from bioimageio.core._resource_tests import test_model