diff --git a/src/__main__.py b/src/__main__.py index 2ff29c8..c259f85 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -1,6 +1,7 @@ import argparse import json import os +from pathlib import Path import shutil import warnings @@ -182,6 +183,7 @@ def main() -> None: print(f"Renaming dic is saved to {datalist_file}") model_folder = os.path.join( + Path(__file__).parent, "nnunet_results", "Dataset%s_Task%s_dlicv/nnUNetTrainer__nnUNetPlans__3d_fullres/" % (args.d, args.d), @@ -193,8 +195,8 @@ def main() -> None: print("DLICV model not found, downloading...") from huggingface_hub import snapshot_download - - snapshot_download(repo_id="nichart/DLICV", local_dir=".") + local_src = Path(__file__).parent + snapshot_download(repo_id="nichart/DLICV", local_dir=local_src) print("DLICV model has been successfully downloaded!") else: print("Loading the model...")