diff --git a/src/autoseg/train_job.py b/src/autoseg/train_job.py index 58941c5..84c3a20 100644 --- a/src/autoseg/train_job.py +++ b/src/autoseg/train_job.py @@ -1,6 +1,6 @@ from more_itertools import raise_ from .train import mtlsd_train, aclsd_train, stelarr_train -from .utils import tiff_to_zarr, create_masks +from .utils import tiff_to_zarr, create_masks, wkw_seg_to_zarr, download_wk_skeleton, rasterize_skeleton def train_model( @@ -11,9 +11,14 @@ def train_model( rewrite_file: str = "./rewritten.zarr", rewrite_ds: str = "volumes/training_raw", out_file: str = "./raw_predictions.zarr", + get_labels: bool = False, + get_rasters: bool = False, generate_masks: bool = False, voxel_size: int = 33, - save_every=2500, + save_every: int =2500, + annotation_id: str = None, + wk_token="YqSgxzFJpP2eyjtqymCTPg", + ) -> None: # TODO: add util funcs for generating masks, pulling paintings @@ -26,6 +31,24 @@ def train_model( except: raise("Could not convert TIFF file to zarr volume") + if get_labels: + try: + wkw_seg_to_zarr(annotation_id=annotation_id, + save_path=".", + zarr_path=raw_file, + wk_token=wk_token, + gt_name="training_labels") + except: + raise("Could not fetch and convert paintings to zarr format") + + if get_rasters: + try: + zip_path: str = download_wk_skeleton(annotation_id=annotation_id, + token=wk_token) + rasterize_skeleton(zip_path=zip_path, raw_file=raw_file) + except: + raise("Could not fetch and convert skeletons to zarr format") + if generate_masks: try: create_masks(raw_file, "volumes/training_gt_labels") diff --git a/src/autoseg/utils.py b/src/autoseg/utils.py index 842c167..a48ff72 100644 --- a/src/autoseg/utils.py +++ b/src/autoseg/utils.py @@ -177,7 +177,7 @@ def download_wk_skeleton( url="http://catmaid2.hms.harvard.edu:9000", annotation_id=None, token = None, - overwrite=None, + overwrite=True, zip_suffix=None, ): # print(f"Downloading {wk_url}/annotations/Explorational/{annotation_ID}...") @@ -441,14 +441,14 @@ def wkw_seg_to_zarr( annotation_id, save_path, zarr_path, - raw_name="volumes/raw", + raw_name="volumes/training_raw", wk_url="http://catmaid2.hms.harvard.edu:9000", wk_token="YqSgxzFJpP2eyjtqymCTPg", gt_name=None, gt_name_prefix="volumes/", overwrite=None, ): - print(f"Downloading {annotation_ID} from {wk_url}...") + print(f"Downloading {annotation_id} from {wk_url}...") with wk.webknossos_context(token=wk_token, url=wk_url): annotation = wk.Annotation.download( annotation_id