Skip to content

Commit

Permalink
Almost final verification code
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolphpienaar committed Apr 19, 2024
1 parent 81e5e4a commit d839bdf
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions spleenseg/core/neuralnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
# from monai.data.dataset import Dataset
from monai.data.utils import decollate_batch
from monai.data.meta_tensor import MetaTensor
from monai.handlers.utils import from_engine

# from monai.config.deviceconfig import print_config
# from monai.apps.utils import download_and_extract
import torch

Expand All @@ -52,7 +52,7 @@
# import shutil
# import glob
# import pudb
from typing import Any
from typing import Any, Sequence
import numpy as np

from spleenseg.transforms import transforms
Expand Down Expand Up @@ -296,7 +296,7 @@ def train(
)

def inference_metricsProcess(self) -> float:
metric: float = self.network.dice_metric.aggregate().item()
metric: float = self.network.dice_metric.aggregate().item() # type: ignore
self.trainingLog.metric_per_epoch.append(metric)
self.network.dice_metric.reset()
if metric > self.trainingLog.best_metric:
Expand Down Expand Up @@ -408,23 +408,38 @@ def plot_bestModel(
)
return 0.0

def bestModel_runOverValidationSpace(self):
self.network.model.load_state_dict(
torch.load(str(self.trainingParams.modelPth))
)
self.slidingWindowInference_do(self.validationSpace, self.plot_bestModel)

def diceMetric_onValidationSpacing(
self,
sample: dict[str, MetaTensor | torch.Tensor],
space: data.LoaderCache,
index: int,
result: torch.Tensor,
) -> float:
metric: float = -1.0
sample["pred"] = result
sample = [
self.f_outputPost(i)
for i in decollate_batch(sample) # type: ignore[arg-type]
]
self.network.model.load_state_dict(
torch.load(str(self.trainingParams.modelPth))
predictions: torch.Tensor
labels: torch.Tensor
predictions, labels = from_engine(["pred", "label"])(sample)
Dm: torch.Tensor = self.network.dice_metric(
y_pred=predictions, # type: ignore
y=labels, # type: ignore
)
self.slidingWindowInference_do(self.validationSpace, self.plot_bestModel)
return 0.0
print(f"Best prediction dice metric: {Dm}")
if space.loader.batch_size:
if index == len(space.cache) // space.loader.batch_size:
metric = self.network.dice_metric.aggregate().item()
print(f"metric on original image spacing: {metric}")
return metric

def bestModel_evaluateImageSpacings(self, validationTransforms: Compose):
self.network.model.load_state_dict(
Expand All @@ -437,3 +452,6 @@ def bestModel_evaluateImageSpacings(self, validationTransforms: Compose):
transforms.f_labelAsDiscreted(),
]
)
self.slidingWindowInference_do(
self.validationSpace, self.diceMetric_onValidationSpacing
)

0 comments on commit d839bdf

Please sign in to comment.