From 1f1c7280b0005384d586733c949d5843e9787eed Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 8 Aug 2024 10:57:11 +0200 Subject: [PATCH 01/30] Added H5Slide reader and fixed reshaping in backend --- ahcore/backends.py | 38 +++++++++++++++++++++++++++++++++++++- ahcore/readers.py | 14 ++++++++++---- 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/ahcore/backends.py b/ahcore/backends.py index 94e66d9..cb719b2 100644 --- a/ahcore/backends.py +++ b/ahcore/backends.py @@ -4,7 +4,7 @@ from dlup.backends.common import AbstractSlideBackend from dlup.types import PathLike -from ahcore.readers import StitchingMode, ZarrFileImageReader +from ahcore.readers import StitchingMode, ZarrFileImageReader, H5FileImageReader class ZarrSlide(AbstractSlideBackend): @@ -42,3 +42,39 @@ def read_region(self, coordinates: tuple[int, int], level: int, size: tuple[int, def close(self): self._reader.close() + +class H5Slide(AbstractSlideBackend): + def __init__(self, filename: PathLike, stitching_mode: StitchingMode | str = StitchingMode.CROP) -> None: + super().__init__(filename) + self._reader: H5FileImageReader = H5FileImageReader(filename, stitching_mode=stitching_mode) + self._spacings = [(self._reader.mpp, self._reader.mpp)] + + @property + def size(self): + return self._reader.size + + @property + def level_dimensions(self) -> tuple[tuple[int, int], ...]: + return (self._reader.size,) + + @property + def level_downsamples(self) -> tuple[float, ...]: + return (1.0,) + + @property + def vendor(self) -> str: + return "H5FileImageReader" + + @property + def properties(self) -> dict[str, Any]: + return self._reader.metadata + + @property + def magnification(self): + return None + + def read_region(self, coordinates: tuple[int, int], level: int, size: tuple[int, int]) -> pyvips.Image: + return self._reader.read_region(coordinates, level, size) + + def close(self): + self._reader.close() \ No newline at end of file diff --git a/ahcore/readers.py b/ahcore/readers.py index 9c3b583..e61965d 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -156,11 +156,16 @@ def metadata(self) -> dict[str, Any]: assert self._metadata return self._metadata - def _decompress_data(self, tile: GenericNumberArray) -> GenericNumberArray: + def _decompress_and_reshape_data(self, tile: GenericNumberArray) -> GenericNumberArray: if self._is_binary: with PIL.Image.open(io.BytesIO(tile)) as img: - return np.array(img).transpose(2, 0, 1) + return np.array(img).transpose(2, 0, 1) # fixme: this also shouldn't work because the thing is flattened and doesn't have 3 dimensions else: + # If handling features, we need to expand dimensions to match the expected shape. + if tile.ndim == 1: # fixme: is this the correct location for this + if not self._tile_size == [1, 1]: + raise NotImplementedError(f"Tile is single dimensional and {self._tile_size=} should be [1, 1], other cases have not been considered and cause unwanted behaviour.") + return tile.reshape(self._num_channels, *self._tile_size) return tile def read_region(self, location: tuple[int, int], level: int, size: tuple[int, int]) -> pyvips.Image: @@ -201,7 +206,7 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in total_rows = math.ceil((self._size[1] - self._tile_overlap[1]) / self._stride[1]) total_cols = math.ceil((self._size[0] - self._tile_overlap[0]) / self._stride[0]) - assert total_rows * total_cols == num_tiles + assert total_rows * total_cols == num_tiles # Equality only holds if features where created without mask x, y = location w, h = size @@ -230,7 +235,7 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in tile = ( self._empty_tile() if tile_index_in_image_dataset == -1 - else self._decompress_data(image_dataset[tile_index_in_image_dataset]) + else self._decompress_and_reshape_data(image_dataset[tile_index_in_image_dataset]) ) start_y = i * self._stride[1] - y end_y = start_y + self._tile_size[1] @@ -242,6 +247,7 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in img_start_x = max(0, start_x) img_end_x = min(w, end_x) + if self._stitching_mode == StitchingMode.CROP: crop_start_y = img_start_y - start_y crop_end_y = img_end_y - start_y From 5eb07b3e32a3df3d16f2207b47e6c479af1fa59c Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 8 Aug 2024 11:50:26 +0200 Subject: [PATCH 02/30] added database models for features --- ahcore/utils/database_models.py | 41 +++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/ahcore/utils/database_models.py b/ahcore/utils/database_models.py index 943e449..a7501bf 100644 --- a/ahcore/utils/database_models.py +++ b/ahcore/utils/database_models.py @@ -72,6 +72,7 @@ class Image(Base): annotations: Mapped[List["ImageAnnotations"]] = relationship("ImageAnnotations", back_populates="image") labels: Mapped[List["ImageLabels"]] = relationship("ImageLabels", back_populates="image") caches: Mapped[List["ImageCache"]] = relationship("ImageCache", back_populates="image") + features: Mapped[List["ImageFeature"]] = relationship("ImageFeature", back_populates="image") class ImageCache(Base): @@ -115,6 +116,46 @@ class CacheDescription(Base): cache: Mapped["ImageCache"] = relationship("ImageCache", back_populates="description") +class ImageFeature(Base): + """Image feature table.""" + + __tablename__ = "image_feature" + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + filename = Column(String, unique=True, nullable=False) + reader = Column(String) + num_tiles = Column(Integer) + image_id = Column(Integer, ForeignKey("image.id"), nullable=False) + + image: Mapped["Image"] = relationship("Image", back_populates="features") + description: Mapped["FeatureDescription"] = relationship("FeatureDescription", back_populates="image_feature") + +class FeatureDescription(Base): + """Feature description table.""" + + __tablename__ = "feature_description" + + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + + mpp = Column(Float) + tile_size_width = Column(Integer) + tile_size_height = Column(Integer) + tile_overlap_width = Column(Integer) + tile_overlap_height = Column(Integer) + description = Column(String) + + version = Column(String, unique=True, nullable=False) # use this to select which features we want to use + + model_name = Column(String) + model_path = Column(String) + feature_dimension = Column(Integer) + image_transforms_description = Column(String) # it would be nice to have a way to track which transforms the feature extractors used, but maybe this is not the best way to do it + class Mask(Base): """Mask table.""" From 117b66e2fc274c883237d6943a4ae11db9f6a429 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 8 Aug 2024 14:10:06 +0200 Subject: [PATCH 03/30] added loading features as a dataset + necessary utils --- ahcore/utils/data.py | 5 +-- ahcore/utils/manifest.py | 77 +++++++++++++++++++++++++++++++++++----- ahcore/writers.py | 3 +- 3 files changed, 73 insertions(+), 12 deletions(-) diff --git a/ahcore/utils/data.py b/ahcore/utils/data.py index 39a5fa9..8adac60 100644 --- a/ahcore/utils/data.py +++ b/ahcore/utils/data.py @@ -58,9 +58,10 @@ class DataDescription(BaseModel): manifest_database_uri: str manifest_name: str split_version: str + feature_version: Optional[str] = None annotations_dir: Path - training_grid: GridDescription - inference_grid: GridDescription + training_grid: Optional[GridDescription] = None + inference_grid: Optional[GridDescription] = None index_map: Optional[Dict[str, int]] remap_labels: Optional[Dict[str, str]] = None use_class_weights: Optional[bool] = False diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 7a83ab9..01ee2a4 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -34,6 +34,7 @@ Patient, Split, SplitDefinitions, + ImageFeature, ) from ahcore.utils.io import get_enum_key_from_value, get_logger from ahcore.utils.rois import compute_rois @@ -144,6 +145,28 @@ def get_labels_from_record(record: Image | Patient) -> list[tuple[str, str]] | N _labels = [(str(label.key), str(label.value)) for label in record.labels] if record.labels else None return _labels +def get_relevant_feature_info_from_record(record: ImageFeature, data_description: DataDescription) -> tuple[Path, PositiveFloat, tuple[PositiveInt, PositiveInt], tuple[PositiveInt, PositiveInt], TilingMode, ImageBackend, PositiveFloat]: + """Get the features from a record of type Image. + + Parameters + ---------- + record : Type[Image] + The record containing the features. + + Returns + ------- + tuple[Path, PositiveFloat, tuple[PositiveInt, PositiveInt], tuple[PositiveInt, PositiveInt], TilingMode, ImageBackend, PositiveFloat] + The features of the image. + """ + image_path = data_description.data_dir / record.filename + mpp = record.mpp + tile_size = (record.num_tiles, 1) # this would load all the features in one go --> can be extended to only load relevant tile level features + tile_overlap = (0, 0) + tile_mode = TilingMode.C + backend = ImageBackend[str(record.reader)] + overwrite_mpp = record.mpp + return image_path, mpp, tile_size, tile_overlap, tile_mode, backend, overwrite_mpp + def _get_rois(mask: WsiAnnotations | None, data_description: DataDescription, stage: str) -> Optional[Rois]: if (mask is None) or (stage != "fit") or (not data_description.convert_mask_to_rois): @@ -300,6 +323,28 @@ def get_image_metadata_by_id(self, image_id: int) -> ImageMetadata: assert image is not None # mypy return fetch_image_metadata(image) + def get_image_features_by_image_and_feature_version(self, image_id: int, feature_version: str) -> ImageFeature: + """ + Fetch the features for an image based on its ID and feature version. + + Parameters + ---------- + image_id : int + The ID of the image. + feature_version : str + The version of the features. + + Returns + ------- + ImageFeature + The features of the image. + """ + image_feature = self._session.query(ImageFeature).filter_by(image_id=image_id, version=feature_version).first() + self._ensure_record(image_feature, f"No features found for image ID {image_id} and feature version {feature_version}") + assert image_feature is not None + # todo: make sure that this only allows to run one ImageFeature, I think it should be good bc of the unique constraint + return image_feature + def __enter__(self) -> "DataManager": return self @@ -333,19 +378,23 @@ def datasets_from_data_description( assert isinstance(stage, str), "Stage should be a string." if stage == "fit": - grid_description = data_description.training_grid + grid_description = data_description.training_grid else: - grid_description = data_description.inference_grid + grid_description = data_description.inference_grid patients = db_manager.get_records_by_split( manifest_name=data_description.manifest_name, split_version=data_description.split_version, split_category=stage, ) + + use_features = data_description.feature_version is not None + for patient in patients: patient_labels = get_labels_from_record(patient) for image in patient.images: + mask, annotations = get_mask_and_annotations_from_record(annotations_root, image) assert isinstance(mask, WsiAnnotations) or (mask is None) image_labels = get_labels_from_record(image) @@ -353,12 +402,22 @@ def datasets_from_data_description( rois = _get_rois(mask, data_description, stage) mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold + if use_features: + image_feature = db_manager.get_image_features_by_image_and_feature_version(image.id, data_description.feature_version) + image_path, mpp, tile_size, tile_overlap, tile_mode, backend, overwrite_mpp = get_relevant_feature_info_from_record(image_feature, data_description) + else: + image_path = image_root / image.filename + tile_size = grid_description.tile_size + tile_overlap = grid_description.tile_overlap + backend = ImageBackend[str(image.reader)] + mpp = grid_description.mpp + overwrite_mpp = image.mpp + dataset = TiledWsiDataset.from_standard_tiling( - path=image_root / image.filename, - mpp=grid_description.mpp, - tile_size=grid_description.tile_size, - tile_overlap=grid_description.tile_overlap, - tile_mode=TilingMode.overflow, + path=image_path, + mpp=mpp, + tile_size=tile_size, + tile_overlap=tile_overlap, grid_order=GridOrder.C, crop=False, mask=mask, @@ -368,8 +427,8 @@ def datasets_from_data_description( annotations=annotations if stage != "predict" else None, labels=labels, # type: ignore transform=transform, - backend=ImageBackend[str(image.reader)], - overwrite_mpp=(image.mpp, image.mpp), + backend=backend, + overwrite_mpp=(overwrite_mpp, overwrite_mpp), limit_bounds=True, apply_color_profile=data_description.apply_color_profile, internal_handler="vips", diff --git a/ahcore/writers.py b/ahcore/writers.py index 1a59997..19fa909 100644 --- a/ahcore/writers.py +++ b/ahcore/writers.py @@ -87,6 +87,7 @@ def __init__( precision: InferencePrecision | None = None, grid: Grid | None = None, ) -> None: + # todo: better documentation for this, can basically set everything it is almost independent from slide image self._grid = grid self._filename: Path = filename self._size: tuple[int, int] = size @@ -481,7 +482,7 @@ def insert_data(self, batch: GenericNumberArray) -> None: raise ValueError(f"Batch should have a single element when writing h5. Got batch shape {batch.shape}.") batch_size = batch.shape[0] self._data[self._current_index : self._current_index + batch_size] = ( - batch.flatten() if self._is_compressed_image else batch + batch.flatten() if self._is_compressed_image else batch # fixme: flatten shouldn't work here ) def create_dataset( From 0d19bf30bd89a90c32a1ba418dfad910ce0896f6 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 8 Aug 2024 15:52:27 +0200 Subject: [PATCH 04/30] fix black --- ahcore/backends.py | 3 ++- ahcore/data/dataset.py | 6 ++---- ahcore/readers.py | 9 ++++++--- ahcore/utils/database_models.py | 6 +++++- ahcore/utils/manifest.py | 32 +++++++++++++++++++++++++------- 5 files changed, 40 insertions(+), 16 deletions(-) diff --git a/ahcore/backends.py b/ahcore/backends.py index cb719b2..3c93c84 100644 --- a/ahcore/backends.py +++ b/ahcore/backends.py @@ -43,6 +43,7 @@ def read_region(self, coordinates: tuple[int, int], level: int, size: tuple[int, def close(self): self._reader.close() + class H5Slide(AbstractSlideBackend): def __init__(self, filename: PathLike, stitching_mode: StitchingMode | str = StitchingMode.CROP) -> None: super().__init__(filename) @@ -77,4 +78,4 @@ def read_region(self, coordinates: tuple[int, int], level: int, size: tuple[int, return self._reader.read_region(coordinates, level, size) def close(self): - self._reader.close() \ No newline at end of file + self._reader.close() diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index f4c7d13..ea2d7a9 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -88,12 +88,10 @@ def __len__(self) -> int: return self.cumulative_sizes[-1] @overload - def __getitem__(self, index: int) -> DlupDatasetSample: - ... + def __getitem__(self, index: int) -> DlupDatasetSample: ... @overload - def __getitem__(self, index: slice) -> list[DlupDatasetSample]: - ... + def __getitem__(self, index: slice) -> list[DlupDatasetSample]: ... def __getitem__(self, index: Union[int, slice]) -> DlupDatasetSample | list[DlupDatasetSample]: """Returns the sample at the given index.""" diff --git a/ahcore/readers.py b/ahcore/readers.py index e61965d..0d5889b 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -159,12 +159,16 @@ def metadata(self) -> dict[str, Any]: def _decompress_and_reshape_data(self, tile: GenericNumberArray) -> GenericNumberArray: if self._is_binary: with PIL.Image.open(io.BytesIO(tile)) as img: - return np.array(img).transpose(2, 0, 1) # fixme: this also shouldn't work because the thing is flattened and doesn't have 3 dimensions + return np.array(img).transpose( + 2, 0, 1 + ) # fixme: this also shouldn't work because the thing is flattened and doesn't have 3 dimensions else: # If handling features, we need to expand dimensions to match the expected shape. if tile.ndim == 1: # fixme: is this the correct location for this if not self._tile_size == [1, 1]: - raise NotImplementedError(f"Tile is single dimensional and {self._tile_size=} should be [1, 1], other cases have not been considered and cause unwanted behaviour.") + raise NotImplementedError( + f"Tile is single dimensional and {self._tile_size=} should be [1, 1], other cases have not been considered and cause unwanted behaviour." + ) return tile.reshape(self._num_channels, *self._tile_size) return tile @@ -247,7 +251,6 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in img_start_x = max(0, start_x) img_end_x = min(w, end_x) - if self._stitching_mode == StitchingMode.CROP: crop_start_y = img_start_y - start_y crop_end_y = img_end_y - start_y diff --git a/ahcore/utils/database_models.py b/ahcore/utils/database_models.py index a7501bf..4356779 100644 --- a/ahcore/utils/database_models.py +++ b/ahcore/utils/database_models.py @@ -132,6 +132,7 @@ class ImageFeature(Base): image: Mapped["Image"] = relationship("Image", back_populates="features") description: Mapped["FeatureDescription"] = relationship("FeatureDescription", back_populates="image_feature") + class FeatureDescription(Base): """Feature description table.""" @@ -154,7 +155,10 @@ class FeatureDescription(Base): model_name = Column(String) model_path = Column(String) feature_dimension = Column(Integer) - image_transforms_description = Column(String) # it would be nice to have a way to track which transforms the feature extractors used, but maybe this is not the best way to do it + image_transforms_description = Column( + String + ) # it would be nice to have a way to track which transforms the feature extractors used, but maybe this is not the best way to do it + class Mask(Base): """Mask table.""" diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 01ee2a4..76c2760 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -145,7 +145,16 @@ def get_labels_from_record(record: Image | Patient) -> list[tuple[str, str]] | N _labels = [(str(label.key), str(label.value)) for label in record.labels] if record.labels else None return _labels -def get_relevant_feature_info_from_record(record: ImageFeature, data_description: DataDescription) -> tuple[Path, PositiveFloat, tuple[PositiveInt, PositiveInt], tuple[PositiveInt, PositiveInt], TilingMode, ImageBackend, PositiveFloat]: + +def get_relevant_feature_info_from_record(record: ImageFeature, data_description: DataDescription) -> tuple[ + Path, + PositiveFloat, + tuple[PositiveInt, PositiveInt], + tuple[PositiveInt, PositiveInt], + TilingMode, + ImageBackend, + PositiveFloat, +]: """Get the features from a record of type Image. Parameters @@ -160,7 +169,10 @@ def get_relevant_feature_info_from_record(record: ImageFeature, data_description """ image_path = data_description.data_dir / record.filename mpp = record.mpp - tile_size = (record.num_tiles, 1) # this would load all the features in one go --> can be extended to only load relevant tile level features + tile_size = ( + record.num_tiles, + 1, + ) # this would load all the features in one go --> can be extended to only load relevant tile level features tile_overlap = (0, 0) tile_mode = TilingMode.C backend = ImageBackend[str(record.reader)] @@ -340,7 +352,9 @@ def get_image_features_by_image_and_feature_version(self, image_id: int, feature The features of the image. """ image_feature = self._session.query(ImageFeature).filter_by(image_id=image_id, version=feature_version).first() - self._ensure_record(image_feature, f"No features found for image ID {image_id} and feature version {feature_version}") + self._ensure_record( + image_feature, f"No features found for image ID {image_id} and feature version {feature_version}" + ) assert image_feature is not None # todo: make sure that this only allows to run one ImageFeature, I think it should be good bc of the unique constraint return image_feature @@ -378,9 +392,9 @@ def datasets_from_data_description( assert isinstance(stage, str), "Stage should be a string." if stage == "fit": - grid_description = data_description.training_grid + grid_description = data_description.training_grid else: - grid_description = data_description.inference_grid + grid_description = data_description.inference_grid patients = db_manager.get_records_by_split( manifest_name=data_description.manifest_name, @@ -403,8 +417,12 @@ def datasets_from_data_description( mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold if use_features: - image_feature = db_manager.get_image_features_by_image_and_feature_version(image.id, data_description.feature_version) - image_path, mpp, tile_size, tile_overlap, tile_mode, backend, overwrite_mpp = get_relevant_feature_info_from_record(image_feature, data_description) + image_feature = db_manager.get_image_features_by_image_and_feature_version( + image.id, data_description.feature_version + ) + image_path, mpp, tile_size, tile_overlap, tile_mode, backend, overwrite_mpp = ( + get_relevant_feature_info_from_record(image_feature, data_description) + ) else: image_path = image_root / image.filename tile_size = grid_description.tile_size From 7fcf23e4b36722a80e489311c0f655b94be9e188 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 8 Aug 2024 16:19:35 +0200 Subject: [PATCH 05/30] improved classification pre_transforms and added random sampling of tiles --- ahcore/transforms/pre_transforms.py | 38 ++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index 2430778..6fabaaf 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -97,6 +97,13 @@ def for_wsi_classification( cls, data_description: DataDescription, requires_target: bool = True ) -> PreTransformTaskFactory: transforms: list[PreTransformCallable] = [] + + transforms.append(ImageToTensor()) + + transforms.append(SampleNFeatures(n=1000)) + + transforms.append(AllowCollate()) + if not requires_target: return cls(transforms) @@ -117,6 +124,35 @@ def __repr__(self) -> str: return f"PreTransformTaskFactory(transforms={self._transforms})" +class SampleNFeatures: + def __init__(self, n=1000): + self.n = n + logger.critical( + f"Sampling {n} features from the image. Sampling WITH replacement is done if there are not enough tiles." + ) + + def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: + features = sample["image"] + if not type(features) is torch.Tensor: + raise ValueError( + f"Expected features to be a torch.Tensor, got {type(features)}. Apply ImageToTensor transform first." + ) + + feature_dim, h, w = features.shape + + if not 0 < w < 1: + raise ValueError(f"Expected features to have a width dimension of 1, got {w}.") + + # TODO: DO WE WANT THIS????? SAMPLING WITH REPLACEMENT OR NOT??? + n_random_indices = ( + np.random.choice(h, self.n, replace=False) if h > self.n else np.random.choice(h, self.n, replace=True) + ) + + sample["image"] = features[:, n_random_indices, :] + + return sample + + class LabelToClassIndex: """ Maps label values to class indices according to the index_map specified in the data description. @@ -218,7 +254,7 @@ class ImageToTensor: def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]: tile: pyvips.Image = sample["image"] # Flatten the image to remove the alpha channel, using white as the background color - tile_ = tile.flatten(background=[255, 255, 255]) + tile_ = tile.flatten(background=[255, 255, 255]) # todo: check if this doesn't mess up features # Convert VIPS image to a numpy array then to a torch tensor np_image = tile_.numpy() From c8e1b7cdb59172709d29b0e24959645233058206 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 15 Aug 2024 13:04:09 +0200 Subject: [PATCH 06/30] Added specific tile_size and size to the writer, so that the reader can read features correctly --- ahcore/readers.py | 12 ++++++------ ahcore/writers.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/ahcore/readers.py b/ahcore/readers.py index 0d5889b..4eab384 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -112,12 +112,12 @@ def _open_file(self) -> None: self._read_metadata() if not self._metadata: - raise ValueError("Metadata of h5 file is empty.") + raise ValueError("Metadata of file is empty.") self._mpp = self._metadata["mpp"] - self._tile_size = self._metadata["tile_size"] + self._tile_size = self._metadata["reader_tile_size"] if "reader_tile_size" in self._metadata.keys() else self._metadata["tile_size"] self._tile_overlap = self._metadata["tile_overlap"] - self._size = self._metadata["size"] + self._size = self._metadata["reader_size"] if "reader_size" in self._metadata.keys() else self._metadata["size"] self._num_channels = self._metadata["num_channels"] self._dtype = self._metadata["dtype"] self._precision = self._metadata["precision"] @@ -165,9 +165,9 @@ def _decompress_and_reshape_data(self, tile: GenericNumberArray) -> GenericNumbe else: # If handling features, we need to expand dimensions to match the expected shape. if tile.ndim == 1: # fixme: is this the correct location for this - if not self._tile_size == [1, 1]: + if not self._tile_size[1] == 1: raise NotImplementedError( - f"Tile is single dimensional and {self._tile_size=} should be [1, 1], other cases have not been considered and cause unwanted behaviour." + f"Tile is single dimensional and {self._tile_size=} should be [x, 1], other cases have not been considered and cause unwanted behaviour." ) return tile.reshape(self._num_channels, *self._tile_size) return tile @@ -210,7 +210,7 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in total_rows = math.ceil((self._size[1] - self._tile_overlap[1]) / self._stride[1]) total_cols = math.ceil((self._size[0] - self._tile_overlap[0]) / self._stride[0]) - assert total_rows * total_cols == num_tiles # Equality only holds if features where created without mask + # assert total_rows * total_cols == num_tiles # Equality only holds if features where created without mask x, y = location w, h = size diff --git a/ahcore/writers.py b/ahcore/writers.py index 19fa909..6b4303e 100644 --- a/ahcore/writers.py +++ b/ahcore/writers.py @@ -231,6 +231,16 @@ def construct_metadata(self, writer_metadata: WriterMetadata) -> dict[str, Any]: "has_color_profile": self._color_profile is not None, } + if not self._is_compressed_image: + # When writing features, features are of size (num_tiles, 1). + # Setting reader_size to this value makes the readers able to read these features as images. + # Setting the reader_tile_size to this value allows for reading all the features in one go without loops. + # If these values are specified the reader will use these, instead of tile_size and size + + assert len(self._grid) == self._num_samples, "Number of tiles should be equal to the number of samples, this holds in the writer_callback." + + metadata.update({"reader_tile_size": (len(self._grid), 1), "reader_size": (len(self._grid), 1)}) + if self._extra_metadata: metadata.update(self._extra_metadata) @@ -329,6 +339,10 @@ def init_writer(self, first_coordinates: GenericNumberArray, first_batch: Generi self.set_grid() assert self._grid + + if self._grid.order != GridOrder.C: + raise ValueError(f"Grid order should be C, other orderings are not supported. Got {self._grid.order}") + num_tiles = len(self._grid) self._tile_indices = self.create_dataset( From d9d6346bd1b9ecbd7e254441091fcbae625de4fa Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 15 Aug 2024 14:57:40 +0200 Subject: [PATCH 07/30] Added DataFormat enum that handles reading features in the readers and is set in the writers --- ahcore/callbacks/file_writer_callback.py | 6 ++-- ahcore/cli/tiling.py | 4 ++- ahcore/readers.py | 15 ++++++--- ahcore/utils/types.py | 8 +++++ ahcore/writers.py | 39 +++++++++++------------- 5 files changed, 42 insertions(+), 30 deletions(-) diff --git a/ahcore/callbacks/file_writer_callback.py b/ahcore/callbacks/file_writer_callback.py index 2000648..f7591e6 100644 --- a/ahcore/callbacks/file_writer_callback.py +++ b/ahcore/callbacks/file_writer_callback.py @@ -11,7 +11,7 @@ from ahcore.utils.callbacks import get_output_filename as get_output_filename_ from ahcore.utils.data import DataDescription, GridDescription from ahcore.utils.io import get_logger -from ahcore.utils.types import InferencePrecision, NormalizationType +from ahcore.utils.types import InferencePrecision, NormalizationType, DataFormat from ahcore.writers import Writer logger = get_logger(__name__) @@ -27,6 +27,7 @@ def __init__( normalization_type: str = NormalizationType.LOGITS, precision: str = InferencePrecision.FP32, callbacks: list[ConvertCallbacks] | None = None, + data_format = DataFormat.IMAGE, ): """ Callback to write predictions to H5 files. This callback is used to write whole-slide predictions to single H5 @@ -54,6 +55,7 @@ def __init__( self._suffix = ".cache" self._normalization_type: NormalizationType = NormalizationType(normalization_type) self._precision: InferencePrecision = InferencePrecision(precision) + self._data_format = data_format super().__init__( writer_class=writer_class, @@ -128,7 +130,7 @@ def build_writer_class(self, pl_module: AhCoreLightningModule, stage: str, filen tile_overlap=tile_overlap, num_samples=num_samples, color_profile=None, - is_compressed_image=False, + data_format=self._data_format, progress=None, precision=InferencePrecision(self._precision), grid=grid, diff --git a/ahcore/cli/tiling.py b/ahcore/cli/tiling.py index 3f07540..e7deaef 100644 --- a/ahcore/cli/tiling.py +++ b/ahcore/cli/tiling.py @@ -31,6 +31,8 @@ from ahcore.cli import dir_path, file_path from ahcore.writers import H5FileImageWriter, Writer, ZarrFileImageWriter +from ahcore.utils.types import DataFormat + _WriterClass = Type[Writer] logger = getLogger(__name__) @@ -359,7 +361,7 @@ def _tiling_pipeline( tile_size=dataset_cfg.tile_size, tile_overlap=dataset_cfg.tile_overlap, num_samples=len(dataset), - is_compressed_image=compression != "none", + data_format=DataFormat.COMPRESSED_IMAGE if compression != "none" else DataFormat.IMAGE, color_profile=color_profile, extra_metadata=extra_metadata, grid=dataset.grids[0][0], diff --git a/ahcore/readers.py b/ahcore/readers.py index 4eab384..cac795b 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -24,7 +24,7 @@ from zarr.storage import ZipStore from ahcore.utils.io import get_logger -from ahcore.utils.types import BoundingBoxType, GenericNumberArray, InferencePrecision +from ahcore.utils.types import BoundingBoxType, GenericNumberArray, InferencePrecision, DataFormat logger = get_logger(__name__) @@ -48,6 +48,8 @@ def __init__(self, filename: Path, stitching_mode: StitchingMode) -> None: self._file: Optional[Any] = None self._metadata = None + self._num_tiles = None + self._data_format = None self._mpp = None self._tile_size = None self._tile_overlap = None @@ -114,10 +116,13 @@ def _open_file(self) -> None: if not self._metadata: raise ValueError("Metadata of file is empty.") + self._num_tiles = self._metadata["num_tiles"] + self._data_format = DataFormat(self._metadata["data_format"]) if "data_format" in self._metadata.keys() else DataFormat.IMAGE self._mpp = self._metadata["mpp"] - self._tile_size = self._metadata["reader_tile_size"] if "reader_tile_size" in self._metadata.keys() else self._metadata["tile_size"] + # features are always read at tile_size (1, 1), possibly faster to read the whole feature at once + self._tile_size = (self._num_tiles, 1) if self._data_format == DataFormat.FEATURE else self._metadata["tile_size"] self._tile_overlap = self._metadata["tile_overlap"] - self._size = self._metadata["reader_size"] if "reader_size" in self._metadata.keys() else self._metadata["size"] + self._size = (self._num_tiles, 1) if self._data_format == DataFormat.FEATURE else self._metadata["size"] self._num_channels = self._metadata["num_channels"] self._dtype = self._metadata["dtype"] self._precision = self._metadata["precision"] @@ -204,13 +209,13 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in assert self._tile_overlap is not None, "self._tile_overlap should not be None" image_dataset = self._file["data"] - num_tiles = self._metadata["num_tiles"] + tile_indices = self._file["tile_indices"] total_rows = math.ceil((self._size[1] - self._tile_overlap[1]) / self._stride[1]) total_cols = math.ceil((self._size[0] - self._tile_overlap[0]) / self._stride[0]) - # assert total_rows * total_cols == num_tiles # Equality only holds if features where created without mask + assert total_rows * total_cols == self._num_tiles or self._data_format == DataFormat.FEATURE, f"{total_rows=}, {total_cols=} and {self._num_tiles=}" # Equality only holds if features where created without mask x, y = location w, h = size diff --git a/ahcore/utils/types.py b/ahcore/utils/types.py index ba0b137..6b18131 100644 --- a/ahcore/utils/types.py +++ b/ahcore/utils/types.py @@ -86,3 +86,11 @@ class ViTEmbedMode(str, Enum): CONCAT_MEAN = "embed_concat_mean" CONCAT = "embed_concat" # Extend as necessary + +class DataFormat(str, Enum): + """Data format for the writer.""" + + FEATURE = "feature" + IMAGE = "image" + COMPRESSED_IMAGE = "compressed_image" + MASK = "mask" diff --git a/ahcore/writers.py b/ahcore/writers.py index 6b4303e..3aad8b3 100644 --- a/ahcore/writers.py +++ b/ahcore/writers.py @@ -14,6 +14,7 @@ from contextlib import contextmanager from pathlib import Path from typing import Any, Generator, NamedTuple, Optional +from enum import Enum import dlup import h5py @@ -26,11 +27,14 @@ import ahcore from ahcore.utils.io import get_git_hash, get_logger -from ahcore.utils.types import GenericNumberArray, InferencePrecision +from ahcore.utils.types import GenericNumberArray, InferencePrecision, DataFormat logger = get_logger(__name__) + + + def decode_array_to_pil(array: npt.NDArray[np.uint8]) -> PIL.Image.Image: """Convert encoded array to PIL image @@ -80,7 +84,7 @@ def __init__( tile_size: tuple[int, int], tile_overlap: tuple[int, int], num_samples: int, - is_compressed_image: bool = False, + data_format: DataFormat = DataFormat.IMAGE, color_profile: bytes | None = None, progress: Optional[Any] = None, extra_metadata: Optional[dict[str, Any]] = None, @@ -95,7 +99,7 @@ def __init__( self._tile_size: tuple[int, int] = tile_size self._tile_overlap: tuple[int, int] = tile_overlap self._num_samples: int = num_samples - self._is_compressed_image: bool = is_compressed_image + self._data_format = data_format self._color_profile: bytes | None = color_profile self._extra_metadata = extra_metadata self._precision = precision @@ -113,6 +117,7 @@ def __init__( self._partial_suffix: str = f"{self._filename.suffix}.partial" + @abc.abstractmethod def open_file(self, mode: str = "w") -> Any: pass @@ -163,7 +168,7 @@ def _batch_generator( yield tile def get_writer_metadata(self, first_batch: GenericNumberArray) -> WriterMetadata: - if self._is_compressed_image: + if self._data_format == DataFormat.COMPRESSED_IMAGE: if self._precision is not None: raise ValueError("Precision cannot be set when writing compressed images.") # We need to read the first batch as it is a compressed PIL image @@ -222,7 +227,7 @@ def construct_metadata(self, writer_metadata: WriterMetadata) -> dict[str, Any]: "mode": writer_metadata.mode, "format": writer_metadata.format, "dtype": writer_metadata.dtype, - "is_binary": self._is_compressed_image, + "data_format": self._data_format.value, "grid_offset": writer_metadata.grid_offset, "precision": self._precision.value if self._precision else str(InferencePrecision.FP32), "multiplier": ( @@ -231,16 +236,6 @@ def construct_metadata(self, writer_metadata: WriterMetadata) -> dict[str, Any]: "has_color_profile": self._color_profile is not None, } - if not self._is_compressed_image: - # When writing features, features are of size (num_tiles, 1). - # Setting reader_size to this value makes the readers able to read these features as images. - # Setting the reader_tile_size to this value allows for reading all the features in one go without loops. - # If these values are specified the reader will use these, instead of tile_size and size - - assert len(self._grid) == self._num_samples, "Number of tiles should be equal to the number of samples, this holds in the writer_callback." - - metadata.update({"reader_tile_size": (len(self._grid), 1), "reader_size": (len(self._grid), 1)}) - if self._extra_metadata: metadata.update(self._extra_metadata) @@ -304,7 +299,7 @@ def consume(self, batch_generator: Generator[tuple[GenericNumberArray, GenericNu batch_size = batch.shape[0] coordinates_dataset[self._current_index : self._current_index + batch_size] = coordinates - if self._is_compressed_image: + if self._data_format == DataFormat.COMPRESSED_IMAGE: # When the batch has variable lengths, we need to insert each sample separately for sample in batch: self.insert_data(sample[np.newaxis, ...]) @@ -353,7 +348,7 @@ def init_writer(self, first_coordinates: GenericNumberArray, first_batch: Generi compression="gzip", ) - if not self._is_compressed_image: + if self._data_format == DataFormat.COMPRESSED_IMAGE: shape = first_batch.shape[1:] self._data = self.create_dataset( file, @@ -448,14 +443,14 @@ def create_variable_length_dataset( def insert_data(self, batch: GenericNumberArray) -> None: """Insert a batch into a Zarr dataset.""" - if not batch.shape[0] == 1 and self._is_compressed_image: + if not batch.shape[0] == 1 and self._data_format == DataFormat.COMPRESSED_IMAGE: raise ValueError(f"Batch should have a single element when writing zarr. Got batch shape {batch.shape}.") - if self._is_compressed_image: + if self._data_format == DataFormat.COMPRESSED_IMAGE: self._data[self._current_index] = batch.reshape(-1) else: self._data[self._current_index : self._current_index + batch.shape[0]] = ( - batch.flatten() if self._is_compressed_image else batch + batch.flatten() if self._data_format == DataFormat.COMPRESSED_IMAGE else batch ) def write_metadata(self, metadata: dict[str, Any], file: Any) -> None: @@ -492,11 +487,11 @@ def write_metadata(self, metadata: Any, file: Any) -> None: def insert_data(self, batch: GenericNumberArray) -> None: """Insert a batch into a H5 dataset.""" - if not batch.shape[0] == 1 and self._is_compressed_image: + if not batch.shape[0] == 1 and self._data_format == DataFormat.COMPRESSED_IMAGE: raise ValueError(f"Batch should have a single element when writing h5. Got batch shape {batch.shape}.") batch_size = batch.shape[0] self._data[self._current_index : self._current_index + batch_size] = ( - batch.flatten() if self._is_compressed_image else batch # fixme: flatten shouldn't work here + batch.flatten() if self._data_format == DataFormat.COMPRESSED_IMAGE else batch # fixme: flatten shouldn't work here ) def create_dataset( From 7094379fd3ec0642ef496ac908bf10fe109eb766 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Fri, 16 Aug 2024 13:47:42 +0200 Subject: [PATCH 08/30] fix bugs in database models --- ahcore/utils/database_models.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/ahcore/utils/database_models.py b/ahcore/utils/database_models.py index 4356779..e2df80c 100644 --- a/ahcore/utils/database_models.py +++ b/ahcore/utils/database_models.py @@ -113,7 +113,7 @@ class CacheDescription(Base): mask_threshold = Column(Float) grid_order = Column(String) - cache: Mapped["ImageCache"] = relationship("ImageCache", back_populates="description") + cache: Mapped[List["ImageCache"]] = relationship("ImageCache", back_populates="description") class ImageFeature(Base): @@ -128,9 +128,10 @@ class ImageFeature(Base): reader = Column(String) num_tiles = Column(Integer) image_id = Column(Integer, ForeignKey("image.id"), nullable=False) + feature_description_id = Column(Integer, ForeignKey("feature_description.id"), nullable=False) image: Mapped["Image"] = relationship("Image", back_populates="features") - description: Mapped["FeatureDescription"] = relationship("FeatureDescription", back_populates="image_feature") + feature_description: Mapped["FeatureDescription"] = relationship("FeatureDescription", back_populates="features") class FeatureDescription(Base): @@ -150,15 +151,18 @@ class FeatureDescription(Base): tile_overlap_height = Column(Integer) description = Column(String) - version = Column(String, unique=True, nullable=False) # use this to select which features we want to use + # use this to select which features we want to use + version = Column(String, unique=True, nullable=False) + model_name = Column(String) model_path = Column(String) feature_dimension = Column(Integer) - image_transforms_description = Column( - String - ) # it would be nice to have a way to track which transforms the feature extractors used, but maybe this is not the best way to do it + image_transforms_description = Column(String) + # it would be nice to have a way to track which transforms the feature extractors used, + # but maybe this is not the best way to do it + features: Mapped[List["ImageFeature"]] = relationship("ImageFeature", back_populates="feature_description") class Mask(Base): """Mask table.""" From e29fa3dbc37f21f3c45335b74ade06694c036f21 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Fri, 16 Aug 2024 13:48:50 +0200 Subject: [PATCH 09/30] added ahcore ImageBackend enum which includes both ahcore and dlup backends --- ahcore/backends.py | 22 ++++++++++++++++++++++ ahcore/readers.py | 1 + ahcore/utils/manifest.py | 2 +- 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/ahcore/backends.py b/ahcore/backends.py index 3c93c84..afc9195 100644 --- a/ahcore/backends.py +++ b/ahcore/backends.py @@ -6,6 +6,28 @@ from ahcore.readers import StitchingMode, ZarrFileImageReader, H5FileImageReader +from enum import Enum +from typing import Any, Callable + +from dlup.backends.openslide_backend import OpenSlideSlide +from dlup.backends.tifffile_backend import TifffileSlide +from dlup.backends.pyvips_backend import PyVipsSlide +from dlup.types import PathLike + + + +class ImageBackend(Enum): + """Available image backends.""" + + OPENSLIDE: Callable[[PathLike], OpenSlideSlide] = OpenSlideSlide + PYVIPS: Callable[[PathLike], PyVipsSlide] = PyVipsSlide + TIFFFILE: Callable[[PathLike], TifffileSlide] = TifffileSlide + H5: Callable[[PathLike], H5Slide] = H5Slide + ZARR: Callable[[PathLike], ZarrSlide] = ZarrSlide + + def __call__(self, *args: "ImageBackend" | str) -> Any: + return self.value(*args) + class ZarrSlide(AbstractSlideBackend): def __init__(self, filename: PathLike, stitching_mode: StitchingMode | str = StitchingMode.CROP) -> None: diff --git a/ahcore/readers.py b/ahcore/readers.py index cac795b..6b3b580 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -117,6 +117,7 @@ def _open_file(self) -> None: raise ValueError("Metadata of file is empty.") self._num_tiles = self._metadata["num_tiles"] + # set a standard value if it is not present self._data_format = DataFormat(self._metadata["data_format"]) if "data_format" in self._metadata.keys() else DataFormat.IMAGE self._mpp = self._metadata["mpp"] # features are always read at tile_size (1, 1), possibly faster to read the whole feature at once diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 76c2760..e1e7060 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -12,7 +12,7 @@ from dlup import SlideImage from dlup.annotations import WsiAnnotations -from dlup.backends import ImageBackend +from ahcore.backends import ImageBackend from dlup.data.dataset import RegionFromWsiDatasetSample, TiledWsiDataset, TileSample from dlup.tiling import GridOrder, TilingMode from pydantic import BaseModel From 65d3ef2ab5e1cb985caf725324cc56e69df8f9d2 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Tue, 20 Aug 2024 13:04:26 +0200 Subject: [PATCH 10/30] added dataformat enum and fixed loading of datasets to work with features --- ahcore/backends.py | 27 +++++++++++++-------------- ahcore/utils/manifest.py | 32 ++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/ahcore/backends.py b/ahcore/backends.py index afc9195..e87be60 100644 --- a/ahcore/backends.py +++ b/ahcore/backends.py @@ -15,20 +15,6 @@ from dlup.types import PathLike - -class ImageBackend(Enum): - """Available image backends.""" - - OPENSLIDE: Callable[[PathLike], OpenSlideSlide] = OpenSlideSlide - PYVIPS: Callable[[PathLike], PyVipsSlide] = PyVipsSlide - TIFFFILE: Callable[[PathLike], TifffileSlide] = TifffileSlide - H5: Callable[[PathLike], H5Slide] = H5Slide - ZARR: Callable[[PathLike], ZarrSlide] = ZarrSlide - - def __call__(self, *args: "ImageBackend" | str) -> Any: - return self.value(*args) - - class ZarrSlide(AbstractSlideBackend): def __init__(self, filename: PathLike, stitching_mode: StitchingMode | str = StitchingMode.CROP) -> None: super().__init__(filename) @@ -101,3 +87,16 @@ def read_region(self, coordinates: tuple[int, int], level: int, size: tuple[int, def close(self): self._reader.close() + + +class ImageBackend(Enum): + """Available image backends.""" + + OPENSLIDE: Callable[[PathLike], OpenSlideSlide] = OpenSlideSlide + PYVIPS: Callable[[PathLike], PyVipsSlide] = PyVipsSlide + TIFFFILE: Callable[[PathLike], TifffileSlide] = TifffileSlide + H5: Callable[[PathLike], H5Slide] = H5Slide + ZARR: Callable[[PathLike], ZarrSlide] = ZarrSlide + + def __call__(self, *args) -> Any: + return self.value(*args) \ No newline at end of file diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index e1e7060..46ba02a 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -35,6 +35,7 @@ Split, SplitDefinitions, ImageFeature, + FeatureDescription, ) from ahcore.utils.io import get_enum_key_from_value, get_logger from ahcore.utils.rois import compute_rois @@ -146,12 +147,13 @@ def get_labels_from_record(record: Image | Patient) -> list[tuple[str, str]] | N return _labels -def get_relevant_feature_info_from_record(record: ImageFeature, data_description: DataDescription) -> tuple[ +def get_relevant_feature_info_from_record(record: ImageFeature, + data_description: DataDescription, + feature_description: FeatureDescription) -> tuple[ Path, PositiveFloat, tuple[PositiveInt, PositiveInt], tuple[PositiveInt, PositiveInt], - TilingMode, ImageBackend, PositiveFloat, ]: @@ -168,16 +170,16 @@ def get_relevant_feature_info_from_record(record: ImageFeature, data_description The features of the image. """ image_path = data_description.data_dir / record.filename - mpp = record.mpp + mpp = feature_description.mpp tile_size = ( record.num_tiles, 1, ) # this would load all the features in one go --> can be extended to only load relevant tile level features tile_overlap = (0, 0) - tile_mode = TilingMode.C - backend = ImageBackend[str(record.reader)] - overwrite_mpp = record.mpp - return image_path, mpp, tile_size, tile_overlap, tile_mode, backend, overwrite_mpp + + backend = ImageBackend[str(record.reader)].value + overwrite_mpp = feature_description.mpp + return image_path, mpp, tile_size, tile_overlap, backend, overwrite_mpp def _get_rois(mask: WsiAnnotations | None, data_description: DataDescription, stage: str) -> Optional[Rois]: @@ -351,13 +353,14 @@ def get_image_features_by_image_and_feature_version(self, image_id: int, feature ImageFeature The features of the image. """ - image_feature = self._session.query(ImageFeature).filter_by(image_id=image_id, version=feature_version).first() + feature_description = self._session.query(FeatureDescription).filter_by(version=feature_version).first() + image_feature = self._session.query(ImageFeature).filter_by(image_id=image_id, feature_description_id=feature_description.id).first() self._ensure_record( image_feature, f"No features found for image ID {image_id} and feature version {feature_version}" ) assert image_feature is not None # todo: make sure that this only allows to run one ImageFeature, I think it should be good bc of the unique constraint - return image_feature + return image_feature, feature_description def __enter__(self) -> "DataManager": return self @@ -417,12 +420,13 @@ def datasets_from_data_description( mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold if use_features: - image_feature = db_manager.get_image_features_by_image_and_feature_version( + image_feature, feature_description = db_manager.get_image_features_by_image_and_feature_version( image.id, data_description.feature_version ) - image_path, mpp, tile_size, tile_overlap, tile_mode, backend, overwrite_mpp = ( - get_relevant_feature_info_from_record(image_feature, data_description) + image_path, mpp, tile_size, tile_overlap, backend, overwrite_mpp = ( + get_relevant_feature_info_from_record(image_feature, data_description, feature_description) ) + tile_mode = TilingMode.skip else: image_path = image_root / image.filename tile_size = grid_description.tile_size @@ -430,6 +434,9 @@ def datasets_from_data_description( backend = ImageBackend[str(image.reader)] mpp = grid_description.mpp overwrite_mpp = image.mpp + tile_mode = TilingMode.overflow + + # fixme: something is still wrong here, in the reader or in the database, got empty tiles dataset = TiledWsiDataset.from_standard_tiling( path=image_path, @@ -437,6 +444,7 @@ def datasets_from_data_description( tile_size=tile_size, tile_overlap=tile_overlap, grid_order=GridOrder.C, + tile_mode=tile_mode, crop=False, mask=mask, mask_threshold=mask_threshold, From eedee6fbaf4098b90afda11a4f935f5c17465133 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Tue, 20 Aug 2024 13:06:11 +0200 Subject: [PATCH 11/30] Fixes pretransforms to be used on features, also allows for option to turn of cache --- ahcore/data/dataset.py | 7 +++- ahcore/transforms/pre_transforms.py | 50 ++++++++++++++++++++--------- ahcore/utils/data.py | 1 + 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index ea2d7a9..a6438bb 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -117,6 +117,7 @@ def __init__( num_workers: int = 16, persistent_workers: bool = False, pin_memory: bool = False, + use_cache: bool = True, ) -> None: """ Construct a DataModule based on a manifest. @@ -176,6 +177,7 @@ def __init__( self._num_workers = num_workers self._persistent_workers = persistent_workers self._pin_memory = pin_memory + self._use_cache = use_cache self._fit_data_iterator: Iterator[_DlupDataset] | None = None self._validate_data_iterator: Iterator[_DlupDataset] | None = None @@ -196,6 +198,7 @@ def __init__( self._limit_fit_samples = None self._limit_predict_samples = None + @property def data_manager(self) -> DataManager: return self._data_manager @@ -244,7 +247,7 @@ def construct_dataset() -> ConcatDataset: return ConcatDataset(datasets=datasets) self._logger.info("Constructing dataset for stage %s (this can take a while)", stage) - dataset = self._load_from_cache(construct_dataset, stage=stage) + dataset = self._load_from_cache(construct_dataset, stage=stage) if self._use_cache else construct_dataset() setattr(self, f"{stage}_dataset", dataset) lengths = np.asarray([len(ds) for ds in dataset.datasets]) @@ -364,4 +367,6 @@ def uuid(self) -> uuid_module.UUID: str A unique identifier for this datamodule. """ + + # todo: It doesn't take into account different types of pretransforms, which can be important. return basemodel_to_uuid(self.data_description) diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index 6fabaaf..bedb795 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -98,11 +98,8 @@ def for_wsi_classification( ) -> PreTransformTaskFactory: transforms: list[PreTransformCallable] = [] - transforms.append(ImageToTensor()) - transforms.append(SampleNFeatures(n=1000)) - transforms.append(AllowCollate()) if not requires_target: return cls(transforms) @@ -111,6 +108,11 @@ def for_wsi_classification( if index_map is None: raise ConfigurationError("`index_map` is required for classification models when the target is required.") + label_keys = data_description.label_keys + + if label_keys is not None: + transforms.append(SelectSpecificLabels(keys=label_keys)) + transforms.append(LabelToClassIndex(index_map=index_map)) return cls(transforms) @@ -127,28 +129,32 @@ def __repr__(self) -> str: class SampleNFeatures: def __init__(self, n=1000): self.n = n - logger.critical( + logger.warning( f"Sampling {n} features from the image. Sampling WITH replacement is done if there are not enough tiles." ) def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: features = sample["image"] - if not type(features) is torch.Tensor: - raise ValueError( - f"Expected features to be a torch.Tensor, got {type(features)}. Apply ImageToTensor transform first." - ) - feature_dim, h, w = features.shape + # Get the dimensions of the image + feature_dim = features.bands # Number of channels (similar to feature_dim) + h = features.height # Height + w = features.width # Width - if not 0 < w < 1: - raise ValueError(f"Expected features to have a width dimension of 1, got {w}.") + if h != 1: + raise ValueError(f"Expected features to have a width dimension of 1, got {h}.") - # TODO: DO WE WANT THIS????? SAMPLING WITH REPLACEMENT OR NOT??? n_random_indices = ( - np.random.choice(h, self.n, replace=False) if h > self.n else np.random.choice(h, self.n, replace=True) + np.random.choice(w, self.n, replace=False) if w > self.n else np.random.choice(w, self.n, replace=True) ) - sample["image"] = features[:, n_random_indices, :] + # Extract the selected columns (indices) from the image + # Create a new image from the selected indices + + selected_columns = [features.crop(idx, 0, 1, h) for idx in n_random_indices] + + # Combine the selected columns back into a single image + sample["image"] = pyvips.Image.arrayjoin(selected_columns, across=1) return sample @@ -176,6 +182,16 @@ def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: return sample +class SelectSpecificLabels: + def __init__(self, keys: list[str] | str): + if isinstance(keys, str): + keys = [keys] + self._keys = keys + + def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: + sample["labels"] = {label_key: label_value for label_key, label_value in sample["labels"].items() if label_key in self._keys} + return sample + class OneHotEncodeMask: def __init__(self, index_map: dict[str, int]): """Create the one-hot encoding of the mask for segmentation. @@ -254,7 +270,11 @@ class ImageToTensor: def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]: tile: pyvips.Image = sample["image"] # Flatten the image to remove the alpha channel, using white as the background color - tile_ = tile.flatten(background=[255, 255, 255]) # todo: check if this doesn't mess up features + if tile.bands > 4: + # assuming that more than four bands/channels means that we are handling features + tile_ = tile + else: + tile_ = tile.flatten(background=[255, 255, 255]) # todo: check if this doesn't mess up features # Convert VIPS image to a numpy array then to a torch tensor np_image = tile_.numpy() diff --git a/ahcore/utils/data.py b/ahcore/utils/data.py index 8adac60..c5362ae 100644 --- a/ahcore/utils/data.py +++ b/ahcore/utils/data.py @@ -64,6 +64,7 @@ class DataDescription(BaseModel): inference_grid: Optional[GridDescription] = None index_map: Optional[Dict[str, int]] remap_labels: Optional[Dict[str, str]] = None + label_keys: list[str] | str | None = None use_class_weights: Optional[bool] = False convert_mask_to_rois: bool = True use_roi: bool = True From f58c656fd0bee1ed52eb6f256b85e1ec0a92a1b6 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Tue, 20 Aug 2024 13:06:58 +0200 Subject: [PATCH 12/30] minor fixes to allow models for classification --- ahcore/lit_module.py | 2 ++ ahcore/utils/io.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/ahcore/lit_module.py b/ahcore/lit_module.py index 3001501..535e077 100644 --- a/ahcore/lit_module.py +++ b/ahcore/lit_module.py @@ -66,6 +66,8 @@ def __init__( except AttributeError: raise AttributeError("num_classes must be specified in data_description") self._model = model(out_channels=self._num_classes) + elif isinstance(model, nn.Module): + self._model = model else: raise TypeError(f"The class of models: {model.__class__} is not supported on ahcore") self._augmentations = augmentations diff --git a/ahcore/utils/io.py b/ahcore/utils/io.py index bc7b2dd..cde38aa 100644 --- a/ahcore/utils/io.py +++ b/ahcore/utils/io.py @@ -240,6 +240,8 @@ def load_weights(model: LightningModule, config: DictConfig) -> LightningModule: _model = getattr(model, "_model") if isinstance(_model, BaseAhcoreJitModel): return model + if config.ckpt_path == "" or config.ckpt_path is None: + raise ValueError(f"Checkpoint path not provided in config.") else: # Load checkpoint weights lit_ckpt = torch.load(config.ckpt_path) @@ -262,11 +264,13 @@ def validate_checkpoint_paths(config: DictConfig) -> DictConfig: """ # Extract paths with clear fallbacks checkpoint_path = config.get("ckpt_path") + # this is not right and a bit hacky with the new models jit_path = config.get("lit_module", {}).get("model", {}).get("jit_path") # Validate configuration paths_defined = [path for path in [checkpoint_path, jit_path] if path] if len(paths_defined) == 0: - raise RuntimeError("No checkpoint or jit path provided in config.") + logging.warning("No checkpoint or jit path provided in config.") + return config elif len(paths_defined) > 1: raise RuntimeError("Checkpoint path and jit path cannot be defined simultaneously.") else: From aad86376bffa7dbfe22339a5fff959797de39ef7 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Tue, 20 Aug 2024 13:07:29 +0200 Subject: [PATCH 13/30] Adapt for dataformat enum --- ahcore/readers.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/ahcore/readers.py b/ahcore/readers.py index 6b3b580..b9e5fce 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -116,14 +116,15 @@ def _open_file(self) -> None: if not self._metadata: raise ValueError("Metadata of file is empty.") + self._num_samples = self._metadata["num_samples"] self._num_tiles = self._metadata["num_tiles"] # set a standard value if it is not present self._data_format = DataFormat(self._metadata["data_format"]) if "data_format" in self._metadata.keys() else DataFormat.IMAGE self._mpp = self._metadata["mpp"] # features are always read at tile_size (1, 1), possibly faster to read the whole feature at once - self._tile_size = (self._num_tiles, 1) if self._data_format == DataFormat.FEATURE else self._metadata["tile_size"] + self._tile_size = (self._num_samples, 1) if self._data_format == DataFormat.FEATURE else self._metadata["tile_size"] self._tile_overlap = self._metadata["tile_overlap"] - self._size = (self._num_tiles, 1) if self._data_format == DataFormat.FEATURE else self._metadata["size"] + self._size = (self._num_samples, 1) if self._data_format == DataFormat.FEATURE else self._metadata["size"] self._num_channels = self._metadata["num_channels"] self._dtype = self._metadata["dtype"] self._precision = self._metadata["precision"] @@ -237,6 +238,31 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in if self._stitching_mode == StitchingMode.AVERAGE: average_mask = np.zeros((h, w), dtype=self._dtype) + if self._data_format == DataFormat.FEATURE: + if self._stitching_mode != StitchingMode.CROP: + raise NotImplementedError("Stitching mode other than CROP is not supported for features.") + + if image_dataset.shape[0] != self._num_samples: + raise ValueError(f"Reading features expects that the saved feature vectors are the same " + f"length as the number of samples in the dataset. " + f"Feature vector length was {image_dataset.shape[0]}, " + f"number of samples in the dataset was {self._num_samples}") + + if x+w > self._num_samples or y+h > 1: + if x+w == self._num_samples + 3: + # fixme: this is ugly, but dlup does some resizing... + w = w - 3 + else: + raise ValueError(f"Feature vectors are saved as (num_samples, 1) and the requested size {size} at location {location} is too large.") + + # this simplified version of the crop is done as it is faster than the general crop + return pyvips.Image.new_from_array(np.expand_dims(image_dataset[x: x+w, :], axis=0)) + + + + + + for i in range(start_row, end_row): for j in range(start_col, end_col): tile_idx = (i * total_cols) + j From 0a3dcb3e3cfd26bf7f336a3d2f9ede080a75e9bb Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Wed, 21 Aug 2024 16:25:23 +0200 Subject: [PATCH 14/30] added SetTarget method which chooses what the target will be in the loss. In the future maybe this should be done with an argument from the configs. Now it works with whatever is available --- ahcore/transforms/pre_transforms.py | 38 +++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index bedb795..e12b060 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -5,6 +5,7 @@ from __future__ import annotations +import logging from typing import Any, Callable import numpy as np @@ -40,10 +41,7 @@ def __init__(self, transforms: list[PreTransformCallable]): List of transforms to be used. """ # These are always finally added. - transforms += [ - ImageToTensor(), - AllowCollate(), - ] + transforms += [ImageToTensor(), AllowCollate(), SetTarget()] self._transforms = transforms @classmethod @@ -100,7 +98,6 @@ def for_wsi_classification( transforms.append(SampleNFeatures(n=1000)) - if not requires_target: return cls(transforms) @@ -137,7 +134,6 @@ def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: features = sample["image"] # Get the dimensions of the image - feature_dim = features.bands # Number of channels (similar to feature_dim) h = features.height # Height w = features.width # Width @@ -189,9 +185,12 @@ def __init__(self, keys: list[str] | str): self._keys = keys def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: - sample["labels"] = {label_key: label_value for label_key, label_value in sample["labels"].items() if label_key in self._keys} + sample["labels"] = { + label_key: label_value for label_key, label_value in sample["labels"].items() if label_key in self._keys + } return sample + class OneHotEncodeMask: def __init__(self, index_map: dict[str, int]): """Create the one-hot encoding of the mask for segmentation. @@ -262,6 +261,25 @@ def __call__(self, sample: TileSample) -> dict[str, Any]: return output +class SetTarget: + + def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: + if "annotations_data" in sample and "mask" in sample["annotation_data"] and "labels" in sample.keys(): + sample["target"] = (sample["annotation_data"]["mask"], sample["labels"]) + elif "annotations_data" in sample and "mask" in sample["annotation_data"]: + sample["target"] = sample["annotation_data"]["mask"] + elif "labels" in sample.keys(): + if len(sample["labels"].keys()) == 1: + # if there is only one label, then we just set this without retaining the key + # this makes it compatible with standard loss functions + sample["labels"] = next(iter(sample["labels"].values())) + sample["target"] = sample["labels"] + else: + logging.warning("No target set") + + return sample + + class ImageToTensor: """ Transform to translate the output of a dlup dataset to data_description supported by AhCore @@ -283,6 +301,10 @@ def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]: if sample["image"].sum() == 0: raise RuntimeError(f"Empty tile for {sample['path']} at {sample['coordinates']}") + if "labels" in sample: + for key, value in sample["labels"].items(): + sample["labels"][key] = torch.tensor(value) + # annotation_data is added by the ConvertPolygonToMask transform. if "annotation_data" not in sample: return sample @@ -292,7 +314,7 @@ def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]: if len(mask.shape) == 2: # Mask is not one-hot encoded mask = mask[np.newaxis, ...] - sample["target"] = torch.from_numpy(mask).float() + sample["annotation_data"]["mask"] = torch.from_numpy(mask).float() if "roi" in sample["annotation_data"] and sample["annotation_data"]["roi"] is not None: roi = sample["annotation_data"]["roi"] From 2e24cb174326ccedc15c31250ae81652780d9c04 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Wed, 21 Aug 2024 16:25:54 +0200 Subject: [PATCH 15/30] precommit fixes and removed the three pixel check, dlup will fix that --- ahcore/backends.py | 30 ++++++------- ahcore/callbacks/file_writer_callback.py | 2 +- ahcore/data/dataset.py | 1 - ahcore/lit_module.py | 3 +- ahcore/readers.py | 48 ++++++++++++--------- ahcore/utils/database_models.py | 2 +- ahcore/utils/io.py | 4 +- ahcore/utils/manifest.py | 54 +++++++++++++++++------- ahcore/utils/types.py | 1 + ahcore/writers.py | 13 ++---- 10 files changed, 89 insertions(+), 69 deletions(-) diff --git a/ahcore/backends.py b/ahcore/backends.py index e87be60..c98a847 100644 --- a/ahcore/backends.py +++ b/ahcore/backends.py @@ -1,19 +1,15 @@ -from typing import Any - -import pyvips -from dlup.backends.common import AbstractSlideBackend -from dlup.types import PathLike - -from ahcore.readers import StitchingMode, ZarrFileImageReader, H5FileImageReader - from enum import Enum from typing import Any, Callable +import pyvips +from dlup.backends.common import AbstractSlideBackend from dlup.backends.openslide_backend import OpenSlideSlide -from dlup.backends.tifffile_backend import TifffileSlide from dlup.backends.pyvips_backend import PyVipsSlide +from dlup.backends.tifffile_backend import TifffileSlide from dlup.types import PathLike +from ahcore.readers import H5FileImageReader, StitchingMode, ZarrFileImageReader + class ZarrSlide(AbstractSlideBackend): def __init__(self, filename: PathLike, stitching_mode: StitchingMode | str = StitchingMode.CROP) -> None: @@ -22,7 +18,7 @@ def __init__(self, filename: PathLike, stitching_mode: StitchingMode | str = Sti self._spacings = [(self._reader.mpp, self._reader.mpp)] @property - def size(self): + def size(self) -> tuple[int, int]: return self._reader.size @property @@ -42,13 +38,13 @@ def properties(self) -> dict[str, Any]: return self._reader.metadata @property - def magnification(self): + def magnification(self) -> None: return None def read_region(self, coordinates: tuple[int, int], level: int, size: tuple[int, int]) -> pyvips.Image: return self._reader.read_region(coordinates, level, size) - def close(self): + def close(self) -> None: self._reader.close() @@ -59,7 +55,7 @@ def __init__(self, filename: PathLike, stitching_mode: StitchingMode | str = Sti self._spacings = [(self._reader.mpp, self._reader.mpp)] @property - def size(self): + def size(self) -> tuple[int, int]: return self._reader.size @property @@ -79,13 +75,13 @@ def properties(self) -> dict[str, Any]: return self._reader.metadata @property - def magnification(self): + def magnification(self) -> None: return None def read_region(self, coordinates: tuple[int, int], level: int, size: tuple[int, int]) -> pyvips.Image: return self._reader.read_region(coordinates, level, size) - def close(self): + def close(self) -> None: self._reader.close() @@ -98,5 +94,5 @@ class ImageBackend(Enum): H5: Callable[[PathLike], H5Slide] = H5Slide ZARR: Callable[[PathLike], ZarrSlide] = ZarrSlide - def __call__(self, *args) -> Any: - return self.value(*args) \ No newline at end of file + def __call__(self, *args) -> "ImageBackend": + return self.value(*args) diff --git a/ahcore/callbacks/file_writer_callback.py b/ahcore/callbacks/file_writer_callback.py index f7591e6..fa3a043 100644 --- a/ahcore/callbacks/file_writer_callback.py +++ b/ahcore/callbacks/file_writer_callback.py @@ -27,7 +27,7 @@ def __init__( normalization_type: str = NormalizationType.LOGITS, precision: str = InferencePrecision.FP32, callbacks: list[ConvertCallbacks] | None = None, - data_format = DataFormat.IMAGE, + data_format=DataFormat.IMAGE, ): """ Callback to write predictions to H5 files. This callback is used to write whole-slide predictions to single H5 diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index a6438bb..2977e35 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -198,7 +198,6 @@ def __init__( self._limit_fit_samples = None self._limit_predict_samples = None - @property def data_manager(self) -> DataManager: return self._data_manager diff --git a/ahcore/lit_module.py b/ahcore/lit_module.py index 535e077..53df6df 100644 --- a/ahcore/lit_module.py +++ b/ahcore/lit_module.py @@ -10,6 +10,7 @@ from typing import Any import pytorch_lightning as pl +import torch import torch.optim.optimizer from pytorch_lightning.trainer.states import TrainerFn from torch import nn @@ -38,7 +39,7 @@ class AhCoreLightningModule(pl.LightningModule): def __init__( self, - model: nn.Module | BaseAhcoreJitModel, + model: nn.Module | BaseAhcoreJitModel | functools.partial, optimizer: torch.optim.Optimizer, # noqa data_description: DataDescription, loss: nn.Module | None = None, diff --git a/ahcore/readers.py b/ahcore/readers.py index b9e5fce..3364266 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -119,10 +119,14 @@ def _open_file(self) -> None: self._num_samples = self._metadata["num_samples"] self._num_tiles = self._metadata["num_tiles"] # set a standard value if it is not present - self._data_format = DataFormat(self._metadata["data_format"]) if "data_format" in self._metadata.keys() else DataFormat.IMAGE + self._data_format = ( + DataFormat(self._metadata["data_format"]) if "data_format" in self._metadata.keys() else DataFormat.IMAGE + ) self._mpp = self._metadata["mpp"] # features are always read at tile_size (1, 1), possibly faster to read the whole feature at once - self._tile_size = (self._num_samples, 1) if self._data_format == DataFormat.FEATURE else self._metadata["tile_size"] + self._tile_size = ( + (self._num_samples, 1) if self._data_format == DataFormat.FEATURE else self._metadata["tile_size"] + ) self._tile_overlap = self._metadata["tile_overlap"] self._size = (self._num_samples, 1) if self._data_format == DataFormat.FEATURE else self._metadata["size"] self._num_channels = self._metadata["num_channels"] @@ -164,6 +168,8 @@ def metadata(self) -> dict[str, Any]: return self._metadata def _decompress_and_reshape_data(self, tile: GenericNumberArray) -> GenericNumberArray: + assert self._tile_size is not None, "Cannot happen as this is called inside read_region which also checks this" + if self._is_binary: with PIL.Image.open(io.BytesIO(tile)) as img: return np.array(img).transpose( @@ -174,7 +180,8 @@ def _decompress_and_reshape_data(self, tile: GenericNumberArray) -> GenericNumbe if tile.ndim == 1: # fixme: is this the correct location for this if not self._tile_size[1] == 1: raise NotImplementedError( - f"Tile is single dimensional and {self._tile_size=} should be [x, 1], other cases have not been considered and cause unwanted behaviour." + f"Tile is single dimensional and {self._tile_size=} should be [x, 1], " + f"other cases have not been considered and cause unwanted behaviour." ) return tile.reshape(self._num_channels, *self._tile_size) return tile @@ -217,7 +224,10 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in total_rows = math.ceil((self._size[1] - self._tile_overlap[1]) / self._stride[1]) total_cols = math.ceil((self._size[0] - self._tile_overlap[0]) / self._stride[0]) - assert total_rows * total_cols == self._num_tiles or self._data_format == DataFormat.FEATURE, f"{total_rows=}, {total_cols=} and {self._num_tiles=}" # Equality only holds if features where created without mask + assert ( + total_rows * total_cols == self._num_tiles or self._data_format == DataFormat.FEATURE + ), f"{total_rows=}, {total_cols=} and {self._num_tiles=}" + # Equality only holds if features where created without mask x, y = location w, h = size @@ -243,25 +253,21 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in raise NotImplementedError("Stitching mode other than CROP is not supported for features.") if image_dataset.shape[0] != self._num_samples: - raise ValueError(f"Reading features expects that the saved feature vectors are the same " - f"length as the number of samples in the dataset. " - f"Feature vector length was {image_dataset.shape[0]}, " - f"number of samples in the dataset was {self._num_samples}") - - if x+w > self._num_samples or y+h > 1: - if x+w == self._num_samples + 3: - # fixme: this is ugly, but dlup does some resizing... - w = w - 3 - else: - raise ValueError(f"Feature vectors are saved as (num_samples, 1) and the requested size {size} at location {location} is too large.") - - # this simplified version of the crop is done as it is faster than the general crop - return pyvips.Image.new_from_array(np.expand_dims(image_dataset[x: x+w, :], axis=0)) - - - + raise ValueError( + f"Reading features expects that the saved feature vectors are the same " + f"length as the number of samples in the dataset. " + f"Feature vector length was {image_dataset.shape[0]}, " + f"number of samples in the dataset was {self._num_samples}" + ) + if x + w > self._num_samples or y + h > 1: + raise ValueError( + f"Feature vectors are saved as (num_samples, 1) " + f"and the requested size {size} at location {location} is too large." + ) + # this simplified version of the crop is done as it is faster than the general crop + return pyvips.Image.new_from_array(np.expand_dims(image_dataset[x : x + w, :], axis=0)) for i in range(start_row, end_row): for j in range(start_col, end_col): diff --git a/ahcore/utils/database_models.py b/ahcore/utils/database_models.py index e2df80c..d0928a6 100644 --- a/ahcore/utils/database_models.py +++ b/ahcore/utils/database_models.py @@ -154,7 +154,6 @@ class FeatureDescription(Base): # use this to select which features we want to use version = Column(String, unique=True, nullable=False) - model_name = Column(String) model_path = Column(String) feature_dimension = Column(Integer) @@ -164,6 +163,7 @@ class FeatureDescription(Base): features: Mapped[List["ImageFeature"]] = relationship("ImageFeature", back_populates="feature_description") + class Mask(Base): """Mask table.""" diff --git a/ahcore/utils/io.py b/ahcore/utils/io.py index cde38aa..d33ad37 100644 --- a/ahcore/utils/io.py +++ b/ahcore/utils/io.py @@ -241,7 +241,7 @@ def load_weights(model: LightningModule, config: DictConfig) -> LightningModule: if isinstance(_model, BaseAhcoreJitModel): return model if config.ckpt_path == "" or config.ckpt_path is None: - raise ValueError(f"Checkpoint path not provided in config.") + raise ValueError("Checkpoint path not provided in config.") else: # Load checkpoint weights lit_ckpt = torch.load(config.ckpt_path) @@ -280,7 +280,7 @@ def validate_checkpoint_paths(config: DictConfig) -> DictConfig: return config -def get_git_hash(): +def get_git_hash() -> Optional[str]: try: # Check if we're in a git repository subprocess.run(["git", "rev-parse", "--is-inside-work-tree"], check=True, capture_output=True, text=True) diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 46ba02a..7f353a6 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -8,7 +8,7 @@ import functools from pathlib import Path from types import TracebackType -from typing import Any, Callable, Generator, Literal, Optional, Type, TypedDict, cast +from typing import Any, Callable, Generator, Literal, Optional, Type, TypedDict, cast, Tuple from dlup import SlideImage from dlup.annotations import WsiAnnotations @@ -16,7 +16,7 @@ from dlup.data.dataset import RegionFromWsiDatasetSample, TiledWsiDataset, TileSample from dlup.tiling import GridOrder, TilingMode from pydantic import BaseModel -from sqlalchemy import create_engine +from sqlalchemy import create_engine, Column from sqlalchemy.engine import Engine from sqlalchemy.inspection import inspect from sqlalchemy.orm import Session, sessionmaker @@ -147,9 +147,9 @@ def get_labels_from_record(record: Image | Patient) -> list[tuple[str, str]] | N return _labels -def get_relevant_feature_info_from_record(record: ImageFeature, - data_description: DataDescription, - feature_description: FeatureDescription) -> tuple[ +def get_relevant_feature_info_from_record( + record: ImageFeature, data_description: DataDescription, feature_description: FeatureDescription +) -> tuple[ Path, PositiveFloat, tuple[PositiveInt, PositiveInt], @@ -166,19 +166,20 @@ def get_relevant_feature_info_from_record(record: ImageFeature, Returns ------- - tuple[Path, PositiveFloat, tuple[PositiveInt, PositiveInt], tuple[PositiveInt, PositiveInt], TilingMode, ImageBackend, PositiveFloat] + tuple[Path, PositiveFloat, tuple[PositiveInt, PositiveInt], + tuple[PositiveInt, PositiveInt], TilingMode, ImageBackend, PositiveFloat] The features of the image. """ image_path = data_description.data_dir / record.filename - mpp = feature_description.mpp + mpp = float(feature_description.mpp) tile_size = ( - record.num_tiles, + int(record.num_tiles), 1, ) # this would load all the features in one go --> can be extended to only load relevant tile level features tile_overlap = (0, 0) backend = ImageBackend[str(record.reader)].value - overwrite_mpp = feature_description.mpp + overwrite_mpp = float(feature_description.mpp) return image_path, mpp, tile_size, tile_overlap, backend, overwrite_mpp @@ -186,6 +187,8 @@ def _get_rois(mask: WsiAnnotations | None, data_description: DataDescription, st if (mask is None) or (stage != "fit") or (not data_description.convert_mask_to_rois): return None + assert data_description.training_grid is not None + tile_size = data_description.training_grid.tile_size tile_overlap = data_description.training_grid.tile_overlap @@ -337,7 +340,9 @@ def get_image_metadata_by_id(self, image_id: int) -> ImageMetadata: assert image is not None # mypy return fetch_image_metadata(image) - def get_image_features_by_image_and_feature_version(self, image_id: int, feature_version: str) -> ImageFeature: + def get_image_features_by_image_and_feature_version( + self, image_id: Column[int], feature_version: str | None + ) -> Tuple[ImageFeature, FeatureDescription]: """ Fetch the features for an image based on its ID and feature version. @@ -353,13 +358,25 @@ def get_image_features_by_image_and_feature_version(self, image_id: int, feature ImageFeature The features of the image. """ + if feature_version is None: + raise ValueError("feature_version cannot be None") + feature_description = self._session.query(FeatureDescription).filter_by(version=feature_version).first() - image_feature = self._session.query(ImageFeature).filter_by(image_id=image_id, feature_description_id=feature_description.id).first() + + if feature_description is None: + raise ValueError(f"Couldn't find feature description matching version {feature_version}") + + image_feature = ( + self._session.query(ImageFeature) + .filter_by(image_id=image_id, feature_description_id=feature_description.id) + .first() + ) self._ensure_record( image_feature, f"No features found for image ID {image_id} and feature version {feature_version}" ) assert image_feature is not None - # todo: make sure that this only allows to run one ImageFeature, I think it should be good bc of the unique constraint + # todo: make sure that this only allows to run one ImageFeature, + # I think it should be good bc of the unique constraint return image_feature, feature_description def __enter__(self) -> "DataManager": @@ -399,14 +416,14 @@ def datasets_from_data_description( else: grid_description = data_description.inference_grid + use_features = data_description.feature_version is not None + patients = db_manager.get_records_by_split( manifest_name=data_description.manifest_name, split_version=data_description.split_version, split_category=stage, ) - use_features = data_description.feature_version is not None - for patient in patients: patient_labels = get_labels_from_record(patient) @@ -428,12 +445,17 @@ def datasets_from_data_description( ) tile_mode = TilingMode.skip else: + if grid_description is None: + raise ValueError( + f"grid_description for stage {stage} is None should be set if images are supposed to be used" + ) + image_path = image_root / image.filename tile_size = grid_description.tile_size tile_overlap = grid_description.tile_overlap backend = ImageBackend[str(image.reader)] - mpp = grid_description.mpp - overwrite_mpp = image.mpp + mpp = grid_description.mpp # type: ignore + overwrite_mpp = float(image.mpp) tile_mode = TilingMode.overflow # fixme: something is still wrong here, in the reader or in the database, got empty tiles diff --git a/ahcore/utils/types.py b/ahcore/utils/types.py index 6b18131..7dd5f1f 100644 --- a/ahcore/utils/types.py +++ b/ahcore/utils/types.py @@ -87,6 +87,7 @@ class ViTEmbedMode(str, Enum): CONCAT = "embed_concat" # Extend as necessary + class DataFormat(str, Enum): """Data format for the writer.""" diff --git a/ahcore/writers.py b/ahcore/writers.py index 3aad8b3..edf80b8 100644 --- a/ahcore/writers.py +++ b/ahcore/writers.py @@ -14,7 +14,6 @@ from contextlib import contextmanager from pathlib import Path from typing import Any, Generator, NamedTuple, Optional -from enum import Enum import dlup import h5py @@ -32,9 +31,6 @@ logger = get_logger(__name__) - - - def decode_array_to_pil(array: npt.NDArray[np.uint8]) -> PIL.Image.Image: """Convert encoded array to PIL image @@ -117,7 +113,6 @@ def __init__( self._partial_suffix: str = f"{self._filename.suffix}.partial" - @abc.abstractmethod def open_file(self, mode: str = "w") -> Any: pass @@ -449,9 +444,7 @@ def insert_data(self, batch: GenericNumberArray) -> None: if self._data_format == DataFormat.COMPRESSED_IMAGE: self._data[self._current_index] = batch.reshape(-1) else: - self._data[self._current_index : self._current_index + batch.shape[0]] = ( - batch.flatten() if self._data_format == DataFormat.COMPRESSED_IMAGE else batch - ) + self._data[self._current_index : self._current_index + batch.shape[0]] = batch def write_metadata(self, metadata: dict[str, Any], file: Any) -> None: """Write metadata to Zarr group attributes.""" @@ -491,7 +484,9 @@ def insert_data(self, batch: GenericNumberArray) -> None: raise ValueError(f"Batch should have a single element when writing h5. Got batch shape {batch.shape}.") batch_size = batch.shape[0] self._data[self._current_index : self._current_index + batch_size] = ( - batch.flatten() if self._data_format == DataFormat.COMPRESSED_IMAGE else batch # fixme: flatten shouldn't work here + batch.flatten() + if self._data_format == DataFormat.COMPRESSED_IMAGE + else batch # fixme: flatten shouldn't work here ) def create_dataset( From cc9ea49092e8dec5ddd70856f05ee2c03dcd1c62 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 22 Aug 2024 15:34:15 +0200 Subject: [PATCH 16/30] model will expect Bxnum_tilesxfeature_dim, so the ToTensor method should return that for features --- ahcore/transforms/pre_transforms.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index e12b060..744502e 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -146,7 +146,7 @@ def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: # Extract the selected columns (indices) from the image # Create a new image from the selected indices - + # todo: this can probably be done without a for-loop quicker selected_columns = [features.crop(idx, 0, 1, h) for idx in n_random_indices] # Combine the selected columns back into a single image @@ -288,15 +288,23 @@ class ImageToTensor: def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]: tile: pyvips.Image = sample["image"] # Flatten the image to remove the alpha channel, using white as the background color + using_features = False + if tile.bands > 4: # assuming that more than four bands/channels means that we are handling features + using_features = True tile_ = tile else: tile_ = tile.flatten(background=[255, 255, 255]) # todo: check if this doesn't mess up features # Convert VIPS image to a numpy array then to a torch tensor np_image = tile_.numpy() - sample["image"] = torch.from_numpy(np_image).permute(2, 0, 1).float() + if using_features: + # n_tiles x 1 x feature_dim --> n_tiles x feature_dim + sample["image"] = torch.from_numpy(np_image).squeeze(1).float() + else: + # h x w x c --> c x h x w + sample["image"] = torch.from_numpy(np_image).permute(2, 0, 1).float() if sample["image"].sum() == 0: raise RuntimeError(f"Empty tile for {sample['path']} at {sample['coordinates']}") From 0d97c4aacc2ebb294e975beb4d06248e05a0df1a Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Wed, 4 Sep 2024 17:59:35 +0200 Subject: [PATCH 17/30] fix some mypy --- ahcore/backends.py | 3 +-- ahcore/callbacks/file_writer_callback.py | 8 +++++--- ahcore/lit_module.py | 6 ++++-- ahcore/transforms/pre_transforms.py | 3 +-- ahcore/utils/manifest.py | 24 ++++++++++++++---------- 5 files changed, 25 insertions(+), 19 deletions(-) diff --git a/ahcore/backends.py b/ahcore/backends.py index 1880e54..0adc868 100644 --- a/ahcore/backends.py +++ b/ahcore/backends.py @@ -3,7 +3,6 @@ import pyvips from dlup.backends.common import AbstractSlideBackend -from dlup.types import PathLike # type: ignore from dlup.backends.openslide_backend import OpenSlideSlide from dlup.backends.pyvips_backend import PyVipsSlide from dlup.backends.tifffile_backend import TifffileSlide @@ -95,5 +94,5 @@ class ImageBackend(Enum): H5: Callable[[PathLike], H5Slide] = H5Slide ZARR: Callable[[PathLike], ZarrSlide] = ZarrSlide - def __call__(self, *args) -> "ImageBackend": + def __call__(self, *args: Any) -> OpenSlideSlide | PyVipsSlide | TifffileSlide | H5Slide | ZarrSlide: return self.value(*args) diff --git a/ahcore/callbacks/file_writer_callback.py b/ahcore/callbacks/file_writer_callback.py index fa3a043..bdf3916 100644 --- a/ahcore/callbacks/file_writer_callback.py +++ b/ahcore/callbacks/file_writer_callback.py @@ -11,7 +11,7 @@ from ahcore.utils.callbacks import get_output_filename as get_output_filename_ from ahcore.utils.data import DataDescription, GridDescription from ahcore.utils.io import get_logger -from ahcore.utils.types import InferencePrecision, NormalizationType, DataFormat +from ahcore.utils.types import DataFormat, InferencePrecision, NormalizationType from ahcore.writers import Writer logger = get_logger(__name__) @@ -27,8 +27,8 @@ def __init__( normalization_type: str = NormalizationType.LOGITS, precision: str = InferencePrecision.FP32, callbacks: list[ConvertCallbacks] | None = None, - data_format=DataFormat.IMAGE, - ): + data_format: DataFormat = DataFormat.IMAGE, + ) -> None: """ Callback to write predictions to H5 files. This callback is used to write whole-slide predictions to single H5 files in a separate thread. @@ -105,6 +105,8 @@ def build_writer_class(self, pl_module: AhCoreLightningModule, stage: str, filen num_samples = len(current_dataset) data_description: DataDescription = pl_module.data_description + if data_description.inference_grid is None: + raise ValueError("Inference grid is not defined in the data description.") inference_grid: GridDescription = data_description.inference_grid mpp = inference_grid.mpp diff --git a/ahcore/lit_module.py b/ahcore/lit_module.py index 19a56ce..041a5a9 100644 --- a/ahcore/lit_module.py +++ b/ahcore/lit_module.py @@ -36,10 +36,12 @@ class AhCoreLightningModule(pl.LightningModule): "grid_index", ] + _model: nn.Module | BaseAhcoreJitModel + def __init__( self, - model: nn.Module | BaseAhcoreJitModel | functools.partial, - optimizer: torch.optim.optimzer.Optimizer, # noqa + model: nn.Module | BaseAhcoreJitModel | functools.partial[nn.Module], + optimizer: torch.optim.optimizer.Optimizer, # noqa data_description: DataDescription, loss: nn.Module | None = None, augmentations: dict[str, nn.Module] | None = None, diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index 744502e..f8934e1 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -124,7 +124,7 @@ def __repr__(self) -> str: class SampleNFeatures: - def __init__(self, n=1000): + def __init__(self, n: int = 1000) -> None: self.n = n logger.warning( f"Sampling {n} features from the image. Sampling WITH replacement is done if there are not enough tiles." @@ -262,7 +262,6 @@ def __call__(self, sample: TileSample) -> dict[str, Any]: class SetTarget: - def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: if "annotations_data" in sample and "mask" in sample["annotation_data"] and "labels" in sample.keys(): sample["target"] = (sample["annotation_data"]["mask"], sample["labels"]) diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 7f353a6..701aa9f 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -8,34 +8,34 @@ import functools from pathlib import Path from types import TracebackType -from typing import Any, Callable, Generator, Literal, Optional, Type, TypedDict, cast, Tuple +from typing import Any, Callable, Generator, Literal, Optional, Tuple, Type, TypedDict, cast from dlup import SlideImage from dlup.annotations import WsiAnnotations -from ahcore.backends import ImageBackend from dlup.data.dataset import RegionFromWsiDatasetSample, TiledWsiDataset, TileSample from dlup.tiling import GridOrder, TilingMode from pydantic import BaseModel -from sqlalchemy import create_engine, Column +from sqlalchemy import Column, create_engine from sqlalchemy.engine import Engine from sqlalchemy.inspection import inspect from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.sql import exists +from ahcore.backends import ImageBackend from ahcore.exceptions import RecordNotFoundError from ahcore.utils.data import DataDescription from ahcore.utils.database_models import ( Base, CategoryEnum, + FeatureDescription, Image, ImageAnnotations, + ImageFeature, Manifest, Mask, Patient, Split, SplitDefinitions, - ImageFeature, - FeatureDescription, ) from ahcore.utils.io import get_enum_key_from_value, get_logger from ahcore.utils.rois import compute_rois @@ -428,7 +428,6 @@ def datasets_from_data_description( patient_labels = get_labels_from_record(patient) for image in patient.images: - mask, annotations = get_mask_and_annotations_from_record(annotations_root, image) assert isinstance(mask, WsiAnnotations) or (mask is None) image_labels = get_labels_from_record(image) @@ -440,9 +439,14 @@ def datasets_from_data_description( image_feature, feature_description = db_manager.get_image_features_by_image_and_feature_version( image.id, data_description.feature_version ) - image_path, mpp, tile_size, tile_overlap, backend, overwrite_mpp = ( - get_relevant_feature_info_from_record(image_feature, data_description, feature_description) - ) + ( + image_path, + mpp, + tile_size, + tile_overlap, + backend, + overwrite_mpp, + ) = get_relevant_feature_info_from_record(image_feature, data_description, feature_description) tile_mode = TilingMode.skip else: if grid_description is None: @@ -475,7 +479,7 @@ def datasets_from_data_description( annotations=annotations if stage != "predict" else None, labels=labels, # type: ignore transform=transform, - backend=backend, + backend=backend, # type: ignore overwrite_mpp=(overwrite_mpp, overwrite_mpp), limit_bounds=True, apply_color_profile=data_description.apply_color_profile, From f5c08dc38156a5658b08433b56f5b994de8dd9f6 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 5 Sep 2024 13:42:43 +0200 Subject: [PATCH 18/30] fixes mypy --- ahcore/backends.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ahcore/backends.py b/ahcore/backends.py index 0adc868..16431c3 100644 --- a/ahcore/backends.py +++ b/ahcore/backends.py @@ -8,6 +8,7 @@ from dlup.backends.tifffile_backend import TifffileSlide from dlup.types import PathLike # type: ignore + from ahcore.readers import H5FileImageReader, StitchingMode, ZarrFileImageReader @@ -95,4 +96,4 @@ class ImageBackend(Enum): ZARR: Callable[[PathLike], ZarrSlide] = ZarrSlide def __call__(self, *args: Any) -> OpenSlideSlide | PyVipsSlide | TifffileSlide | H5Slide | ZarrSlide: - return self.value(*args) + return self.value(*args) # type: ignore From ec9285def0335c92cfac38c7f8bc4321cdb88ff3 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Fri, 6 Sep 2024 11:37:13 +0200 Subject: [PATCH 19/30] fix test for readers --- tests/test_readers/test_readers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_readers/test_readers.py b/tests/test_readers/test_readers.py index 26278d2..722fb86 100644 --- a/tests/test_readers/test_readers.py +++ b/tests/test_readers/test_readers.py @@ -66,6 +66,7 @@ def _read_metadata(self) -> None: "is_binary": self._file.attrs["is_binary"], "has_color_profile": self._file.attrs["has_color_profile"], "num_tiles": self._file.attrs["num_tiles"], + "data_format": self._file.attrs["data_format"], } def _open_file(self) -> None: @@ -84,6 +85,7 @@ def _open_file(self) -> None: self._precision = None self._multiplier = None self._is_binary = None + self._data_format = None if not self._filename.is_file(): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), str(self._filename)) @@ -109,6 +111,9 @@ def _open_file(self) -> None: self._tile_size[1] - self._tile_overlap[1], ) + self._num_tiles = self._metadata["num_tiles"] + self._data_format = self._metadata["data_format"] + def close(self) -> None: if self._file is not None: self._file.close() @@ -141,6 +146,7 @@ def temp_h5_file(tmp_path: Path) -> Generator[Path, None, None]: f.attrs["is_binary"] = is_binary f.attrs["has_color_profile"] = has_color_profile f.attrs["num_tiles"] = num_tiles + f.attrs["data_format"] = "image" f.create_dataset("data", (num_tiles, num_channels, tile_size[0], tile_size[1]), dtype=dtype) f.create_dataset("tile_indices", (num_tiles,), dtype=np.uint8) From 0c0d1e4be0e30e12c2b5fa98713824da20c301e0 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Fri, 6 Sep 2024 17:18:41 +0200 Subject: [PATCH 20/30] added tests for features --- ahcore/readers.py | 11 +++-- ahcore/writers.py | 8 ++-- tests/test_readers/test_readers.py | 5 --- tests/test_writers/test_h5_writer.py | 67 ++++++++++++++++++++++++++-- 4 files changed, 75 insertions(+), 16 deletions(-) diff --git a/ahcore/readers.py b/ahcore/readers.py index 3364266..b9b48e3 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -24,7 +24,7 @@ from zarr.storage import ZipStore from ahcore.utils.io import get_logger -from ahcore.utils.types import BoundingBoxType, GenericNumberArray, InferencePrecision, DataFormat +from ahcore.utils.types import BoundingBoxType, DataFormat, GenericNumberArray, InferencePrecision logger = get_logger(__name__) @@ -133,7 +133,12 @@ def _open_file(self) -> None: self._dtype = self._metadata["dtype"] self._precision = self._metadata["precision"] self._multiplier = self._metadata["multiplier"] - self._is_binary = self._metadata["is_binary"] + self._is_binary = getattr(self._metadata, "is_binary", None) + if self._is_binary is not None: + logger.warning( + f"Found is_binary in metadata, of file {self._filename}. " + f"This tag is deprecated and might be removed in future versions" + ) self._stride = ( self._tile_size[0] - self._tile_overlap[0], self._tile_size[1] - self._tile_overlap[1], @@ -170,7 +175,7 @@ def metadata(self) -> dict[str, Any]: def _decompress_and_reshape_data(self, tile: GenericNumberArray) -> GenericNumberArray: assert self._tile_size is not None, "Cannot happen as this is called inside read_region which also checks this" - if self._is_binary: + if self._is_binary or self._data_format == DataFormat.COMPRESSED_IMAGE: with PIL.Image.open(io.BytesIO(tile)) as img: return np.array(img).transpose( 2, 0, 1 diff --git a/ahcore/writers.py b/ahcore/writers.py index fe86f95..3bcb816 100644 --- a/ahcore/writers.py +++ b/ahcore/writers.py @@ -26,7 +26,7 @@ import ahcore from ahcore.utils.io import get_git_hash, get_logger -from ahcore.utils.types import GenericNumberArray, InferencePrecision, DataFormat +from ahcore.utils.types import DataFormat, GenericNumberArray, InferencePrecision logger = get_logger(__name__) @@ -343,7 +343,7 @@ def init_writer(self, first_coordinates: GenericNumberArray, first_batch: Generi compression="gzip", ) - if self._data_format == DataFormat.COMPRESSED_IMAGE: + if self._data_format != DataFormat.COMPRESSED_IMAGE: shape = first_batch.shape[1:] self._data = self.create_dataset( file, @@ -484,9 +484,7 @@ def insert_data(self, batch: GenericNumberArray) -> None: raise ValueError(f"Batch should have a single element when writing h5. Got batch shape {batch.shape}.") batch_size = batch.shape[0] self._data[self._current_index : self._current_index + batch_size] = ( - batch.flatten() - if self._data_format == DataFormat.COMPRESSED_IMAGE - else batch # fixme: flatten shouldn't work here + batch.flatten() if self._data_format == DataFormat.COMPRESSED_IMAGE else batch ) def create_dataset( diff --git a/tests/test_readers/test_readers.py b/tests/test_readers/test_readers.py index 189c5b5..e382810 100644 --- a/tests/test_readers/test_readers.py +++ b/tests/test_readers/test_readers.py @@ -65,7 +65,6 @@ def _read_metadata(self) -> None: "dtype": self._file.attrs["dtype"], "precision": self._file.attrs["precision"], "multiplier": self._file.attrs["multiplier"], - "is_binary": self._file.attrs["is_binary"], "has_color_profile": self._file.attrs["has_color_profile"], "num_tiles": self._file.attrs["num_tiles"], "data_format": self._file.attrs["data_format"], @@ -86,7 +85,6 @@ def _open_file(self) -> None: self._stride = None self._precision = None self._multiplier = None - self._is_binary = None self._data_format = None if not self._filename.is_file(): @@ -107,7 +105,6 @@ def _open_file(self) -> None: self._dtype = self._metadata["dtype"] self._precision = self._metadata["precision"] self._multiplier = self._metadata["multiplier"] - self._is_binary = self._metadata["is_binary"] self._stride = ( self._tile_size[0] - self._tile_overlap[0], self._tile_size[1] - self._tile_overlap[1], @@ -132,7 +129,6 @@ def temp_h5_file(tmp_path: Path) -> Generator[Path, None, None]: dtype = "uint8" precision = "FP32" multiplier = 1.0 - is_binary = False has_color_profile = False num_tiles = 4 # 2x2 grid of tiles @@ -145,7 +141,6 @@ def temp_h5_file(tmp_path: Path) -> Generator[Path, None, None]: f.attrs["dtype"] = dtype f.attrs["precision"] = precision f.attrs["multiplier"] = multiplier - f.attrs["is_binary"] = is_binary f.attrs["has_color_profile"] = has_color_profile f.attrs["num_tiles"] = num_tiles f.attrs["data_format"] = "image" diff --git a/tests/test_writers/test_h5_writer.py b/tests/test_writers/test_h5_writer.py index 34537e8..8a08bb6 100644 --- a/tests/test_writers/test_h5_writer.py +++ b/tests/test_writers/test_h5_writer.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from ahcore.utils.types import GenericNumberArray, InferencePrecision +from ahcore.utils.types import DataFormat, GenericNumberArray, InferencePrecision from ahcore.writers import H5FileImageWriter @@ -18,6 +18,14 @@ def temp_h5_file(tmp_path: Path) -> Generator[Path, None, None]: h5_file_path.unlink() +@pytest.fixture +def temp_h5_feature_file(tmp_path: Path) -> Generator[Path, None, None]: + h5_file_path = tmp_path / "test_data_feature.h5" + yield h5_file_path + if h5_file_path.exists(): + h5_file_path.unlink() + + @pytest.fixture def dummy_batch_data() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: dummy_coordinates = np.array([[0, 0]]) @@ -25,6 +33,13 @@ def dummy_batch_data() -> Generator[tuple[GenericNumberArray, GenericNumberArray yield dummy_coordinates, dummy_batch +@pytest.fixture +def dummy_feature_batch_data() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: + dummy_coordinates = np.array([[i, 0] for i in range(16)]).astype(np.float32) + dummy_batch = np.random.rand(16, 512).astype(np.float32) + yield dummy_coordinates, dummy_batch + + @pytest.fixture def dummy_batch_generator( dummy_batch_data: tuple[GenericNumberArray, GenericNumberArray] @@ -35,6 +50,17 @@ def generator() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None return generator +@pytest.fixture +def dummy_feature_batch_generator( + dummy_feature_batch_data: tuple[GenericNumberArray, GenericNumberArray] +) -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: + def generator() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: + for _ in range(1): + yield dummy_feature_batch_data + + return generator + + def test_h5_file_image_writer_creation(temp_h5_file: Path) -> None: size = (200, 200) mpp = 0.5 @@ -88,6 +114,41 @@ def test_h5_file_image_writer_consume(temp_h5_file: Path, dummy_batch_generator: assert np.allclose(h5file["coordinates"], dummy_coordinates) +def test_h5_file_image_writer_consume_feature(temp_h5_feature_file: Path, dummy_feature_batch_generator: Any) -> None: + size = (32, 512) + mpp = 0.5 + tile_size = (1, 1) + tile_overlap = (0, 0) + num_samples = 32 + data_format = DataFormat.FEATURE + batch_size = 16 + + writer = H5FileImageWriter( + filename=temp_h5_feature_file, + size=size, + mpp=mpp, + tile_size=tile_size, + tile_overlap=tile_overlap, + num_samples=num_samples, + data_format=data_format, + ) + + writer.consume(dummy_feature_batch_generator()) + + with h5py.File(temp_h5_feature_file, "r") as h5file: + assert "data" in h5file + assert "coordinates" in h5file + assert np.array(h5file["data"]).shape == (num_samples, size[1]) + assert np.array(h5file["coordinates"]).shape == (num_samples, 2) + + gen = dummy_feature_batch_generator() + for idx, (dummy_coordinates, dummy_batch) in enumerate(gen): + assert np.allclose(h5file["data"][idx * batch_size : (idx + 1) * batch_size, :], dummy_batch) + assert np.allclose(h5file["coordinates"][idx * batch_size : (idx + 1) * batch_size, :], dummy_coordinates) + + # + + def test_h5_file_image_writer_metadata(temp_h5_file: Path, dummy_batch_generator: Any) -> None: size = (200, 200) mpp = 0.5 @@ -99,10 +160,10 @@ def test_h5_file_image_writer_metadata(temp_h5_file: Path, dummy_batch_generator tiling_mode = "overflow" format = "RAW" dtype = "float32" - is_binary = False grid_offset = (0, 0) precision = InferencePrecision.FP32 multiplier = 1.0 + data_format = DataFormat.IMAGE writer = H5FileImageWriter( filename=temp_h5_file, @@ -128,10 +189,10 @@ def test_h5_file_image_writer_metadata(temp_h5_file: Path, dummy_batch_generator assert metadata["tiling_mode"] == tiling_mode assert metadata["format"] == format assert metadata["dtype"] == dtype - assert metadata["is_binary"] == is_binary assert metadata["grid_offset"] == list(grid_offset) assert metadata["precision"] == precision assert metadata["multiplier"] == multiplier + assert metadata["data_format"] == data_format def test_h5_file_image_writer_multiple_tiles(temp_h5_file: Path) -> None: From 8adffcc70c869115ff7ea91b896d9373489829d9 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Wed, 11 Sep 2024 11:33:47 +0200 Subject: [PATCH 21/30] cross_entropy now handles BxC inputs as well, also improved logic --- ahcore/losses.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/ahcore/losses.py b/ahcore/losses.py index 0f69b7c..0bc35c0 100644 --- a/ahcore/losses.py +++ b/ahcore/losses.py @@ -7,6 +7,7 @@ from __future__ import annotations +import logging from typing import Callable, Optional, Union, cast import numpy as np @@ -67,7 +68,7 @@ def __init__( if class_proportions is not None: _class_weights = 1 / class_proportions _class_weights[_class_weights.isnan()] = 0.0 - _class_weights = _class_weights / _class_weights.max() + _class_weights = _class_weights / _class_weights.max() # todo: check, shouldn't this be .sum? self._class_weights = _class_weights else: self._class_weights = None @@ -141,26 +142,34 @@ def cross_entropy( else: roi_sum = torch.tensor([np.prod(tuple(input.shape)[2:])]).to(input.device) + if input.dim() != target.dim(): + raise ValueError(f"Dimension do not match for input and target. Got {input.dim()} and {target.dim()}") + + if input.dim() == 2 and target.dim() == 2: + # handle cls task as an image of size 1x1 + input = input.unsqueeze(-1).unsqueeze(-1) + target = target.unsqueeze(-1).unsqueeze(-1) + if ignore_index is None: ignore_index = -100 # compute cross_entropy pixel by pixel - if not multiclass: - _cross_entropy = F.cross_entropy( + if multiclass: + _cross_entropy = F.binary_cross_entropy_with_logits( input, - target.argmax(dim=1), - ignore_index=ignore_index, + target, weight=None if weight is None else weight.to(input.device), reduction="none", - label_smoothing=label_smoothing, + pos_weight=None, ) else: - _cross_entropy = F.binary_cross_entropy_with_logits( + _cross_entropy = F.cross_entropy( input, - target, + target.argmax(dim=1), + ignore_index=ignore_index, weight=None if weight is None else weight.to(input.device), reduction="none", - pos_weight=None, + label_smoothing=label_smoothing, ) if limit is not None: @@ -181,6 +190,7 @@ def cross_entropy( ) / (roi_sum * topk) + def soft_dice( input: torch.Tensor, target: torch.Tensor, From 5eea88d0acd7eaf48e669f594c6252d6ff9b38a2 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Wed, 11 Sep 2024 11:34:47 +0200 Subject: [PATCH 22/30] make dimension work out for labels --- ahcore/transforms/pre_transforms.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index f8934e1..45d00ee 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -271,10 +271,11 @@ def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: if len(sample["labels"].keys()) == 1: # if there is only one label, then we just set this without retaining the key # this makes it compatible with standard loss functions - sample["labels"] = next(iter(sample["labels"].values())) + sample["labels"] = next(iter(sample["labels"].values())) # todo: make this nice sample["target"] = sample["labels"] else: - logging.warning("No target set") + # logging.warning("No target set") # this can be done only for training and validation??? + pass return sample @@ -308,9 +309,11 @@ def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]: if sample["image"].sum() == 0: raise RuntimeError(f"Empty tile for {sample['path']} at {sample['coordinates']}") - if "labels" in sample: + if "labels" in sample.keys() and sample["labels"] is not None: for key, value in sample["labels"].items(): - sample["labels"][key] = torch.tensor(value) + sample["labels"][key] = torch.tensor(value, dtype=torch.float32) + if sample["labels"][key].dim() == 0: + sample["labels"][key] = sample["labels"][key].unsqueeze(0) # annotation_data is added by the ConvertPolygonToMask transform. if "annotation_data" not in sample: From c26910af6ba2a2c14ca79c8e8aee7dd08005bf1a Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Wed, 11 Sep 2024 11:35:22 +0200 Subject: [PATCH 23/30] simplify reader and dataset builders --- ahcore/readers.py | 174 ++++++++++++++++++++------------------- ahcore/utils/manifest.py | 65 ++++++++------- 2 files changed, 128 insertions(+), 111 deletions(-) diff --git a/ahcore/readers.py b/ahcore/readers.py index b9b48e3..4b6d5b9 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -124,11 +124,10 @@ def _open_file(self) -> None: ) self._mpp = self._metadata["mpp"] # features are always read at tile_size (1, 1), possibly faster to read the whole feature at once - self._tile_size = ( - (self._num_samples, 1) if self._data_format == DataFormat.FEATURE else self._metadata["tile_size"] - ) + # this should be fixed in the writer and features have to be remade so this doesn't have to be hardcoded + self._tile_size = self._metadata["tile_size"] self._tile_overlap = self._metadata["tile_overlap"] - self._size = (self._num_samples, 1) if self._data_format == DataFormat.FEATURE else self._metadata["size"] + self._size = self._metadata["size"] self._num_channels = self._metadata["num_channels"] self._dtype = self._metadata["dtype"] self._precision = self._metadata["precision"] @@ -139,6 +138,8 @@ def _open_file(self) -> None: f"Found is_binary in metadata, of file {self._filename}. " f"This tag is deprecated and might be removed in future versions" ) + if self._is_binary: + self._data_format = DataFormat.COMPRESSED_IMAGE self._stride = ( self._tile_size[0] - self._tile_overlap[0], self._tile_size[1] - self._tile_overlap[1], @@ -175,62 +176,20 @@ def metadata(self) -> dict[str, Any]: def _decompress_and_reshape_data(self, tile: GenericNumberArray) -> GenericNumberArray: assert self._tile_size is not None, "Cannot happen as this is called inside read_region which also checks this" - if self._is_binary or self._data_format == DataFormat.COMPRESSED_IMAGE: + if self._data_format == DataFormat.COMPRESSED_IMAGE: with PIL.Image.open(io.BytesIO(tile)) as img: return np.array(img).transpose( 2, 0, 1 ) # fixme: this also shouldn't work because the thing is flattened and doesn't have 3 dimensions else: - # If handling features, we need to expand dimensions to match the expected shape. - if tile.ndim == 1: # fixme: is this the correct location for this - if not self._tile_size[1] == 1: - raise NotImplementedError( - f"Tile is single dimensional and {self._tile_size=} should be [x, 1], " - f"other cases have not been considered and cause unwanted behaviour." - ) - return tile.reshape(self._num_channels, *self._tile_size) return tile - def read_region(self, location: tuple[int, int], level: int, size: tuple[int, int]) -> pyvips.Image: - """ - Reads a region in the stored h5 file. This function stitches the regions as saved in the cache file. Doing this - it takes into account: - 1) The region overlap, several region merging strategies are implemented: cropping, averaging across borders - and taking the maximum across borders. - 2) If tiles are saved or not. In case the tiles are skipped due to a background mask, an empty tile is returned. - - Parameters - ---------- - location : tuple[int, int] - Coordinates (x, y) of the upper left corner of the region. - level : int - The level of the region. Only level 0 is supported. - size : tuple[int, int] - The (h, w) size of the extracted region. - - Returns - ------- - pyvips.Image - Extracted region - """ - if level != 0: - raise ValueError("Only level 0 is supported") - - if self._file is None: - self._open_file() - assert self._file, "File is not open. Should not happen" - assert self._tile_size is not None, "self._tile_size should not be None" - assert self._tile_overlap is not None, "self._tile_overlap should not be None" - - image_dataset = self._file["data"] - - tile_indices = self._file["tile_indices"] - + def _read_image_region(self, image_dataset, tile_indices, size, location): total_rows = math.ceil((self._size[1] - self._tile_overlap[1]) / self._stride[1]) total_cols = math.ceil((self._size[0] - self._tile_overlap[0]) / self._stride[0]) assert ( - total_rows * total_cols == self._num_tiles or self._data_format == DataFormat.FEATURE + total_rows * total_cols == self._num_tiles ), f"{total_rows=}, {total_cols=} and {self._num_tiles=}" # Equality only holds if features where created without mask @@ -253,27 +212,6 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in if self._stitching_mode == StitchingMode.AVERAGE: average_mask = np.zeros((h, w), dtype=self._dtype) - if self._data_format == DataFormat.FEATURE: - if self._stitching_mode != StitchingMode.CROP: - raise NotImplementedError("Stitching mode other than CROP is not supported for features.") - - if image_dataset.shape[0] != self._num_samples: - raise ValueError( - f"Reading features expects that the saved feature vectors are the same " - f"length as the number of samples in the dataset. " - f"Feature vector length was {image_dataset.shape[0]}, " - f"number of samples in the dataset was {self._num_samples}" - ) - - if x + w > self._num_samples or y + h > 1: - raise ValueError( - f"Feature vectors are saved as (num_samples, 1) " - f"and the requested size {size} at location {location} is too large." - ) - - # this simplified version of the crop is done as it is faster than the general crop - return pyvips.Image.new_from_array(np.expand_dims(image_dataset[x : x + w, :], axis=0)) - for i in range(start_row, end_row): for j in range(start_col, end_col): tile_idx = (i * total_cols) + j @@ -315,8 +253,9 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in average_mask[img_start_y:img_end_y, img_start_x:img_end_x] += 1 stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x] += tile[ - :, tile_start_y:tile_end_y, tile_start_x:tile_end_x - ] + :, tile_start_y:tile_end_y, + tile_start_x:tile_end_x + ] elif self._stitching_mode == StitchingMode.MAXIMUM: tile_start_y = max(0, -start_y) @@ -327,25 +266,94 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in if i == start_row and j == start_col: # The first tile cannot be compared with anything. So, we just copy it. stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x] = tile[ - :, tile_start_y:tile_end_y, tile_start_x:tile_end_x - ] + :, tile_start_y:tile_end_y, + tile_start_x:tile_end_x + ] stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x] = np.maximum( stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x], tile[:, tile_start_y:tile_end_y, tile_start_x:tile_end_x], ) - # Adjust the precision and convert to float32 before averaging to avoid loss of precision. - if self._precision != str(InferencePrecision.UINT8) or self._stitching_mode == StitchingMode.AVERAGE: - stitched_image = stitched_image / self._multiplier - stitched_image = stitched_image.astype(np.float32) + # Adjust the precision and convert to float32 before averaging to avoid loss of precision. + if self._precision != str(InferencePrecision.UINT8) or self._stitching_mode == StitchingMode.AVERAGE: + stitched_image = stitched_image / self._multiplier + stitched_image = stitched_image.astype(np.float32) + + if self._stitching_mode == StitchingMode.AVERAGE: + overlap_regions = average_mask > 0 + # Perform division to average the accumulated pixel values + stitched_image[:, overlap_regions] = stitched_image[:, overlap_regions] / average_mask[overlap_regions] + + return pyvips.Image.new_from_array(stitched_image.transpose(1, 2, 0)) + + def _read_feature_region(self, image_dataset, tile_indices, size, location): + x, y = location + w, h = size + + if self._stitching_mode != StitchingMode.CROP: + raise NotImplementedError("Stitching mode other than CROP is not supported for features.") + + if image_dataset.shape[0] != self._num_samples: + raise ValueError( + f"Reading features expects that the saved feature vectors are the same " + f"length as the number of samples in the dataset. " + f"Feature vector length was {image_dataset.shape[0]}, " + f"number of samples in the dataset was {self._num_samples}" + ) + + if x + w > self._num_samples or y + h > 1: + raise ValueError( + f"Feature vectors are saved as (num_samples, 1) " + f"and the requested size {size} at location {location} is too large." + ) + + # this simplified version of the crop is done as it is faster than the crop for images crop + return pyvips.Image.new_from_array(np.expand_dims(image_dataset[x: x + w, :], axis=0)) + + def read_region(self, location: tuple[int, int], level: int, size: tuple[int, int]) -> pyvips.Image: + """ + Reads a region in the stored h5 file. This function stitches the regions as saved in the cache file. Doing this + it takes into account: + 1) The region overlap, several region merging strategies are implemented: cropping, averaging across borders + and taking the maximum across borders. + 2) If tiles are saved or not. In case the tiles are skipped due to a background mask, an empty tile is returned. + + Parameters + ---------- + location : tuple[int, int] + Coordinates (x, y) of the upper left corner of the region. + level : int + The level of the region. Only level 0 is supported. + size : tuple[int, int] + The (h, w) size of the extracted region. + + Returns + ------- + pyvips.Image + Extracted region + """ + if level != 0: + raise ValueError("Only level 0 is supported") + + if self._file is None: + self._open_file() + assert self._file, "File is not open. Should not happen" + assert self._tile_size is not None, "self._tile_size should not be None" + assert self._tile_overlap is not None, "self._tile_overlap should not be None" + + image_dataset = self._file["data"] + + tile_indices = self._file["tile_indices"] + + if self._data_format == DataFormat.IMAGE: + return self._read_image_region(image_dataset, tile_indices, size, location) + elif self._data_format == DataFormat.FEATURE: + return self._read_feature_region(image_dataset, tile_indices, size, location) + + - if self._stitching_mode == StitchingMode.AVERAGE: - overlap_regions = average_mask > 0 - # Perform division to average the accumulated pixel values - stitched_image[:, overlap_regions] = stitched_image[:, overlap_regions] / average_mask[overlap_regions] - return pyvips.Image.new_from_array(stitched_image.transpose(1, 2, 0)) @abc.abstractmethod def close(self) -> None: diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 701aa9f..a4d7925 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -14,6 +14,7 @@ from dlup.annotations import WsiAnnotations from dlup.data.dataset import RegionFromWsiDatasetSample, TiledWsiDataset, TileSample from dlup.tiling import GridOrder, TilingMode +from dlup.backends import ImageBackend as DLUPImageBackend from pydantic import BaseModel from sqlalchemy import Column, create_engine from sqlalchemy.engine import Engine @@ -397,6 +398,41 @@ def close(self) -> None: self.__session.close() self.__session = None +def get_image_info(db_manager: DataManager, + data_description: DataDescription, + image: Image) -> tuple[Path, tuple[PositiveInt, PositiveInt], tuple[PositiveInt, PositiveInt], ImageBackend, PositiveFloat, PositiveFloat, TilingMode]: + if data_description.feature_version is not None: + # if feature_version is defined we use features + # right now this selects all features, todo: add some argument tile_size to overwrite this + + image_feature, feature_description = db_manager.get_image_features_by_image_and_feature_version( + image.id, data_description.feature_version + ) + + ( + image_path, + mpp, + tile_size, + tile_overlap, + backend, + overwrite_mpp, + ) = get_relevant_feature_info_from_record(image_feature, data_description, feature_description) + + tile_mode = TilingMode.skip + + return image_path, tile_size, tile_overlap, backend, mpp, overwrite_mpp, tile_mode + + else: + image_path = data_description.data_dir / image.filename + tile_size = (data_description.training_grid.tile_size, data_description.training_grid.tile_size) + tile_overlap = (data_description.training_grid.tile_overlap, data_description.training_grid.tile_overlap) + backend = DLUPImageBackend[str(image.reader)] + mpp = data_description.training_grid.mpp + overwrite_mpp = float(image.mpp) + tile_mode = TilingMode.overflow + return image_path, tile_size, tile_overlap, backend, mpp, overwrite_mpp, tile_mode + + def datasets_from_data_description( db_manager: DataManager, @@ -435,34 +471,7 @@ def datasets_from_data_description( rois = _get_rois(mask, data_description, stage) mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold - if use_features: - image_feature, feature_description = db_manager.get_image_features_by_image_and_feature_version( - image.id, data_description.feature_version - ) - ( - image_path, - mpp, - tile_size, - tile_overlap, - backend, - overwrite_mpp, - ) = get_relevant_feature_info_from_record(image_feature, data_description, feature_description) - tile_mode = TilingMode.skip - else: - if grid_description is None: - raise ValueError( - f"grid_description for stage {stage} is None should be set if images are supposed to be used" - ) - - image_path = image_root / image.filename - tile_size = grid_description.tile_size - tile_overlap = grid_description.tile_overlap - backend = ImageBackend[str(image.reader)] - mpp = grid_description.mpp # type: ignore - overwrite_mpp = float(image.mpp) - tile_mode = TilingMode.overflow - - # fixme: something is still wrong here, in the reader or in the database, got empty tiles + image_path, tile_size, tile_overlap, backend, mpp, overwrite_mpp, tile_mode = get_image_info(db_manager, data_description, image) dataset = TiledWsiDataset.from_standard_tiling( path=image_path, From 7efaa7021d81e1263421339807078c1ee2e6199c Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Wed, 11 Sep 2024 14:55:41 +0200 Subject: [PATCH 24/30] fix writers and tests for writers --- ahcore/utils/types.py | 1 - ahcore/writers.py | 17 ++++++++++++++--- tests/test_writers/test_h5_writer.py | 21 +++++++++------------ 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/ahcore/utils/types.py b/ahcore/utils/types.py index 7dd5f1f..be3e076 100644 --- a/ahcore/utils/types.py +++ b/ahcore/utils/types.py @@ -94,4 +94,3 @@ class DataFormat(str, Enum): FEATURE = "feature" IMAGE = "image" COMPRESSED_IMAGE = "compressed_image" - MASK = "mask" diff --git a/ahcore/writers.py b/ahcore/writers.py index 3bcb816..80b9e87 100644 --- a/ahcore/writers.py +++ b/ahcore/writers.py @@ -343,7 +343,7 @@ def init_writer(self, first_coordinates: GenericNumberArray, first_batch: Generi compression="gzip", ) - if self._data_format != DataFormat.COMPRESSED_IMAGE: + if self._data_format == DataFormat.FEATURE: shape = first_batch.shape[1:] self._data = self.create_dataset( file, @@ -351,9 +351,10 @@ def init_writer(self, first_coordinates: GenericNumberArray, first_batch: Generi shape=(self._num_samples,) + shape, dtype=first_batch.dtype, compression="gzip", - chunks=(1,) + shape, + chunks=(self._num_samples,) + shape, # this should be the fastest as we are loading + # all features everytime ) - else: + elif self._data_format == DataFormat.COMPRESSED_IMAGE: self._data = self.create_variable_length_dataset( file, name="data", @@ -361,6 +362,16 @@ def init_writer(self, first_coordinates: GenericNumberArray, first_batch: Generi chunks=(1,), compression="gzip", ) + else: # data_format == DataFormat.IMAGE + shape = first_batch.shape[1:] + self._data = self.create_dataset( + file, + "data", + shape=(self._num_samples,) + shape, + dtype=first_batch.dtype, + compression="gzip", + chunks=(1,) + shape, + ) if self._color_profile: data = np.frombuffer(self._color_profile, dtype=np.uint8) diff --git a/tests/test_writers/test_h5_writer.py b/tests/test_writers/test_h5_writer.py index 8a08bb6..53d6416 100644 --- a/tests/test_writers/test_h5_writer.py +++ b/tests/test_writers/test_h5_writer.py @@ -33,12 +33,6 @@ def dummy_batch_data() -> Generator[tuple[GenericNumberArray, GenericNumberArray yield dummy_coordinates, dummy_batch -@pytest.fixture -def dummy_feature_batch_data() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: - dummy_coordinates = np.array([[i, 0] for i in range(16)]).astype(np.float32) - dummy_batch = np.random.rand(16, 512).astype(np.float32) - yield dummy_coordinates, dummy_batch - @pytest.fixture def dummy_batch_generator( @@ -52,11 +46,14 @@ def generator() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None @pytest.fixture def dummy_feature_batch_generator( - dummy_feature_batch_data: tuple[GenericNumberArray, GenericNumberArray] ) -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: + batch_size = 16 def generator() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: - for _ in range(1): - yield dummy_feature_batch_data + np.random.seed(42) + for batch_idx in range(2): + dummy_coordinates = np.array([[i, 0] for i in range(batch_size * batch_idx, batch_size * (batch_idx + 1))]).astype(np.float32) + dummy_batch = np.random.rand(batch_size, 512).astype(np.float32) + yield dummy_coordinates, dummy_batch return generator @@ -115,7 +112,7 @@ def test_h5_file_image_writer_consume(temp_h5_file: Path, dummy_batch_generator: def test_h5_file_image_writer_consume_feature(temp_h5_feature_file: Path, dummy_feature_batch_generator: Any) -> None: - size = (32, 512) + size = (32, 1) mpp = 0.5 tile_size = (1, 1) tile_overlap = (0, 0) @@ -133,12 +130,13 @@ def test_h5_file_image_writer_consume_feature(temp_h5_feature_file: Path, dummy_ data_format=data_format, ) + writer.consume(dummy_feature_batch_generator()) with h5py.File(temp_h5_feature_file, "r") as h5file: assert "data" in h5file assert "coordinates" in h5file - assert np.array(h5file["data"]).shape == (num_samples, size[1]) + assert np.array(h5file["data"]).shape == (num_samples, 512) assert np.array(h5file["coordinates"]).shape == (num_samples, 2) gen = dummy_feature_batch_generator() @@ -146,7 +144,6 @@ def test_h5_file_image_writer_consume_feature(temp_h5_feature_file: Path, dummy_ assert np.allclose(h5file["data"][idx * batch_size : (idx + 1) * batch_size, :], dummy_batch) assert np.allclose(h5file["coordinates"][idx * batch_size : (idx + 1) * batch_size, :], dummy_coordinates) - # def test_h5_file_image_writer_metadata(temp_h5_file: Path, dummy_batch_generator: Any) -> None: From 2d83964da2aeb429d338aceca07d7827c7ae4b79 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Fri, 13 Sep 2024 13:11:33 +0200 Subject: [PATCH 25/30] fix test, mypy and file_writer + callback --- ahcore/backends.py | 3 +- ahcore/callbacks/file_writer_callback.py | 78 +++++++++++++++--------- ahcore/data/dataset.py | 6 +- ahcore/losses.py | 2 - ahcore/readers.py | 56 +++++++++-------- ahcore/utils/manifest.py | 60 ++++++++++++------ ahcore/utils/types.py | 2 +- tests/test_writers/test_h5_writer.py | 27 +++++--- 8 files changed, 143 insertions(+), 91 deletions(-) diff --git a/ahcore/backends.py b/ahcore/backends.py index 16431c3..25d9a80 100644 --- a/ahcore/backends.py +++ b/ahcore/backends.py @@ -8,7 +8,6 @@ from dlup.backends.tifffile_backend import TifffileSlide from dlup.types import PathLike # type: ignore - from ahcore.readers import H5FileImageReader, StitchingMode, ZarrFileImageReader @@ -96,4 +95,4 @@ class ImageBackend(Enum): ZARR: Callable[[PathLike], ZarrSlide] = ZarrSlide def __call__(self, *args: Any) -> OpenSlideSlide | PyVipsSlide | TifffileSlide | H5Slide | ZarrSlide: - return self.value(*args) # type: ignore + return self.value(*args) # type: ignore diff --git a/ahcore/callbacks/file_writer_callback.py b/ahcore/callbacks/file_writer_callback.py index bdf3916..f2d2c97 100644 --- a/ahcore/callbacks/file_writer_callback.py +++ b/ahcore/callbacks/file_writer_callback.py @@ -4,6 +4,7 @@ from typing import Type from dlup.data.dataset import TiledWsiDataset +from dlup.tiling import Grid from ahcore.callbacks.abstract_writer_callback import AbstractWriterCallback from ahcore.callbacks.converters.common import ConvertCallbacks @@ -27,7 +28,7 @@ def __init__( normalization_type: str = NormalizationType.LOGITS, precision: str = InferencePrecision.FP32, callbacks: list[ConvertCallbacks] | None = None, - data_format: DataFormat = DataFormat.IMAGE, + data_format: str = DataFormat.IMAGE, ) -> None: """ Callback to write predictions to H5 files. This callback is used to write whole-slide predictions to single H5 @@ -55,7 +56,7 @@ def __init__( self._suffix = ".cache" self._normalization_type: NormalizationType = NormalizationType(normalization_type) self._precision: InferencePrecision = InferencePrecision(precision) - self._data_format = data_format + self._data_format = DataFormat(data_format) super().__init__( writer_class=writer_class, @@ -99,36 +100,15 @@ def build_writer_class(self, pl_module: AhCoreLightningModule, stage: str, filen with open(link_fn, "a" if link_fn.is_file() else "w") as file: file.write(f"{filename},{output_filename}\n") - current_dataset: TiledWsiDataset - current_dataset, _ = self._total_dataset.index_to_dataset(self._dataset_index) # type: ignore - slide_image = current_dataset.slide_image - num_samples = len(current_dataset) - - data_description: DataDescription = pl_module.data_description - if data_description.inference_grid is None: - raise ValueError("Inference grid is not defined in the data description.") - inference_grid: GridDescription = data_description.inference_grid - - mpp = inference_grid.mpp - if mpp is None: - mpp = slide_image.mpp - - _, size = slide_image.get_scaled_slide_bounds(slide_image.get_scaling(mpp)) - - # Let's get the data_description, so we can figure out the tile size and things like that - tile_size = inference_grid.tile_size - tile_overlap = inference_grid.tile_overlap - - if stage == "validate": - grid = current_dataset._grids[0][0] # pylint: disable=protected-access - else: - grid = None # During inference we don't have a grid around ROI + size, mpp, tile_size, tile_overlap, num_samples, grid = self._get_writer_data_args( + pl_module, data_format=self._data_format, stage=stage + ) writer = self._writer_class( output_filename, - size=size, + size=size, # --> (num_samples,1) mpp=mpp, - tile_size=tile_size, + tile_size=tile_size, # --> (1,1) tile_overlap=tile_overlap, num_samples=num_samples, color_profile=None, @@ -139,3 +119,45 @@ def build_writer_class(self, pl_module: AhCoreLightningModule, stage: str, filen ) return writer + + def _get_writer_data_args( + self, pl_module: AhCoreLightningModule, data_format: DataFormat, stage: str + ) -> tuple[tuple[int, int], float, tuple[int, int], tuple[int, int], int, Grid | None]: + current_dataset: TiledWsiDataset + current_dataset, _ = self._total_dataset.index_to_dataset(self._dataset_index) # type: ignore + slide_image = current_dataset.slide_image + num_samples = len(current_dataset) + + if data_format == DataFormat.IMAGE or data_format == DataFormat.COMPRESSED_IMAGE: + data_description: DataDescription = pl_module.data_description + if data_description.inference_grid is None: + raise ValueError("Inference grid is not defined in the data description.") + inference_grid: GridDescription = data_description.inference_grid + + mpp = inference_grid.mpp + if mpp is None: + mpp = slide_image.mpp + + _, size = slide_image.get_scaled_slide_bounds(slide_image.get_scaling(mpp)) + + # Let's get the data_description, so we can figure out the tile size and things like that + tile_size = inference_grid.tile_size + tile_overlap = inference_grid.tile_overlap + + if stage == "validate": + grid = current_dataset._grids[0][0] # pylint: disable=protected-access + else: + grid = None # During inference we don't have a grid around ROI + + elif data_format == DataFormat.FEATURE: + size = (num_samples, 1) + mpp = 1.0 + tile_size = (1, 1) + tile_overlap = (0, 0) + num_samples = num_samples + grid = None # just let the writer make a new grid + + else: + raise NotImplementedError(f"Data format {data_format} is not yet supported.") + + return size, mpp, tile_size, tile_overlap, num_samples, grid diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index 85a5bd0..94d88ef 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -88,10 +88,12 @@ def __len__(self) -> int: return self.cumulative_sizes[-1] @overload - def __getitem__(self, index: int) -> DlupDatasetSample: ... + def __getitem__(self, index: int) -> DlupDatasetSample: + ... @overload - def __getitem__(self, index: slice) -> list[DlupDatasetSample]: ... + def __getitem__(self, index: slice) -> list[DlupDatasetSample]: + ... def __getitem__(self, index: Union[int, slice]) -> DlupDatasetSample | list[DlupDatasetSample]: """Returns the sample at the given index.""" diff --git a/ahcore/losses.py b/ahcore/losses.py index 0bc35c0..9fecee3 100644 --- a/ahcore/losses.py +++ b/ahcore/losses.py @@ -7,7 +7,6 @@ from __future__ import annotations -import logging from typing import Callable, Optional, Union, cast import numpy as np @@ -190,7 +189,6 @@ def cross_entropy( ) / (roi_sum * topk) - def soft_dice( input: torch.Tensor, target: torch.Tensor, diff --git a/ahcore/readers.py b/ahcore/readers.py index 4b6d5b9..ad231b6 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -53,13 +53,14 @@ def __init__(self, filename: Path, stitching_mode: StitchingMode) -> None: self._mpp = None self._tile_size = None self._tile_overlap = None - self._size = None + self._size: Optional[tuple[int, int]] = None self._num_channels = None self._dtype = None self._stride = None self._precision = None self._multiplier = None self._is_binary = None + self._num_samples = None @classmethod def from_file_path(cls, filename: Path, stitching_mode: StitchingMode = StitchingMode.CROP) -> "FileImageReader": @@ -69,8 +70,9 @@ def from_file_path(cls, filename: Path, stitching_mode: StitchingMode = Stitchin def size(self) -> tuple[int, int]: if not self._size: self._open_file() - assert self._size - return tuple(self._size) + + assert self._size and len(self.size) == 2 + return self._size[0], self._size[1] @property def mpp(self) -> float: @@ -123,8 +125,6 @@ def _open_file(self) -> None: DataFormat(self._metadata["data_format"]) if "data_format" in self._metadata.keys() else DataFormat.IMAGE ) self._mpp = self._metadata["mpp"] - # features are always read at tile_size (1, 1), possibly faster to read the whole feature at once - # this should be fixed in the writer and features have to be remade so this doesn't have to be hardcoded self._tile_size = self._metadata["tile_size"] self._tile_overlap = self._metadata["tile_overlap"] self._size = self._metadata["size"] @@ -184,13 +184,19 @@ def _decompress_and_reshape_data(self, tile: GenericNumberArray) -> GenericNumbe else: return tile - def _read_image_region(self, image_dataset, tile_indices, size, location): + def _read_image_region(self, size: tuple[int, int], location: tuple[int, int]) -> pyvips.Image: + assert self._size is not None + assert self._tile_size is not None + assert self._tile_overlap is not None + + image_dataset: h5py.Dataset = self._file["data"] + + tile_indices: h5py.Dataset = self._file["tile_indices"] + total_rows = math.ceil((self._size[1] - self._tile_overlap[1]) / self._stride[1]) total_cols = math.ceil((self._size[0] - self._tile_overlap[0]) / self._stride[0]) - assert ( - total_rows * total_cols == self._num_tiles - ), f"{total_rows=}, {total_cols=} and {self._num_tiles=}" + assert total_rows * total_cols == self._num_tiles, f"{total_rows=}, {total_cols=} and {self._num_tiles=}" # Equality only holds if features where created without mask x, y = location @@ -253,9 +259,8 @@ def _read_image_region(self, image_dataset, tile_indices, size, location): average_mask[img_start_y:img_end_y, img_start_x:img_end_x] += 1 stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x] += tile[ - :, tile_start_y:tile_end_y, - tile_start_x:tile_end_x - ] + :, tile_start_y:tile_end_y, tile_start_x:tile_end_x + ] elif self._stitching_mode == StitchingMode.MAXIMUM: tile_start_y = max(0, -start_y) @@ -266,9 +271,8 @@ def _read_image_region(self, image_dataset, tile_indices, size, location): if i == start_row and j == start_col: # The first tile cannot be compared with anything. So, we just copy it. stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x] = tile[ - :, tile_start_y:tile_end_y, - tile_start_x:tile_end_x - ] + :, tile_start_y:tile_end_y, tile_start_x:tile_end_x + ] stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x] = np.maximum( stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x], @@ -287,7 +291,13 @@ def _read_image_region(self, image_dataset, tile_indices, size, location): return pyvips.Image.new_from_array(stitched_image.transpose(1, 2, 0)) - def _read_feature_region(self, image_dataset, tile_indices, size, location): + def _read_feature_region(self, size: tuple[int, int], location: tuple[int, int]) -> pyvips.Image: + assert self._num_samples is not None + + image_dataset: h5py.Dataset = self._file["data"] + + assert image_dataset is not None + x, y = location w, h = size @@ -309,7 +319,7 @@ def _read_feature_region(self, image_dataset, tile_indices, size, location): ) # this simplified version of the crop is done as it is faster than the crop for images crop - return pyvips.Image.new_from_array(np.expand_dims(image_dataset[x: x + w, :], axis=0)) + return pyvips.Image.new_from_array(np.expand_dims(image_dataset[x : x + w, :], axis=0)) def read_region(self, location: tuple[int, int], level: int, size: tuple[int, int]) -> pyvips.Image: """ @@ -342,18 +352,10 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in assert self._tile_size is not None, "self._tile_size should not be None" assert self._tile_overlap is not None, "self._tile_overlap should not be None" - image_dataset = self._file["data"] - - tile_indices = self._file["tile_indices"] - if self._data_format == DataFormat.IMAGE: - return self._read_image_region(image_dataset, tile_indices, size, location) + return self._read_image_region(size, location) elif self._data_format == DataFormat.FEATURE: - return self._read_feature_region(image_dataset, tile_indices, size, location) - - - - + return self._read_feature_region(size, location) @abc.abstractmethod def close(self) -> None: diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index a4d7925..b74ac6c 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -12,9 +12,9 @@ from dlup import SlideImage from dlup.annotations import WsiAnnotations +from dlup.backends import ImageBackend as DLUPImageBackend from dlup.data.dataset import RegionFromWsiDatasetSample, TiledWsiDataset, TileSample from dlup.tiling import GridOrder, TilingMode -from dlup.backends import ImageBackend as DLUPImageBackend from pydantic import BaseModel from sqlalchemy import Column, create_engine from sqlalchemy.engine import Engine @@ -398,9 +398,19 @@ def close(self) -> None: self.__session.close() self.__session = None -def get_image_info(db_manager: DataManager, - data_description: DataDescription, - image: Image) -> tuple[Path, tuple[PositiveInt, PositiveInt], tuple[PositiveInt, PositiveInt], ImageBackend, PositiveFloat, PositiveFloat, TilingMode]: + +def get_image_info( + db_manager: DataManager, data_description: DataDescription, image: Image, stage: str +) -> tuple[ + Path, + tuple[PositiveInt, PositiveInt], + tuple[PositiveInt, PositiveInt], + ImageBackend, + PositiveFloat, + PositiveFloat, + TilingMode, + Optional[Tuple[int, int]], +]: if data_description.feature_version is not None: # if feature_version is defined we use features # right now this selects all features, todo: add some argument tile_size to overwrite this @@ -420,18 +430,29 @@ def get_image_info(db_manager: DataManager, tile_mode = TilingMode.skip - return image_path, tile_size, tile_overlap, backend, mpp, overwrite_mpp, tile_mode + output_tile_size = None + + return image_path, tile_size, tile_overlap, backend, mpp, overwrite_mpp, tile_mode, output_tile_size else: + if stage == "fit": + grid_description = data_description.training_grid + else: + grid_description = data_description.inference_grid + + if grid_description is None: + raise ValueError(f"Grid (for stage {stage}) is not defined in the data description.") + image_path = data_description.data_dir / image.filename - tile_size = (data_description.training_grid.tile_size, data_description.training_grid.tile_size) - tile_overlap = (data_description.training_grid.tile_overlap, data_description.training_grid.tile_overlap) + tile_size = grid_description.tile_size + tile_overlap = grid_description.tile_overlap backend = DLUPImageBackend[str(image.reader)] - mpp = data_description.training_grid.mpp + mpp = getattr(grid_description, "mpp", 1.0) overwrite_mpp = float(image.mpp) tile_mode = TilingMode.overflow - return image_path, tile_size, tile_overlap, backend, mpp, overwrite_mpp, tile_mode + output_tile_size = getattr(grid_description, "output_tile_size", None) + return image_path, tile_size, tile_overlap, backend, mpp, overwrite_mpp, tile_mode, output_tile_size def datasets_from_data_description( @@ -442,18 +463,10 @@ def datasets_from_data_description( ) -> Generator[TiledWsiDataset, None, None]: logger.info(f"Reading manifest from {data_description.manifest_database_uri} for stage {stage}") - image_root = data_description.data_dir annotations_root = data_description.annotations_dir assert isinstance(stage, str), "Stage should be a string." - if stage == "fit": - grid_description = data_description.training_grid - else: - grid_description = data_description.inference_grid - - use_features = data_description.feature_version is not None - patients = db_manager.get_records_by_split( manifest_name=data_description.manifest_name, split_version=data_description.split_version, @@ -471,7 +484,16 @@ def datasets_from_data_description( rois = _get_rois(mask, data_description, stage) mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold - image_path, tile_size, tile_overlap, backend, mpp, overwrite_mpp, tile_mode = get_image_info(db_manager, data_description, image) + ( + image_path, + tile_size, + tile_overlap, + backend, + mpp, + overwrite_mpp, + tile_mode, + output_tile_size, + ) = get_image_info(db_manager, data_description, image, stage) dataset = TiledWsiDataset.from_standard_tiling( path=image_path, @@ -483,7 +505,7 @@ def datasets_from_data_description( crop=False, mask=mask, mask_threshold=mask_threshold, - output_tile_size=getattr(grid_description, "output_tile_size", None), + output_tile_size=output_tile_size, rois=rois if rois is not None else None, annotations=annotations if stage != "predict" else None, labels=labels, # type: ignore diff --git a/ahcore/utils/types.py b/ahcore/utils/types.py index be3e076..5bcd1b7 100644 --- a/ahcore/utils/types.py +++ b/ahcore/utils/types.py @@ -28,7 +28,7 @@ def is_non_negative(v: int | float) -> int | float: NonNegativeFloat = Annotated[float, AfterValidator(is_non_negative)] BoundingBoxType = tuple[tuple[int, int], tuple[int, int]] Rois = list[BoundingBoxType] -GenericNumberArray = npt.NDArray[np.int_ | np.float_] +GenericNumberArray = npt.NDArray[np.int_ | np.float_ | np.float64 | np.float32] DlupDatasetSample = dict[str, Any] _DlupDataset = Dataset[DlupDatasetSample] diff --git a/tests/test_writers/test_h5_writer.py b/tests/test_writers/test_h5_writer.py index 53d6416..57f9884 100644 --- a/tests/test_writers/test_h5_writer.py +++ b/tests/test_writers/test_h5_writer.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Any, Generator +from typing import Any, Callable, Generator import h5py import numpy as np @@ -33,11 +33,10 @@ def dummy_batch_data() -> Generator[tuple[GenericNumberArray, GenericNumberArray yield dummy_coordinates, dummy_batch - @pytest.fixture def dummy_batch_generator( dummy_batch_data: tuple[GenericNumberArray, GenericNumberArray] -) -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: +) -> Callable[[], Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]]: def generator() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: yield dummy_batch_data @@ -45,13 +44,17 @@ def generator() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None @pytest.fixture -def dummy_feature_batch_generator( -) -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: +def dummy_feature_batch_generator() -> ( + Callable[[], Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]] +): batch_size = 16 + def generator() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: np.random.seed(42) for batch_idx in range(2): - dummy_coordinates = np.array([[i, 0] for i in range(batch_size * batch_idx, batch_size * (batch_idx + 1))]).astype(np.float32) + dummy_coordinates = np.array( + [[i, 0] for i in range(batch_size * batch_idx, batch_size * (batch_idx + 1))] + ).astype(np.float32) dummy_batch = np.random.rand(batch_size, 512).astype(np.float32) yield dummy_coordinates, dummy_batch @@ -130,7 +133,6 @@ def test_h5_file_image_writer_consume_feature(temp_h5_feature_file: Path, dummy_ data_format=data_format, ) - writer.consume(dummy_feature_batch_generator()) with h5py.File(temp_h5_feature_file, "r") as h5file: @@ -145,7 +147,6 @@ def test_h5_file_image_writer_consume_feature(temp_h5_feature_file: Path, dummy_ assert np.allclose(h5file["coordinates"][idx * batch_size : (idx + 1) * batch_size, :], dummy_coordinates) - def test_h5_file_image_writer_metadata(temp_h5_file: Path, dummy_batch_generator: Any) -> None: size = (200, 200) mpp = 0.5 @@ -219,7 +220,13 @@ def multiple_tile_generator() -> Generator[tuple[GenericNumberArray, GenericNumb with h5py.File(temp_h5_file, "r") as h5file: assert "data" in h5file assert "coordinates" in h5file - assert h5file["data"].shape == (num_samples, 3, 200, 200) - assert h5file["coordinates"].shape == (num_samples, 2) + dataset: h5py.Dataset = h5file["data"] + coordinates: h5py.Dataset = h5file["coordinates"] + + assert isinstance(dataset, h5py.Dataset) + assert isinstance(coordinates, h5py.Dataset) + + assert dataset.shape == (num_samples, 3, 200, 200) # pylint: disable=no-member + assert coordinates.shape == (num_samples, 2) # pylint: disable=no-member for i in range(num_samples): assert np.allclose(h5file["coordinates"][i], [i * 200, 0]) From 71bb871355c91e963da2a58852add6f2abf1245b Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Fri, 13 Sep 2024 14:24:01 +0200 Subject: [PATCH 26/30] now also passes tests... --- ahcore/readers.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ahcore/readers.py b/ahcore/readers.py index ad231b6..2c2c6e1 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -279,17 +279,17 @@ def _read_image_region(self, size: tuple[int, int], location: tuple[int, int]) - tile[:, tile_start_y:tile_end_y, tile_start_x:tile_end_x], ) - # Adjust the precision and convert to float32 before averaging to avoid loss of precision. - if self._precision != str(InferencePrecision.UINT8) or self._stitching_mode == StitchingMode.AVERAGE: - stitched_image = stitched_image / self._multiplier - stitched_image = stitched_image.astype(np.float32) + # Adjust the precision and convert to float32 before averaging to avoid loss of precision. + if self._precision != str(InferencePrecision.UINT8) or self._stitching_mode == StitchingMode.AVERAGE: + stitched_image = stitched_image / self._multiplier + stitched_image = stitched_image.astype(np.float32) - if self._stitching_mode == StitchingMode.AVERAGE: - overlap_regions = average_mask > 0 - # Perform division to average the accumulated pixel values - stitched_image[:, overlap_regions] = stitched_image[:, overlap_regions] / average_mask[overlap_regions] + if self._stitching_mode == StitchingMode.AVERAGE: + overlap_regions = average_mask > 0 + # Perform division to average the accumulated pixel values + stitched_image[:, overlap_regions] = stitched_image[:, overlap_regions] / average_mask[overlap_regions] - return pyvips.Image.new_from_array(stitched_image.transpose(1, 2, 0)) + return pyvips.Image.new_from_array(stitched_image.transpose(1, 2, 0)) def _read_feature_region(self, size: tuple[int, int], location: tuple[int, int]) -> pyvips.Image: assert self._num_samples is not None @@ -356,6 +356,8 @@ def read_region(self, location: tuple[int, int], level: int, size: tuple[int, in return self._read_image_region(size, location) elif self._data_format == DataFormat.FEATURE: return self._read_feature_region(size, location) + else: + raise NotImplementedError(f"Data format {self._data_format} is not supported.") @abc.abstractmethod def close(self) -> None: From ae580cb17c814f614b152f7a778eb650c37550cb Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Mon, 16 Sep 2024 16:00:02 +0200 Subject: [PATCH 27/30] mypy, pylint and it runs now --- ahcore/readers.py | 5 +- ahcore/utils/manifest.py | 174 ++++++++++++++++++++++++++++++--------- 2 files changed, 136 insertions(+), 43 deletions(-) diff --git a/ahcore/readers.py b/ahcore/readers.py index 2c2c6e1..564cd6e 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -71,7 +71,8 @@ def size(self) -> tuple[int, int]: if not self._size: self._open_file() - assert self._size and len(self.size) == 2 + assert self._size, "Size should be set after opening the file" + return self._size[0], self._size[1] @property @@ -319,7 +320,7 @@ def _read_feature_region(self, size: tuple[int, int], location: tuple[int, int]) ) # this simplified version of the crop is done as it is faster than the crop for images crop - return pyvips.Image.new_from_array(np.expand_dims(image_dataset[x : x + w, :], axis=0)) + return pyvips.Image.new_from_array(np.expand_dims(image_dataset[x : x + w, :], axis=0).astype(np.float32)) def read_region(self, location: tuple[int, int], level: int, size: tuple[int, int]) -> pyvips.Image: """ diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index b74ac6c..40d8bab 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -6,6 +6,7 @@ from __future__ import annotations import functools +import logging from pathlib import Path from types import TracebackType from typing import Any, Callable, Generator, Literal, Optional, Tuple, Type, TypedDict, cast @@ -60,12 +61,27 @@ class _AnnotationReadersDict(TypedDict): "ASAP_XML": WsiAnnotations.from_asap_xml, "DARWIN_JSON": WsiAnnotations.from_darwin_json, "GEOJSON": WsiAnnotations.from_geojson, - "PYVIPS": functools.partial(SlideImage.from_file_path, backend=ImageBackend.PYVIPS), - "TIFFFILE": functools.partial(SlideImage.from_file_path, backend=ImageBackend.TIFFFILE), - "OPENSLIDE": functools.partial(SlideImage.from_file_path, backend=ImageBackend.OPENSLIDE), + "PYVIPS": functools.partial(SlideImage.from_file_path, backend=DLUPImageBackend.PYVIPS), + "TIFFFILE": functools.partial(SlideImage.from_file_path, backend=DLUPImageBackend.TIFFFILE), + "OPENSLIDE": functools.partial(SlideImage.from_file_path, backend=DLUPImageBackend.OPENSLIDE), } +class ImageInfoDict(TypedDict): + image_path: Optional[Path] + tile_size: Optional[Tuple[int, int]] + tile_overlap: Optional[Tuple[int, int]] + backend: Optional[ImageBackend] + mpp: Optional[float] + overwrite_mpp: Optional[float] + tile_mode: Optional[TilingMode] + output_tile_size: Optional[Tuple[int, int]] + mask: Optional[_AnnotationReturnTypes] + mask_threshold: Optional[float] + rois: Optional[Rois] + annotations: Optional[_AnnotationReturnTypes] + + def parse_annotations_from_record( annotations_root: Path, record: list[Mask] | list[ImageAnnotations] ) -> _AnnotationReturnTypes | None: @@ -184,7 +200,7 @@ def get_relevant_feature_info_from_record( return image_path, mpp, tile_size, tile_overlap, backend, overwrite_mpp -def _get_rois(mask: WsiAnnotations | None, data_description: DataDescription, stage: str) -> Optional[Rois]: +def _get_rois(mask: Optional[WsiAnnotations], data_description: DataDescription, stage: str) -> Optional[Rois]: if (mask is None) or (stage != "fit") or (not data_description.convert_mask_to_rois): return None @@ -343,7 +359,7 @@ def get_image_metadata_by_id(self, image_id: int) -> ImageMetadata: def get_image_features_by_image_and_feature_version( self, image_id: Column[int], feature_version: str | None - ) -> Tuple[ImageFeature, FeatureDescription]: + ) -> Tuple[ImageFeature | None, FeatureDescription]: """ Fetch the features for an image based on its ID and feature version. @@ -372,10 +388,8 @@ def get_image_features_by_image_and_feature_version( .filter_by(image_id=image_id, feature_description_id=feature_description.id) .first() ) - self._ensure_record( - image_feature, f"No features found for image ID {image_id} and feature version {feature_version}" - ) - assert image_feature is not None + if not image_feature: + logging.warning(f"No features found for image ID {image_id} and feature version {feature_version}") # todo: make sure that this only allows to run one ImageFeature, # I think it should be good bc of the unique constraint return image_feature, feature_description @@ -401,16 +415,25 @@ def close(self) -> None: def get_image_info( db_manager: DataManager, data_description: DataDescription, image: Image, stage: str -) -> tuple[ - Path, - tuple[PositiveInt, PositiveInt], - tuple[PositiveInt, PositiveInt], - ImageBackend, - PositiveFloat, - PositiveFloat, - TilingMode, - Optional[Tuple[int, int]], -]: +) -> ImageInfoDict: + # Initialize the output dictionary with all keys set to None + image_info: ImageInfoDict = { + "image_path": None, + "tile_size": None, + "tile_overlap": None, + "backend": None, + "mpp": None, + "overwrite_mpp": None, + "tile_mode": None, + "output_tile_size": None, + "mask": None, + "mask_threshold": None, + "rois": None, + "annotations": None, + } + + annotations_root = data_description.annotations_dir + if data_description.feature_version is not None: # if feature_version is defined we use features # right now this selects all features, todo: add some argument tile_size to overwrite this @@ -419,6 +442,10 @@ def get_image_info( image.id, data_description.feature_version ) + if image_feature is None: + # Directly return the initialized dictionary with None values + return image_info + ( image_path, mpp, @@ -428,11 +455,25 @@ def get_image_info( overwrite_mpp, ) = get_relevant_feature_info_from_record(image_feature, data_description, feature_description) - tile_mode = TilingMode.skip - - output_tile_size = None + # Update the dictionary with the actual values + image_info.update( + { + "image_path": image_path, + "tile_size": tile_size, + "tile_overlap": tile_overlap, + "backend": backend, + "mpp": mpp, + "overwrite_mpp": overwrite_mpp, + "tile_mode": TilingMode.skip, + "output_tile_size": None, + "mask": None, + "mask_threshold": None, + "rois": None, + "annotations": None, + } + ) - return image_path, tile_size, tile_overlap, backend, mpp, overwrite_mpp, tile_mode, output_tile_size + return image_info else: if stage == "fit": @@ -443,6 +484,11 @@ def get_image_info( if grid_description is None: raise ValueError(f"Grid (for stage {stage}) is not defined in the data description.") + mask, annotations = get_mask_and_annotations_from_record(annotations_root, image) + assert isinstance(mask, WsiAnnotations) or (mask is None) + mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold + rois = _get_rois(mask, data_description, stage) + image_path = data_description.data_dir / image.filename tile_size = grid_description.tile_size tile_overlap = grid_description.tile_overlap @@ -452,7 +498,25 @@ def get_image_info( tile_mode = TilingMode.overflow output_tile_size = getattr(grid_description, "output_tile_size", None) - return image_path, tile_size, tile_overlap, backend, mpp, overwrite_mpp, tile_mode, output_tile_size + # Update the dictionary with the actual values + image_info.update( + { + "image_path": image_path, + "tile_size": tile_size, + "tile_overlap": tile_overlap, + "backend": backend, + "mpp": mpp, + "overwrite_mpp": overwrite_mpp, + "tile_mode": tile_mode, + "output_tile_size": output_tile_size, + "mask": mask, + "mask_threshold": mask_threshold, + "rois": rois, + "annotations": annotations, + } + ) + + return image_info def datasets_from_data_description( @@ -463,8 +527,6 @@ def datasets_from_data_description( ) -> Generator[TiledWsiDataset, None, None]: logger.info(f"Reading manifest from {data_description.manifest_database_uri} for stage {stage}") - annotations_root = data_description.annotations_dir - assert isinstance(stage, str), "Stage should be a string." patients = db_manager.get_records_by_split( @@ -477,23 +539,53 @@ def datasets_from_data_description( patient_labels = get_labels_from_record(patient) for image in patient.images: - mask, annotations = get_mask_and_annotations_from_record(annotations_root, image) - assert isinstance(mask, WsiAnnotations) or (mask is None) image_labels = get_labels_from_record(image) labels = None if patient_labels is image_labels is None else (patient_labels or []) + (image_labels or []) - rois = _get_rois(mask, data_description, stage) - mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold - - ( - image_path, - tile_size, - tile_overlap, - backend, - mpp, - overwrite_mpp, - tile_mode, - output_tile_size, - ) = get_image_info(db_manager, data_description, image, stage) + + image_info = get_image_info(db_manager, data_description, image, stage) + + image_path = image_info["image_path"] + + if image_path is None: + # if no feature is found... + continue + + mpp = image_info["mpp"] + tile_size = image_info["tile_size"] + tile_overlap = image_info["tile_overlap"] + backend = image_info["backend"] + overwrite_mpp = image_info["overwrite_mpp"] + tile_mode = image_info["tile_mode"] + output_tile_size = image_info["output_tile_size"] + mask = image_info["mask"] + mask_threshold = image_info["mask_threshold"] + rois = image_info["rois"] + annotations = image_info["annotations"] + + assert isinstance(image_path, Path) + assert isinstance(mpp, float) + assert ( + isinstance(tile_size, tuple) + and len(tile_size) == 2 + and all(isinstance(i, int) for i in tile_size) # pylint: disable=not-an-iterable + ) + assert ( + isinstance(tile_overlap, tuple) + and len(tile_overlap) == 2 + and all(isinstance(i, int) for i in tile_overlap) # pylint: disable=not-an-iterable + ) + assert backend is not None + assert isinstance(overwrite_mpp, float) or overwrite_mpp is None + assert isinstance(tile_mode, TilingMode) + assert ( + isinstance(output_tile_size, tuple) + and len(output_tile_size) == 2 + and all(isinstance(i, int) for i in output_tile_size) # pylint: disable=not-an-iterable + or (output_tile_size is None) + ) + assert isinstance(mask, WsiAnnotations) or (mask is None) + assert isinstance(mask_threshold, float) or mask_threshold is None + assert isinstance(annotations, WsiAnnotations) or (annotations is None) dataset = TiledWsiDataset.from_standard_tiling( path=image_path, From ea8a330b558d3d226990dc4f0be8e8c1644f125d Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Tue, 24 Sep 2024 13:21:40 +0200 Subject: [PATCH 28/30] bugfixes to make the writer working --- ahcore/callbacks/file_writer_callback.py | 2 +- ahcore/transforms/pre_transforms.py | 20 +++++++++++--------- ahcore/utils/data.py | 1 + ahcore/utils/manifest.py | 20 ++++++++++++++------ 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/ahcore/callbacks/file_writer_callback.py b/ahcore/callbacks/file_writer_callback.py index f2d2c97..4535584 100644 --- a/ahcore/callbacks/file_writer_callback.py +++ b/ahcore/callbacks/file_writer_callback.py @@ -155,7 +155,7 @@ def _get_writer_data_args( tile_size = (1, 1) tile_overlap = (0, 0) num_samples = num_samples - grid = None # just let the writer make a new grid + grid = current_dataset._grids[0][0] # give grid, bc doesn't work otherwise else: raise NotImplementedError(f"Data format {data_format} is not yet supported.") diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index 45d00ee..b752cfb 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -5,7 +5,6 @@ from __future__ import annotations -import logging from typing import Any, Callable import numpy as np @@ -26,7 +25,9 @@ class PreTransformTaskFactory: - def __init__(self, transforms: list[PreTransformCallable]): + def __init__( + self, transforms: list[PreTransformCallable], data_description: DataDescription, requires_target: bool + ) -> None: """ Pre-transforms are transforms that are applied to the samples directly originating from the dataset. These transforms are typically the same for the specific tasks (e.g., segmentation, @@ -71,7 +72,7 @@ def for_segmentation( """ transforms: list[PreTransformCallable] = [] if not requires_target: - return cls(transforms) + return cls(transforms, data_description, requires_target) if data_description.index_map is None: raise ConfigurationError("`index_map` is required for segmentation models when the target is required.") @@ -88,7 +89,7 @@ def for_segmentation( if not multiclass: transforms.append(OneHotEncodeMask(index_map=data_description.index_map)) - return cls(transforms) + return cls(transforms, data_description, requires_target) @classmethod def for_wsi_classification( @@ -99,7 +100,7 @@ def for_wsi_classification( transforms.append(SampleNFeatures(n=1000)) if not requires_target: - return cls(transforms) + return cls(transforms, data_description, requires_target) index_map = data_description.index_map if index_map is None: @@ -112,7 +113,7 @@ def for_wsi_classification( transforms.append(LabelToClassIndex(index_map=index_map)) - return cls(transforms) + return cls(transforms, data_description, requires_target) def __call__(self, data: DlupDatasetSample) -> DlupDatasetSample: for transform in self._transforms: @@ -311,9 +312,10 @@ def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]: if "labels" in sample.keys() and sample["labels"] is not None: for key, value in sample["labels"].items(): - sample["labels"][key] = torch.tensor(value, dtype=torch.float32) - if sample["labels"][key].dim() == 0: - sample["labels"][key] = sample["labels"][key].unsqueeze(0) + if isinstance(value, float) or isinstance(value, int): + sample["labels"][key] = torch.tensor(value, dtype=torch.float32) + if sample["labels"][key].dim() == 0: + sample["labels"][key] = sample["labels"][key].unsqueeze(0) # annotation_data is added by the ConvertPolygonToMask transform. if "annotation_data" not in sample: diff --git a/ahcore/utils/data.py b/ahcore/utils/data.py index c5362ae..463d144 100644 --- a/ahcore/utils/data.py +++ b/ahcore/utils/data.py @@ -69,3 +69,4 @@ class DataDescription(BaseModel): convert_mask_to_rois: bool = True use_roi: bool = True apply_color_profile: bool = False + tiling_mode: Optional[str] = None diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 40d8bab..9627f80 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -62,7 +62,10 @@ class _AnnotationReadersDict(TypedDict): "DARWIN_JSON": WsiAnnotations.from_darwin_json, "GEOJSON": WsiAnnotations.from_geojson, "PYVIPS": functools.partial(SlideImage.from_file_path, backend=DLUPImageBackend.PYVIPS), - "TIFFFILE": functools.partial(SlideImage.from_file_path, backend=DLUPImageBackend.TIFFFILE), + "TIFFFILE": functools.partial( + SlideImage.from_file_path, + backend=DLUPImageBackend.TIFFFILE, + ), "OPENSLIDE": functools.partial(SlideImage.from_file_path, backend=DLUPImageBackend.OPENSLIDE), } @@ -200,11 +203,12 @@ def get_relevant_feature_info_from_record( return image_path, mpp, tile_size, tile_overlap, backend, overwrite_mpp -def _get_rois(mask: Optional[WsiAnnotations], data_description: DataDescription, stage: str) -> Optional[Rois]: +def _get_rois(mask: Optional[_AnnotationReturnTypes], data_description: DataDescription, stage: str) -> Optional[Rois]: if (mask is None) or (stage != "fit") or (not data_description.convert_mask_to_rois): return None assert data_description.training_grid is not None + assert isinstance(mask, WsiAnnotations) # this is necessary for the compute_rois to work tile_size = data_description.training_grid.tile_size tile_overlap = data_description.training_grid.tile_overlap @@ -485,7 +489,7 @@ def get_image_info( raise ValueError(f"Grid (for stage {stage}) is not defined in the data description.") mask, annotations = get_mask_and_annotations_from_record(annotations_root, image) - assert isinstance(mask, WsiAnnotations) or (mask is None) + assert isinstance(mask, WsiAnnotations) or (mask is None) or isinstance(mask, SlideImage) mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold rois = _get_rois(mask, data_description, stage) @@ -495,7 +499,11 @@ def get_image_info( backend = DLUPImageBackend[str(image.reader)] mpp = getattr(grid_description, "mpp", 1.0) overwrite_mpp = float(image.mpp) - tile_mode = TilingMode.overflow + tile_mode = ( + TilingMode(data_description.tiling_mode) + if data_description.tiling_mode is not None + else TilingMode.overflow + ) output_tile_size = getattr(grid_description, "output_tile_size", None) # Update the dictionary with the actual values @@ -535,7 +543,7 @@ def datasets_from_data_description( split_category=stage, ) - for patient in patients: + for patient_idx, patient in enumerate(patients): patient_labels = get_labels_from_record(patient) for image in patient.images: @@ -583,7 +591,7 @@ def datasets_from_data_description( and all(isinstance(i, int) for i in output_tile_size) # pylint: disable=not-an-iterable or (output_tile_size is None) ) - assert isinstance(mask, WsiAnnotations) or (mask is None) + assert isinstance(mask, WsiAnnotations) or (mask is None) or isinstance(mask, SlideImage) assert isinstance(mask_threshold, float) or mask_threshold is None assert isinstance(annotations, WsiAnnotations) or (annotations is None) From 88dff49b50910a4a99cb8454e7bad8a5d1931687 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Tue, 24 Sep 2024 16:39:53 +0200 Subject: [PATCH 29/30] cleaned feature description and manifest --- ahcore/utils/database_models.py | 5 -- ahcore/utils/manifest.py | 141 +++++++++++++++----------------- 2 files changed, 65 insertions(+), 81 deletions(-) diff --git a/ahcore/utils/database_models.py b/ahcore/utils/database_models.py index d0928a6..5e1d29d 100644 --- a/ahcore/utils/database_models.py +++ b/ahcore/utils/database_models.py @@ -149,17 +149,12 @@ class FeatureDescription(Base): tile_size_height = Column(Integer) tile_overlap_width = Column(Integer) tile_overlap_height = Column(Integer) - description = Column(String) # use this to select which features we want to use version = Column(String, unique=True, nullable=False) model_name = Column(String) - model_path = Column(String) feature_dimension = Column(Integer) - image_transforms_description = Column(String) - # it would be nice to have a way to track which transforms the feature extractors used, - # but maybe this is not the best way to do it features: Mapped[List["ImageFeature"]] = relationship("ImageFeature", back_populates="feature_description") diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 9627f80..483ef12 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -169,14 +169,7 @@ def get_labels_from_record(record: Image | Patient) -> list[tuple[str, str]] | N def get_relevant_feature_info_from_record( record: ImageFeature, data_description: DataDescription, feature_description: FeatureDescription -) -> tuple[ - Path, - PositiveFloat, - tuple[PositiveInt, PositiveInt], - tuple[PositiveInt, PositiveInt], - ImageBackend, - PositiveFloat, -]: +) -> ImageInfoDict: """Get the features from a record of type Image. Parameters @@ -200,7 +193,68 @@ def get_relevant_feature_info_from_record( backend = ImageBackend[str(record.reader)].value overwrite_mpp = float(feature_description.mpp) - return image_path, mpp, tile_size, tile_overlap, backend, overwrite_mpp + + output_dict: ImageInfoDict = { + "image_path": image_path, + "tile_size": tile_size, + "tile_overlap": tile_overlap, + "backend": backend, + "mpp": mpp, + "overwrite_mpp": overwrite_mpp, + "tile_mode": TilingMode.skip, + "output_tile_size": None, + "mask": None, + "mask_threshold": None, + "rois": None, + "annotations": None, + } + + return output_dict + + +def get_relevant_image_info_from_record( + image: Image, data_description: DataDescription, annotations_root: Path, stage: str +) -> ImageInfoDict: + if stage == "fit": + grid_description = data_description.training_grid + else: + grid_description = data_description.inference_grid + + if grid_description is None: + raise ValueError(f"Grid (for stage {stage}) is not defined in the data description.") + + mask, annotations = get_mask_and_annotations_from_record(annotations_root, image) + assert isinstance(mask, WsiAnnotations) or (mask is None) or isinstance(mask, SlideImage) + mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold + rois = _get_rois(mask, data_description, stage) + + image_path = data_description.data_dir / image.filename + tile_size = grid_description.tile_size + tile_overlap = grid_description.tile_overlap + backend = DLUPImageBackend[str(image.reader)] + mpp = getattr(grid_description, "mpp", 1.0) + overwrite_mpp = float(image.mpp) + tile_mode = ( + TilingMode(data_description.tiling_mode) if data_description.tiling_mode is not None else TilingMode.overflow + ) + output_tile_size = getattr(grid_description, "output_tile_size", None) + + output_dict: ImageInfoDict = { + "image_path": image_path, + "tile_size": tile_size, + "tile_overlap": tile_overlap, + "backend": backend, + "mpp": mpp, + "overwrite_mpp": overwrite_mpp, + "tile_mode": tile_mode, + "output_tile_size": output_tile_size, + "mask": mask, + "mask_threshold": mask_threshold, + "rois": rois, + "annotations": annotations, + } + + return output_dict def _get_rois(mask: Optional[_AnnotationReturnTypes], data_description: DataDescription, stage: str) -> Optional[Rois]: @@ -450,79 +504,14 @@ def get_image_info( # Directly return the initialized dictionary with None values return image_info - ( - image_path, - mpp, - tile_size, - tile_overlap, - backend, - overwrite_mpp, - ) = get_relevant_feature_info_from_record(image_feature, data_description, feature_description) - # Update the dictionary with the actual values - image_info.update( - { - "image_path": image_path, - "tile_size": tile_size, - "tile_overlap": tile_overlap, - "backend": backend, - "mpp": mpp, - "overwrite_mpp": overwrite_mpp, - "tile_mode": TilingMode.skip, - "output_tile_size": None, - "mask": None, - "mask_threshold": None, - "rois": None, - "annotations": None, - } - ) + image_info.update(get_relevant_feature_info_from_record(image_feature, data_description, feature_description)) return image_info else: - if stage == "fit": - grid_description = data_description.training_grid - else: - grid_description = data_description.inference_grid - - if grid_description is None: - raise ValueError(f"Grid (for stage {stage}) is not defined in the data description.") - - mask, annotations = get_mask_and_annotations_from_record(annotations_root, image) - assert isinstance(mask, WsiAnnotations) or (mask is None) or isinstance(mask, SlideImage) - mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold - rois = _get_rois(mask, data_description, stage) - - image_path = data_description.data_dir / image.filename - tile_size = grid_description.tile_size - tile_overlap = grid_description.tile_overlap - backend = DLUPImageBackend[str(image.reader)] - mpp = getattr(grid_description, "mpp", 1.0) - overwrite_mpp = float(image.mpp) - tile_mode = ( - TilingMode(data_description.tiling_mode) - if data_description.tiling_mode is not None - else TilingMode.overflow - ) - output_tile_size = getattr(grid_description, "output_tile_size", None) - # Update the dictionary with the actual values - image_info.update( - { - "image_path": image_path, - "tile_size": tile_size, - "tile_overlap": tile_overlap, - "backend": backend, - "mpp": mpp, - "overwrite_mpp": overwrite_mpp, - "tile_mode": tile_mode, - "output_tile_size": output_tile_size, - "mask": mask, - "mask_threshold": mask_threshold, - "rois": rois, - "annotations": annotations, - } - ) + image_info.update(get_relevant_image_info_from_record(image, data_description, annotations_root, stage)) return image_info From e43c5aef45b6a3b6fe240debdcae51ffeb30cae1 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Wed, 2 Oct 2024 16:27:00 +0200 Subject: [PATCH 30/30] fixes review comments --- ahcore/cli/tiling.py | 5 ++--- ahcore/readers.py | 9 +++++++++ ahcore/transforms/pre_transforms.py | 30 +++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/ahcore/cli/tiling.py b/ahcore/cli/tiling.py index 22f4b70..8b0debe 100644 --- a/ahcore/cli/tiling.py +++ b/ahcore/cli/tiling.py @@ -33,9 +33,8 @@ from rich.progress import Progress from ahcore.cli import dir_path, file_path -from ahcore.writers import H5FileImageWriter, Writer, ZarrFileImageWriter - from ahcore.utils.types import DataFormat +from ahcore.writers import H5FileImageWriter, Writer, ZarrFileImageWriter _WriterClass = Type[Writer] @@ -365,7 +364,7 @@ def _tiling_pipeline( tile_size=dataset_cfg.tile_size, tile_overlap=dataset_cfg.tile_overlap, num_samples=len(dataset), - data_format=DataFormat.COMPRESSED_IMAGE if compression != "none" else DataFormat.IMAGE, + data_format=DataFormat.IMAGE, color_profile=color_profile, extra_metadata=extra_metadata, grid=dataset.grids[0][0], diff --git a/ahcore/readers.py b/ahcore/readers.py index 564cd6e..a7fbfa3 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -186,6 +186,10 @@ def _decompress_and_reshape_data(self, tile: GenericNumberArray) -> GenericNumbe return tile def _read_image_region(self, size: tuple[int, int], location: tuple[int, int]) -> pyvips.Image: + """ + Reads a region in the stored h5 file. This function allows for multiple stitching modes, + such as cropping, averaging and taking the maximum across borders. + """ assert self._size is not None assert self._tile_size is not None assert self._tile_overlap is not None @@ -293,6 +297,8 @@ def _read_image_region(self, size: tuple[int, int], location: tuple[int, int]) - return pyvips.Image.new_from_array(stitched_image.transpose(1, 2, 0)) def _read_feature_region(self, size: tuple[int, int], location: tuple[int, int]) -> pyvips.Image: + """Reads a region in the stored h5 file. This function reads the feature vectors as saved in the cache file. + Features are assumed to have no overlap and only work with stitching mode CROP.""" assert self._num_samples is not None image_dataset: h5py.Dataset = self._file["data"] @@ -313,6 +319,9 @@ def _read_feature_region(self, size: tuple[int, int], location: tuple[int, int]) f"number of samples in the dataset was {self._num_samples}" ) + if self._tile_overlap != (0, 0): + raise ValueError("Reading features expects that the saved feature vectors have no overlap.") + if x + w > self._num_samples or y + h > 1: raise ValueError( f"Feature vectors are saved as (num_samples, 1) " diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index b752cfb..985bddc 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -95,6 +95,22 @@ def for_segmentation( def for_wsi_classification( cls, data_description: DataDescription, requires_target: bool = True ) -> PreTransformTaskFactory: + """ + Pre-transforms for whole-slide classification tasks. + If the target is required these transforms are applied as follows: + - Features from a 1000 tiles are randomly sampled. + - The labels are selected from the data description. + + Parameters + ---------- + data_description : DataDescription + requires_target : bool + + Returns + ------- + PreTransformTaskFactory + The `PreTransformTaskFactory` initialized for whole-slide classification tasks. + """ transforms: list[PreTransformCallable] = [] transforms.append(SampleNFeatures(n=1000)) @@ -125,6 +141,13 @@ def __repr__(self) -> str: class SampleNFeatures: + """Sample N features from the image. Sampling is done with replacement if there are not enough tiles. + Parameters + ---------- + n : int + Number of features to sample. + """ + def __init__(self, n: int = 1000) -> None: self.n = n logger.warning( @@ -180,6 +203,13 @@ def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: class SelectSpecificLabels: + """Removes labels that are not in the list of keys. + Parameters + ---------- + keys : list[str] | str + List of keys to retain. + """ + def __init__(self, keys: list[str] | str): if isinstance(keys, str): keys = [keys]