Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add predict command #406

Merged
merged 81 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
4aba807
add predict command
FynnBe Jul 23, 2024
25aa7c0
remove unused import
FynnBe Jul 24, 2024
9146c3b
improve model_descr type check
FynnBe Jul 24, 2024
87dc749
remove commented old prediction commands
FynnBe Jul 24, 2024
b47e2b3
allow to package to folder
FynnBe Jul 24, 2024
bbfbbac
document packaging to folder
FynnBe Jul 24, 2024
dcd8334
add Get started section
FynnBe Jul 24, 2024
70a19e8
update type annotations for pyright 1.1.373
FynnBe Jul 25, 2024
abf5d20
pin spec
FynnBe Jul 25, 2024
b3402f7
use load_model_description
FynnBe Jul 25, 2024
21e24cb
improve predict command error for missing input
FynnBe Jul 25, 2024
17b4bbe
update test_bioimageio_spec_version
FynnBe Jul 25, 2024
eb9872d
include weight format in test name
FynnBe Aug 1, 2024
216bf00
make sure directory to save tensor in exists
FynnBe Aug 1, 2024
397c7a2
allow a mapping for save_sample input arg path
FynnBe Aug 1, 2024
7c78620
only calculate stats if any measures are missing
FynnBe Aug 1, 2024
3765812
add inspection helpers to get required measures
FynnBe Aug 1, 2024
4e256f3
allow axis ids to be given as strings
FynnBe Aug 1, 2024
1724589
use pydantic for stat measures and make (small) tensors json serializ…
FynnBe Aug 1, 2024
ea2cac7
rewrite CLI
FynnBe Aug 1, 2024
a0b00fb
we need a main func in main for the endpoint
FynnBe Aug 1, 2024
b97c330
update dependencies
FynnBe Aug 1, 2024
aa3e934
remove invalid alias
FynnBe Aug 1, 2024
9c98c7e
WIP update README.md
FynnBe Aug 1, 2024
3f8af93
fix typing issue
FynnBe Aug 5, 2024
d7f0a78
update tests
FynnBe Aug 5, 2024
b557940
add default path
FynnBe Aug 5, 2024
4b01578
update test_scale_range_axes
FynnBe Aug 5, 2024
2d22241
add default package path
FynnBe Aug 5, 2024
04e9f0c
read command line arguments from file
FynnBe Aug 5, 2024
8418fed
remove default path
FynnBe Aug 5, 2024
d8fb60f
reference file formats from imageio
FynnBe Aug 5, 2024
dcfd021
add cli file example
FynnBe Aug 5, 2024
6a9bb2d
complete test command
FynnBe Aug 5, 2024
bfd7c13
add output path explicitly
FynnBe Aug 5, 2024
c36fec9
bump patch version
FynnBe Aug 5, 2024
4497fbf
set output path explicitly
FynnBe Aug 5, 2024
f24b96a
black
FynnBe Aug 5, 2024
d7c4547
add conda env doc link
FynnBe Aug 6, 2024
3dfcc6c
remove clutter
FynnBe Aug 6, 2024
352422f
improve CLI
FynnBe Aug 9, 2024
4052343
log inputs
FynnBe Aug 12, 2024
b8b7f6a
pass without shorten input sequence
FynnBe Aug 12, 2024
4418a08
drop singleton batch axis when saving a tensor
FynnBe Aug 12, 2024
fe471ef
improve logging
FynnBe Aug 12, 2024
01e43fc
improve example
FynnBe Aug 12, 2024
d704e53
remove unused imports
FynnBe Aug 12, 2024
ba29ddc
improve doc strings
FynnBe Aug 12, 2024
c629a91
use argparse.RawTextHelpFormatter
FynnBe Aug 12, 2024
aa5c316
add weight_format option to predict command
FynnBe Aug 12, 2024
eec7bde
make sure example dir exists
FynnBe Aug 12, 2024
915e56c
fail for missing input samples
FynnBe Aug 12, 2024
103ca42
ignore empty initial dataset measures
FynnBe Aug 13, 2024
ee9be64
perform IO checks based on env var
FynnBe Aug 13, 2024
a0ae60c
add section on logging level
FynnBe Aug 13, 2024
5b1bd86
fix tqdm call
FynnBe Aug 13, 2024
d9fd4f6
insert singleton axis at right position
FynnBe Aug 13, 2024
83d3290
improve stat serialization
FynnBe Aug 13, 2024
d5a0814
fix tensor_custom_before_validator
FynnBe Aug 13, 2024
ca0169e
try all array permutations to match singleton requirements
FynnBe Aug 13, 2024
2605728
fix _get_array_view
FynnBe Aug 13, 2024
e2d1616
actually remove YAML_FILE
FynnBe Aug 14, 2024
090d979
improve saving with imageio
FynnBe Aug 14, 2024
7be6a7e
bump imageio to make sure v3 is available
FynnBe Aug 14, 2024
cbe02ba
simplify io with imageio.v3
FynnBe Aug 14, 2024
6c68179
allow dims to be AxisLike
FynnBe Aug 14, 2024
b9897da
improve help text formatting
FynnBe Aug 14, 2024
29de8e6
update 'Get started' section
FynnBe Aug 14, 2024
f6c3bd7
do not rule out singleton axis as easily
FynnBe Aug 16, 2024
e5d83ba
allow space and time axes to be singletons
FynnBe Aug 21, 2024
6147d4d
avoid 'ABCMeta' object is not subscriptable
FynnBe Aug 21, 2024
1d3000b
bump spec
FynnBe Aug 21, 2024
c766f12
bump spec in dev envs
FynnBe Aug 21, 2024
14f599e
add predict_sample_with_fixed_blocking
FynnBe Aug 22, 2024
bc98d65
do not convert axis id
FynnBe Aug 23, 2024
5a9c1c7
do not convert axes ids for proc ops
FynnBe Aug 23, 2024
83d6a92
expose more functions
FynnBe Aug 23, 2024
9726e60
update digest_spec
FynnBe Aug 23, 2024
2ec9e9e
AxisId is also AxisLike!
FynnBe Aug 23, 2024
88b5fca
Merge remote-tracking branch 'origin/main' into predict_cmd
FynnBe Sep 13, 2024
7306f98
remove unused dependency
FynnBe Sep 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 289 additions & 40 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion bioimageio/core/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.6.8"
"version": "0.6.9"
}
5 changes: 4 additions & 1 deletion bioimageio/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@

from bioimageio.spec import build_description as build_description
from bioimageio.spec import dump_description as dump_description
from bioimageio.spec import load_dataset_description as load_dataset_description
from bioimageio.spec import load_description as load_description
from bioimageio.spec import (
load_description_and_validate_format_only as load_description_and_validate_format_only,
)
from bioimageio.spec import load_model_description as load_model_description
from bioimageio.spec import save_bioimageio_package as save_bioimageio_package
from bioimageio.spec import (
save_bioimageio_package_as_folder as save_bioimageio_package_as_folder,
)
from bioimageio.spec import save_bioimageio_yaml_only as save_bioimageio_yaml_only
from bioimageio.spec import validate_format as validate_format

from . import digest_spec as digest_spec
from ._prediction_pipeline import PredictionPipeline as PredictionPipeline
from ._prediction_pipeline import (
create_prediction_pipeline as create_prediction_pipeline,
Expand All @@ -38,4 +41,4 @@
# aliases
test_resource = test_description
load_resource = load_description
load_model = load_description
load_model = load_model_description
8 changes: 7 additions & 1 deletion bioimageio/core/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from bioimageio.core.commands import main
from bioimageio.core.cli import Bioimageio


def main():
cli = Bioimageio() # pyright: ignore[reportCallIssue]
cli.run()


if __name__ == "__main__":
main()
78 changes: 48 additions & 30 deletions bioimageio/core/_prediction_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def __init__(
postprocessing: List[Processing],
model_adapter: ModelAdapter,
default_ns: Union[
v0_5.ParameterizedSize.N,
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N],
v0_5.ParameterizedSize_N,
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
] = 10,
default_batch_size: int = 1,
) -> None:
Expand Down Expand Up @@ -179,40 +179,17 @@ def get_output_sample_id(self, input_sample_id: SampleId):
self.model_description.id or self.model_description.name
)

def predict_sample_with_blocking(
def predict_sample_with_fixed_blocking(
self,
sample: Sample,
input_block_shape: Mapping[MemberId, Mapping[AxisId, int]],
*,
skip_preprocessing: bool = False,
skip_postprocessing: bool = False,
ns: Optional[
Union[
v0_5.ParameterizedSize.N,
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N],
]
] = None,
batch_size: Optional[int] = None,
) -> Sample:
"""predict a sample by splitting it into blocks according to the model and the `ns` parameter"""
if not skip_preprocessing:
self.apply_preprocessing(sample)

if isinstance(self.model_description, v0_4.ModelDescr):
raise NotImplementedError(
"predict with blocking not implemented for v0_4.ModelDescr {self.model_description.name}"
)

ns = ns or self._default_ns
if isinstance(ns, int):
ns = {
(ipt.id, a.id): ns
for ipt in self.model_description.inputs
for a in ipt.axes
if isinstance(a.size, v0_5.ParameterizedSize)
}
input_block_shape = self.model_description.get_tensor_sizes(
ns, batch_size or self._default_batch_size
).inputs

n_blocks, input_blocks = sample.split_into_blocks(
input_block_shape,
halo=self._default_input_halo,
Expand All @@ -239,6 +216,47 @@ def predict_sample_with_blocking(

return predicted_sample

def predict_sample_with_blocking(
self,
sample: Sample,
skip_preprocessing: bool = False,
skip_postprocessing: bool = False,
ns: Optional[
Union[
v0_5.ParameterizedSize_N,
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
]
] = None,
batch_size: Optional[int] = None,
) -> Sample:
"""predict a sample by splitting it into blocks according to the model and the `ns` parameter"""

if isinstance(self.model_description, v0_4.ModelDescr):
raise NotImplementedError(
"`predict_sample_with_blocking` not implemented for v0_4.ModelDescr"
+ f" {self.model_description.name}."
+ " Consider using `predict_sample_with_fixed_blocking`"
)

ns = ns or self._default_ns
if isinstance(ns, int):
ns = {
(ipt.id, a.id): ns
for ipt in self.model_description.inputs
for a in ipt.axes
if isinstance(a.size, v0_5.ParameterizedSize)
}
input_block_shape = self.model_description.get_tensor_sizes(
ns, batch_size or self._default_batch_size
).inputs

return self.predict_sample_with_fixed_blocking(
sample,
input_block_shape=input_block_shape,
skip_preprocessing=skip_preprocessing,
skip_postprocessing=skip_postprocessing,
)

# def predict(
# self,
# inputs: Predict_IO,
Expand Down Expand Up @@ -310,8 +328,8 @@ def create_prediction_pipeline(
),
model_adapter: Optional[ModelAdapter] = None,
ns: Union[
v0_5.ParameterizedSize.N,
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N],
v0_5.ParameterizedSize_N,
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize_N],
] = 10,
**deprecated_kwargs: Any,
) -> PredictionPipeline:
Expand Down
16 changes: 8 additions & 8 deletions bioimageio/core/_resource_tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import traceback
import warnings
from itertools import product
from typing import Dict, Hashable, List, Literal, Optional, Set, Tuple, Union
from typing import Dict, Hashable, List, Literal, Optional, Sequence, Set, Tuple, Union

import numpy as np
from loguru import logger
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_description(
*,
format_version: Union[Literal["discover", "latest"], str] = "discover",
weight_format: Optional[WeightsFormat] = None,
devices: Optional[List[str]] = None,
devices: Optional[Sequence[str]] = None,
absolute_tolerance: float = 1.5e-4,
relative_tolerance: float = 1e-4,
decimal: Optional[int] = None,
Expand All @@ -83,7 +83,7 @@ def load_description_and_test(
*,
format_version: Union[Literal["discover", "latest"], str] = "discover",
weight_format: Optional[WeightsFormat] = None,
devices: Optional[List[str]] = None,
devices: Optional[Sequence[str]] = None,
absolute_tolerance: float = 1.5e-4,
relative_tolerance: float = 1e-4,
decimal: Optional[int] = None,
Expand Down Expand Up @@ -138,12 +138,12 @@ def load_description_and_test(
def _test_model_inference(
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
weight_format: WeightsFormat,
devices: Optional[List[str]],
devices: Optional[Sequence[str]],
absolute_tolerance: float,
relative_tolerance: float,
decimal: Optional[int],
) -> None:
test_name = "Reproduce test outputs from test inputs"
test_name = f"Reproduce test outputs from test inputs ({weight_format})"
logger.info("starting '{}'", test_name)
error: Optional[str] = None
tb: List[str] = []
Expand Down Expand Up @@ -209,15 +209,15 @@ def _test_model_inference(
def _test_model_inference_parametrized(
model: v0_5.ModelDescr,
weight_format: WeightsFormat,
devices: Optional[List[str]],
devices: Optional[Sequence[str]],
) -> None:
if not any(
isinstance(a.size, v0_5.ParameterizedSize)
for ipt in model.inputs
for a in ipt.axes
):
# no parameterized sizes => set n=0
ns: Set[v0_5.ParameterizedSize.N] = {0}
ns: Set[v0_5.ParameterizedSize_N] = {0}
else:
ns = {0, 1, 2}

Expand All @@ -236,7 +236,7 @@ def _test_model_inference_parametrized(
# no batch axis
batch_sizes = {1}

test_cases: Set[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = {
test_cases: Set[Tuple[v0_5.ParameterizedSize_N, BatchSize]] = {
(n, b) for n, b in product(sorted(ns), sorted(batch_sizes))
}
logger.info(
Expand Down
27 changes: 6 additions & 21 deletions bioimageio/core/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,6 @@ def _get_axis_type(a: Literal["b", "t", "i", "c", "x", "y", "z"]):
S = TypeVar("S", bound=str)


def _get_axis_id(a: Union[Literal["b", "t", "i", "c"], S]):
if a == "b":
return AxisId("batch")
elif a == "t":
return AxisId("time")
elif a == "i":
return AxisId("index")
elif a == "c":
return AxisId("channel")
else:
return AxisId(a)


AxisId = v0_5.AxisId

T = TypeVar("T")
Expand All @@ -47,7 +34,7 @@ def _get_axis_id(a: Union[Literal["b", "t", "i", "c"], S]):
BatchSize = int

AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"]
AxisLike = Union[AxisLetter, v0_5.AnyAxis, "Axis"]
AxisLike = Union[AxisId, AxisLetter, v0_5.AnyAxis, "Axis"]


@dataclass
Expand All @@ -62,7 +49,7 @@ def create(cls, axis: AxisLike) -> Axis:
elif isinstance(axis, Axis):
return Axis(id=axis.id, type=axis.type)
elif isinstance(axis, str):
return Axis(id=_get_axis_id(axis), type=_get_axis_type(axis))
return Axis(id=AxisId(axis), type=_get_axis_type(axis))
elif isinstance(axis, v0_5.AxisBase):
return Axis(id=AxisId(axis.id), type=axis.type)
else:
Expand All @@ -71,7 +58,7 @@ def create(cls, axis: AxisLike) -> Axis:

@dataclass
class AxisInfo(Axis):
maybe_singleton: bool
maybe_singleton: bool # TODO: replace 'maybe_singleton' with size min/max for better axis guessing

@classmethod
def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisInfo:
Expand All @@ -80,18 +67,16 @@ def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisI

axis_base = super().create(axis)
if maybe_singleton is None:
if isinstance(axis, Axis):
maybe_singleton = False
elif isinstance(axis, str):
maybe_singleton = axis == "b"
if isinstance(axis, (Axis, str)):
maybe_singleton = True
else:
if axis.size is None:
maybe_singleton = True
elif isinstance(axis.size, int):
maybe_singleton = axis.size == 1
elif isinstance(axis.size, v0_5.SizeReference):
maybe_singleton = (
False # TODO: check if singleton is ok for a `SizeReference`
True # TODO: check if singleton is ok for a `SizeReference`
)
elif isinstance(
axis.size, (v0_5.ParameterizedSize, v0_5.DataDependentSize)
Expand Down
Loading
Loading