diff --git a/src/autoseg/train_job.py b/src/autoseg/train_job.py index 33baafe..d90e7a1 100644 --- a/src/autoseg/train_job.py +++ b/src/autoseg/train_job.py @@ -1,5 +1,5 @@ from .train import mtlsd_train, aclsd_train, stelarr_train -from .utils import tiff_to_zarr +from .utils import tiff_to_zarr, create_masks def train_model( @@ -8,8 +8,9 @@ def train_model( warmup: int = 100000, raw_file: str = "path/to/.zarr/or/.n5/or/.tiff", rewrite_file: str = "./rewritten.zarr", - rewrite_ds: str = "training_raw", + rewrite_ds: str = "volumes/training_raw", out_file: str = "./raw_predictions.zarr", + generate_masks: bool = False, voxel_size: int = 33, save_every=2500, ) -> None: @@ -21,7 +22,9 @@ def train_model( out_ds=rewrite_ds) raw_file: str = rewrite_file - + if generate_masks: + create_masks(raw_file, "volumes/training_gt_labels") + model_type: str = model_type.lower() if model_type == "mtlsd": mtlsd_train(