Skip to content

Commit

Permalink
change 2Detect modes for more flexiblity
Browse files Browse the repository at this point in the history
  • Loading branch information
AnderBiguri committed Mar 19, 2024
1 parent 3573b52 commit b80d057
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 89 deletions.
221 changes: 133 additions & 88 deletions LION/data_loaders/deteCT.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ def __init__(
else:
self.angle_index = list(range(len(self.geo.angles)))

if parameters.sinogram_mode != parameters.reconstruction_mode:
warnings.warn(
"Sinogram mode and reconstruction mode don't match, so reconstruction is not from the sinogram you are getting... \n This should be an error, but I trust that you won what you are doing"
)
if parameters.query == "" and (
not parameters.flat_field_correction or not parameters.dark_field_correction
):
Expand Down Expand Up @@ -87,38 +83,39 @@ def __init__(
- the mix
- the detector it was sampled with
"""
# Defining the sinogram mode
self.sinogram_mode = parameters.sinogram_mode
"""
The sinogram_mode (str) argument is a keyword defining what sinogram mode of the dataset to use:
| mode1 | mode2 | mode3
Tube Voltage | 90kV | 90kV | 60kV
Tube power | 3W | 90W | 60W
Filter | Thoraeus | Thoraeus | No Filter
"""
# Defining the reconstruction mode
self.reconstruction_mode = parameters.reconstruction_mode
# Defining the input and target mode
self.input_mode = parameters.input_mode
self.target_mode = parameters.target_mode

"""
The reconstruction_mode (str) argument is a keyword defining what image mode of the dataset to use:
The input_mode (str) argument is a keyword defining what input mode of the dataset to use:
| mode1 | mode2 | mode3
Tube Voltage | 90kV | 90kV | 60kV
Tube power | 3W | 90W | 60W
Filter | Thoraeus | Thoraeus | No Filter
"""

# Defining the task
self.task = parameters.task
"""
The task (str) argument is a keyword defining what is the dataset used for:
- task == 'reconstruction' -> the dataset returns the sinogram and the reconstruction
- task == 'segmentation' -> the dataset returns the reconstruction and the segmentation
- task == 'joint' -> the dataset returns the sinogram, the reconstruction and the segmentation
- task == 'sino2sino' -> input and target are both sinograms
- task == 'sino2recon' -> input is a sinogram and target is a reconstruction
- task == 'recon2recon' -> input and target are both reconstructions
- task == 'recon2seg' -> input is a reconstruction and target is a segmentation
- task == 'sino2seg' -> input is a sinogram and target is a segmentation
- task == 'joint' -> input is a sinogram, target is a reconstruction and segmentation
"""

assert self.task in [
"reconstruction",
"segmentation",
"sino2sino",
"sino2recon",
"recon2recon",
"recon2seg",
"sino2seg",
"joint",
], f'Wrong task argument, must be in ["reconstruction", "segmentation", "joint"]'
], f'Wrong task argument, must be in ["sino2sino", "sino2recon", "recon2recon", "recon2seg", "sino2seg", "joint"]'

assert mode in [
"train",
"validation",
Expand Down Expand Up @@ -194,9 +191,9 @@ def __init__(
def default_parameters():
param = LIONParameter()
param.path_to_dataset = DETECT_PROCESSED_DATASET_PATH
param.sinogram_mode = "mode2"
param.reconstruction_mode = "mode2"
param.task = "reconstruction"
param.input_mode = "mode2"
param.target_mode = "mode2"
param.task = "sino2recon"
param.training_proportion = 0.8
param.validation_proportion = 0.1
param.test_proportion = 0.1
Expand Down Expand Up @@ -307,76 +304,124 @@ def __len__(self):
+ 1
)

def __getitem__(self, index):
def __load_and_preprocess_sinogram__(self, index, mode):

slice_row = self.slice_dataframe.iloc[index]
path_to_sinogram = self.path_to_dataset.joinpath(
f"{slice_row['slice_identifier']}/{self.sinogram_mode}"
path_to_input = self.path_to_dataset.joinpath(
f"{slice_row['slice_identifier']}/{mode}"
)
path_to_reconstruction = self.path_to_dataset.joinpath(
f"{slice_row['slice_identifier']}/{self.reconstruction_mode}"
sinogram = torch.from_numpy(
np.load(path_to_input.joinpath("sinogram.npy"))
).unsqueeze(0)
if self.flat_field_correction:
flat = torch.from_numpy(
np.load(path_to_input.joinpath("flat.npy"))
).unsqueeze(0)
else:
flat = 1
if self.dark_field_correction:
dark = torch.from_numpy(
np.load(path_to_input.joinpath("dark.npy"))
).unsqueeze(0)
else:
dark = 0
sinogram = (sinogram - dark) / (flat - dark)
if self.log_transform:
sinogram = -torch.log(sinogram)

sinogram = torch.flip(sinogram, [2])
sinogram = sinogram[:, self.angle_index, :]

# Interpolate if geometry is not default
if self.geo.detector_shape != self.get_default_geometry().detector_shape:
sinogram = torch.nn.functional.interpolate(
sinogram.unsqueeze(0),
size=(sinogram.shape[1], self.geo.detector_shape[1]),
mode="bilinear",
)
sinogram = torch.squeeze(sinogram, 0)

return sinogram

def __load_and_preprocess_reconstruction__(self, index, mode):
slice_row = self.slice_dataframe.iloc[index]
path_to_input = self.path_to_dataset.joinpath(
f"{slice_row['slice_identifier']}/{mode}"
)
path_to_segmentation = self.path_to_dataset.joinpath(
reconstruction = torch.from_numpy(
np.load(path_to_input.joinpath("reconstruction.npy"))
).unsqueeze(0)
# Interpolate if geometry is not default
if self.geo.image_shape != self.get_default_geometry().image_shape:
reconstruction = torch.nn.functional.interpolate(
reconstruction.unsqueeze(0),
size=(self.geo.image_shape[1], self.geo.image_shape[2]),
mode="bilinear",
)
reconstruction = torch.squeeze(reconstruction, 0)
return reconstruction

def __load_and_preprocess_segmentation__(self, index):
slice_row = self.slice_dataframe.iloc[index]
path_to_input = self.path_to_dataset.joinpath(
f"{slice_row['slice_identifier']}/mode2"
)
segmentation = torch.from_numpy(
np.load(path_to_input.joinpath("segmentation.npy"))
).unsqueeze(0)
# Interpolate if geometry is not default
if self.geo.image_shape != self.get_default_geometry().image_shape:
segmentation = torch.nn.functional.interpolate(
segmentation.unsqueeze(0),
size=(self.geo.image_shape[1], self.geo.image_shape[2]),
mode="nearest",
)
segmentation = torch.squeeze(segmentation, 0)
return segmentation

if self.task == ["segmentation", "joint"]:
segmentation = torch.from_numpy(
np.load(path_to_segmentation.joinpath("segmentation.npy"))
).unsqueeze(0)
def __getitem__(self, index):

if self.task in ["reconstruction", "joint"]:
sinogram = torch.from_numpy(
np.load(path_to_sinogram.joinpath("sinogram.npy"))
).unsqueeze(0)
if self.flat_field_correction:
flat = torch.from_numpy(
np.load(path_to_sinogram.joinpath("flat.npy"))
).unsqueeze(0)
else:
flat = 1
if self.dark_field_correction:
dark = torch.from_numpy(
np.load(path_to_sinogram.joinpath("dark.npy"))
).unsqueeze(0)
# If input is sinogram, we need to load the sinogram
if self.task in ["sino2sino", "sino2recon", "sino2seg", "joint"]:
input = self.__load_and_preprocess_sinogram__(index, self.input_mode)
# if input is reconstruction, we need to load the reconstruction
else:
# Even if input is recon, we may want to actually do it ourselves.
if self.do_recon:
sinogram = self.__load_and_preprocess_sinogram__(index, self.input_mode)
op = deteCT.get_operator()
if self.recon_algo == "nag_ls":
input = nag_ls(op, sinogram, 100, min_constraint=0)
elif self.recon_algo == "fdk":
input = fdk(op, sinogram)
else:
dark = 0
sinogram = (sinogram - dark) / (flat - dark)
if self.log_transform:
sinogram = -torch.log(sinogram)

sinogram = torch.flip(sinogram, [2])
sinogram = sinogram[:, self.angle_index, :]

# Interpolate if geometry is not default
if self.geo.detector_shape != self.get_default_geometry().detector_shape:
sinogram = torch.nn.functional.interpolate(
sinogram.unsqueeze(0),
size=(sinogram.shape[1], self.geo.detector_shape[1]),
mode="bilinear",
# Otherwise just load the recon
input = self.__load_and_preprocess_reconstruction__(
index, self.input_mode
)
sinogram = torch.squeeze(sinogram, 0)
if self.do_recon:
op = deteCT.get_operator()
if self.recon_algo == "nag_ls":
reconstruction = nag_ls(op, sinogram, 100, min_constraint=0)
elif self.recon_algo == "fdk":
reconstruction = fdk(op, sinogram)
else:
reconstruction = torch.from_numpy(
np.load(path_to_reconstruction.joinpath("reconstruction.npy"))
).unsqueeze(0)
# Interpolate if geometry is not default
if self.geo.image_shape != self.get_default_geometry().image_shape:
reconstruction = torch.nn.functional.interpolate(
reconstruction.unsqueeze(0),
size=(self.geo.image_shape[1], self.geo.image_shape[2]),
mode="bilinear",

# If target is sinogram, we need to load the sinogram
if self.task in ["sino2sino", "recon2sino", "joint"]:
target = self.__load_and_preprocess_sinogram__(index, self.target_mode)
# if target is reconstruction, we need to load the reconstruction
elif self.task in ["recon2recon", "sino2recon"]:
if self.do_recon:
sinogram = self.__load_and_preprocess_sinogram__(
index, self.target_mode
)
reconstruction = torch.squeeze(reconstruction, 0)

if self.task == "reconstruction":
return sinogram, reconstruction
if self.task == "segmentation":
return reconstruction, segmentation
if self.task == "joint":
return sinogram, reconstruction, segmentation
op = deteCT.get_operator()
if self.recon_algo == "nag_ls":
target = nag_ls(op, sinogram, 100, min_constraint=0)
elif self.recon_algo == "fdk":
target = fdk(op, sinogram)
target = self.__load_and_preprocess_reconstruction__(
index, self.target_mode
)
elif self.task in ["recon2seg", "sino2seg"]:
# Get paths to the dataset
target = self.__load_and_preprocess_segmentation__(index)

if self.task != "joint":
return input, target
else:
return input, target, self.__load_and_preprocess_segmentation__(index)
2 changes: 1 addition & 1 deletion demos/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Demos of LION
WIP
WIP, do not rely on thisfolder

## The following demos are currently available:

Expand Down

0 comments on commit b80d057

Please sign in to comment.