diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 96c71592..73f9c4a9 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -5,7 +5,7 @@ from attr import dataclass from typing_extensions import Final -from bioimageio.core.stat_measures import Measure +from bioimageio.core.stat_measures import MeasureBase from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import TensorId @@ -19,27 +19,3 @@ PER_SAMPLE = "per_sample" PER_DATASET = "per_dataset" - - -MeasureVar = TypeVar("MeasureVar", bound=Measure) -ModeVar = TypeVar("ModeVar", Literal["per_sample"], Literal["per_dataset"]) - - -@dataclass(frozen=True) -class RequiredMeasure(Generic[MeasureVar, ModeVar]): - measure: MeasureVar - tensor_id: TensorId - mode: ModeVar - - -@dataclass(frozen=True) -class SampleMeasure(RequiredMeasure[MeasureVar, Literal["per_sample"]]): - pass - - -@dataclass(frozen=True) -class DatasetMeasure(RequiredMeasure[MeasureVar, Literal["per_dataset"]]): - pass - - -MeasureValue = xr.DataArray diff --git a/bioimageio/core/proc_impl.py b/bioimageio/core/proc_impl.py index d061a1f8..26de8cdf 100644 --- a/bioimageio/core/proc_impl.py +++ b/bioimageio/core/proc_impl.py @@ -3,6 +3,7 @@ from dataclasses import InitVar, dataclass, field, fields from types import MappingProxyType from typing import ( + Any, ClassVar, FrozenSet, Generic, @@ -23,13 +24,21 @@ import numpy as np import xarray as xr from numpy.typing import DTypeLike -from typing_extensions import LiteralString, assert_never - -from bioimageio.core.common import MeasureValue, ProcessingDescrBase, ProcessingKwargs, RequiredMeasure, Sample -from bioimageio.core.stat_measures import Mean, Percentile, Std +from typing_extensions import LiteralString, assert_never, Unpack + +from bioimageio.core.common import ( + AnyRequiredMeasure, + AxisId, + MeasureVar, + ProcessingDescrBase, + ProcessingKwargs, + RequiredMeasure, + Sample, +) +from bioimageio.core.stat_measures import Mean, MeasureValue, Percentile, Std from bioimageio.spec._internal.base_nodes import NodeWithExplicitlySetFields from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import NonBatchAxisId, TensorId +from bioimageio.spec.model.v0_5 import NonBatchAxisId, TensorId, BinarizeKwargs AssertProcessingId = Literal["assert_dtype"] @@ -48,19 +57,19 @@ class AssertDtype(AssertProcessingBase): kwargs: AssertDtypeKwargs -M = TypeVar("M", RequiredMeasure, MeasureValue) +M = TypeVar("M", AnyRequiredMeasure, MeasureValue) @dataclass -class NamedMeasures(Generic[M]): +class NamedMeasures: """Named Measures that specifies all required/computed measures of a Processing instance""" - def get_set(self) -> Set[M]: + def get_set(self) -> Set[RequiredMeasure[Any, Any]]: return {getattr(self, f.name) for f in fields(self)} # The two generics are conceptually a higher kinded generic -R = TypeVar("R", bound=NamedMeasures[RequiredMeasure]) +R = TypeVar("R", bound=NamedMeasures[RequiredMeasure[Any, Any]]) C = TypeVar("C", bound=NamedMeasures[MeasureValue]) @@ -68,92 +77,171 @@ def get_set(self) -> Set[M]: ProcInput = TypeVar("ProcInput", xr.DataArray, Sample) -@dataclass(frozen=True) -class ProcessingImplBase(Generic[PKwargs, R, C], ABC): - """Base class for all Pre- and Postprocessing implementations.""" +Tensor = xr.DataArray - tensor_id: TensorId - """id of tensor to operate on""" +@dataclass +class Operator(Generic[PKwargs], ABC): kwargs: PKwargs - computed_measures: InitVar[Mapping[RequiredMeasure, MeasureValue]] = field( - default=MappingProxyType[RequiredMeasure, MeasureValue]({}) - ) - assert type(R) is type(C), "R and C are conceptually a higher kindes generic, their class has to be identical" - required: R = field(init=False) - computed: C = field(init=False) - - def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None: - object.__setattr__(self, "required", self.get_required_measures(self.tensor_id, self.kwargs)) - selected = {} - for f in fields(self.required): - req = getattr(self.required, f.name) - if req in computed_measures: - selected[f.name] = computed_measures[req] - else: - raise ValueError(f"Missing computed measure: {req} (as '{f.name}').") - - object.__setattr__(self, "computed", self.required.__class__(**selected)) - + computed: @abstractmethod - @classmethod - def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMeasures[RequiredMeasure]: + def __call__(self) -> Tensor: ... - def __call__(self, __input: ProcInput, /) -> ProcInput: - if isinstance(__input, xr.DataArray): - return self.apply(__input) - else: - return self.apply_to_sample(__input) - @abstractmethod - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - """apply processing""" - ... +@dataclass +class Binarize(Operator): + threshold: float + + +# class Source(Operator): +# def __call__(self) -> Tensor: +# return Tensor() - def apply_to_sample(self, sample: Sample) -> Sample: - ret = dict(sample) - ret[self.tensor_id] = self.apply(sample[self.tensor_id]) - return ret + +@dataclass +class Smooth(Operator): + sigma: float + tensor_source: Source @abstractmethod - def get_descr(self) -> Union[ProcessingDescrBase, AssertProcessingBase]: + @classmethod + def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> R: ... + def __call__(self) -> Tensor: + return tensor * self.sigma # fixme -@dataclass(frozen=True) -class ProcessingImplBaseWoMeasures( - ProcessingImplBase[PKwargs, NamedMeasures[RequiredMeasure], NamedMeasures[MeasureValue]] -): - @classmethod - def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMeasures[RequiredMeasure]: - return NamedMeasures() +class Diff(Operator): + def __call__(self, a: Tensor, b: Tensor) -> Tensor: + return a - b -@dataclass(frozen=True) -class AssertDtypeImpl(ProcessingImplBaseWoMeasures[AssertDtypeKwargs]): - kwargs_class = AssertDtypeKwargs - _assert_with: Tuple[Type[DTypeLike], ...] = field(init=False) - - def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None: - super().__post_init__(computed_measures) - if isinstance(self.kwargs.dtype, str): - dtype = [self.kwargs.dtype] - else: - dtype = self.kwargs.dtype - object.__setattr__(self, "assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype)) - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - assert isinstance(tensor.dtype, self._assert_with) - return tensor - def get_descr(self): - return AssertDtype(kwargs=self.kwargs) -@dataclass(frozen=True) -class BinarizeImpl(ProcessingImplBaseWoMeasures[Union[v0_4.BinarizeKwargs, v0_5.BinarizeKwargs]]): +@dataclass +class SimpleOperator(Operator, ABC): + input_id: TensorId + output_id: TensorId + + def __call__(self, sample: Sample, /) -> Sample: + ret = dict(sample) + ret[self.output_id] = self.apply(sample[self.input_id]) + return ret + + + @abstractmethod + def apply(self, tensor: xr.DataArray) -> xr.DataArray: + ... + +# @dataclass(frozen=True) +# class ProcessingImplBase(Generic[PKwargs, R, C], ABC): +# """Base class for all Pre- and Postprocessing implementations.""" + +# tensor_id: TensorId +# """id of tensor to operate on""" +# kwargs: PKwargs +# computed_measures: InitVar[Mapping[AnyRequiredMeasure, MeasureValue]] = field( +# default=MappingProxyType[AnyRequiredMeasure, MeasureValue]({}) +# ) +# assert type(R) is type(C), "R and C are conceptually a higher kindes generic, their class has to be identical" +# required: R = field(init=False) +# computed: C = field(init=False) + +# def __post_init__(self, computed_measures: Mapping[AnyRequiredMeasure, MeasureValue]) -> None: +# object.__setattr__(self, "required", self.get_required_measures(self.tensor_id, self.kwargs)) +# selected = {} +# for f in fields(self.required): +# req = getattr(self.required, f.name) +# if req in computed_measures: +# selected[f.name] = computed_measures[req] +# else: +# raise ValueError(f"Missing computed measure: {req} (as '{f.name}').") + +# object.__setattr__(self, "computed", self.required.__class__(**selected)) + +# @abstractmethod +# @classmethod +# def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> R: +# ... + +# def __call__(self, __input: ProcInput, /) -> ProcInput: +# if isinstance(__input, xr.DataArray): +# return self.apply(__input) +# else: +# return self.apply_to_sample(__input) + +# @abstractmethod +# def apply(self, tensor: xr.DataArray) -> xr.DataArray: +# """apply processing""" +# ... + +# def apply_to_sample(self, sample: Sample) -> Sample: +# ret = dict(sample) +# ret[self.tensor_id] = self.apply(sample[self.tensor_id]) +# return ret + +# @abstractmethod +# def get_descr(self) -> Union[ProcessingDescrBase, AssertProcessingBase]: +# ... + + +# @dataclass(frozen=True) +# class ProcessingImplBaseWoMeasures( +# ProcessingImplBase[PKwargs, NamedMeasures[AnyRequiredMeasure], NamedMeasures[MeasureValue]] +# ): +# @classmethod +# def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMeasures[AnyRequiredMeasure]: +# return NamedMeasures() + + +# @dataclass(frozen=True) +# class AssertDtypeImpl(ProcessingImplBaseWoMeasures[AssertDtypeKwargs]): +# kwargs_class = AssertDtypeKwargs +# _assert_with: Tuple[Type[DTypeLike], ...] = field(init=False) + +# def __post_init__(self, computed_measures: Mapping[AnyRequiredMeasure, MeasureValue]) -> None: +# super().__post_init__(computed_measures) +# if isinstance(self.kwargs.dtype, str): +# dtype = [self.kwargs.dtype] +# else: +# dtype = self.kwargs.dtype + +# object.__setattr__(self, "assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype)) + +# def apply(self, tensor: xr.DataArray) -> xr.DataArray: +# assert isinstance(tensor.dtype, self._assert_with) +# return tensor + +# def get_descr(self): +# return AssertDtype(kwargs=self.kwargs) + +# class AssertDtype(Operator): +# dtype: Sequence[DTypeLike] +# _assert_with: Tuple[Type[DTypeLike], ...] = field(init=False) + +# def __post_init__(self, computed_measures: Mapping[AnyRequiredMeasure, MeasureValue]) -> None: +# super().__post_init__(computed_measures) +# if isinstance(self.kwargs.dtype, str): +# dtype = [self.kwargs.dtype] +# else: +# dtype = self.kwargs.dtype + +# object.__setattr__(self, "assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype)) + +# def apply(self, tensor: xr.DataArray) -> xr.DataArray: +# assert isinstance(tensor.dtype, self._assert_with) +# return tensor + +# def get_descr(self): +# return AssertDtype(kwargs=self.kwargs) + +@dataclass +class BinarizeImpl(Operator): """'output = tensor > threshold'.""" + threshold: float def apply(self, tensor: xr.DataArray) -> xr.DataArray: return tensor > self.kwargs.threshold @@ -237,7 +325,7 @@ class NamedMeasuresScaleMeanVariance(NamedMeasures[M]): class ScaleMeanVarianceImpl( ProcessingImplBase[ Union[v0_4.ScaleMeanVarianceKwargs, v0_5.ScaleMeanVarianceKwargs], - NamedMeasuresScaleMeanVariance[RequiredMeasure], + NamedMeasuresScaleMeanVariance[AnyRequiredMeasure], NamedMeasuresScaleMeanVariance[MeasureValue], ] ): @@ -248,7 +336,7 @@ def get_required_measures( if kwargs.axes is None: axes = None elif isinstance(kwargs.axes, str): - axes = tuple(NonBatchAxisId(a) for a in kwargs.axes) + axes = tuple(AxisId(a) for a in kwargs.axes) elif isinstance(kwargs.axes, collections.abc.Sequence): # pyright: ignore[reportUnnecessaryIsInstance] axes = tuple(kwargs.axes) else: @@ -283,14 +371,14 @@ class NamedMeasuresScaleRange(NamedMeasures[M]): class ScaleRangeImpl( ProcessingImplBase[ Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs], - NamedMeasuresScaleRange[RequiredMeasure], + NamedMeasuresScaleRange[RequiredMeasure[Percentile, Any]], NamedMeasuresScaleRange[MeasureValue], ] ): @classmethod def get_required_measures(cls, tensor_id: TensorId, kwargs: Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs]): ref_name = kwargs.reference_tensor or tensor_id - axes = None if kwargs.axes is None else tuple(NonBatchAxisId(a) for a in kwargs.axes) + axes = None if kwargs.axes is None else tuple(AxisId(a) for a in kwargs.axes) return NamedMeasuresScaleRange( lower=RequiredMeasure(Percentile(kwargs.min_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode), upper=RequiredMeasure(Percentile(kwargs.max_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode), diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index a66c46aa..e0079b89 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -35,7 +35,7 @@ class _SetupProcessing(NamedTuple): def setup_pre_and_postprocessing(model: ModelDescr, dataset: Iterator[Sample]) -> _SetupProcessing: Prepared = List[Tuple[Type[ProcessingImplBase[Any, Any, Any]], ProcessingKwargs, TensorId]] - required_measures: Set[RequiredMeasure] = set() + required_measures: Set[RequiredMeasure[Any, Any]] = set() def prepare_procs(tensor_descrs: Sequence[TensorDescr]): prepared: Prepared = [] diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py index f31bcfe0..d465cd6c 100644 --- a/bioimageio/core/stat_calculators.py +++ b/bioimageio/core/stat_calculators.py @@ -25,6 +25,7 @@ Tuple, Type, Union, + assert_never, ) import numpy as np @@ -35,14 +36,27 @@ PER_DATASET, PER_SAMPLE, AxisId, - DatasetMeasure, - MeasureVar, - RequiredMeasure, Sample, - SampleMeasure, TensorId, ) -from bioimageio.core.stat_measures import Mean, Measure, Percentile, Std, Var +from bioimageio.core.stat_measures import ( + DatasetMean, + DatasetMeasureBase, + DatasetMeasureVar, + DatasetPercentile, + DatasetStd, + DatasetVar, + Measure, + MeasureVar, + Percentile, + SampleMean, + SampleMeasureBase, + SamplePercentile, + SampleStd, + SampleVar, + Std, + Var, +) try: import crick # type: ignore @@ -52,29 +66,29 @@ MeasureValue = Union[xr.DataArray, float] -class SampleMeasureCalculator(ABC, Generic[MeasureVar]): - """group of measures for more efficient computation of multiple measures per sample""" +# class SampleMeasureCalculator(ABC): +# """group of measures for more efficient computation of multiple measures per sample""" - @abstractmethod - def compute(self, sample: Sample) -> Mapping[SampleMeasure[MeasureVar], MeasureValue]: - ... +# @abstractmethod +# def compute(self, sample: Sample) -> Mapping[SampleMeasure, MeasureValue]: +# ... -class DatasetMeasureCalculator(ABC, Generic[MeasureVar]): - """group of measures for more efficient computation of multiple measures per dataset""" +# class DatasetMeasureCalculator(ABC): +# """group of measures for more efficient computation of multiple measures per dataset""" - @abstractmethod - def update_with_sample(self, sample: Sample) -> None: - """update intermediate representation with a data sample""" - ... +# @abstractmethod +# def update_with_sample(self, sample: Sample) -> None: +# """update intermediate representation with a data sample""" +# ... - @abstractmethod - def finalize(self) -> Mapping[DatasetMeasure[MeasureVar], MeasureValue]: - """compute statistics from intermediate representation""" - ... +# @abstractmethod +# def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: +# """compute statistics from intermediate representation""" +# ... -class MeanCalculator(SampleMeasureCalculator[Mean], DatasetMeasureCalculator[Mean]): +class MeanCalculator: def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): super().__init__() self._axes = None if axes is None else tuple(axes) @@ -83,11 +97,8 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): self._mean: Optional[xr.DataArray] = None def compute(self, sample: Sample): - return { - SampleMeasure(measure=Mean(axes=self._axes), tensor_id=self._tensor_id): sample[self._tensor_id].mean( - dim=self._axes - ) - } + mean = SampleMean(axes=self._axes, tensor_id=self._tensor_id) + return {mean: mean.compute(sample)} def update_with_sample(self, sample: Sample): tensor = sample[self._tensor_id].astype(np.float64, copy=False) @@ -106,14 +117,14 @@ def update_with_sample(self, sample: Sample): self._mean = (n_a * mean_a + n_b * mean_b) / n assert self._mean.dtype == np.float64 - def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: + def finalize(self) -> Mapping[DatasetMeasureBase, MeasureValue]: if self._mean is None: return {} else: - return {DatasetMeasure(measure=Mean(axes=self._axes), tensor_id=self._tensor_id): self._mean} + return {DatasetMean(axes=self._axes, tensor_id=self._tensor_id): self._mean} -class MeanVarStdCalculator(SampleMeasureCalculator, DatasetMeasureCalculator): +class MeanVarStdCalculator: def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]]): super().__init__() self._axes = None if axes is None else tuple(axes) @@ -134,9 +145,9 @@ def compute(self, sample: Sample): var = xr.dot(c, c, dims=self._axes) / n std = np.sqrt(var) return { - SampleMeasure(Mean(axes=self._axes), tensor_id=self._tensor_id): mean, - SampleMeasure(Var(axes=self._axes), tensor_id=self._tensor_id): var, - SampleMeasure(Std(axes=self._axes), tensor_id=self._tensor_id): std, + SampleMean(axes=self._axes, tensor_id=self._tensor_id): mean, + SampleVar(axes=self._axes, tensor_id=self._tensor_id): var, + SampleStd(axes=self._axes, tensor_id=self._tensor_id): std, } def update_with_sample(self, sample: Sample): @@ -163,7 +174,7 @@ def update_with_sample(self, sample: Sample): self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n assert self._m2.dtype == np.float64 - def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: + def finalize(self) -> Mapping[DatasetMeasureBase, MeasureValue]: if self._mean is None: return {} else: @@ -171,13 +182,13 @@ def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: var = self._m2 / self._n sqrt: xr.DataArray = np.sqrt(var) # type: ignore return { - DatasetMeasure(tensor_id=self._tensor_id, measure=Mean(axes=self._axes)): self._mean, - DatasetMeasure(tensor_id=self._tensor_id, measure=Var(axes=self._axes)): var, - DatasetMeasure(tensor_id=self._tensor_id, measure=Std(axes=self._axes)): sqrt, + DatasetMean(tensor_id=self._tensor_id, axes=self._axes): self._mean, + DatasetVar(tensor_id=self._tensor_id, axes=self._axes): var, + DatasetStd(tensor_id=self._tensor_id, axes=self._axes): sqrt, } -class SamplePercentilesCalculator(SampleMeasureCalculator): +class SamplePercentilesCalculator: def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Sequence[float]): super().__init__() assert all(0 <= n <= 100 for n in ns) @@ -189,13 +200,10 @@ def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Se def compute(self, sample: Sample): tensor = sample[self._tensor_id] ps = tensor.quantile(self._qs, dim=self._axes) # type: ignore - return { - SampleMeasure(measure=Percentile(n=n, axes=self._axes), tensor_id=self._tensor_id): p - for n, p in zip(self.ns, ps) - } + return {SamplePercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): p for n, p in zip(self.ns, ps)} -class MeanPercentilesCalculator(DatasetMeasureCalculator): +class MeanPercentilesCalculator: def __init__(self, tensor_id: TensorId, axes: Optional[Sequence[AxisId]], ns: Sequence[float]): super().__init__() assert all(0 <= n <= 100 for n in ns) @@ -222,18 +230,18 @@ def update_with_sample(self, sample: Sample): self._n += n - def finalize(self) -> Mapping[DatasetMeasure, MeasureValue]: + def finalize(self) -> Mapping[DatasetPercentile, MeasureValue]: if self._estimates is None: return {} else: warnings.warn("Computed dataset percentiles naively by averaging percentiles of samples.") return { - DatasetMeasure(measure=Percentile(n=n, axes=self._axes), tensor_id=self._tensor_id): e + DatasetPercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): e for n, e in zip(self._ns, self._estimates) } -class CrickPercentilesCalculator(DatasetMeasureCalculator): +class CrickPercentilesCalculator: def __init__(self, tensor_name: TensorId, axes: Optional[Sequence[AxisId]], ns: Sequence[float]): warnings.warn("Computing dataset percentiles with experimental 'crick' library.") super().__init__() @@ -273,16 +281,14 @@ def update_with_sample(self, sample: Sample): for i, idx in enumerate(self._indices): self._digest[i].update(tensor.isel(dict(zip(self._dims[1:], idx)))) - def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: + def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: if self._digest is None: return {} else: assert self._dims is not None vs: NDArray[Any] = np.asarray([[d.quantile(q) for d in self._digest] for q in self._qs]).reshape(self._shape) # type: ignore return { - DatasetMeasure(measure=Percentile(n=n, axes=self._axes), tensor_id=self._tensor_id): xr.DataArray( - v, dims=self._dims[1:] - ) + DatasetPercentile(n=n, axes=self._axes, tensor_id=self._tensor_id): xr.DataArray(v, dims=self._dims[1:]) for n, v in zip(self._ns, vs) } @@ -295,24 +301,26 @@ def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: DatasetPercentileCalculator = CrickPercentilesCalculator -class NaivSampleMeasureCalculator(SampleMeasureCalculator): - """wrapper for measures to match interface of SampleMeasureGroup""" +class NaivSampleMeasureCalculator: + """wrapper for measures to match interface of other sample measure calculators""" - def __init__(self, tensor_id: TensorId, measure: Measure): + def __init__(self, tensor_id: TensorId, measure: SampleMeasureBase): super().__init__() self.tensor_name = tensor_id self.measure = measure - def compute(self, sample: Sample) -> Mapping[SampleMeasure, MeasureValue]: - return { - SampleMeasure(measure=self.measure, tensor_id=self.tensor_name): self.measure.compute( - sample[self.tensor_name] - ) - } + def compute(self, sample: Sample) -> Mapping[SampleMeasureBase, MeasureValue]: + return {self.measure: self.measure.compute(sample)} + + +SampleMeasureCalculator = Union[ + MeanCalculator, MeanVarStdCalculator, SamplePercentilesCalculator, NaivSampleMeasureCalculator +] +DatasetMeasureCalculator = Union[MeanCalculator, MeanVarStdCalculator, DatasetPercentileCalculator] def get_measure_calculators( - required_measures: Iterable[RequiredMeasure], + required_measures: Iterable[Measure], ) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: """determines which calculators are needed to compute the required measures efficiently""" @@ -320,50 +328,58 @@ def get_measure_calculators( dataset_calculators: List[DatasetMeasureCalculator] = [] # split required measures into groups - required_means: Set[RequiredMeasure] = set() - required_mean_var_std: Set[RequiredMeasure] = set() - required_percentiles: Set[RequiredMeasure] = set() + required_sample_means: Set[SampleMean] = set() + required_dataset_means: Set[DatasetMean] = set() + required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() + required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = set() + required_sample_percentiles: Set[SamplePercentile] = set() + required_dataset_percentiles: Set[DatasetPercentile] = set() for rm in required_measures: - if isinstance(rm.measure, Mean): - required_means.add(rm) - elif isinstance(rm.measure, (Var, Std)): - required_mean_var_std.update( - { - RequiredMeasure(measure=msv(rm.measure.axes), tensor_id=rm.tensor_id, mode=rm.mode) - for msv in (Mean, Std, Var) - } + if isinstance(rm, SampleMean): + required_sample_means.add(rm) + elif isinstance(rm, DatasetMean): + required_dataset_means.add(rm) + elif isinstance(rm, (SampleVar, SampleStd)): + required_sample_mean_var_std.update( + {msv(axes=rm.axes, tensor_id=rm.tensor_id) for msv in (SampleMean, SampleStd, SampleVar)} + ) + assert rm in required_sample_mean_var_std + elif isinstance(rm, (DatasetVar, DatasetStd)): + required_dataset_mean_var_std.update( + {msv(axes=rm.axes, tensor_id=rm.tensor_id) for msv in (DatasetMean, DatasetStd, DatasetVar)} ) - assert rm in required_mean_var_std - elif isinstance(rm.measure, Percentile): - required_percentiles.add(rm) - elif rm.mode == PER_SAMPLE: - sample_calculators.append(NaivSampleMeasureCalculator(tensor_id=rm.tensor_id, measure=rm.measure)) + assert rm in required_dataset_mean_var_std + elif isinstance(rm, SamplePercentile): + required_sample_percentiles.add(rm) + elif isinstance(rm, DatasetPercentile): # pyright: ignore[reportUnnecessaryIsInstance] + required_dataset_percentiles.add(rm) else: - raise NotImplementedError(f"Computing statistics for {rm.measure} {rm.mode} not yet implemented") + assert_never(rm) + + for rm in required_sample_means: + if rm in required_sample_mean_var_std: + # computed togehter with var and std + continue + + sample_calculators.append(MeanCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) + + for rm in required_sample_mean_var_std: + sample_calculators.append(MeanVarStdCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) - for rm in required_means: - if rm in required_mean_var_std: + for rm in required_dataset_means: + if rm in required_dataset_mean_var_std: # computed togehter with var and std continue - if rm.mode == PER_SAMPLE: - sample_calculators.append(MeanCalculator(tensor_id=rm.tensor_id, axes=rm.measure.axes)) - # add all mean measures that are not included in a mean/var/std group - for tn, m in means: - if (tn, m.axes) not in required_mean_var_std: - # compute only mean - if mode == PER_SAMPLE: - calculators[mode].append(NaivSampleMeasureCalculator(tensor_id=tn, measure=m)) - elif mode == PER_DATASET: - calculators[mode].append(DatasetMeanCalculator(tensor_id=tn, axes=m.axes)) - else: - raise NotImplementedError(mode) - - for tn, axes in mean_var_std_groups: - calculators[mode].append(MeanVarStdCalculator(tensor_id=tn, axes=axes)) - - for (tn, axes), ns in required_percentiles.items(): + dataset_calculators.append(MeanCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) + + for rm in required_dataset_mean_var_std: + dataset_calculators.append(MeanVarStdCalculator(tensor_id=rm.tensor_id, axes=rm.axes)) + + for rm in required_sample_percentiles: + sample_calculators.append(SamplePercentilesCalculator(tensor_id=rm.tensor_id, axes=axes)) + for (tn, axes), ns in required_sample_percentiles.items(): if mode == PER_SAMPLE: calculators[mode].append(SamplePercentilesCalculator(tensor_id=tn, axes=axes, ns=ns)) elif mode == PER_DATASET: @@ -375,7 +391,7 @@ def get_measure_calculators( def compute_measures( - measures: RequiredMeasures, *, sample: Optional[Sample] = None, dataset: Iterator[Sample] = () + measures: Set[Measure], *, sample: Optional[Sample] = None, dataset: Iterator[Sample] = () ) -> ComputedMeasures: ms_groups = get_measure_calculators(measures) ret = {PER_SAMPLE: {}, PER_DATASET: {}} diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py index 29d6857a..6f4f3aa9 100644 --- a/bioimageio/core/stat_measures.py +++ b/bioimageio/core/stat_measures.py @@ -2,48 +2,84 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, TypeVar, Union import xarray as xr -from bioimageio.core.common import MeasureValue -from bioimageio.spec.model.v0_5 import AxisId +from bioimageio.core.common import Sample +from bioimageio.spec.model.v0_5 import AxisId, TensorId + +MeasureValue = Union[float, xr.DataArray] + + +@dataclass(frozen=True) +class MeasureBase(ABC): + tensor_id: TensorId @dataclass(frozen=True) -class Measure(ABC): +class SampleMeasureBase(MeasureBase, ABC): @abstractmethod - def compute(self, tensor: xr.DataArray) -> MeasureValue: - """compute the measure (and also associated other Measures)""" + def compute(self, sample: Sample) -> MeasureValue: + """compute the measure""" ... @dataclass(frozen=True) -class Mean(Measure): +class DatasetMeasureBase(MeasureBase, ABC): + pass + + +@dataclass(frozen=True) +class _Mean(MeasureBase): axes: Optional[Tuple[AxisId, ...]] = None - def compute(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.mean(dim=self.axes) + +@dataclass(frozen=True) +class SampleMean(_Mean, SampleMeasureBase): + def compute(self, sample: Sample) -> MeasureValue: + return sample[self.tensor_id].mean(dim=self.axes) + + +@dataclass(frozen=True) +class DatasetMean(_Mean, DatasetMeasureBase): + pass @dataclass(frozen=True) -class Std(Measure): +class _Std(MeasureBase): axes: Optional[Tuple[AxisId, ...]] = None - def compute(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.std(dim=self.axes) + +@dataclass(frozen=True) +class SampleStd(_Std, SampleMeasureBase): + def compute(self, sample: Sample) -> MeasureValue: + return sample[self.tensor_id].std(dim=self.axes) + + +@dataclass(frozen=True) +class DatasetStd(_Std, DatasetMeasureBase): + pass @dataclass(frozen=True) -class Var(Measure): +class _Var(MeasureBase): axes: Optional[Tuple[AxisId, ...]] = None - def compute(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.var(dim=self.axes) + +@dataclass(frozen=True) +class SampleVar(_Var, SampleMeasureBase): + def compute(self, sample: Sample) -> MeasureValue: + return sample[self.tensor_id].var(dim=self.axes) + + +@dataclass(frozen=True) +class DatasetVar(_Var, DatasetMeasureBase): + pass @dataclass(frozen=True) -class Percentile(Measure): +class _Percentile(MeasureBase): n: float axes: Optional[Tuple[AxisId, ...]] = None @@ -51,5 +87,23 @@ def __post_init__(self): assert self.n >= 0 assert self.n <= 100 - def compute(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.quantile(self.n / 100.0, dim=self.axes) + +@dataclass(frozen=True) +class SamplePercentile(_Percentile, SampleMeasureBase): + def compute(self, sample: Sample) -> MeasureValue: + return sample[self.tensor_id].tensor.quantile(self.n / 100.0, dim=self.axes) + + +@dataclass(frozen=True) +class DatasetPercentile(_Percentile, DatasetMeasureBase): + pass + + +SampleMeasure = Union[SampleMean, SampleStd, SampleVar, SamplePercentile] +DatasetMeasure = Union[DatasetMean, DatasetStd, DatasetVar, DatasetPercentile] +Measure = Union[SampleMeasure, DatasetMeasure] + +# MeasureVar = TypeVar("MeasureVar", bound=MeasureBase) +# SampleMeasureVar = TypeVar("SampleMeasureVar", bound=SampleMeasureBase) +# DatasetMeasureVar = TypeVar("DatasetMeasureVar", bound=DatasetMeasureBase) +# ModeVar = TypeVar("ModeVar", bound=Literal["per_sample", "per_dataset"]) diff --git a/bioimageio/core/stat_state.py b/bioimageio/core/stat_state.py index 107383be..24f062c9 100644 --- a/bioimageio/core/stat_state.py +++ b/bioimageio/core/stat_state.py @@ -3,9 +3,9 @@ from tqdm import tqdm -from bioimageio.core.common import PER_DATASET, PER_SAMPLE, MeasureValue, RequiredMeasure, Sample, TensorId +from bioimageio.core.common import PER_DATASET, PER_SAMPLE, RequiredMeasure, Sample, TensorId from bioimageio.core.stat_calculators import MeasureGroups, MeasureValue, get_measure_calculators -from bioimageio.core.stat_measures import Measure +from bioimageio.core.stat_measures import MeasureBase, MeasureValue @dataclass @@ -15,7 +15,7 @@ class StatsState: required_measures: Iterable[RequiredMeasure] -def compute_statistics() +def compute_statistics(): dataset: Iterable[Sample] update_dataset_stats_after_n_samples: Optional[int] = None update_dataset_stats_for_n_samples: Union[int, float] = float("inf") diff --git a/bioimageio/core/utils.py b/bioimageio/core/utils.py index 46fff6db..e69de29b 100644 --- a/bioimageio/core/utils.py +++ b/bioimageio/core/utils.py @@ -1,32 +0,0 @@ -from functools import singledispatch -from typing import Any, Dict, List, Union - -import numpy as np -import xarray as xr -from numpy.typing import NDArray - -from bioimageio.spec.model import v0_4, v0_5 -from bioimageio.spec.model.v0_5 import TensorId -from bioimageio.spec.utils import download, load_array - -# @singledispatch -# def is_valid_tensor(description: object, tensor: Union[NDArray[Any], xr.DataArray]) -> bool: -# raise NotImplementedError(type(description)) - -# is_valid_tensor.register -# def _(description: v0_4.InputTensor, tensor: Union[NDArray[Any], xr.DataArray]): - - -@singledispatch -def get_test_input_tensors(model: object) -> List[xr.DataArray]: - raise NotImplementedError(type(model)) - - -@get_test_input_tensors.register -def _(model: v0_4.Model): - data = [load_array(download(ipt).path) for ipt in model.test_inputs] - assert all(isinstance(d, np.ndarray) for d in data) - - -# @get_test_input_tensors.register -# def _(model: v0_5.Model): diff --git a/bioimageio/core/weight_converter/torch/torchscript.py b/bioimageio/core/weight_converter/torch/torchscript.py index bace789e..e01ac34f 100644 --- a/bioimageio/core/weight_converter/torch/torchscript.py +++ b/bioimageio/core/weight_converter/torch/torchscript.py @@ -1,22 +1,24 @@ -from typing import List, Sequence -from typing_extensions import Any, assert_never from pathlib import Path -from typing import Union +from typing import List, Sequence, Union import numpy as np import torch from numpy.testing import assert_array_almost_equal +from typing_extensions import Any, assert_never -from bioimageio.spec import load_description -from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec import load_description from bioimageio.spec.common import InvalidDescription +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import Version from bioimageio.spec.utils import download from .utils import load_model + # FIXME: remove Any -def _check_predictions(model: Any, scripted_model: Any, model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", input_data: Sequence[torch.Tensor]): +def _check_predictions( + model: Any, scripted_model: Any, model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", input_data: Sequence[torch.Tensor] +): def _check(input_: Sequence[torch.Tensor]) -> None: expected_tensors = model(*input_) if isinstance(expected_tensors, torch.Tensor): @@ -37,7 +39,7 @@ def _check(input_: Sequence[torch.Tensor]) -> None: _check(input_data) if len(model_spec.inputs) > 1: - return # FIXME: why don't we check multiple inputs? + return # FIXME: why don't we check multiple inputs? input_descr = model_spec.inputs[0] if isinstance(input_descr, v0_4.InputTensorDescr): @@ -57,7 +59,7 @@ def _check(input_: Sequence[torch.Tensor]) -> None: step.append(0) elif isinstance(axis.size, (v0_5.AxisId, v0_5.TensorAxisId, type(None))): raise NotImplementedError(f"Can't verify inputs that don't specify their shape fully: {axis}") - elif isinstance(axis.size, v0_5.SizeReference): # pyright: ignore [reportUnnecessaryIsInstance] + elif isinstance(axis.size, v0_5.SizeReference): # pyright: ignore [reportUnnecessaryIsInstance] raise NotImplementedError(f"Can't handle axes like '{axis}' yet") else: assert_never(axis.size) @@ -74,36 +76,26 @@ def _check(input_: Sequence[torch.Tensor]) -> None: raise ValueError(f"Mismatched shapes: {this_shape}. Expected at least {min_shape}") _check(this_input) + def convert_weights_to_torchscript( - model_spec: Union[str, Path, v0_4.ModelDescr, v0_5.ModelDescr], output_path: Path, use_tracing: bool = True -): + model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr], output_path: Path, use_tracing: bool = True +) -> v0_5.TorchscriptWeightsDescr: """Convert model weights from format 'pytorch_state_dict' to 'torchscript'. Args: - model_spec: location of the resource for the input bioimageio model + model_descr: location of the resource for the input bioimageio model output_path: where to save the torchscript weights use_tracing: whether to use tracing or scripting to export the torchscript format """ - if isinstance(model_spec, (str, Path)): - loaded_spec = load_description(Path(model_spec)) - if isinstance(loaded_spec, InvalidDescription): - raise ValueError(f"Bad resource description: {loaded_spec}") - if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)): - raise TypeError(f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr") - model_spec = loaded_spec - - state_dict_weights_descr = model_spec.weights.pytorch_state_dict + + state_dict_weights_descr = model_descr.weights.pytorch_state_dict if state_dict_weights_descr is None: - raise ValueError(f"The provided model does not have weights in the pytorch state dict format") + raise ValueError("The provided model does not have weights in the pytorch state dict format") - with torch.no_grad(): - if isinstance(model_spec, v0_4.ModelDescr): - downloaded_test_inputs = [download(inp) for inp in model_spec.test_inputs] - else: - downloaded_test_inputs = [inp.test_tensor.download() for inp in model_spec.inputs] + input_data = model_descr.get_input_test_arrays() - input_data = [np.load(dl.path).astype("float32") for dl in downloaded_test_inputs] - input_data = [torch.from_numpy(inp) for inp in input_data] + with torch.no_grad(): + input_data = [torch.from_numpy(inp.astype("float32")) for inp in input_data] model = load_model(state_dict_weights_descr) @@ -113,13 +105,11 @@ def convert_weights_to_torchscript( else: scripted_model: Any = torch.jit.script(model) - ret = _check_predictions( - model=model, - scripted_model=scripted_model, - model_spec=model_spec, - input_data=input_data - ) + _check_predictions(model=model, scripted_model=scripted_model, model_spec=model_descr, input_data=input_data) # save the torchscript model scripted_model.save(str(output_path)) # does not support Path, so need to cast to str - return ret + + return v0_5.TorchscriptWeightsDescr( + source=output_path, pytorch_version=Version(torch.__version__), parent="pytorch_state_dict" + )