-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixing Tests, Fixing MPP issue, Adding tolerance to in memory dataset
- Loading branch information
1 parent
115b324
commit 9326a09
Showing
9 changed files
with
117 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,3 +20,4 @@ pathopatch.egg-info/ | |
dist | ||
build | ||
push_build.yml | ||
debug |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,11 @@ | |
# @ Fabian Hörst, [email protected] | ||
# Institute for Artifical Intelligence in Medicine, | ||
# University Medicine Essen | ||
import sys | ||
|
||
from click import Option | ||
|
||
sys.path.append("/Users/fhoerst/Fabian-Projekte/Preprocessing/PathoPatcher") | ||
|
||
import logging | ||
import os | ||
|
@@ -22,8 +26,7 @@ | |
from shapely.geometry import Polygon | ||
from torch.utils.data import Dataset | ||
from torchvision.transforms.v2 import ToTensor | ||
|
||
from pathopatch import logger | ||
from PIL import Image | ||
from pathopatch.utils.exceptions import WrongParameterException | ||
from pathopatch.utils.patch_util import ( | ||
DeepZoomGeneratorOS, | ||
|
@@ -43,11 +46,11 @@ | |
warnings.filterwarnings("ignore", category=UserWarning) | ||
|
||
|
||
class PreProcessingDatasetConfig(BaseModel): | ||
class LivePatchWSIConfig(BaseModel): | ||
"""Storing the configuration for the PatchWSIDataset | ||
Args: | ||
wsipath (str): Path to the WSI | ||
wsi_path (str): Path to the WSI | ||
wsi_properties (dict, optional): Dictionary with manual WSI metadata, but just applies if metadata cannot be derived from OpenSlide (e.g., for .tiff files). Supported keys are slide_mpp and magnification | ||
patch_size (int, optional): The size of the patches in pixel that will be retrieved from the WSI, e.g. 256 for 256px. Defaults to 256. | ||
patch_overlap (float, optional): The percentage amount pixels that should overlap between two different patches. | ||
|
@@ -63,6 +66,7 @@ class PreProcessingDatasetConfig(BaseModel): | |
expresses which kind of downsampling should be used with | ||
respect to the highest possible resolution. Defaults to 0. | ||
level (int, optional): The tile level for sampling, alternative to downsample. Defaults to None. | ||
target_mpp_tolerance(float, optional): Tolerance for the target_mpp. If wsi mpp is within a range target_mpp +/- tolarance, no rescaling is performed. Defaults to 0.0. | ||
annotation_path (str, optional): Path to the .json file with the annotations. Defaults to None. | ||
label_map_file (str, optional): Path to the .json file with the label map. Defaults to None. | ||
label_map (dict, optional): Dictionary with the label map. Defaults to None. | ||
|
@@ -96,6 +100,7 @@ class PreProcessingDatasetConfig(BaseModel): | |
target_mpp: Optional[float] | ||
target_mag: Optional[float] | ||
level: Optional[int] | ||
target_mpp_tolerance: Optional[float] = 0.0 | ||
|
||
# annotation specific settings | ||
annotation_path: Optional[str] | ||
|
@@ -174,7 +179,7 @@ def __post_init_post_parse__(self) -> None: | |
class LivePatchWSIDataset(Dataset): | ||
def __init__( | ||
self, | ||
slide_processor_config: PreProcessingDatasetConfig, | ||
slide_processor_config: LivePatchWSIConfig, | ||
logger: logging.Logger = None, | ||
transforms: Callable = ToTensor(), | ||
) -> None: | ||
|
@@ -184,7 +189,7 @@ def __init__( | |
functionality for loading and processing WSIs. | ||
Args: | ||
slide_processor_config (PreProcessingDatasetConfig): Configuration for preprocessing the dataset. | ||
slide_processor_config (LivePatchWSIConfig): Configuration for preprocessing the dataset. | ||
logger (logging.Logger, optional): Logger for logging events. Defaults to None. | ||
transforms (Callable, optional): Transforms to apply to the patches. Defaults to ToTensor(). | ||
|
@@ -195,7 +200,7 @@ def __init__( | |
wsi_metadata (dict): Metadata of the WSI | ||
deepzoomgenerator (Union[DeepZoomGeneratorOS, Any]): Class for tile extraction, deepzoom-interface | ||
tile_extractor (Union[DeepZoomGeneratorOS, Any]): Instance of self.deepzoomgenerator | ||
config (PreProcessingDatasetConfig): Configuration for preprocessing the dataset | ||
config (LivePatchWSIConfig): Configuration for preprocessing the dataset | ||
logger (logging.Logger): Logger for logging events | ||
rescaling_factor (int): Rescaling factor for the slide | ||
interesting_coords (List[Tuple[int, int, float]]): List of interesting coordinates (patches -> row, col, ratio) | ||
|
@@ -296,9 +301,6 @@ def _set_tissue_detector(self) -> None: | |
Raises: | ||
ImportError: If torch or torchvision cannot be imported. | ||
Returns: | ||
None | ||
""" | ||
try: | ||
import torch.nn as nn | ||
|
@@ -318,7 +320,7 @@ def _set_tissue_detector(self) -> None: | |
"cuda:0" if torch.cuda.is_available() else "cpu" | ||
) | ||
if self.detector_device == "cpu": | ||
logger.warning( | ||
self.logger.warning( | ||
"No CUDA device detected - Speed may be very slow. Please consider performing extraction on CUDA device or disable tissue detector!" | ||
) | ||
model = mobilenet_v3_small().to(device=self.detector_device) | ||
|
@@ -383,7 +385,7 @@ def _prepare_slide( | |
# Extract the float value | ||
if match: | ||
slide_mpp = float(match.group(1)) | ||
logger.warning( | ||
self.logger.warning( | ||
f"MPP {slide_mpp:.4f} was extracted from the comment of the WSI (Tiff-Metadata comment string) - Please check for correctness!" | ||
) | ||
else: | ||
|
@@ -414,15 +416,25 @@ def _prepare_slide( | |
resulting_mpp = None | ||
|
||
if self.config.target_mpp is not None: | ||
self.config.downsample, self.rescaling_factor = target_mpp_to_downsample( | ||
slide_properties["mpp"], | ||
self.config.target_mpp, | ||
) | ||
if ( | ||
not slide_properties["mpp"] - self.config.target_mpp_tolerance | ||
<= self.config.target_mpp | ||
<= slide_properties["mpp"] + self.config.target_mpp_tolerance | ||
): | ||
( | ||
self.config.downsample, | ||
self.rescaling_factor, | ||
) = target_mpp_to_downsample( | ||
slide_properties["mpp"], | ||
self.config.target_mpp, | ||
) | ||
else: | ||
self.config.downsample = 1 | ||
self.rescaling_factor = 1.0 | ||
if self.rescaling_factor != 1.0: | ||
resulting_mpp = ( | ||
slide_properties["mpp"] | ||
* self.rescaling_factor | ||
/ 2 | ||
* self.config.downsample | ||
) | ||
else: | ||
|
@@ -519,7 +531,7 @@ def _prepare_slide( | |
) | ||
self.logger.debug(f"Number of patches sampled: {len(interesting_coords)}") | ||
if len(interesting_coords) == 0: | ||
logger.warning(f"No patches sampled from {self.config.wsi_path}") | ||
self.logger.warning(f"No patches sampled from {self.config.wsi_path}") | ||
|
||
self.wsi_metadata = { | ||
"orig_n_tiles_cols": n_cols, | ||
|
@@ -539,7 +551,7 @@ def _prepare_slide( | |
|
||
return list(interesting_coords), level, polygons_downsampled, region_labels | ||
|
||
def _get_wsi_annotations(self, downsample: int): | ||
def _get_wsi_annotations(self, downsample: int): # TODO: docstring | ||
region_labels: List[str] = [] | ||
polygons: List[Polygon] = [] | ||
polygons_downsampled: List[Polygon] = [] | ||
|
@@ -610,7 +622,7 @@ def __getitem__(self, index: int) -> Tuple[np.ndarray, dict, np.ndarray]: | |
ratio = {} | ||
patch_mask = np.zeros( | ||
(self.res_tile_size, self.res_tile_size), dtype=np.uint8 | ||
) # TODO: | ||
) | ||
else: | ||
intersected_labels, ratio, patch_mask = get_intersected_labels( | ||
tile_size=self.res_tile_size, | ||
|
@@ -634,6 +646,19 @@ def __getitem__(self, index: int) -> Tuple[np.ndarray, dict, np.ndarray]: | |
normalization_vector_path=self.config.normalization_vector_json, | ||
) | ||
|
||
if self.res_tile_size != self.config.patch_size: | ||
image_tile = Image.fromarray(image_tile) | ||
if self.res_tile_size > self.config.patch_size: | ||
image_tile.thumbnail( | ||
(self.config.patch_size, self.config.patch_size), | ||
getattr(Image, "Resampling", Image).LANCZOS, | ||
) | ||
else: | ||
image_tile = image_tile.resize( | ||
(self.config.patch_size, self.config.patch_size), | ||
getattr(Image, "Resampling", Image).LANCZOS, | ||
) | ||
image_tile = np.array(image_tile) | ||
try: | ||
image_tile = self.transforms(image_tile) | ||
except TypeError: | ||
|
@@ -651,24 +676,24 @@ def __getitem__(self, index: int) -> Tuple[np.ndarray, dict, np.ndarray]: | |
return image_tile, patch_metadata, patch_mask | ||
|
||
|
||
class PatchWSIDataloader: | ||
"""Dataloader for PatchWSIDataset | ||
class LivePatchWSIDataloader: | ||
"""Dataloader for LivePatchWSIDataset | ||
Args: | ||
dataset (PatchWSIDataset): Dataset to load patches from. | ||
dataset (LivePatchWSIDataset): Dataset to load patches from. | ||
batch_size (int): Batch size for the dataloader. | ||
shuffle (bool, optional): To shuffle iterations. Defaults to False. | ||
seed (int, optional): Seed for shuffle. Defaults to 42. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset: PatchWSIDataset, | ||
dataset: LivePatchWSIDataset, | ||
batch_size: int, | ||
shuffle: bool = False, | ||
seed: int = 42, | ||
) -> None: | ||
assert isinstance(dataset, PatchWSIDataset) | ||
assert isinstance(dataset, LivePatchWSIDataset) | ||
assert isinstance(batch_size, int) | ||
assert isinstance(shuffle, bool) | ||
assert isinstance(seed, int) | ||
|
@@ -682,6 +707,7 @@ def __init__( | |
if self.shuffle: | ||
grtr = np.random.default_rng(seed) | ||
self.element_list = grtr.permutation(self.element_list) | ||
self.i = 0 | ||
|
||
def __iter__(self): | ||
self.i = 0 | ||
|
@@ -732,4 +758,21 @@ def __next__(self) -> Tuple[torch.Tensor, List[dict], List[np.ndarray]]: | |
raise StopIteration | ||
|
||
def __len__(self): | ||
return int(np.ceil(len(self.dataset) / self.batch_size)) | ||
return int(np.ceil((len(self.dataset) - self.discard_count) / self.batch_size)) | ||
|
||
|
||
if __name__ == "__main__": | ||
"""Just for testing purposes""" | ||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.DEBUG) | ||
logger.info("Test") | ||
config = LivePatchWSIConfig( | ||
wsi_path="/Users/fhoerst/Fabian-Projekte/Selocan/RicardoScans/266819.svs", | ||
patch_size=256, | ||
patch_overlap=0, | ||
target_mpp=0.3, | ||
target_mpp_tolerance=0.1, | ||
) | ||
ps_dataset = LivePatchWSIDataset(config, logger) | ||
ps_dataloader = LivePatchWSIDataloader(ps_dataset, batch_size=8) | ||
ps_dataloader.__next__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file modified
BIN
+18.8 KB
(130%)
...arget_mpp_macenko/results/CMU-1-Small-Region/patches/CMU-1-Small-Region_1_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters