From 936085c3de88edfa74d316b05ccc683892647939 Mon Sep 17 00:00:00 2001 From: Spiros Maggioros Date: Thu, 12 Dec 2024 03:10:38 +0200 Subject: [PATCH] Fixes regarding the part_id and the weights download --- DLMUSE/__main__.py | 3 +++ DLMUSE/dlmuse_pipeline.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/DLMUSE/__main__.py b/DLMUSE/__main__.py index 7f8365e..bc0aa2b 100644 --- a/DLMUSE/__main__.py +++ b/DLMUSE/__main__.py @@ -39,12 +39,14 @@ def main() -> None: # Required Arguments parser.add_argument( "-i", + "--in_dir", type=str, required=True, help="[REQUIRED] Input folder with LPS oriented T1 sMRI Intra Cranial Volumes (ICV) in Nifti format (nii.gz).", ) parser.add_argument( "-o", + "--out_dir", type=str, required=True, help="[REQUIRED] Output folder for the segmentation results in Nifti format (nii.gz).", @@ -217,6 +219,7 @@ def main() -> None: args.device, args.clear_cache, args.d, + args.c, args.part_id, args.num_parts, args.step_size, diff --git a/DLMUSE/dlmuse_pipeline.py b/DLMUSE/dlmuse_pipeline.py index d51aa52..a416801 100644 --- a/DLMUSE/dlmuse_pipeline.py +++ b/DLMUSE/dlmuse_pipeline.py @@ -15,7 +15,8 @@ def run_pipeline( out_dir: str, device: str, clear_cache: bool = False, - d: str = "901", + d: str = "903", + c: str = "3d_fullres", part_id: int = 0, num_parts: int = 1, step_size: float = 0.5, @@ -75,7 +76,7 @@ def run_pipeline( model_folder = os.path.join( Path(__file__).parent, "nnunet_results", - "Dataset%s_Task%s_dlicv/nnUNetTrainer__nnUNetPlans__3d_fullres/" % (d, d), + "Dataset%s_Task%s_DLMUSEV2/nnUNetTrainer__nnUNetPlans__%s/" % (d, d, c), ) if clear_cache: @@ -90,15 +91,16 @@ def run_pipeline( from huggingface_hub import snapshot_download local_src = Path(__file__).parent - snapshot_download(repo_id="nichart/DLICV", local_dir=local_src) - print("DLICV model has been successfully downloaded!") + snapshot_download(repo_id="nichart/DLMUSE", local_dir=local_src) + print("DLMUSE model has been successfully downloaded!") else: print("Loading the model...") prepare_data_folder(out_dir) - # Check for invalid arguments - advise users to see nnUNetv2 documentation - assert part_id < num_parts, "See nnUNetv2_predict -h." + assert ( + part_id < num_parts + ), "part_id < num_parts. Please see nnUNetv2_predict -h." assert device in [ "cpu",