Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Load features for classification #104

Merged
merged 32 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1f1c728
Added H5Slide reader and fixed reshaping in backend
moerlemans Aug 8, 2024
5eb07b3
added database models for features
moerlemans Aug 8, 2024
117b66e
added loading features as a dataset + necessary utils
moerlemans Aug 8, 2024
0d19bf3
fix black
moerlemans Aug 8, 2024
7fcf23e
improved classification pre_transforms and added random sampling of t…
moerlemans Aug 8, 2024
c8e1b7c
Added specific tile_size and size to the writer, so that the reader c…
moerlemans Aug 15, 2024
d9d6346
Added DataFormat enum that handles reading features in the readers an…
moerlemans Aug 15, 2024
7094379
fix bugs in database models
moerlemans Aug 16, 2024
e29fa3d
added ahcore ImageBackend enum which includes both ahcore and dlup ba…
moerlemans Aug 16, 2024
65d3ef2
added dataformat enum and fixed loading of datasets to work with feat…
moerlemans Aug 20, 2024
eedee6f
Fixes pretransforms to be used on features, also allows for option to…
moerlemans Aug 20, 2024
f58c656
minor fixes to allow models for classification
moerlemans Aug 20, 2024
aad8637
Adapt for dataformat enum
moerlemans Aug 20, 2024
0a3dcb3
added SetTarget method which chooses what the target will be in the l…
moerlemans Aug 21, 2024
2e24cb1
precommit fixes and removed the three pixel check, dlup will fix that
moerlemans Aug 21, 2024
cc9ea49
model will expect Bxnum_tilesxfeature_dim, so the ToTensor method sho…
moerlemans Aug 22, 2024
e60768d
merge with main, fixes mypy
moerlemans Sep 4, 2024
0d97c4a
fix some mypy
moerlemans Sep 4, 2024
f5c08dc
fixes mypy
moerlemans Sep 5, 2024
ec9285d
fix test for readers
moerlemans Sep 6, 2024
cb2d42b
Merge branch 'main' into feature/feature-dataset
moerlemans Sep 6, 2024
0c0d1e4
added tests for features
moerlemans Sep 6, 2024
8adffcc
cross_entropy now handles BxC inputs as well, also improved logic
moerlemans Sep 11, 2024
5eea88d
make dimension work out for labels
moerlemans Sep 11, 2024
c26910a
simplify reader and dataset builders
moerlemans Sep 11, 2024
7efaa70
fix writers and tests for writers
moerlemans Sep 11, 2024
2d83964
fix test, mypy and file_writer + callback
moerlemans Sep 13, 2024
71bb871
now also passes tests...
moerlemans Sep 13, 2024
ae580cb
mypy, pylint and it runs now
moerlemans Sep 16, 2024
ea8a330
bugfixes to make the writer working
moerlemans Sep 24, 2024
88dff49
cleaned feature description and manifest
moerlemans Sep 24, 2024
e43c5ae
fixes review comments
moerlemans Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions ahcore/backends.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Any
from enum import Enum
from typing import Any, Callable

import pyvips
from dlup.backends.common import AbstractSlideBackend
from dlup.backends.openslide_backend import OpenSlideSlide
from dlup.backends.pyvips_backend import PyVipsSlide
from dlup.backends.tifffile_backend import TifffileSlide
from dlup.types import PathLike # type: ignore

from ahcore.readers import StitchingMode, ZarrFileImageReader
from ahcore.readers import H5FileImageReader, StitchingMode, ZarrFileImageReader


class ZarrSlide(AbstractSlideBackend):
Expand Down Expand Up @@ -42,3 +46,53 @@ def read_region(self, coordinates: tuple[int, int], level: int, size: tuple[int,

def close(self) -> None:
self._reader.close()


class H5Slide(AbstractSlideBackend):
def __init__(self, filename: PathLike, stitching_mode: StitchingMode | str = StitchingMode.CROP) -> None:
super().__init__(filename)
self._reader: H5FileImageReader = H5FileImageReader(filename, stitching_mode=stitching_mode)
self._spacings = [(self._reader.mpp, self._reader.mpp)]

@property
def size(self) -> tuple[int, int]:
return self._reader.size

@property
def level_dimensions(self) -> tuple[tuple[int, int], ...]:
return (self._reader.size,)

@property
def level_downsamples(self) -> tuple[float, ...]:
return (1.0,)

@property
def vendor(self) -> str:
return "H5FileImageReader"

@property
def properties(self) -> dict[str, Any]:
return self._reader.metadata

@property
def magnification(self) -> None:
return None

def read_region(self, coordinates: tuple[int, int], level: int, size: tuple[int, int]) -> pyvips.Image:
return self._reader.read_region(coordinates, level, size)

def close(self) -> None:
self._reader.close()


class ImageBackend(Enum):
"""Available image backends."""

OPENSLIDE: Callable[[PathLike], OpenSlideSlide] = OpenSlideSlide
PYVIPS: Callable[[PathLike], PyVipsSlide] = PyVipsSlide
TIFFFILE: Callable[[PathLike], TifffileSlide] = TifffileSlide
H5: Callable[[PathLike], H5Slide] = H5Slide
ZARR: Callable[[PathLike], ZarrSlide] = ZarrSlide

def __call__(self, *args: Any) -> OpenSlideSlide | PyVipsSlide | TifffileSlide | H5Slide | ZarrSlide:
return self.value(*args) # type: ignore
80 changes: 53 additions & 27 deletions ahcore/callbacks/file_writer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import Type

from dlup.data.dataset import TiledWsiDataset
from dlup.tiling import Grid

from ahcore.callbacks.abstract_writer_callback import AbstractWriterCallback
from ahcore.callbacks.converters.common import ConvertCallbacks
from ahcore.lit_module import AhCoreLightningModule
from ahcore.utils.callbacks import get_output_filename as get_output_filename_
from ahcore.utils.data import DataDescription, GridDescription
from ahcore.utils.io import get_logger
from ahcore.utils.types import InferencePrecision, NormalizationType
from ahcore.utils.types import DataFormat, InferencePrecision, NormalizationType
from ahcore.writers import Writer

logger = get_logger(__name__)
Expand All @@ -27,7 +28,8 @@ def __init__(
normalization_type: str = NormalizationType.LOGITS,
precision: str = InferencePrecision.FP32,
callbacks: list[ConvertCallbacks] | None = None,
):
data_format: str = DataFormat.IMAGE,
) -> None:
"""
Callback to write predictions to H5 files. This callback is used to write whole-slide predictions to single H5
files in a separate thread.
Expand All @@ -54,6 +56,7 @@ def __init__(
self._suffix = ".cache"
self._normalization_type: NormalizationType = NormalizationType(normalization_type)
self._precision: InferencePrecision = InferencePrecision(precision)
self._data_format = DataFormat(data_format)

super().__init__(
writer_class=writer_class,
Expand Down Expand Up @@ -97,41 +100,64 @@ def build_writer_class(self, pl_module: AhCoreLightningModule, stage: str, filen
with open(link_fn, "a" if link_fn.is_file() else "w") as file:
file.write(f"{filename},{output_filename}\n")

current_dataset: TiledWsiDataset
current_dataset, _ = self._total_dataset.index_to_dataset(self._dataset_index) # type: ignore
slide_image = current_dataset.slide_image
num_samples = len(current_dataset)

data_description: DataDescription = pl_module.data_description
inference_grid: GridDescription = data_description.inference_grid

mpp = inference_grid.mpp
if mpp is None:
mpp = slide_image.mpp

_, size = slide_image.get_scaled_slide_bounds(slide_image.get_scaling(mpp))

# Let's get the data_description, so we can figure out the tile size and things like that
tile_size = inference_grid.tile_size
tile_overlap = inference_grid.tile_overlap

if stage == "validate":
grid = current_dataset._grids[0][0] # pylint: disable=protected-access
else:
grid = None # During inference we don't have a grid around ROI
size, mpp, tile_size, tile_overlap, num_samples, grid = self._get_writer_data_args(
pl_module, data_format=self._data_format, stage=stage
)

writer = self._writer_class(
output_filename,
size=size,
size=size, # --> (num_samples,1)
mpp=mpp,
tile_size=tile_size,
tile_size=tile_size, # --> (1,1)
tile_overlap=tile_overlap,
num_samples=num_samples,
color_profile=None,
is_compressed_image=False,
data_format=self._data_format,
progress=None,
precision=InferencePrecision(self._precision),
grid=grid,
)

return writer

def _get_writer_data_args(
self, pl_module: AhCoreLightningModule, data_format: DataFormat, stage: str
) -> tuple[tuple[int, int], float, tuple[int, int], tuple[int, int], int, Grid | None]:
current_dataset: TiledWsiDataset
current_dataset, _ = self._total_dataset.index_to_dataset(self._dataset_index) # type: ignore
slide_image = current_dataset.slide_image
num_samples = len(current_dataset)

if data_format == DataFormat.IMAGE or data_format == DataFormat.COMPRESSED_IMAGE:
data_description: DataDescription = pl_module.data_description
if data_description.inference_grid is None:
raise ValueError("Inference grid is not defined in the data description.")
inference_grid: GridDescription = data_description.inference_grid

mpp = inference_grid.mpp
if mpp is None:
mpp = slide_image.mpp

_, size = slide_image.get_scaled_slide_bounds(slide_image.get_scaling(mpp))

# Let's get the data_description, so we can figure out the tile size and things like that
tile_size = inference_grid.tile_size
tile_overlap = inference_grid.tile_overlap

if stage == "validate":
grid = current_dataset._grids[0][0] # pylint: disable=protected-access
else:
grid = None # During inference we don't have a grid around ROI

elif data_format == DataFormat.FEATURE:
size = (num_samples, 1)
mpp = 1.0
tile_size = (1, 1)
tile_overlap = (0, 0)
Comment on lines +152 to +156
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make the mpp equal to the factor of the tile size and mpp.

Copy link
Contributor Author

@moerlemans moerlemans Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not fully understand what you mean here, but these settings are mainly here so that the reader gets them correctly. I can change the mpp as long as the reader also reads at that specific mpp. 1.0 was chosen now for ease

num_samples = num_samples
grid = current_dataset._grids[0][0] # give grid, bc doesn't work otherwise

else:
raise NotImplementedError(f"Data format {data_format} is not yet supported.")

return size, mpp, tile_size, tile_overlap, num_samples, grid
3 changes: 2 additions & 1 deletion ahcore/cli/tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from rich.progress import Progress

from ahcore.cli import dir_path, file_path
from ahcore.utils.types import DataFormat
from ahcore.writers import H5FileImageWriter, Writer, ZarrFileImageWriter

_WriterClass = Type[Writer]
Expand Down Expand Up @@ -363,7 +364,7 @@ def _tiling_pipeline(
tile_size=dataset_cfg.tile_size,
tile_overlap=dataset_cfg.tile_overlap,
num_samples=len(dataset),
is_compressed_image=compression != "none",
data_format=DataFormat.IMAGE,
color_profile=color_profile,
extra_metadata=extra_metadata,
grid=dataset.grids[0][0],
Expand Down
6 changes: 5 additions & 1 deletion ahcore/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
num_workers: int = 16,
persistent_workers: bool = False,
pin_memory: bool = False,
use_cache: bool = True,
) -> None:
"""
Construct a DataModule based on a manifest.
Expand Down Expand Up @@ -178,6 +179,7 @@ def __init__(
self._num_workers = num_workers
self._persistent_workers = persistent_workers
self._pin_memory = pin_memory
self._use_cache = use_cache

self._fit_data_iterator: Iterator[_DlupDataset] | None = None
self._validate_data_iterator: Iterator[_DlupDataset] | None = None
Expand Down Expand Up @@ -245,7 +247,7 @@ def construct_dataset() -> ConcatDataset:
return ConcatDataset(datasets=datasets)

self._logger.info("Constructing dataset for stage %s (this can take a while)", stage)
dataset = self._load_from_cache(construct_dataset, stage=stage)
dataset = self._load_from_cache(construct_dataset, stage=stage) if self._use_cache else construct_dataset()
setattr(self, f"{stage}_dataset", dataset)

lengths = np.asarray([len(ds) for ds in dataset.datasets])
Expand Down Expand Up @@ -365,4 +367,6 @@ def uuid(self) -> uuid_module.UUID:
str
A unique identifier for this datamodule.
"""

# todo: It doesn't take into account different types of pretransforms, which can be important.
return basemodel_to_uuid(self.data_description)
6 changes: 5 additions & 1 deletion ahcore/lit_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ class AhCoreLightningModule(pl.LightningModule):
"grid_index",
]

_model: nn.Module | BaseAhcoreJitModel

def __init__(
self,
model: nn.Module | BaseAhcoreJitModel,
model: nn.Module | BaseAhcoreJitModel | functools.partial[nn.Module],
optimizer: torch.optim.optimizer.Optimizer, # noqa
data_description: DataDescription,
loss: nn.Module | None = None,
Expand Down Expand Up @@ -66,6 +68,8 @@ def __init__(
except AttributeError:
raise AttributeError("num_classes must be specified in data_description")
self._model = model(out_channels=self._num_classes)
elif isinstance(model, nn.Module):
self._model = model
else:
raise TypeError(f"The class of models: {model.__class__} is not supported on ahcore")
self._augmentations = augmentations
Expand Down
26 changes: 17 additions & 9 deletions ahcore/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
if class_proportions is not None:
_class_weights = 1 / class_proportions
_class_weights[_class_weights.isnan()] = 0.0
_class_weights = _class_weights / _class_weights.max()
_class_weights = _class_weights / _class_weights.max() # todo: check, shouldn't this be .sum?
self._class_weights = _class_weights
else:
self._class_weights = None
Expand Down Expand Up @@ -141,26 +141,34 @@ def cross_entropy(
else:
roi_sum = torch.tensor([np.prod(tuple(input.shape)[2:])]).to(input.device)

if input.dim() != target.dim():
raise ValueError(f"Dimension do not match for input and target. Got {input.dim()} and {target.dim()}")

if input.dim() == 2 and target.dim() == 2:
# handle cls task as an image of size 1x1
input = input.unsqueeze(-1).unsqueeze(-1)
target = target.unsqueeze(-1).unsqueeze(-1)

if ignore_index is None:
ignore_index = -100

# compute cross_entropy pixel by pixel
if not multiclass:
_cross_entropy = F.cross_entropy(
if multiclass:
_cross_entropy = F.binary_cross_entropy_with_logits(
input,
target.argmax(dim=1),
ignore_index=ignore_index,
target,
weight=None if weight is None else weight.to(input.device),
reduction="none",
label_smoothing=label_smoothing,
pos_weight=None,
)
else:
_cross_entropy = F.binary_cross_entropy_with_logits(
_cross_entropy = F.cross_entropy(
input,
target,
target.argmax(dim=1),
ignore_index=ignore_index,
weight=None if weight is None else weight.to(input.device),
reduction="none",
pos_weight=None,
label_smoothing=label_smoothing,
)

if limit is not None:
Expand Down
Loading