Skip to content

Commit

Permalink
Fixes regarding the part_id and the weights download
Browse files Browse the repository at this point in the history
  • Loading branch information
spirosmaggioros committed Dec 12, 2024
1 parent b11854d commit 936085c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
3 changes: 3 additions & 0 deletions DLMUSE/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ def main() -> None:
# Required Arguments
parser.add_argument(
"-i",
"--in_dir",
type=str,
required=True,
help="[REQUIRED] Input folder with LPS oriented T1 sMRI Intra Cranial Volumes (ICV) in Nifti format (nii.gz).",
)
parser.add_argument(
"-o",
"--out_dir",
type=str,
required=True,
help="[REQUIRED] Output folder for the segmentation results in Nifti format (nii.gz).",
Expand Down Expand Up @@ -217,6 +219,7 @@ def main() -> None:
args.device,
args.clear_cache,
args.d,
args.c,
args.part_id,
args.num_parts,
args.step_size,
Expand Down
14 changes: 8 additions & 6 deletions DLMUSE/dlmuse_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def run_pipeline(
out_dir: str,
device: str,
clear_cache: bool = False,
d: str = "901",
d: str = "903",
c: str = "3d_fullres",
part_id: int = 0,
num_parts: int = 1,
step_size: float = 0.5,
Expand Down Expand Up @@ -75,7 +76,7 @@ def run_pipeline(
model_folder = os.path.join(
Path(__file__).parent,
"nnunet_results",
"Dataset%s_Task%s_dlicv/nnUNetTrainer__nnUNetPlans__3d_fullres/" % (d, d),
"Dataset%s_Task%s_DLMUSEV2/nnUNetTrainer__nnUNetPlans__%s/" % (d, d, c),
)

if clear_cache:
Expand All @@ -90,15 +91,16 @@ def run_pipeline(
from huggingface_hub import snapshot_download

local_src = Path(__file__).parent
snapshot_download(repo_id="nichart/DLICV", local_dir=local_src)
print("DLICV model has been successfully downloaded!")
snapshot_download(repo_id="nichart/DLMUSE", local_dir=local_src)
print("DLMUSE model has been successfully downloaded!")
else:
print("Loading the model...")

prepare_data_folder(out_dir)

# Check for invalid arguments - advise users to see nnUNetv2 documentation
assert part_id < num_parts, "See nnUNetv2_predict -h."
assert (
part_id < num_parts
), "part_id < num_parts. Please see nnUNetv2_predict -h."

assert device in [
"cpu",
Expand Down

0 comments on commit 936085c

Please sign in to comment.