diff --git a/src/__main__.py b/src/__main__.py index 124a00a..cc9ab76 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -172,22 +172,6 @@ def main() -> None: des_folder = os.path.join(args.o, "renamed_image") - # DELETE THIS ONCE THE CHANGE WORKS - # prepare_data_folder(des_folder) - # rename_dic, rename_back_dict = rename_and_copy_files(src_folder, des_folder) - - # datalist_file = os.path.join(des_folder, "renaming.json") - # with open(datalist_file, "w", encoding="utf-8") as f: - # json.dump(rename_dic, f, ensure_ascii=False, indent=4) - # print(f"Renaming dic is saved to {datalist_file}") - - # # model_folder = '../nnunet_results/Dataset901_Task901_dlicv/nnUNetTrainer__nnUNetPlans__3d_fullres/' - # model_folder = os.path.join( - # args.m, - # "Dataset%s_Task%s_dlicv/nnUNetTrainer__nnUNetPlans__3d_fullres/" - # % (args.d, args.d), - # ) - # check if -i argument is a folder, list (csv), or a single file (nii.gz) if os.path.isdir(args.i): # if args.i is a directory src_folder = args.i @@ -198,27 +182,18 @@ def main() -> None: json.dump(rename_dic, f, ensure_ascii=False, indent=4) print(f"Renaming dic is saved to {datalist_file}") - else: # if args.i is a file - if args.i.split(".")[-1] == "csv": # if args.i is a .csv list - print("List input (.csv) detected!") - sys.exit() # don't do anything for now - elif ( - args.i.split(".")[-2] == "nii" & args.i.split(".")[-2] == "gz" - ): # if args.i is a .nii.gz file - print("Nifti file (.nii.gz) input detected!") - sys.exit() # don't do anything for now - model_folder = os.path.join( "nnunet_results", "Dataset%s_Task%s_dlicv/nnUNetTrainer__nnUNetPlans__3d_fullres/" % (args.d, args.d), ) + # Check if model exists. If not exist, download using HuggingFace if not os.path.exists(model_folder): # HF download model - print("DLICV model not found, downloading") - from huggingface_hub import snapshot_download + print("DLICV model not found, downloading...") + from huggingface_hub import snapshot_download snapshot_download(repo_id="nichart/DLICV", local_dir=".") print("DLICV model has been successfully downloaded!") else: @@ -277,7 +252,6 @@ def main() -> None: ) # Final prediction - predictor.predict_from_files( des_folder, args.o,