From 87b3f4b7ca48f324e3100fc879b0b9f0d82d2a1d Mon Sep 17 00:00:00 2001 From: AlbertDominguez Date: Tue, 15 Oct 2024 13:20:21 +0200 Subject: [PATCH] minor changes to cli --- spotiflow/cli/train.py | 50 ++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/spotiflow/cli/train.py b/spotiflow/cli/train.py index f689de0..901a003 100644 --- a/spotiflow/cli/train.py +++ b/spotiflow/cli/train.py @@ -24,7 +24,7 @@ log.addHandler(console_handler) ALLOWED_EXTENSIONS = ("tif", "tiff", "png", "jpg", "jpeg") -PRETRAINED_MODELS = list_registered() +PRETRAINED_MODELS = tuple(["general"]+sorted([m for m in list_registered() if m != "general"])) def get_data( @@ -59,8 +59,8 @@ def get_args() -> argparse.Namespace: help="Path to directory containing images and annotations. Please refer to the documentation (https://weigertlab.github.io/spotiflow/train.html#data-format) to see the required format.", ) required.add_argument( - "-s", - "--save-dir", + "-o", + "--outdir", type=Path, required=True, help="Output directory where the model will be stored.", @@ -182,7 +182,14 @@ def get_args() -> argparse.Namespace: default=42, help="Seed for reproducibility. Defaults to 42.", ) - + train_args.add_argument( + "--logger", + type=str, + required=False, + choices=["none", "tensorboard", "wandb"], + default="tensorboard", + help="Logger to use for monitoring training. Defaults to 'tensorboard'.", + ) args = parser.parse_args() return args @@ -195,20 +202,18 @@ def main(): log.info("Loading training data...") train_images, train_spots = get_data(args.data_dir / "train", is_3d=args.is_3d) - assert len(train_images) == len( - train_spots - ), "Number of images and spots do not match." - assert ( - len(train_images) > 0 - ), "No images were found in the 'train' subfolder of the given directory." + if len(train_images) != len(train_spots): + raise ValueError(f"Number of images and spots in {args.data_dir/'train'} do not match.") + if len(train_images) == 0: + raise ValueError(f"No images were found in the {args.data_dir/'train'}.") log.info(f"Training data loaded (N={len(train_images)}).") log.info("Loading validation data...") val_images, val_spots = get_data(args.data_dir / "val", is_3d=args.is_3d) - assert len(val_images) == len(val_spots), "Number of images and spots do not match." - assert ( - len(val_images) > 0 - ), "No images were found in the 'val' subfolder of the given directory." + if len(val_images) != len(val_spots): + raise ValueError(f"Number of images and spots in {args.data_dir/'val'} do not match.") + if len(val_images) == 0: + raise ValueError(f"No images were found in the {args.data_dir/'val'}.") log.info(f"Validation data loaded (N={len(val_images)}).") if args.finetune_from is None: @@ -236,12 +241,14 @@ def main(): verbose=True, ) else: - assert ( - Path(args.finetune_from) != args.save_dir - ), "The save directory cannot be the same as the pre-trained model to be finetuned!" - assert ( - Path(args.finetune_from).exists() - ), f"Given pre-trained model '{args.finetune_from}' does not exist!" + if Path(args.finetune_from) == args.outdir: + err_msg = "The save directory cannot be the same as the pre-trained model to be finetuned!" + raise ValueError(err_msg) + if not Path(args.finetune_from).is_dir(): + err_msg = f"Given pre-trained model '{args.finetune_from}' does not exist! Please provide either one of the pre-trained models ({', '.join(PRETRAINED_MODELS)}) or a valid directory containing a model.".strip().replace( + "\n", " " + ) + raise ValueError(err_msg) log.info(f"Finetuning local model '{args.finetune_from}' to be fine-tuned.") model = Spotiflow.from_folder( args.finetune_from, @@ -256,8 +263,9 @@ def main(): train_spots, val_images, val_spots, - save_dir=args.save_dir, + save_dir=args.outdir, device=args.device, + logger=args.logger, train_config={ "batch_size": args.batch_size, "crop_size": args.crop_size,