From 6c9bf0fbaca05e1a8b98453842a39245e4fc90ab Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 20 Nov 2023 14:58:31 +0100 Subject: [PATCH] WIP add model_utils --- bioimageio/core/model_utils.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 bioimageio/core/model_utils.py diff --git a/bioimageio/core/model_utils.py b/bioimageio/core/model_utils.py new file mode 100644 index 00000000..2c8dd51f --- /dev/null +++ b/bioimageio/core/model_utils.py @@ -0,0 +1,31 @@ +from functools import singledispatch +from typing import Any, List, Union + +import numpy as np +import xarray as xr +from numpy.typing import NDArray + +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + +# @singledispatch +# def is_valid_tensor(description: object, tensor: Union[NDArray[Any], xr.DataArray]) -> bool: +# raise NotImplementedError(type(description)) + +# is_valid_tensor.register +# def _(description: v0_4.InputTensor, tensor: Union[NDArray[Any], xr.DataArray]): + + +@singledispatch +def get_test_input_tensors(model: object) -> List[xr.DataArray]: + raise NotImplementedError(type(model)) + + +@get_test_input_tensors.register +def _(model: v0_4.Model): + data = [np.load(download(ipt).path) for ipt in model.test_inputs] + assert all(isinstance(d, np.ndarray) for d in data) + + +# @get_test_input_tensors.register +# def _(model: v0_5.Model):