Skip to content

Commit

Permalink
Generate masks for training volume if denoted
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 30, 2023
1 parent 7245e30 commit 0a7e95d
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/autoseg/train_job.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .train import mtlsd_train, aclsd_train, stelarr_train
from .utils import tiff_to_zarr
from .utils import tiff_to_zarr, create_masks


def train_model(
Expand All @@ -8,8 +8,9 @@ def train_model(
warmup: int = 100000,
raw_file: str = "path/to/.zarr/or/.n5/or/.tiff",
rewrite_file: str = "./rewritten.zarr",
rewrite_ds: str = "training_raw",
rewrite_ds: str = "volumes/training_raw",
out_file: str = "./raw_predictions.zarr",
generate_masks: bool = False,
voxel_size: int = 33,
save_every=2500,
) -> None:
Expand All @@ -21,7 +22,9 @@ def train_model(
out_ds=rewrite_ds)
raw_file: str = rewrite_file


if generate_masks:
create_masks(raw_file, "volumes/training_gt_labels")

model_type: str = model_type.lower()
if model_type == "mtlsd":
mtlsd_train(
Expand Down

0 comments on commit 0a7e95d

Please sign in to comment.