Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into predict_cmd
Browse files Browse the repository at this point in the history
# Conflicts:
#	bioimageio/core/_resource_tests.py
  • Loading branch information
FynnBe committed Sep 13, 2024
2 parents 2ec9e9e + db331df commit 88b5fca
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 608 deletions.
73 changes: 66 additions & 7 deletions bioimageio/core/_resource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,18 @@ def test_model(
source: Union[v0_5.ModelDescr, PermissiveFileSource],
weight_format: Optional[WeightsFormat] = None,
devices: Optional[List[str]] = None,
decimal: int = 4,
absolute_tolerance: float = 1.5e-4,
relative_tolerance: float = 1e-4,
decimal: Optional[int] = None,
) -> ValidationSummary:
"""Test model inference"""
# NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
return test_description(
source,
weight_format=weight_format,
devices=devices,
absolute_tolerance=absolute_tolerance,
relative_tolerance=relative_tolerance,
decimal=decimal,
expected_type="model",
)
Expand All @@ -53,15 +58,20 @@ def test_description(
format_version: Union[Literal["discover", "latest"], str] = "discover",
weight_format: Optional[WeightsFormat] = None,
devices: Optional[Sequence[str]] = None,
decimal: int = 4,
absolute_tolerance: float = 1.5e-4,
relative_tolerance: float = 1e-4,
decimal: Optional[int] = None,
expected_type: Optional[str] = None,
) -> ValidationSummary:
"""Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models"""
# NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
rd = load_description_and_test(
source,
format_version=format_version,
weight_format=weight_format,
devices=devices,
absolute_tolerance=absolute_tolerance,
relative_tolerance=relative_tolerance,
decimal=decimal,
expected_type=expected_type,
)
Expand All @@ -74,10 +84,13 @@ def load_description_and_test(
format_version: Union[Literal["discover", "latest"], str] = "discover",
weight_format: Optional[WeightsFormat] = None,
devices: Optional[Sequence[str]] = None,
decimal: int = 4,
absolute_tolerance: float = 1.5e-4,
relative_tolerance: float = 1e-4,
decimal: Optional[int] = None,
expected_type: Optional[str] = None,
) -> Union[ResourceDescr, InvalidDescr]:
"""Test RDF dynamically, e.g. model inference of test inputs"""
# NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
if (
isinstance(source, ResourceDescrBase)
and format_version != "discover"
Expand Down Expand Up @@ -110,7 +123,9 @@ def load_description_and_test(
else:
weight_formats = [weight_format]
for w in weight_formats:
_test_model_inference(rd, w, devices, decimal)
_test_model_inference(
rd, w, devices, absolute_tolerance, relative_tolerance, decimal
)
if not isinstance(rd, v0_4.ModelDescr):
_test_model_inference_parametrized(rd, w, devices)

Expand All @@ -124,12 +139,21 @@ def _test_model_inference(
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
weight_format: WeightsFormat,
devices: Optional[Sequence[str]],
decimal: int,
absolute_tolerance: float,
relative_tolerance: float,
decimal: Optional[int],
) -> None:
test_name = f"Reproduce test outputs from test inputs ({weight_format})"
logger.info("starting '{}'", test_name)
error: Optional[str] = None
tb: List[str] = []

precision_args = _handle_legacy_precision_args(
absolute_tolerance=absolute_tolerance,
relative_tolerance=relative_tolerance,
decimal=decimal,
)

try:
inputs = get_test_inputs(model)
expected = get_test_outputs(model)
Expand All @@ -149,8 +173,11 @@ def _test_model_inference(
error = "Output tensors for test case may not be None"
break
try:
np.testing.assert_array_almost_equal(
res.data, exp.data, decimal=decimal
np.testing.assert_allclose(
res.data,
exp.data,
rtol=precision_args["relative_tolerance"],
atol=precision_args["absolute_tolerance"],
)
except AssertionError as e:
error = f"Output and expected output disagree:\n {e}"
Expand Down Expand Up @@ -361,6 +388,38 @@ def _test_expected_resource_type(
)


def _handle_legacy_precision_args(
absolute_tolerance: float, relative_tolerance: float, decimal: Optional[int]
) -> Dict[str, float]:
"""
Transform the precision arguments to conform with the current implementation.
If the deprecated `decimal` argument is used it overrides the new behaviour with
the old behaviour.
"""
# Already conforms with current implementation
if decimal is None:
return {
"absolute_tolerance": absolute_tolerance,
"relative_tolerance": relative_tolerance,
}

warnings.warn(
"The argument `decimal` has been depricated in favour of "
+ "`relative_tolerance` and `absolute_tolerance`, with different validation "
+ "logic, using `numpy.testing.assert_allclose, see "
+ "'https://numpy.org/doc/stable/reference/generated/"
+ "numpy.testing.assert_allclose.html'. Passing a value for `decimal` will "
+ "cause validation to revert to the old behaviour."
)
# decimal overrides new behaviour,
# have to convert the params to emulate old behaviour
return {
"absolute_tolerance": 1.5 * 10 ** (-decimal),
"relative_tolerance": 0,
}


# def debug_model(
# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
# *,
Expand Down
3 changes: 2 additions & 1 deletion dev/env.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
name: core
channels:
- conda-forge
- defaults
dependencies:
- bioimageio.spec>=0.5.3.2
- black
Expand All @@ -11,6 +10,8 @@ dependencies:
- imageio>=2.5
- jupyter
- jupyter-black
- ipykernel
- matplotlib
- keras>=3.0
- loguru
- numpy
Expand Down
Loading

0 comments on commit 88b5fca

Please sign in to comment.