diff --git a/FigureGenerator/screenshot_maker.py b/FigureGenerator/screenshot_maker.py index 9bec310..6f8992e 100644 --- a/FigureGenerator/screenshot_maker.py +++ b/FigureGenerator/screenshot_maker.py @@ -140,7 +140,6 @@ def __init__(self, args): self.read_images_and_store_arrays() def read_images_and_store_arrays(self): - input_images = [ rescale_intensity(resample_image(sitk.ReadImage(image))) for image in self.images @@ -262,7 +261,9 @@ def read_images_and_store_arrays(self): self.input_masks_bounded = None # if mask is not defined, pick the middle of the array max_id = ( - np.around(np.true_divide(np.array(self.input_images_bounded)[0].shape, 2)) + np.around( + np.true_divide(np.array(self.input_images_bounded)[0].shape, 2) + ) .astype(int) .tolist() ) @@ -315,7 +316,6 @@ def save_image(self, output_file): for mask_slice in mask_slices: for image_slice in image_slices: for i, _ in enumerate(image_slice): - mask = None if mask_slice[i] is not None: mask = mask_slice[i] @@ -387,3 +387,47 @@ def save_image(self, output_file): plt.tight_layout() plt.savefig(os.path.join(output_file)) + + +def figure_generator( + input_images: str, + ylabels: str, + output: str, + input_mask: str = None, + opacity: float = 0.5, + borderpc: float = 0.05, + axisrow: bool = False, + fontsize: int = 15, + boundtype: str = "image", +) -> None: + """ + This is a functional interface to the class :class:`FigureGenerator`. It takes in the same arguments as the class and generates the figure. + + Args: + input_images (str): The input images separated by comma. The images should be in the same order as the ylabels. + ylabels (str): The ylabels separated by comma. The ylabels should be in the same order as the input images. + output (str): The output file name. + input_mask (str, optional): The input masks separated by comma. The masks should be in the same order as the input images. Defaults to None. + opacity (float, optional): The opacity of the mask. Defaults to 0.5. + borderpc (float, optional): The border percentage of the mask. Defaults to 0.05. + axisrow (bool, optional): Whether to show the axis row. Defaults to False. + fontsize (int, optional): The fontsize of the figure. Defaults to 15. + boundtype (str, optional): The type of bounding. Can be "image" or "mask". Defaults to "image". + """ + assert len(input_images.split(",")) == len( + ylabels.split(",") + ), "Number of images and number of ylabels should be same" + import argparse + + # save the screenshot + args_for_fig_gen = argparse.ArgumentParser() + args_for_fig_gen.images = input_images + args_for_fig_gen.ylabels = ylabels + args_for_fig_gen.opacity = opacity + args_for_fig_gen.borderpc = borderpc + args_for_fig_gen.axisrow = axisrow + args_for_fig_gen.fontsize = fontsize + args_for_fig_gen.boundtype = boundtype + args_for_fig_gen.output = output + fig_generator = FigureGenerator(args_for_fig_gen) + fig_generator.save_image(args_for_fig_gen.output)