diff --git a/scripts/aggregate_data_information.py b/scripts/aggregate_data_information.py index 7086b23..6ce446a 100644 --- a/scripts/aggregate_data_information.py +++ b/scripts/aggregate_data_information.py @@ -200,7 +200,7 @@ def active_zone_train_data(): "01": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/01_hoi_maus_2020_incomplete", # noqa "04": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/04_hoi_stem_examples", # noqa "06": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/06_hoi_wt_stem750_fm", # noqa - "12": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/2D_data/20241021_imig_2014_data_transfer_exported_grouped", # noqa + "12": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/12_chemical_fix_cryopreparation", # noqa } test_tomograms = { @@ -467,11 +467,58 @@ def get_image_sizes_tem_2d(): print(f["raw"].shape) +def mito_train_data(): + train_root = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/fidi_down_s2" + test_tomograms = [ + "36859_J1_66K_TS_CA3_MF_18_rec_2Kb1dawbp_crop_downscaled.h5", + "3.2_downscaled.h5", + ] + all_tomos = sorted(glob(os.path.join(train_root, "*.h5"))) + + tomo_names = [] + tomo_condition = [] + tomo_mitos = [] + tomo_resolution = [] + tomo_train = [] + + for tomo in all_tomos: + fname = os.path.basename(tomo) + split = "test" if fname in test_tomograms else "train/val" + if "36859" in fname or "37371" in fname: # This is from the STEM dataset. + condition = stem + resolution = 2 * 0.868 + else: # This is from the TEM Single-Axis Dataset + condition = single_ax_tem + # These were scaled, despite the resolution mismatch + resolution = 2 * 1.554 + + with h5py.File(tomo, "r") as f: + seg = f["labels/mitochondria"][:] + n_mitos = len(np.unique(seg)) - 1 + + tomo_names.append(tomo) + tomo_condition.append(condition) + tomo_train.append(split) + tomo_resolution.append(resolution) + tomo_mitos.append(n_mitos) + + df = pd.DataFrame({ + "tomogram": tomo_names, + "condition": tomo_condition, + "resolution": tomo_resolution, + "used_for": tomo_train, + "mito_count_all": tomo_mitos, + }) + + os.makedirs("data_summary", exist_ok=True) + df.to_excel("./data_summary/mitochondria.xlsx", index=False) + + def main(): # active_zone_train_data() # compartment_train_data() - # mito_train_data() - vesicle_train_data() + mito_train_data() + # vesicle_train_data() # vesicle_domain_adaptation_data() # get_n_images_frog() diff --git a/scripts/cooper/full_reconstruction/segment_mitochondria.py b/scripts/cooper/full_reconstruction/segment_mitochondria.py index 395de78..cb82275 100644 --- a/scripts/cooper/full_reconstruction/segment_mitochondria.py +++ b/scripts/cooper/full_reconstruction/segment_mitochondria.py @@ -8,23 +8,53 @@ ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/04_full_reconstruction" # noqa MODEL_PATH = "/scratch-grete/projects/nim00007/models/exports_for_cooper/mito_model_s2.pt" # noqa +# MODEL_PATH = "/scratch-grete/projects/nim00007/models/luca/mito/source_domain" + def run_seg(path): + + out_folder = "./mito_seg" + ds, fname = os.path.split(path) + ds = os.path.basename(ds) + + os.makedirs(os.path.join(out_folder, ds), exist_ok=True) + out_path = os.path.join(out_folder, ds, fname) + if os.path.exists(out_path): + return + with h5py.File(path, "r") as f: - if "labels/mitochondria" in f: - return raw = f["raw"][:] scale = (0.5, 0.5, 0.5) seg = segment_mitochondria(raw, model_path=MODEL_PATH, scale=scale, verbose=False) - with h5py.File(path, "a") as f: + with h5py.File(out_path, "a") as f: + f.create_dataset("labels/mitochondria", data=seg, compression="gzip") + + +def run_seg_and_pred(path): + with h5py.File(path, "r") as f: + raw = f["raw"][:] + + scale = (0.5, 0.5, 0.5) + seg, pred = segment_mitochondria( + raw, model_path=MODEL_PATH, scale=scale, verbose=False, return_predictions=True + ) + + out_folder = "./mito_pred" + os.makedirs(out_folder, exist_ok=True) + out_path = os.path.join(out_folder, os.path.basename(path)) + + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw[::2, ::2, ::2]) f.create_dataset("labels/mitochondria", data=seg, compression="gzip") + f.create_dataset("pred", data=pred, compression="gzip") def main(): paths = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True)) for path in tqdm(paths): run_seg(path) + # run_seg_and_pred(path) main() diff --git a/scripts/data_summary/vesicle_training_data.xlsx b/scripts/data_summary/vesicle_training_data.xlsx index 0f9ee1e..57fb145 100644 Binary files a/scripts/data_summary/vesicle_training_data.xlsx and b/scripts/data_summary/vesicle_training_data.xlsx differ diff --git a/scripts/prepare_zenodo_uploads.py b/scripts/prepare_zenodo_uploads.py new file mode 100644 index 0000000..b642c07 --- /dev/null +++ b/scripts/prepare_zenodo_uploads.py @@ -0,0 +1,246 @@ +import os +from glob import glob +from shutil import copyfile + +import h5py +from tqdm import tqdm + +OUTPUT_ROOT = "./data_summary/for_zenodo" + + +def _copy_vesicles(tomos, out_folder): + label_key = "labels/vesicles/combined_vesicles" + os.makedirs(out_folder, exist_ok=True) + for tomo in tqdm(tomos, desc="Export tomos"): + out_path = os.path.join(out_folder, os.path.basename(tomo)) + if os.path.exists(out_path): + continue + + with h5py.File(tomo, "r") as f: + raw = f["raw"][:] + labels = f[label_key][:] + try: + fname = f.attrs["filename"] + except KeyError: + fname = None + + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw, compression="gzip") + f.create_dataset("labels/vesicles", data=labels, compression="gzip") + if fname is not None: + f.attrs["filename"] = fname + + +def _export_vesicles(train_root, test_root, name): + train_tomograms = sorted(glob(os.path.join(train_root, "*.h5"))) + test_tomograms = sorted(glob(os.path.join(test_root, "*.h5"))) + print(f"Vesicle data for {name}:") + print(len(train_tomograms), len(test_tomograms), len(train_tomograms) + len(test_tomograms)) + + train_out = os.path.join(OUTPUT_ROOT, "synapse-net", "vesicles", "train", name) + _copy_vesicles(train_tomograms, train_out) + + test_out = os.path.join(OUTPUT_ROOT, "synapse-net", "vesicles", "test", name) + _copy_vesicles(test_tomograms, test_out) + + +def _export_az(train_root, test_tomos, name): + tomograms = sorted(glob(os.path.join(train_root, "*.h5"))) + print(f"AZ data for {name}:") + + train_out = os.path.join(OUTPUT_ROOT, "synapse-net", "active_zones", "train", name) + test_out = os.path.join(OUTPUT_ROOT, "synapse-net", "active_zones", "test", name) + + os.makedirs(train_out, exist_ok=True) + os.makedirs(test_out, exist_ok=True) + + for tomo in tqdm(tomograms): + fname = os.path.basename(tomo) + if tomo in test_tomos: + out_path = os.path.join(test_out, fname) + else: + out_path = os.path.join(train_out, fname) + if os.path.exists(out_path): + continue + + with h5py.File(tomo, "r") as f: + raw = f["raw"][:] + az = f["labels/AZ"][:] + + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw, compression="gzip") + f.create_dataset("labels/AZ", data=az, compression="gzip") + + +# NOTE: we have very few mito annotations from 01, so we don't include them in here. +def prepare_single_ax_stem_chemical_fix(): + # single-axis-tem: vesicles + train_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/01_hoi_maus_2020_incomplete" # noqa + test_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/01_hoi_maus_2020_incomplete" # noqa + _export_vesicles(train_root, test_root, name="single_axis_tem") + + # single-axis-tem: active zones + train_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/01_hoi_maus_2020_incomplete" # noqa + test_tomos = [ + "WT_MF_DIV28_01_MS_09204_F1.h5", "WT_MF_DIV14_01_MS_B2_09175_CA3.h5", "M13_CTRL_22723_O2_05_DIV29_5.2.h5", "WT_Unt_SC_09175_D4_05_DIV14_mtk_05.h5", # noqa + "20190805_09002_B4_SC_11_SP.h5", "20190807_23032_D4_SC_01_SP.h5", "M13_DKO_22723_A1_03_DIV29_03_MS.h5", "WT_MF_DIV28_05_MS_09204_F1.h5", "M13_CTRL_09201_S2_06_DIV31_06_MS.h5", # noqa + "WT_MF_DIV28_1.2_MS_09002_B1.h5", "WT_Unt_SC_09175_C4_04_DIV15_mtk_04.h5", "M13_DKO_22723_A4_10_DIV29_10_MS.h5", "WT_MF_DIV14_3.2_MS_D2_09175_CA3.h5", # noqa + "20190805_09002_B4_SC_10_SP.h5", "M13_CTRL_09201_S2_02_DIV31_02_MS.h5", "WT_MF_DIV14_04_MS_E1_09175_CA3.h5", "WT_MF_DIV28_10_MS_09002_B3.h5", "WT_Unt_SC_05646_D4_02_DIV16_mtk_02.h5", "M13_DKO_22723_A4_08_DIV29_08_MS.h5", "WT_MF_DIV28_04_MS_09204_M1.h5", "WT_MF_DIV28_03_MS_09204_F1.h5", "M13_DKO_22723_A1_05_DIV29_05_MS.h5", # noqa + "WT_Unt_SC_09175_C4_06_DIV15_mtk_06.h5", "WT_MF_DIV28_09_MS_09002_B3.h5", "20190524_09204_F4_SC_07_SP.h5", + "WT_MF_DIV14_02_MS_C2_09175_CA3.h5", "M13_DKO_23037_K1_01_DIV29_01_MS.h5", "WT_Unt_SC_09175_E2_01_DIV14_mtk_01.h5", "20190807_23032_D4_SC_05_SP.h5", "WT_MF_DIV14_01_MS_E2_09175_CA3.h5", "WT_MF_DIV14_03_MS_B2_09175_CA3.h5", "M13_DKO_09201_O1_01_DIV31_01_MS.h5", "M13_DKO_09201_U1_04_DIV31_04_MS.h5", # noqa + "WT_MF_DIV14_04_MS_E2_09175_CA3_2.h5", "WT_Unt_SC_09175_D5_01_DIV14_mtk_01.h5", + "M13_CTRL_22723_O2_05_DIV29_05_MS_.h5", "WT_MF_DIV14_02_MS_B2_09175_CA3.h5", "WT_MF_DIV14_01.2_MS_D1_09175_CA3.h5", # noqa + ] + _export_az(train_root, test_tomos, name="single_axis_tem") + + # chemical_fixation: vesicles + train_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/12_chemical_fix_cryopreparation" # noqa + test_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/12_chemical_fix_cryopreparation" # noqa + _export_vesicles(train_root, test_root, name="chemical_fixation") + + # chemical-fixation: active zones + train_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/12_chemical_fix_cryopreparation" # noqa + test_tomos = ["20180305_09_MS.h5", "20180305_04_MS.h5", "20180305_08_MS.h5", + "20171113_04_MS.h5", "20171006_05_MS.h5", "20180305_01_MS.h5"] + _export_az(train_root, test_tomos, name="chemical_fixation") + + +def prepare_ier(): + root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser/other_tomograms" + sets = { + "01_vesicle_pools": "vesicle_pools", + "02_tether": "tether", + "03_ratten_tomos": "rat", + } + + output_folder = os.path.join(OUTPUT_ROOT, "IER") + label_names = { + "ribbons": "ribbon", + "membrane": "membrane", + "presynapse": "PD", + "postsynapse": "PSD", + "vesicles": "vesicles", + } + + for name, output_name in sets.items(): + out_set = os.path.join(output_folder, output_name) + os.makedirs(out_set, exist_ok=True) + tomos = sorted(glob(os.path.join(root, name, "*.h5"))) + + print("Export", output_name) + for tomo in tqdm(tomos): + with h5py.File(tomo, "r") as f: + try: + fname = os.path.split(f.attrs["filename"])[1][:-4] + except KeyError: + fname = f.attrs["path"][1] + fname = "_".join(fname.split("/")[-2:]) + + out_path = os.path.join(out_set, os.path.basename(tomo)) + if os.path.exists(out_path): + continue + + raw = f["raw"][:] + labels = {} + for label_name, out_name in label_names.items(): + key = f"labels/{label_name}" + if key not in f: + continue + labels[out_name] = f[key][:] + + with h5py.File(out_path, "a") as f: + f.attrs["filename"] = fname + f.create_dataset("raw", data=raw, compression="gzip") + for label_name, seg in labels.items(): + f.create_dataset(f"labels/{label_name}", data=seg, compression="gzip") + + +def prepare_frog(): + root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/rizzoli/extracted" + train_tomograms = [ + "block10U3A_three.h5", "block30UB_one_two.h5", "block30UB_two.h5", "block10U3A_one.h5", + "block184B_one.h5", "block30UB_three.h5", "block10U3A_two.h5", "block30UB_four.h5", + "block30UB_one.h5", "block10U3A_five.h5", + ] + test_tomograms = ["block10U3A_four.h5", "block30UB_five.h5"] + + output_folder = os.path.join(OUTPUT_ROOT, "frog") + output_train = os.path.join(output_folder, "train_unlabeled") + os.makedirs(output_train, exist_ok=True) + + for name in train_tomograms: + path = os.path.join(root, name) + out_path = os.path.join(output_train, name) + if os.path.exists(out_path): + continue + copyfile(path, out_path) + + output_test = os.path.join(output_folder, "test") + os.makedirs(output_test, exist_ok=True) + for name in test_tomograms: + path = os.path.join(root, name) + out_path = os.path.join(output_test, name) + if os.path.exists(out_path): + continue + copyfile(path, out_path) + + +def prepare_2d_tem(): + train_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/2D_data/maus_2020_tem2d_wt_unt_div14_exported_scaled/good_for_DAtraining/maus_2020_tem2d_wt_unt_div14_exported_scaled" # noqa + test_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicle_gt_2d/maus_2020_tem2d" # noqa + train_images = [ + "MF_05649_P-09175-E_06.h5", "MF_05646_C-09175-B_001B.h5", "MF_05649_P-09175-E_07.h5", + "MF_05649_G-09175-C_001.h5", "MF_05646_C-09175-B_002.h5", "MF_05649_G-09175-C_04.h5", + "MF_05649_P-09175-E_05.h5", "MF_05646_C-09175-B_000.h5", "MF_05646_C-09175-B_001.h5" + ] + test_images = [ + "MF_05649_G-09175-C_04B.h5", "MF_05646_C-09175-B_000B.h5", + "MF_05649_G-09175-C_03.h5", "MF_05649_G-09175-C_02.h5" + ] + print(len(train_images) + len(test_images)) + + output_folder = os.path.join(OUTPUT_ROOT, "2d_tem") + + output_train = os.path.join(output_folder, "train_unlabeled") + os.makedirs(output_train, exist_ok=True) + for name in tqdm(train_images, desc="Export train images"): + out_path = os.path.join(output_train, name) + if os.path.exists(out_path): + continue + in_path = os.path.join(train_root, name) + with h5py.File(in_path, "r") as f: + raw = f["raw"][:] + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw, compression="gzip") + + output_test = os.path.join(output_folder, "test") + os.makedirs(output_test, exist_ok=True) + for name in tqdm(test_images, desc="Export test images"): + out_path = os.path.join(output_test, name) + if os.path.exists(out_path): + continue + in_path = os.path.join(test_root, name) + with h5py.File(in_path, "r") as f: + raw = f["data"][:] + labels = f["labels/vesicles"][:] + mask = f["labels/mask"][:] + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw, compression="gzip") + f.create_dataset("labels/vesicles", data=labels, compression="gzip") + f.create_dataset("labels/mask", data=mask, compression="gzip") + + +def prepare_munc_snap(): + pass + + +def main(): + prepare_single_ax_stem_chemical_fix() + # prepare_2d_tem() + # prepare_frog() + # prepare_ier() + # prepare_munc_snap() + + +if __name__ == "__main__": + main() diff --git a/synaptic_reconstruction/inference/active_zone.py b/synaptic_reconstruction/inference/active_zone.py new file mode 100644 index 0000000..d611693 --- /dev/null +++ b/synaptic_reconstruction/inference/active_zone.py @@ -0,0 +1,122 @@ +import time +from typing import Dict, List, Optional, Tuple, Union + +import elf.parallel as parallel +import numpy as np +import torch + +from skimage.segmentation import find_boundaries +from synaptic_reconstruction.inference.util import get_prediction, _Scaler + + +def find_intersection_boundary(segmented_AZ: np.ndarray, segmented_compartment: np.ndarray) -> np.ndarray: + """ + Find the cumulative intersection of the boundary of each label in segmented_compartment with segmented_AZ. + + Args: + segmented_AZ: 3D array representing the active zone (AZ). + segmented_compartment: 3D array representing the compartment, with multiple labels. + + Returns: + Array with the cumulative intersection of all boundaries of segmented_compartment labels with segmented_AZ. + """ + # Step 0: Initialize an empty array to accumulate intersections + cumulative_intersection = np.zeros_like(segmented_AZ, dtype=bool) + + # Step 1: Loop through each unique label in segmented_compartment (excluding 0 if it represents background) + labels = np.unique(segmented_compartment) + labels = labels[labels != 0] # Exclude background label (0) if necessary + + for label in labels: + # Step 2: Create a binary mask for the current label + label_mask = (segmented_compartment == label) + + # Step 3: Find the boundary of the current label's compartment + boundary_compartment = find_boundaries(label_mask, mode='outer') + + # Step 4: Find the intersection with the AZ for this label's boundary + intersection = np.logical_and(boundary_compartment, segmented_AZ) + + # Step 5: Accumulate intersections for each label + cumulative_intersection = np.logical_or(cumulative_intersection, intersection) + + return cumulative_intersection.astype(int) # Convert boolean array to int (1 for intersecting points, 0 elsewhere) + + +def _run_segmentation( + foreground, verbose, min_size, + # blocking shapes for parallel computation + block_shape=(128, 256, 256), +): + + # get the segmentation via seeded watershed + t0 = time.time() + seg = parallel.label(foreground > 0.5, block_shape=block_shape, verbose=verbose) + if verbose: + print("Compute connected components in", time.time() - t0, "s") + + # size filter + t0 = time.time() + ids, sizes = parallel.unique(seg, return_counts=True, block_shape=block_shape, verbose=verbose) + filter_ids = ids[sizes < min_size] + seg[np.isin(seg, filter_ids)] = 0 + if verbose: + print("Size filter in", time.time() - t0, "s") + seg = np.where(seg > 0, 1, 0) + return seg + + +def segment_active_zone( + input_volume: np.ndarray, + model_path: Optional[str] = None, + model: Optional[torch.nn.Module] = None, + tiling: Optional[Dict[str, Dict[str, int]]] = None, + min_size: int = 500, + verbose: bool = True, + return_predictions: bool = False, + scale: Optional[List[float]] = None, + mask: Optional[np.ndarray] = None, + compartment: Optional[np.ndarray] = None, +) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """Segment active zones in an input volume. + + Args: + input_volume: The input volume to segment. + model_path: The path to the model checkpoint if `model` is not provided. + model: Pre-loaded model. Either `model_path` or `model` is required. + tiling: The tiling configuration for the prediction. + verbose: Whether to print timing information. + scale: The scale factor to use for rescaling the input volume before prediction. + mask: An optional mask that is used to restrict the segmentation. + compartment: + + Returns: + The foreground mask as a numpy array. + """ + if verbose: + print("Segmenting AZ in volume of shape", input_volume.shape) + # Create the scaler to handle prediction with a different scaling factor. + scaler = _Scaler(scale, verbose) + input_volume = scaler.scale_input(input_volume) + + # Rescale the mask if it was given and run prediction. + if mask is not None: + mask = scaler.scale_input(mask, is_segmentation=True) + pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose) + + # Run segmentation and rescale the result if necessary. + foreground = pred[0] + print(f"shape {foreground.shape}") + + segmentation = _run_segmentation(foreground, verbose=verbose, min_size=min_size) + + # returning prediciton and intersection not possible atm, but currently do not need prediction anyways + if return_predictions: + pred = scaler.rescale_output(pred, is_segmentation=False) + return segmentation, pred + + if compartment is not None: + intersection = find_intersection_boundary(segmentation, compartment) + return segmentation, intersection + + return segmentation diff --git a/synaptic_reconstruction/inference/util.py b/synaptic_reconstruction/inference/util.py index cedfb07..5a799f3 100644 --- a/synaptic_reconstruction/inference/util.py +++ b/synaptic_reconstruction/inference/util.py @@ -332,7 +332,7 @@ def inference_helper( mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext) assert len(input_files) == len(mask_files) - for i, img_path in tqdm(enumerate(input_files), total=len(input_files)): + for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"): # Determine the output file name. input_folder, input_name = os.path.split(img_path) @@ -350,7 +350,12 @@ def inference_helper( # Check if the output path is already present. # If it is we skip the prediction, unless force was set to true. if os.path.exists(output_path) and not force: - continue + if output_key is None: + continue + else: + with open_file(output_path, "r") as f: + if output_key in f: + continue # Load the input volume. If we have extra_files then this concatenates the # data across a new first axis (= channel axis). diff --git a/synaptic_reconstruction/tools/cli.py b/synaptic_reconstruction/tools/cli.py index bcb3085..54a52a3 100644 --- a/synaptic_reconstruction/tools/cli.py +++ b/synaptic_reconstruction/tools/cli.py @@ -83,6 +83,7 @@ def imod_object_cli(): # TODO: handle kwargs # TODO: add custom model path +# TODO: enable autoscaling from input resolution def segmentation_cli(): parser = argparse.ArgumentParser(description="Run segmentation.") parser.add_argument( @@ -117,15 +118,24 @@ def segmentation_cli(): parser.add_argument( "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." ) + parser.add_argument( + "--segmentation_key", "-s", help="" + ) + # TODO enable autoscaling + parser.add_argument( + "--scale", type=float, default=None, help="" + ) args = parser.parse_args() model = get_model(args.model) tiling = parse_tiling(args.tile_shape, args.halo) + scale = None if args.scale is None else 3 * (args.scale,) segmentation_function = partial( - run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling, + run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling, scale=scale ) inference_helper( args.input_path, args.output_path, segmentation_function, mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, + output_key=args.segmentation_key, ) diff --git a/synaptic_reconstruction/tools/util.py b/synaptic_reconstruction/tools/util.py index 2d135cc..cb4b67b 100644 --- a/synaptic_reconstruction/tools/util.py +++ b/synaptic_reconstruction/tools/util.py @@ -6,8 +6,10 @@ import numpy as np import pooch -from ..inference.vesicles import segment_vesicles +from ..inference.active_zone import segment_active_zone +from ..inference.compartments import segment_compartments from ..inference.mitochondria import segment_mitochondria +from ..inference.vesicles import segment_vesicles def _save_table(save_path, data): @@ -102,9 +104,9 @@ def run_segmentation( elif model_type == "mitochondria": segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose) elif model_type == "active_zone": - raise NotImplementedError + segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose) elif model_type == "compartments": - raise NotImplementedError + segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose) elif model_type == "inner_ear_structures": raise NotImplementedError else: