Skip to content

Commit

Permalink
minor changes to cli
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Oct 15, 2024
1 parent 51b7b2b commit 87b3f4b
Showing 1 changed file with 29 additions and 21 deletions.
50 changes: 29 additions & 21 deletions spotiflow/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
log.addHandler(console_handler)

ALLOWED_EXTENSIONS = ("tif", "tiff", "png", "jpg", "jpeg")
PRETRAINED_MODELS = list_registered()
PRETRAINED_MODELS = tuple(["general"]+sorted([m for m in list_registered() if m != "general"]))


def get_data(
Expand Down Expand Up @@ -59,8 +59,8 @@ def get_args() -> argparse.Namespace:
help="Path to directory containing images and annotations. Please refer to the documentation (https://weigertlab.github.io/spotiflow/train.html#data-format) to see the required format.",
)
required.add_argument(
"-s",
"--save-dir",
"-o",
"--outdir",
type=Path,
required=True,
help="Output directory where the model will be stored.",
Expand Down Expand Up @@ -182,7 +182,14 @@ def get_args() -> argparse.Namespace:
default=42,
help="Seed for reproducibility. Defaults to 42.",
)

train_args.add_argument(
"--logger",
type=str,
required=False,
choices=["none", "tensorboard", "wandb"],
default="tensorboard",
help="Logger to use for monitoring training. Defaults to 'tensorboard'.",
)
args = parser.parse_args()
return args

Expand All @@ -195,20 +202,18 @@ def main():

log.info("Loading training data...")
train_images, train_spots = get_data(args.data_dir / "train", is_3d=args.is_3d)
assert len(train_images) == len(
train_spots
), "Number of images and spots do not match."
assert (
len(train_images) > 0
), "No images were found in the 'train' subfolder of the given directory."
if len(train_images) != len(train_spots):
raise ValueError(f"Number of images and spots in {args.data_dir/'train'} do not match.")
if len(train_images) == 0:
raise ValueError(f"No images were found in the {args.data_dir/'train'}.")
log.info(f"Training data loaded (N={len(train_images)}).")

log.info("Loading validation data...")
val_images, val_spots = get_data(args.data_dir / "val", is_3d=args.is_3d)
assert len(val_images) == len(val_spots), "Number of images and spots do not match."
assert (
len(val_images) > 0
), "No images were found in the 'val' subfolder of the given directory."
if len(val_images) != len(val_spots):
raise ValueError(f"Number of images and spots in {args.data_dir/'val'} do not match.")
if len(val_images) == 0:
raise ValueError(f"No images were found in the {args.data_dir/'val'}.")
log.info(f"Validation data loaded (N={len(val_images)}).")

if args.finetune_from is None:
Expand Down Expand Up @@ -236,12 +241,14 @@ def main():
verbose=True,
)
else:
assert (
Path(args.finetune_from) != args.save_dir
), "The save directory cannot be the same as the pre-trained model to be finetuned!"
assert (
Path(args.finetune_from).exists()
), f"Given pre-trained model '{args.finetune_from}' does not exist!"
if Path(args.finetune_from) == args.outdir:
err_msg = "The save directory cannot be the same as the pre-trained model to be finetuned!"
raise ValueError(err_msg)
if not Path(args.finetune_from).is_dir():
err_msg = f"Given pre-trained model '{args.finetune_from}' does not exist! Please provide either one of the pre-trained models ({', '.join(PRETRAINED_MODELS)}) or a valid directory containing a model.".strip().replace(
"\n", " "
)
raise ValueError(err_msg)
log.info(f"Finetuning local model '{args.finetune_from}' to be fine-tuned.")
model = Spotiflow.from_folder(
args.finetune_from,
Expand All @@ -256,8 +263,9 @@ def main():
train_spots,
val_images,
val_spots,
save_dir=args.save_dir,
save_dir=args.outdir,
device=args.device,
logger=args.logger,
train_config={
"batch_size": args.batch_size,
"crop_size": args.crop_size,
Expand Down

0 comments on commit 87b3f4b

Please sign in to comment.