diff --git a/src/autoseg/train_job.py b/src/autoseg/train_job.py index d90e7a1..58941c5 100644 --- a/src/autoseg/train_job.py +++ b/src/autoseg/train_job.py @@ -1,3 +1,4 @@ +from more_itertools import raise_ from .train import mtlsd_train, aclsd_train, stelarr_train from .utils import tiff_to_zarr, create_masks @@ -17,14 +18,20 @@ def train_model( # TODO: add util funcs for generating masks, pulling paintings if raw_file.endswith(".tiff") or raw_file.endswith(".tif"): - tiff_to_zarr(tiff_file=raw_file, - out_file=rewrite_file, - out_ds=rewrite_ds) - raw_file: str = rewrite_file + try: + tiff_to_zarr(tiff_file=raw_file, + out_file=rewrite_file, + out_ds=rewrite_ds) + raw_file: str = rewrite_file + except: + raise("Could not convert TIFF file to zarr volume") if generate_masks: - create_masks(raw_file, "volumes/training_gt_labels") - + try: + create_masks(raw_file, "volumes/training_gt_labels") + except: + raise("Could not generate masks - check to make sure a painting labels volume exists") + model_type: str = model_type.lower() if model_type == "mtlsd": mtlsd_train(