Skip to content

Commit

Permalink
🎨 Update code structure for dataloader_units
Browse files Browse the repository at this point in the history
  • Loading branch information
shaneahmed committed Aug 16, 2024
1 parent da0ce4f commit 606a2c0
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 9 deletions.
11 changes: 11 additions & 0 deletions tests/engines/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline"
shutil.rmtree(tmp_path / "dump", ignore_errors=True)

predictor.run(
images=[mini_wsi_svs],
units="level",
resolution=0,
patch_mode=False,
save_dir=f"{tmp_path}/dump",
)
assert predictor._ioconfig.input_resolutions[0]["units"] == "level"
assert predictor._ioconfig.input_resolutions[0]["resolution"] == 0
shutil.rmtree(tmp_path / "dump", ignore_errors=True)


def test_patch_predictor_api(
sample_patch1: Path,
Expand Down
31 changes: 23 additions & 8 deletions tiatoolbox/models/engine/engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import copy
import shutil
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, TypedDict
Expand Down Expand Up @@ -639,7 +640,7 @@ def post_process_patches(

def save_predictions(
self: EngineABC,
processed_predictions: dict,
processed_predictions: dict | Path,
output_type: str,
save_dir: Path | None = None,
**kwargs: dict,
Expand Down Expand Up @@ -681,16 +682,23 @@ def save_predictions(
# class_dict set from kwargs
class_dict = kwargs.get("class_dict")

processed_predictions_path: str | Path | None = None

# Need to add support for zarr conversion.
if self.cache_mode:
processed_predictions_path = processed_predictions
processed_predictions = zarr.open(processed_predictions, mode="r")

return dict_to_store(
out_file = dict_to_store(
processed_predictions,
scale_factor,
class_dict,
save_path,
)
if processed_predictions_path is not None:
shutil.rmtree(processed_predictions_path)

return out_file

return (
dict_to_zarr(
Expand Down Expand Up @@ -1057,15 +1065,22 @@ def _run_wsi_mode(
dataloader_units = dataloader.dataset.units
dataloader_resolution = dataloader.dataset.resolution

slide_resolution = (1.0, 1.0)
if dataloader_units != "baseline":
# if dataloader units is baseline slide resolution is 1.0.
# in this case dataloader resolution / slide resolution will be
# equal to dataloader resolution.
scale_factor = dataloader_resolution

if dataloader_units == "mpp":
wsimeta_dict = dataloader.dataset.reader.info.as_dict()
slide_resolution = wsimeta_dict[dataloader_units]
scale_factor = tuple(np.divide(slide_resolution, dataloader_resolution))

scale_factor = tuple(np.divide(slide_resolution, dataloader_resolution))

if dataloader_units != "mpp":
scale_factor = tuple(np.divide(dataloader_resolution, slide_resolution))
if dataloader_units == "level":
wsimeta_dict = dataloader.dataset.reader.info.as_dict()
downsample_ratio = wsimeta_dict["level_downsamples"][
dataloader_resolution
]
scale_factor = (1.0 / downsample_ratio, 1.0 / downsample_ratio)

raw_predictions = self.infer_wsi(
dataloader=dataloader,
Expand Down
2 changes: 1 addition & 1 deletion tiatoolbox/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,7 +1309,7 @@ def dict_to_store(

# if a save director is provided, then dump store into a file
if save_path:
# ensure parent directory exisits
# ensure parent directory exists
save_path.parent.absolute().mkdir(parents=True, exist_ok=True)
# ensure proper db extension
save_path = save_path.parent.absolute() / (save_path.stem + ".db")
Expand Down

0 comments on commit 606a2c0

Please sign in to comment.