Skip to content

Commit

Permalink
Add more transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolphpienaar committed Apr 22, 2024
1 parent 3425b52 commit d823c13
Showing 1 changed file with 87 additions and 21 deletions.
108 changes: 87 additions & 21 deletions spleenseg/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,58 @@
from monai.transforms.croppad.dictionary import CropForegroundd, RandCropByPosNegLabeld
from monai.transforms.spatial.dictionary import Orientationd, RandAffined, Spacingd
from monai.transforms.intensity.dictionary import ScaleIntensityRanged
from typing import Any, Optional, Callable, Hashable, Mapping, Dict, Union
from monai.transforms.io.array import LoadImage
from monai.config.type_definitions import PathLike
from typing import Any, Callable, Hashable, Mapping, Sequence
from monai.data.meta_tensor import MetaTensor
import numpy as np
from numpy import ndarray
from pathlib import Path
from spleenseg.plotting import plotting


def f_LoadImaged() -> Callable[[dict[str, Any]], dict[str, Any]]:
return LoadImaged(keys=["image", "label"])
def f_LoadImaged(
keys: list[str] = ["image", "label"],
) -> Callable[[dict[str, Any]], dict[str, Any]]:
return LoadImaged(keys=keys)


def f_LoadImage() -> Callable[
[PathLike | Sequence[PathLike]],
torch.Tensor
| Any
| MetaTensor
| tuple[
torch.Tensor | Any | MetaTensor,
dict[Any, Any]
| Any
| ndarray[Any, Any]
| tuple[Any, ...]
| list[Any]
| bool
| str
| float
| int
| None,
],
]:
return LoadImage()


def f_SaveImaged(outputDir: Path) -> Callable[[dict[str, Any]], dict[str, Any]]:
return SaveImaged(
keys="pred",
meta_keys="pred_meta_dict",
output_dir=str(outputDir),
output_postfix="seg",
resample=False,
)


def f_EnsureChannelFirstd() -> (
Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]]
):
return EnsureChannelFirstd(keys=["image", "label"])
def f_EnsureChannelFirstd(
keys: list[str] = ["image", "label"],
) -> Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]]:
return EnsureChannelFirstd(keys=keys)


def f_ScaleIntensityRanged() -> (
Expand All @@ -52,25 +90,26 @@ def f_ScaleIntensityRanged() -> (
)


def f_CropForegroundd() -> (
Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]]
):
return CropForegroundd(keys=["image", "label"], source_key="image")
def f_CropForegroundd(
keys: list[str] = ["image", "label"],
) -> Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]]:
return CropForegroundd(keys=keys, source_key="image", allow_smaller=True)


def f_Orientationd() -> (
Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]]
):
return Orientationd(keys=["image", "label"], axcodes="RAS")
def f_Orientationd(
keys: list[str] = ["image", "label"],
) -> Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]]:
return Orientationd(keys=keys, axcodes="RAS")


def f_Spaceingd() -> (
Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]]
):
def f_Spaceingd(
keys: list[str] = ["image", "label"],
mode: tuple[str, ...] = ("bilinear", "nearest"),
) -> Callable[[Mapping[Hashable, torch.Tensor]], Mapping[Hashable, torch.Tensor]]:
return Spacingd(
keys=["image", "label"],
keys=keys,
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
mode=mode,
)


Expand Down Expand Up @@ -162,6 +201,30 @@ def trainingAndValidation_transformsSetup() -> tuple[Compose, Compose]:
return trainingTransforms, validationTransforms


def validation_transformsOnOriginal() -> Compose:
transforms: list = [
f_LoadImaged(),
f_EnsureChannelFirstd(),
f_Orientationd(["image"]),
f_Spaceingd(["image"], tuple(["bilinear"])),
f_ScaleIntensityRanged(),
f_CropForegroundd(["image"]),
]
return transforms_build(transforms)


def inferenceUse_transforms() -> Compose:
transforms: list = [
f_LoadImaged(["image"]),
f_EnsureChannelFirstd(["image"]),
f_Orientationd(["image"]),
f_Spaceingd(["image"], tuple(["bilinear"])),
f_ScaleIntensityRanged(),
f_CropForegroundd(["image"]),
]
return transforms_build(transforms)


def transforms_check(
outputdir: Path, files: list[dict[str, str]], transforms: Compose
) -> bool:
Expand All @@ -171,6 +234,9 @@ def transforms_check(
if not check_data:
return False
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
print("")
print("Checking transforms... :")
print(f"sample image shape: {image.shape}")
print(f"sample label shape: {label.shape}")
plotting.plot_imageAndLabel(image, label, outputdir / "exemplar_image_label.jpg")
return True

0 comments on commit d823c13

Please sign in to comment.