Skip to content

Commit

Permalink
WIP measures
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Nov 28, 2023
1 parent a227388 commit 6ea6c85
Show file tree
Hide file tree
Showing 8 changed files with 381 additions and 289 deletions.
26 changes: 1 addition & 25 deletions bioimageio/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from attr import dataclass
from typing_extensions import Final

from bioimageio.core.stat_measures import Measure
from bioimageio.core.stat_measures import MeasureBase
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import TensorId

Expand All @@ -19,27 +19,3 @@

PER_SAMPLE = "per_sample"
PER_DATASET = "per_dataset"


MeasureVar = TypeVar("MeasureVar", bound=Measure)
ModeVar = TypeVar("ModeVar", Literal["per_sample"], Literal["per_dataset"])


@dataclass(frozen=True)
class RequiredMeasure(Generic[MeasureVar, ModeVar]):
measure: MeasureVar
tensor_id: TensorId
mode: ModeVar


@dataclass(frozen=True)
class SampleMeasure(RequiredMeasure[MeasureVar, Literal["per_sample"]]):
pass


@dataclass(frozen=True)
class DatasetMeasure(RequiredMeasure[MeasureVar, Literal["per_dataset"]]):
pass


MeasureValue = xr.DataArray
246 changes: 167 additions & 79 deletions bioimageio/core/proc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import InitVar, dataclass, field, fields
from types import MappingProxyType
from typing import (
Any,
ClassVar,
FrozenSet,
Generic,
Expand All @@ -23,13 +24,21 @@
import numpy as np
import xarray as xr
from numpy.typing import DTypeLike
from typing_extensions import LiteralString, assert_never

from bioimageio.core.common import MeasureValue, ProcessingDescrBase, ProcessingKwargs, RequiredMeasure, Sample
from bioimageio.core.stat_measures import Mean, Percentile, Std
from typing_extensions import LiteralString, assert_never, Unpack

from bioimageio.core.common import (
AnyRequiredMeasure,
AxisId,
MeasureVar,
ProcessingDescrBase,
ProcessingKwargs,
RequiredMeasure,
Sample,
)
from bioimageio.core.stat_measures import Mean, MeasureValue, Percentile, Std
from bioimageio.spec._internal.base_nodes import NodeWithExplicitlySetFields
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import NonBatchAxisId, TensorId
from bioimageio.spec.model.v0_5 import NonBatchAxisId, TensorId, BinarizeKwargs

AssertProcessingId = Literal["assert_dtype"]

Expand All @@ -48,112 +57,191 @@ class AssertDtype(AssertProcessingBase):
kwargs: AssertDtypeKwargs


M = TypeVar("M", RequiredMeasure, MeasureValue)
M = TypeVar("M", AnyRequiredMeasure, MeasureValue)


@dataclass
class NamedMeasures(Generic[M]):
class NamedMeasures:
"""Named Measures that specifies all required/computed measures of a Processing instance"""

def get_set(self) -> Set[M]:
def get_set(self) -> Set[RequiredMeasure[Any, Any]]:
return {getattr(self, f.name) for f in fields(self)}


# The two generics are conceptually a higher kinded generic
R = TypeVar("R", bound=NamedMeasures[RequiredMeasure])
R = TypeVar("R", bound=NamedMeasures[RequiredMeasure[Any, Any]])
C = TypeVar("C", bound=NamedMeasures[MeasureValue])


PKwargs = TypeVar("PKwargs", bound=ProcessingKwargs)
ProcInput = TypeVar("ProcInput", xr.DataArray, Sample)


@dataclass(frozen=True)
class ProcessingImplBase(Generic[PKwargs, R, C], ABC):
"""Base class for all Pre- and Postprocessing implementations."""
Tensor = xr.DataArray

tensor_id: TensorId
"""id of tensor to operate on"""
@dataclass
class Operator(Generic[PKwargs], ABC):
kwargs: PKwargs
computed_measures: InitVar[Mapping[RequiredMeasure, MeasureValue]] = field(
default=MappingProxyType[RequiredMeasure, MeasureValue]({})
)
assert type(R) is type(C), "R and C are conceptually a higher kindes generic, their class has to be identical"
required: R = field(init=False)
computed: C = field(init=False)

def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None:
object.__setattr__(self, "required", self.get_required_measures(self.tensor_id, self.kwargs))
selected = {}
for f in fields(self.required):
req = getattr(self.required, f.name)
if req in computed_measures:
selected[f.name] = computed_measures[req]
else:
raise ValueError(f"Missing computed measure: {req} (as '{f.name}').")

object.__setattr__(self, "computed", self.required.__class__(**selected))

computed:
@abstractmethod
@classmethod
def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMeasures[RequiredMeasure]:
def __call__(self) -> Tensor:
...

def __call__(self, __input: ProcInput, /) -> ProcInput:
if isinstance(__input, xr.DataArray):
return self.apply(__input)
else:
return self.apply_to_sample(__input)

@abstractmethod
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
"""apply processing"""
...
@dataclass
class Binarize(Operator):
threshold: float


# class Source(Operator):
# def __call__(self) -> Tensor:
# return Tensor()

def apply_to_sample(self, sample: Sample) -> Sample:
ret = dict(sample)
ret[self.tensor_id] = self.apply(sample[self.tensor_id])
return ret

@dataclass
class Smooth(Operator):
sigma: float
tensor_source: Source

@abstractmethod
def get_descr(self) -> Union[ProcessingDescrBase, AssertProcessingBase]:
@classmethod
def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> R:
...

def __call__(self) -> Tensor:
return tensor * self.sigma # fixme

@dataclass(frozen=True)
class ProcessingImplBaseWoMeasures(
ProcessingImplBase[PKwargs, NamedMeasures[RequiredMeasure], NamedMeasures[MeasureValue]]
):
@classmethod
def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMeasures[RequiredMeasure]:
return NamedMeasures()

class Diff(Operator):
def __call__(self, a: Tensor, b: Tensor) -> Tensor:
return a - b

@dataclass(frozen=True)
class AssertDtypeImpl(ProcessingImplBaseWoMeasures[AssertDtypeKwargs]):
kwargs_class = AssertDtypeKwargs
_assert_with: Tuple[Type[DTypeLike], ...] = field(init=False)

def __post_init__(self, computed_measures: Mapping[RequiredMeasure, MeasureValue]) -> None:
super().__post_init__(computed_measures)
if isinstance(self.kwargs.dtype, str):
dtype = [self.kwargs.dtype]
else:
dtype = self.kwargs.dtype

object.__setattr__(self, "assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype))

def apply(self, tensor: xr.DataArray) -> xr.DataArray:
assert isinstance(tensor.dtype, self._assert_with)
return tensor

def get_descr(self):
return AssertDtype(kwargs=self.kwargs)


@dataclass(frozen=True)
class BinarizeImpl(ProcessingImplBaseWoMeasures[Union[v0_4.BinarizeKwargs, v0_5.BinarizeKwargs]]):
@dataclass
class SimpleOperator(Operator, ABC):
input_id: TensorId
output_id: TensorId

def __call__(self, sample: Sample, /) -> Sample:
ret = dict(sample)
ret[self.output_id] = self.apply(sample[self.input_id])
return ret


@abstractmethod
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
...

# @dataclass(frozen=True)
# class ProcessingImplBase(Generic[PKwargs, R, C], ABC):
# """Base class for all Pre- and Postprocessing implementations."""

# tensor_id: TensorId
# """id of tensor to operate on"""
# kwargs: PKwargs
# computed_measures: InitVar[Mapping[AnyRequiredMeasure, MeasureValue]] = field(
# default=MappingProxyType[AnyRequiredMeasure, MeasureValue]({})
# )
# assert type(R) is type(C), "R and C are conceptually a higher kindes generic, their class has to be identical"
# required: R = field(init=False)
# computed: C = field(init=False)

# def __post_init__(self, computed_measures: Mapping[AnyRequiredMeasure, MeasureValue]) -> None:
# object.__setattr__(self, "required", self.get_required_measures(self.tensor_id, self.kwargs))
# selected = {}
# for f in fields(self.required):
# req = getattr(self.required, f.name)
# if req in computed_measures:
# selected[f.name] = computed_measures[req]
# else:
# raise ValueError(f"Missing computed measure: {req} (as '{f.name}').")

# object.__setattr__(self, "computed", self.required.__class__(**selected))

# @abstractmethod
# @classmethod
# def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> R:
# ...

# def __call__(self, __input: ProcInput, /) -> ProcInput:
# if isinstance(__input, xr.DataArray):
# return self.apply(__input)
# else:
# return self.apply_to_sample(__input)

# @abstractmethod
# def apply(self, tensor: xr.DataArray) -> xr.DataArray:
# """apply processing"""
# ...

# def apply_to_sample(self, sample: Sample) -> Sample:
# ret = dict(sample)
# ret[self.tensor_id] = self.apply(sample[self.tensor_id])
# return ret

# @abstractmethod
# def get_descr(self) -> Union[ProcessingDescrBase, AssertProcessingBase]:
# ...


# @dataclass(frozen=True)
# class ProcessingImplBaseWoMeasures(
# ProcessingImplBase[PKwargs, NamedMeasures[AnyRequiredMeasure], NamedMeasures[MeasureValue]]
# ):
# @classmethod
# def get_required_measures(cls, tensor_id: TensorId, kwargs: PKwargs) -> NamedMeasures[AnyRequiredMeasure]:
# return NamedMeasures()


# @dataclass(frozen=True)
# class AssertDtypeImpl(ProcessingImplBaseWoMeasures[AssertDtypeKwargs]):
# kwargs_class = AssertDtypeKwargs
# _assert_with: Tuple[Type[DTypeLike], ...] = field(init=False)

# def __post_init__(self, computed_measures: Mapping[AnyRequiredMeasure, MeasureValue]) -> None:
# super().__post_init__(computed_measures)
# if isinstance(self.kwargs.dtype, str):
# dtype = [self.kwargs.dtype]
# else:
# dtype = self.kwargs.dtype

# object.__setattr__(self, "assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype))

# def apply(self, tensor: xr.DataArray) -> xr.DataArray:
# assert isinstance(tensor.dtype, self._assert_with)
# return tensor

# def get_descr(self):
# return AssertDtype(kwargs=self.kwargs)

# class AssertDtype(Operator):
# dtype: Sequence[DTypeLike]
# _assert_with: Tuple[Type[DTypeLike], ...] = field(init=False)

# def __post_init__(self, computed_measures: Mapping[AnyRequiredMeasure, MeasureValue]) -> None:
# super().__post_init__(computed_measures)
# if isinstance(self.kwargs.dtype, str):
# dtype = [self.kwargs.dtype]
# else:
# dtype = self.kwargs.dtype

# object.__setattr__(self, "assert_with", tuple(type(numpy.dtype(dt)) for dt in dtype))

# def apply(self, tensor: xr.DataArray) -> xr.DataArray:
# assert isinstance(tensor.dtype, self._assert_with)
# return tensor

# def get_descr(self):
# return AssertDtype(kwargs=self.kwargs)

@dataclass
class BinarizeImpl(Operator):
"""'output = tensor > threshold'."""
threshold: float

def apply(self, tensor: xr.DataArray) -> xr.DataArray:
return tensor > self.kwargs.threshold
Expand Down Expand Up @@ -237,7 +325,7 @@ class NamedMeasuresScaleMeanVariance(NamedMeasures[M]):
class ScaleMeanVarianceImpl(
ProcessingImplBase[
Union[v0_4.ScaleMeanVarianceKwargs, v0_5.ScaleMeanVarianceKwargs],
NamedMeasuresScaleMeanVariance[RequiredMeasure],
NamedMeasuresScaleMeanVariance[AnyRequiredMeasure],
NamedMeasuresScaleMeanVariance[MeasureValue],
]
):
Expand All @@ -248,7 +336,7 @@ def get_required_measures(
if kwargs.axes is None:
axes = None
elif isinstance(kwargs.axes, str):
axes = tuple(NonBatchAxisId(a) for a in kwargs.axes)
axes = tuple(AxisId(a) for a in kwargs.axes)
elif isinstance(kwargs.axes, collections.abc.Sequence): # pyright: ignore[reportUnnecessaryIsInstance]
axes = tuple(kwargs.axes)
else:
Expand Down Expand Up @@ -283,14 +371,14 @@ class NamedMeasuresScaleRange(NamedMeasures[M]):
class ScaleRangeImpl(
ProcessingImplBase[
Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs],
NamedMeasuresScaleRange[RequiredMeasure],
NamedMeasuresScaleRange[RequiredMeasure[Percentile, Any]],
NamedMeasuresScaleRange[MeasureValue],
]
):
@classmethod
def get_required_measures(cls, tensor_id: TensorId, kwargs: Union[v0_4.ScaleRangeKwargs, v0_5.ScaleRangeKwargs]):
ref_name = kwargs.reference_tensor or tensor_id
axes = None if kwargs.axes is None else tuple(NonBatchAxisId(a) for a in kwargs.axes)
axes = None if kwargs.axes is None else tuple(AxisId(a) for a in kwargs.axes)
return NamedMeasuresScaleRange(
lower=RequiredMeasure(Percentile(kwargs.min_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode),
upper=RequiredMeasure(Percentile(kwargs.max_percentile, axes=axes), cast(TensorId, ref_name), kwargs.mode),
Expand Down
2 changes: 1 addition & 1 deletion bioimageio/core/proc_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class _SetupProcessing(NamedTuple):
def setup_pre_and_postprocessing(model: ModelDescr, dataset: Iterator[Sample]) -> _SetupProcessing:
Prepared = List[Tuple[Type[ProcessingImplBase[Any, Any, Any]], ProcessingKwargs, TensorId]]

required_measures: Set[RequiredMeasure] = set()
required_measures: Set[RequiredMeasure[Any, Any]] = set()

def prepare_procs(tensor_descrs: Sequence[TensorDescr]):
prepared: Prepared = []
Expand Down
Loading

0 comments on commit 6ea6c85

Please sign in to comment.