Skip to content

Commit

Permalink
series selection and update output paths
Browse files Browse the repository at this point in the history
  • Loading branch information
louisblankemeier committed Sep 26, 2023
1 parent 47448c4 commit c9aaf1d
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 14 deletions.
9 changes: 9 additions & 0 deletions comp2comp/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# import dicom2nifti
import dosma as dm
import SimpleITK as sitk
import pydicom

from comp2comp.inference_class_base import InferenceClass

Expand Down Expand Up @@ -115,6 +116,14 @@ def __call__(self, inference_pipeline):
def dicom_series_to_nifti(input_path, output_file, reorient_nifti):
reader = sitk.ImageSeriesReader()
dicom_names = reader.GetGDCMSeriesFileNames(str(input_path))
print("Reading dicoms...")
ds = pydicom.filereader.dcmread(dicom_names[0])
image_type_list = list(ds.ImageType)
if any("gsi" in s.lower() for s in image_type_list):
raise ValueError("GSI Image Type detected")
reader.SetFileNames(dicom_names)
image = reader.Execute()
if image.GetDirection() != (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0):
raise ValueError("Image orientation is not axial")
print(ds)
sitk.WriteImage(image, output_file)
2 changes: 1 addition & 1 deletion comp2comp/io/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_dicom_or_nifti_paths_and_num(path):
if path.endswith(".nii") or path.endswith(".nii.gz"):
return [(path, 1)]
dicom_nifti_paths = []
for root, _, files in os.walk(path):
for root, dirs, files in os.walk(path):
if len(files) > 0:
if all(file.endswith(".dcm") or file.endswith(".dicom") for file in files):
dicom_nifti_paths.append((root, len(files)))
Expand Down
11 changes: 8 additions & 3 deletions comp2comp/spine/spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,14 @@ def __call__(self, inference_pipeline):
# Compute ROIs
inference_pipeline.spine_model_type = self.spine_model_type

(spine_hus, rois, centroids_3d) = spine_utils.compute_rois(
(spine_hus, rois, segmentation_hus, centroids_3d) = spine_utils.compute_rois(
inference_pipeline.segmentation,
inference_pipeline.medical_volume,
self.spine_model_type,
)

inference_pipeline.spine_hus = spine_hus
inference_pipeline.segmentation_hus = segmentation_hus
inference_pipeline.rois = rois
inference_pipeline.centroids_3d = centroids_3d

Expand All @@ -272,6 +273,7 @@ def __init__(self):
def __call__(self, inference_pipeline):
"""Save metrics to a CSV file."""
self.spine_hus = inference_pipeline.spine_hus
self.seg_hus = inference_pipeline.segmentation_hus
self.output_dir = inference_pipeline.output_dir
self.csv_output_dir = os.path.join(self.output_dir, "metrics")
if not os.path.exists(self.csv_output_dir):
Expand All @@ -281,11 +283,13 @@ def __call__(self, inference_pipeline):

def save_results(self):
"""Save results to a CSV file."""
df = pd.DataFrame(columns=["Level", "ROI HU"])
df = pd.DataFrame(columns=["Level", "ROI HU", "Seg HU"])
for i, level in enumerate(self.spine_hus):
hu = self.spine_hus[level]
row = [level, hu]
seg_hu = self.seg_hus[level]
row = [level, hu, seg_hu]
df.loc[i] = row
df = df.iloc[::-1]
df.to_csv(os.path.join(self.csv_output_dir, "spine_metrics.csv"), index=False)


Expand Down Expand Up @@ -327,6 +331,7 @@ def __call__(self, inference_pipeline):
list(inference_pipeline.centroids_3d.values()),
output_path,
spine_hus=inference_pipeline.spine_hus,
seg_hus=inference_pipeline.segmentation_hus,
model_type=spine_model_type,
pixel_spacing=inference_pipeline.pixel_spacing_list,
format=self.format,
Expand Down
11 changes: 9 additions & 2 deletions comp2comp/spine/spine_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,19 @@ def compute_rois(seg, img, spine_model_type):
rois = {}
spine_hus = {}
centroids_3d = {}
segmentation_hus = {}
for i, level in enumerate(slices):
slice = slices[level]
center_of_mass = compute_center_of_mass(slice)
centroid = np.array([centroids[level], center_of_mass[1], center_of_mass[0]])
roi = roi_from_mask(img, centroid)
spine_hus[level] = mean_img_mask(img.get_fdata(), roi, i)
image_numpy = img.get_fdata()
spine_hus[level] = mean_img_mask(image_numpy, roi, i)
rois[level] = roi
mask = (seg_np == spine_model_type.categories[level]).astype(int)
segmentation_hus[level] = mean_img_mask(image_numpy, mask, i)
centroids_3d[level] = centroid
return (spine_hus, rois, centroids_3d)
return (spine_hus, rois, segmentation_hus, centroids_3d)


def keep_two_largest_connected_components(mask: Dict):
Expand Down Expand Up @@ -358,6 +362,7 @@ def visualize_coronal_sagittal_spine(
centroids_3d: np.ndarray,
output_dir: str,
spine_hus=None,
seg_hus=None,
model_type=None,
pixel_spacing=None,
format="png",
Expand Down Expand Up @@ -453,6 +458,7 @@ def visualize_coronal_sagittal_spine(
output_dir,
sagittal_name,
spine_hus=spine_hus,
seg_hus=seg_hus,
model_type=model_type,
pixel_spacing=pixel_spacing,
)
Expand All @@ -462,6 +468,7 @@ def visualize_coronal_sagittal_spine(
output_dir,
coronal_name,
spine_hus=spine_hus,
seg_hus=seg_hus,
model_type=model_type,
pixel_spacing=pixel_spacing,
)
Expand Down
51 changes: 47 additions & 4 deletions comp2comp/spine/spine_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def spine_binary_segmentation_overlay(
file_name: str,
figure_text_key=None,
spine_hus=None,
seg_hus=None,
spine=True,
model_type=None,
pixel_spacing=None,
Expand Down Expand Up @@ -66,7 +67,7 @@ def spine_binary_segmentation_overlay(
_ROI_COLOR = np.array([1.000, 0.340, 0.200])

_SPINE_TEXT_OFFSET_FROM_TOP = 10.0
_SPINE_TEXT_OFFSET_FROM_RIGHT = 63.0
_SPINE_TEXT_OFFSET_FROM_RIGHT = 40.0
_SPINE_TEXT_VERTICAL_SPACING = 14.0

img_in = np.clip(img_in, -300, 1800)
Expand Down Expand Up @@ -108,18 +109,60 @@ def spine_binary_segmentation_overlay(
area_threshold=0,
)

vis.draw_text(
text="ROI",
position=(
mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 35,
_SPINE_TEXT_OFFSET_FROM_TOP,
),
color=[1, 1, 1],
font_size=9,
horizontal_alignment="center",
)

vis.draw_text(
text="Seg",
position=(
mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT,
_SPINE_TEXT_OFFSET_FROM_TOP,
),
color=[1, 1, 1],
font_size=9,
horizontal_alignment="center",
)

# draw text and lines
for i, level in enumerate(levels):
vis.draw_text(
text=f"{level}: {round(float(spine_hus[level]))}",
text=f"{level}:",
position=(
mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT,
_SPINE_TEXT_VERTICAL_SPACING * i + _SPINE_TEXT_OFFSET_FROM_TOP,
mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 80,
_SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP,
),
color=_COLORS[label_map[level]],
font_size=9,
horizontal_alignment="left",
)
vis.draw_text(
text=f"{round(float(spine_hus[level]))}",
position=(
mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 35,
_SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP,
),
color=_COLORS[label_map[level]],
font_size=9,
horizontal_alignment="center",
)
vis.draw_text(
text=f"{round(float(seg_hus[level]))}",
position=(
mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT,
_SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP,
),
color=_COLORS[label_map[level]],
font_size=9,
horizontal_alignment="center",
)

"""
vis.draw_line(
Expand Down
12 changes: 8 additions & 4 deletions comp2comp/utils/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from time import time
import shutil

from comp2comp.io.io_utils import get_dicom_or_nifti_paths_and_num
from comp2comp.io import io_utils


def process_2d(args, pipeline_builder):
Expand Down Expand Up @@ -49,7 +49,7 @@ def process_3d(args, pipeline_builder):
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_path = os.path.join(output_path, date_time)

for path, num in get_dicom_or_nifti_paths_and_num(args.input_path):
for path, num in io_utils.get_dicom_or_nifti_paths_and_num(args.input_path):
try:
st = time()

Expand Down Expand Up @@ -87,7 +87,10 @@ def process_3d(args, pipeline_builder):
output_dir = Path(
os.path.join(
output_path,
Path(os.path.basename(os.path.normpath(path))),
Path(os.path.basename(os.path.normpath(args.input_path))),
os.path.relpath(
os.path.normpath(path), os.path.normpath(args.input_path)
),
)
)

Expand All @@ -98,7 +101,6 @@ def process_3d(args, pipeline_builder):

pipeline(output_dir=output_dir, model_dir=model_dir)


if not args.save_segmentations:
# remove the segmentations folder
segmentations_dir = os.path.join(output_dir, "segmentations")
Expand All @@ -110,4 +112,6 @@ def process_3d(args, pipeline_builder):
except Exception:
print(f"ERROR PROCESSING {path}\n")
traceback.print_exc()
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
continue

0 comments on commit c9aaf1d

Please sign in to comment.