diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml new file mode 100644 index 0000000..fd93a50 --- /dev/null +++ b/.github/workflows/run_tests.yaml @@ -0,0 +1,40 @@ +name: test + +on: + push: + branches: + - main + tags: + - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 + pull_request: # run CI on commits to any open PR + workflow_dispatch: # can manually trigger CI from GitHub actions tab + + +jobs: + test: + name: ${{ matrix.os }} ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + timeout-minutes: 60 + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.11"] + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup micromamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: environment_cpu.yaml + create-args: >- + python=${{ matrix.python-version }} + + - name: Install SynapseNet + shell: bash -l {0} + run: pip install --no-deps -e . + + - name: Run tests + shell: bash -l {0} + run: python -m unittest discover -s test -v diff --git a/README.md b/README.md index 1a8d937..2af858d 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,10 @@ -# Synaptic Reconstruction +# SynapseNet: Deep Learning for Automatic Synapse Reconstruction -Reconstruction of synaptic structures in electron microscopy. +SynapseNet is a tool for segmentation and analysis of synapses in electron microscopy. -THIS IS WORK IN PROGRESS! +To learn how to use SynapseNet, check out [the documentation](https://computational-cell-analytics.github.io/synapse-net/). +To learn more about how it works, check out [our preprint](TODO). -## Installation - -- Make sure conda or mamba is installed. - - If you don't have a conda installation yet we recommend [micromamba](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html) -- Create the environment with all required dependencies: `mamba env create -f environment.yaml` -- Activate the environment: `mamba activate synaptic-reconstruction` -- Install the package: `pip install -e .` - -## Tools - -### Segmentation Correction - -https://napari.org/stable/howtos/layers/labels.html - -### Distance Measurements +See an example reconstruction of a mossy fibre synapse with SynapseNet. +Automatic segmentation of synaptic vesicles are rendered in orange, active zones in blue and two mitochondria in red and cyan. +![Reconstruction of a mossy fiber synapse](doc/images/synapse-reconstruction.png) diff --git a/doc/images/synapse-reconstruction.png b/doc/images/synapse-reconstruction.png new file mode 100644 index 0000000..b7cf058 Binary files /dev/null and b/doc/images/synapse-reconstruction.png differ diff --git a/doc/start_page.md b/doc/start_page.md index db78bfb..3e8d809 100644 --- a/doc/start_page.md +++ b/doc/start_page.md @@ -1,2 +1,87 @@ -# Synaptic Reconstruction -lorem ipsum... \ No newline at end of file +# SynapseNet: Deep Learning for Automatic Synapse Reconstruction + +SynapseNet is a tool for automatic segmentation and analysis of synapses in electron micrographs. +It provides deep neural networks for: +- Synaptic vesicle segmentation in ssTEM (2d data) and (cryo-)electron tomography (3d data) +- Active zone membrane segmentation in electron tomography +- Mitochondrion segmentation in electron tomography +- Synaptic compartment segmentation in electron tomography +- Synaptic ribbon and pre-synaptic density segmentation for ribbon synapses in electron tomography +It also offers functionality for quantifying synaptic ultrastructure based on segmentation results, for example by measuring vesicle or structure morphology, measuring distances between vesicles and structures, or assigning vesicles into different pools. +SynapseNet mainly targets electron tomography, but can also be appled to other types of electron microscopy, +especially throught the [domain adaptation](domain-adaptation) functionality. + +SynapseNet offers a [napari plugin](napari-plugin), [command line interface](command-line-interface), and [python library](python-library). +Please cite our [bioRxiv preprint](TODO) if you use it in your research. + + +## Requirements & Installation + +SynapseNet was developed and tested on Linux. It should be possible to install and use it on Mac or Windows, but we have not tested this. +Furthermore, SynapseNet requires a GPU for segmentation of 3D volumes. + +You need a [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) or [mamba](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) installation. Follow the instruction at the respective links if you have installed neither. We assume you have `conda` for the rest of the instructions. After installing it, you can use the `conda` command. + +To install it you should follow these steps: +- First, download the SynapseNet repository via +```bash +git clone https://github.com/computational-cell-analytics/synapse-net +``` +- Then, enter the `synapse-net` folder: +```bash +cd synapse-net +``` +- Now you can install the environment for SynapseNet with `conda` from the environment file we proved: +```bash +conda env create -f environment.yaml +``` +- You will need to confirm this step. It will take a while. Afterwards you can activate the environment: +```bash +conda activate synapse-net +``` +- Finally, install SynapseNet itself into the environment: +```bash +pip install -e . +``` + +Now you can use all SynapseNet features. From now on, just activate the environment via +``` +conda activate synapse-net +``` +to use them. + +> Note: If you use `mamba` instead of conda just replace `conda` in the commands above with `mamba`. + +> Note: We also provide an environment for a CPU version of SynapseNet. You can install it by replacing `environment.yaml` with `environment_cpu.yaml` in the respective command above. This version can be used for 2D vesicle segmentation, but it does not work for 3D segmentation. + +> Note: If you have issues with the CUDA version then install a PyTorch that matches your nvidia drivers. See [pytorch.org](https://pytorch.org/) for details. + + +## Napari Plugin + +**The rest of the documentation will be updated in the next days!** + + +## Command Line Functionality + +- segmentation cli +- export to imod + - vesicles / spheres + - objects + + +## Python Library + +- segmentation functions +- distance and morphology measurements +- imod + +### Domain Adaptation + +- explain domain adaptation +- link to the example script + +### Network Training + +- explain / diff to domain adaptation +- link to the example script diff --git a/environment.yaml b/environment.yaml index 82644b0..b669bbd 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,17 +1,22 @@ channels: + - pytorch + - nvidia - conda-forge name: - synaptic-reconstruction + synapse-net dependencies: - - python-elf + - bioimageio.core + - kornia + - magicgui - napari - pip - pyqt - - magicgui + - python-elf - pytorch - - bioimageio.core - - kornia + - pytorch-cuda=12.4 - tensorboard + - torch_em + - torchvision - trimesh - pip: - napari-skimage-regionprops diff --git a/environment_cpu.yaml b/environment_cpu.yaml new file mode 100644 index 0000000..5c68976 --- /dev/null +++ b/environment_cpu.yaml @@ -0,0 +1,18 @@ +channels: + - conda-forge +name: + synapse-net +dependencies: + - bioimageio.core + - kornia + - magicgui + - napari + - pip + - pyqt + - python-elf + - pytorch + - tensorboard + - torch_em + - trimesh + - pip: + - napari-skimage-regionprops diff --git a/plot_distances.sh b/plot_distances.sh deleted file mode 100755 index 5e4b1b1..0000000 --- a/plot_distances.sh +++ /dev/null @@ -1 +0,0 @@ -sr_tools.measure_distances -i /home/pape/Work/data/moser/lipids-julia/corrected_tomos_mrc/TS01.mrc_10.00Apx_corrected.mrc -s /home/pape/Work/data/moser/lipids-julia/results/v1/labels-center-membrane/TS01.mrc_10.00Apx_corrected.tif -m /home/pape/Work/data/moser/lipids-julia/results/v1/distance_measurements/TS01.mrc_10.00Apx_corrected.npz diff --git a/scripts/aggregate_data_information.py b/scripts/aggregate_data_information.py index d90ec8c..6ce446a 100644 --- a/scripts/aggregate_data_information.py +++ b/scripts/aggregate_data_information.py @@ -12,30 +12,24 @@ stem = "STEM" -def aggregate_vesicle_train_data(roots, test_tomograms, conditions, resolutions): +def aggregate_vesicle_train_data(roots, conditions, resolutions): tomo_names = [] - tomo_vesicles = [] + tomo_vesicles_all, tomo_vesicles_imod = [], [] tomo_condition = [] tomo_resolution = [] tomo_train = [] - for ds, root in roots.items(): - print("Aggregate data for", ds) - train_root = root["train"] - if train_root == "": - test_root = root["test"] - tomograms = sorted(glob(os.path.join(test_root, "2024**", "*.h5"), recursive=True)) - this_test_tomograms = [os.path.basename(tomo) for tomo in tomograms] + def aggregate_split(ds, split_root, split): + if ds.startswith("04"): + tomograms = sorted(glob(os.path.join(split_root, "2024**", "*.h5"), recursive=True)) else: - # This is only the case for 04, which is also nested - tomograms = sorted(glob(os.path.join(train_root, "*.h5"))) - this_test_tomograms = test_tomograms[ds] + tomograms = sorted(glob(os.path.join(split_root, "*.h5"))) assert len(tomograms) > 0, ds this_condition = conditions[ds] this_resolution = resolutions[ds][0] - for tomo_path in tqdm(tomograms): + for tomo_path in tqdm(tomograms, desc=f"Aggregate {split}"): fname = os.path.basename(tomo_path) with h5py.File(tomo_path, "r") as f: try: @@ -43,24 +37,39 @@ def aggregate_vesicle_train_data(roots, test_tomograms, conditions, resolutions) except KeyError: tomo_name = fname - n_label_sets = len(f["labels"]) - if n_label_sets > 2: - print(tomo_path, "contains the following labels:", list(f["labels"].keys())) - seg = f["labels/vesicles"][:] - n_vesicles = len(np.unique(seg)) - 1 + if "labels/vesicles/combined_vesicles" in f: + all_vesicles = f["labels/vesicles/combined_vesicles"][:] + imod_vesicles = f["labels/vesicles/masked_vesicles"][:] + n_vesicles_all = len(np.unique(all_vesicles)) - 1 + n_vesicles_imod = len(np.unique(imod_vesicles)) - 2 + else: + vesicles = f["labels/vesicles"][:] + n_vesicles_all = len(np.unique(vesicles)) - 1 + n_vesicles_imod = n_vesicles_all tomo_names.append(tomo_name) - tomo_vesicles.append(n_vesicles) + tomo_vesicles_all.append(n_vesicles_all) + tomo_vesicles_imod.append(n_vesicles_imod) tomo_condition.append(this_condition) tomo_resolution.append(this_resolution) - tomo_train.append("test" if fname in this_test_tomograms else "train/val") + tomo_train.append(split) + + for ds, root in roots.items(): + print("Aggregate data for", ds) + train_root = root["train"] + if train_root != "": + aggregate_split(ds, train_root, "train/val") + test_root = root["test"] + if test_root != "": + aggregate_split(ds, test_root, "test") df = pd.DataFrame({ "tomogram": tomo_names, "condition": tomo_condition, "resolution": tomo_resolution, "used_for": tomo_train, - "vesicle_count": tomo_vesicles, + "vesicle_count_all": tomo_vesicles_all, + "vesicle_count_imod": tomo_vesicles_imod, }) os.makedirs("data_summary", exist_ok=True) @@ -70,15 +79,15 @@ def aggregate_vesicle_train_data(roots, test_tomograms, conditions, resolutions) def vesicle_train_data(): roots = { "01": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/01_hoi_maus_2020_incomplete", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/01_hoi_maus_2020_incomplete", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/01_hoi_maus_2020_incomplete", # noqa }, "02": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/02_hcc_nanogold", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/02_hcc_nanogold", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/02_hcc_nanogold", # noqa }, "03": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/03_hog_cs1sy7", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/03_hog_cs1sy7", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/03_hog_cs1sy7", # noqa }, "04": { @@ -86,44 +95,31 @@ def vesicle_train_data(): "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/ground_truth/04Dataset_for_vesicle_eval/", # noqa }, "05": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/05_stem750_sv_training", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/05_stem750_sv_training", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/05_stem750_sv_training", # noqa }, "07": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/07_hoi_s1sy7_tem250_ihgp", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/07_hoi_s1sy7_tem250_ihgp", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/07_hoi_s1sy7_tem250_ihgp", # noqa }, "09": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/09_stem750_66k", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/09_stem750_66k", # noqa "test": "", }, "10": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/10_tem_single_release", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/10_tem_single_release", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/10_tem_single_release", # noqa }, "11": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/11_tem_multiple_release", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/11_tem_multiple_release", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/11_tem_multiple_release", # noqa }, "12": { - "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/extracted/20240909_cp_datatransfer/12_chemical_fix_cryopreparation", # noqa + "train": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/12_chemical_fix_cryopreparation", # noqa "test": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/12_chemical_fix_cryopreparation", # noqa }, } - test_tomograms = { - "01": ["tomogram-009.h5", "tomogram-038.h5", "tomogram-049.h5", "tomogram-052.h5", "tomogram-057.h5", "tomogram-060.h5", "tomogram-067.h5", "tomogram-074.h5", "tomogram-076.h5", "tomogram-083.h5", "tomogram-133.h5", "tomogram-136.h5", "tomogram-145.h5", "tomogram-149.h5", "tomogram-150.h5"], # noqa - "02": ["tomogram-004.h5", "tomogram-008.h5"], - "03": ["tomogram-003.h5", "tomogram-004.h5", "tomogram-008.h5",], - "04": [], # all used for test - "05": ["tomogram-003.h5", "tomogram-005.h5",], - "07": ["tomogram-006.h5", "tomogram-017.h5",], - "09": [], # no test data - "10": ["tomogram-001.h5", "tomogram-002.h5", "tomogram-007.h5"], - "11": ["tomogram-001.h5 tomogram-007.h5 tomogram-008.h5"], - "12": ["tomogram-004.h5", "tomogram-021.h5", "tomogram-022.h5",], - } - conditions = { "01": single_ax_tem, "02": dual_ax_tem, @@ -150,7 +146,7 @@ def vesicle_train_data(): "12": (1.554, 1.554, 1.554) } - aggregate_vesicle_train_data(roots, test_tomograms, conditions, resolutions) + aggregate_vesicle_train_data(roots, conditions, resolutions) def aggregate_az_train_data(roots, test_tomograms, conditions, resolutions): @@ -204,7 +200,7 @@ def active_zone_train_data(): "01": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/01_hoi_maus_2020_incomplete", # noqa "04": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/04_hoi_stem_examples", # noqa "06": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/06_hoi_wt_stem750_fm", # noqa - "12": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/2D_data/20241021_imig_2014_data_transfer_exported_grouped", # noqa + "12": "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/12_chemical_fix_cryopreparation", # noqa } test_tomograms = { @@ -397,6 +393,11 @@ def vesicle_domain_adaptation_data(): "MF_05649_P-09175-E_06.h5", "MF_05646_C-09175-B_001B.h5", "MF_05649_P-09175-E_07.h5", "MF_05649_G-09175-C_001.h5", "MF_05646_C-09175-B_002.h5", "MF_05649_G-09175-C_04.h5", "MF_05649_P-09175-E_05.h5", "MF_05646_C-09175-B_000.h5", "MF_05646_C-09175-B_001.h5" + ], + "frog": [ + "block10U3A_three.h5", "block30UB_one_two.h5", "block30UB_two.h5", "block10U3A_one.h5", + "block184B_one.h5", "block30UB_three.h5", "block10U3A_two.h5", "block30UB_four.h5", + "block30UB_one.h5", "block10U3A_five.h5", ] } @@ -439,13 +440,89 @@ def vesicle_domain_adaptation_data(): aggregate_da(roots, train_tomograms, test_tomograms, resolutions) +def get_n_images_frog(): + root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/rizzoli/extracted/upsampled_by2" + tomos = ["block10U3A_three.h5", "block30UB_one_two.h5", "block30UB_two.h5", "block10U3A_one.h5", + "block184B_one.h5", "block30UB_three.h5", "block10U3A_two.h5", "block30UB_four.h5", + "block30UB_one.h5", "block10U3A_five.h5"] + + n_images = 0 + for tomo in tomos: + path = os.path.join(root, tomo) + with h5py.File(path, "r") as f: + n_images += f["raw"].shape[0] + print(n_images) + + +def get_image_sizes_tem_2d(): + root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/2D_data/maus_2020_tem2d_wt_unt_div14_exported_scaled/good_for_DAtraining/maus_2020_tem2d_wt_unt_div14_exported_scaled" # noqa + tomos = [ + "MF_05649_P-09175-E_06.h5", "MF_05646_C-09175-B_001B.h5", "MF_05649_P-09175-E_07.h5", + "MF_05649_G-09175-C_001.h5", "MF_05646_C-09175-B_002.h5", "MF_05649_G-09175-C_04.h5", + "MF_05649_P-09175-E_05.h5", "MF_05646_C-09175-B_000.h5", "MF_05646_C-09175-B_001.h5" + ] + for tomo in tomos: + path = os.path.join(root, tomo) + with h5py.File(path, "r") as f: + print(f["raw"].shape) + + +def mito_train_data(): + train_root = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/fidi_down_s2" + test_tomograms = [ + "36859_J1_66K_TS_CA3_MF_18_rec_2Kb1dawbp_crop_downscaled.h5", + "3.2_downscaled.h5", + ] + all_tomos = sorted(glob(os.path.join(train_root, "*.h5"))) + + tomo_names = [] + tomo_condition = [] + tomo_mitos = [] + tomo_resolution = [] + tomo_train = [] + + for tomo in all_tomos: + fname = os.path.basename(tomo) + split = "test" if fname in test_tomograms else "train/val" + if "36859" in fname or "37371" in fname: # This is from the STEM dataset. + condition = stem + resolution = 2 * 0.868 + else: # This is from the TEM Single-Axis Dataset + condition = single_ax_tem + # These were scaled, despite the resolution mismatch + resolution = 2 * 1.554 + + with h5py.File(tomo, "r") as f: + seg = f["labels/mitochondria"][:] + n_mitos = len(np.unique(seg)) - 1 + + tomo_names.append(tomo) + tomo_condition.append(condition) + tomo_train.append(split) + tomo_resolution.append(resolution) + tomo_mitos.append(n_mitos) + + df = pd.DataFrame({ + "tomogram": tomo_names, + "condition": tomo_condition, + "resolution": tomo_resolution, + "used_for": tomo_train, + "mito_count_all": tomo_mitos, + }) + + os.makedirs("data_summary", exist_ok=True) + df.to_excel("./data_summary/mitochondria.xlsx", index=False) + + def main(): # active_zone_train_data() # compartment_train_data() - # mito_train_data() + mito_train_data() # vesicle_train_data() - vesicle_domain_adaptation_data() + # vesicle_domain_adaptation_data() + # get_n_images_frog() + # get_image_sizes_tem_2d() main() diff --git a/scripts/cooper/.gitignore b/scripts/cooper/.gitignore index 43efa15..5fd959a 100644 --- a/scripts/cooper/.gitignore +++ b/scripts/cooper/.gitignore @@ -1 +1,4 @@ pwd.txt +debug/ +mito/ +synapse-examples/ diff --git a/scripts/cooper/analysis/.gitignore b/scripts/cooper/analysis/.gitignore new file mode 100644 index 0000000..b6de208 --- /dev/null +++ b/scripts/cooper/analysis/.gitignore @@ -0,0 +1,8 @@ +screenshots/ +20241108_3D_Imig_DATA_2014/ +*az*/ +mrc_files/ +imig_data/ +results/ +*.xlsx +*.tsv diff --git a/scripts/cooper/analysis/active_zone_analysis.py b/scripts/cooper/analysis/active_zone_analysis.py index d2234c9..bb13ac5 100644 --- a/scripts/cooper/analysis/active_zone_analysis.py +++ b/scripts/cooper/analysis/active_zone_analysis.py @@ -3,15 +3,22 @@ import h5py import numpy as np +import napari +import pandas as pd from scipy.ndimage import binary_closing from skimage.measure import label from synaptic_reconstruction.ground_truth.shape_refinement import edge_filter +from synaptic_reconstruction.morphology import skeletonize_object +from synaptic_reconstruction.distance_measurements import measure_segmentation_to_object_distances from tqdm import tqdm -ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/20241102_TOMO_DATA_Imig2014/final_Imig2014_seg_autoComp" # noqa +from compute_skeleton_area import calculate_surface_area -OUTPUT_AZ = "./boundary_az" +ROOT = "./imig_data" # noqa +OUTPUT_AZ = "./az_segmentation" + +RESOLUTION = (1.554,) * 3 def filter_az(path): @@ -20,6 +27,7 @@ def filter_az(path): ds = os.path.basename(ds) out_path = os.path.join(OUTPUT_AZ, ds, fname) os.makedirs(os.path.join(OUTPUT_AZ, ds), exist_ok=True) + if os.path.exists(out_path): return @@ -56,11 +64,192 @@ def filter_az(path): f.create_dataset("filtered_az", data=az_filtered, compression="gzip") -def main(): +def filter_all_azs(): files = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True)) - for ff in tqdm(files): + for ff in tqdm(files, desc="Filter AZ segmentations."): filter_az(ff) +def process_az(path, view=True): + key = "thin_az" + + with h5py.File(path, "r") as f: + if key in f and not view: + return + az_seg = f["filtered_az"][:] + + az_thin = skeletonize_object(az_seg) + + if view: + ds, fname = os.path.split(path) + ds = os.path.basename(ds) + raw_path = os.path.join(ROOT, ds, fname) + with h5py.File(raw_path, "r") as f: + raw = f["raw"][:] + v = napari.Viewer() + v.add_image(raw) + v.add_labels(az_seg) + v.add_labels(az_thin) + napari.run() + else: + with h5py.File(path, "a") as f: + f.create_dataset(key, data=az_thin, compression="gzip") + + +# Apply thinning to all active zones to obtain 1d surface. +def process_all_azs(): + files = sorted(glob(os.path.join(OUTPUT_AZ, "**/*.h5"), recursive=True)) + for ff in tqdm(files, desc="Thin AZ segmentations."): + process_az(ff, view=False) + + +def measure_az_area(path): + from skimage import measure + + with h5py.File(path, "r") as f: + seg = f["thin_az"][:] + + # Try via surface mesh. + verts, faces, normals, values = measure.marching_cubes(seg, spacing=RESOLUTION) + surface_area1 = measure.mesh_surface_area(verts, faces) + + # Try via custom function. + surface_area2 = calculate_surface_area(seg, voxel_size=RESOLUTION) + + ds, fname = os.path.split(path) + ds = os.path.basename(ds) + + return pd.DataFrame({ + "Dataset": [ds], + "Tomogram": [fname], + "surface_mesh [nm^2]": [surface_area1], + "surface_custom [nm^2]": [surface_area2], + }) + + +# Measure the AZ surface areas. +def measure_all_areas(): + save_path = "./results/area_measurements.xlsx" + if os.path.exists(save_path): + return + + files = sorted(glob(os.path.join(OUTPUT_AZ, "**/*.h5"), recursive=True)) + area_table = [] + for ff in tqdm(files, desc="Measure AZ areas."): + area = measure_az_area(ff) + area_table.append(area) + area_table = pd.concat(area_table) + area_table.to_excel(save_path, index=False) + + manual_results = "/home/pape/Work/my_projects/synaptic-reconstruction/scripts/cooper/debug/surface/manualAZ_surface_area.xlsx" # noqa + manual_results = pd.read_excel(manual_results)[["Dataset", "Tomogram", "manual"]] + comparison_table = pd.merge(area_table, manual_results, on=["Dataset", "Tomogram"], how="inner") + comparison_table.to_excel("./results/area_comparison.xlsx", index=False) + + +def analyze_areas(): + import seaborn as sns + import matplotlib.pyplot as plt + + table = pd.read_excel("./results/area_comparison.xlsx") + + fig, axes = plt.subplots(2) + sns.scatterplot(data=table, x="manual", y="surface_mesh [nm^2]", ax=axes[0]) + sns.scatterplot(data=table, x="manual", y="surface_custom [nm^2]", ax=axes[1]) + plt.show() + + +def measure_distances(ves_path, az_path): + with h5py.File(az_path, "r") as f: + az = f["thin_az"][:] + + with h5py.File(ves_path, "r") as f: + vesicles = f["vesicles/segment_from_combined_vesicles"][:] + + distances, _, _, _ = measure_segmentation_to_object_distances(vesicles, az, resolution=RESOLUTION) + + ds, fname = os.path.split(az_path) + ds = os.path.basename(ds) + + return pd.DataFrame({ + "Dataset": [ds] * len(distances), + "Tomogram": [fname] * len(distances), + "Distance": distances, + }) + + +# Measure the AZ vesicle distances for all vesicles. +def measure_all_distances(): + save_path = "./results/vesicle_az_distances.xlsx" + if os.path.exists(save_path): + return + + ves_files = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True)) + az_files = sorted(glob(os.path.join(OUTPUT_AZ, "**/*.h5"), recursive=True)) + assert len(ves_files) == len(az_files) + + dist_table = [] + for ves_file, az_file in tqdm(zip(ves_files, az_files), total=len(az_files), desc="Measure distances."): + dist = measure_distances(ves_file, az_file) + dist_table.append(dist) + dist_table = pd.concat(dist_table) + + dist_table.to_excel(save_path, index=False) + + +def reformat_distances(): + tab = pd.read_excel("./results/vesicle_az_distances.xlsx") + + munc_ko = {} + munc_ctrl = {} + + snap_ko = {} + snap_ctrl = {} + + for _, row in tab.iterrows(): + ds = row.Dataset + tomo = row.Tomogram + + if ds == "Munc13DKO": + if "CTRL" in tomo: + group = munc_ctrl + else: + group = munc_ko + else: + assert ds == "SNAP25" + if "CTRL" in tomo: + group = snap_ctrl + else: + group = snap_ko + + name = os.path.splitext(tomo)[0] + val = row["Distance [nm]"] + if name in group: + group[name].append(val) + else: + group[name] = [val] + + def save_tab(group, path): + n_ves_max = max(len(v) for v in group.values()) + group = {k: v + [np.nan] * (n_ves_max - len(v)) for k, v in group.items()} + group_tab = pd.DataFrame(group) + group_tab.to_excel(path, index=False) + + os.makedirs("./results/distances_formatted", exist_ok=True) + save_tab(munc_ko, "./results/distances_formatted/munc_ko.xlsx") + save_tab(munc_ctrl, "./results/distances_formatted/munc_ctrl.xlsx") + save_tab(snap_ko, "./results/distances_formatted/snap_ko.xlsx") + save_tab(snap_ctrl, "./results/distances_formatted/snap_ctrl.xlsx") + + +def main(): + # filter_all_azs() + # process_all_azs() + # measure_all_areas() + # analyze_areas() + # measure_all_distances() + reformat_distances() + + if __name__ == "__main__": main() diff --git a/scripts/cooper/analysis/az_postprocessing.py b/scripts/cooper/analysis/az_postprocessing.py new file mode 100644 index 0000000..93ef5dc --- /dev/null +++ b/scripts/cooper/analysis/az_postprocessing.py @@ -0,0 +1,116 @@ +import os +from glob import glob + +import h5py +import napari +import numpy as np + +from magicgui import magicgui +from scipy.ndimage import binary_dilation, binary_opening +from skimage.measure import label + + +def postprocess_az(thin_az_seg): + # seg = binary_dilation(thin_az_seg, iterations=1) + # seg = binary_opening(seg) + seg = label(thin_az_seg) + + ids, sizes = np.unique(seg, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + seg = seg == ids[np.argmax(sizes)].astype("uint8") + return seg + + +def process_az(raw_path, az_path): + with h5py.File(raw_path, "r") as f: + raw = f["raw"][:] + + with h5py.File(az_path, "r") as f: + seg = f["thin_az"][:] + + seg_pp = postprocess_az(seg) + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(seg, opacity=1, visible=True) + segl = v.add_labels(seg_pp, opacity=1) + segl.new_colormap() + v.title = raw_path + napari.run() + + +def check_all_postprocessed(): + raw_paths = sorted(glob(os.path.join("imig_data/**/*.h5"), recursive=True)) + seg_paths = sorted(glob(os.path.join("az_segmentation/**/*.h5"), recursive=True)) + assert len(raw_paths) == len(seg_paths) + for raw_path, seg_path in zip(raw_paths, seg_paths): + process_az(raw_path, seg_path) + + +def proofread_file(raw_path, az_path, out_root): + ds, fname = os.path.split(raw_path) + ds = os.path.basename(ds) + + out_folder = os.path.join(out_root, ds) + os.makedirs(out_folder, exist_ok=True) + out_path = os.path.join(out_folder, fname) + + if os.path.exists(out_path): + return + + with h5py.File(raw_path, "r") as f: + raw = f["raw"][:] + + with h5py.File(az_path, "r") as f: + seg = f["thin_az"][:] + + seg_pp = postprocess_az(seg) + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(seg, opacity=1, visible=True, name="original") + segl = v.add_labels(seg_pp, opacity=1, name="postprocessed") + segl.new_colormap() + + v.title = raw_path + + @magicgui(call_button="Postprocess") + def postprocess(): + seg = v.layers["postprocessed"].data + seg = postprocess_az(seg) + v.layers["postprocessed"].data = seg + + @magicgui(call_button="Save") + def save(): + seg = v.layers["postprocessed"].data + with h5py.File(out_path, "a") as f: + f.create_dataset("az_thin_proofread", data=seg, compression="gzip") + print("Save done!") + + v.window.add_dock_widget(postprocess) + v.window.add_dock_widget(save) + + napari.run() + + +def proofread_az(out_folder): + raw_paths = sorted(glob(os.path.join("imig_data/**/*.h5"), recursive=True)) + seg_paths = sorted(glob(os.path.join("az_segmentation/**/*.h5"), recursive=True)) + assert len(raw_paths) == len(seg_paths) + os.makedirs(out_folder, exist_ok=True) + for i, (raw_path, seg_path) in enumerate(zip(raw_paths, seg_paths)): + print(i, "/", len(seg_paths)) + proofread_file(raw_path, seg_path, out_folder) + + +def main(): + # check_all_postprocessed() + # process_az( + # "./imig_data/Munc13DKO/A_M13DKO_060212_DKO1.1_crop.h5", + # "./az_segmentation/Munc13DKO/A_M13DKO_060212_DKO1.1_crop.h5" + # ) + proofread_az("./proofread_az") + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/analysis/check_size_export.py b/scripts/cooper/analysis/check_size_export.py new file mode 100644 index 0000000..b00fdfa --- /dev/null +++ b/scripts/cooper/analysis/check_size_export.py @@ -0,0 +1,22 @@ +from elf.io import open_file + + +def test_export(): + from synaptic_reconstruction.imod.to_imod import write_segmentation_to_imod_as_points + from subprocess import run + + mrc_path = "20241108_3D_Imig_DATA_2014/!_M13DKO_TOMO_DATA_Imig2014_mrc-mod-FM/A_M13DKO_080212_CTRL4.8_crop/A_M13DKO_080212_CTRL4.8_crop.mrc" # noqa + seg_path = "imig_data/Munc13DKO/A_M13DKO_080212_CTRL4.8_crop.h5" + out_path = "exported_vesicles.mod" + + with open_file(seg_path, "r") as f: + seg = f["vesicles/segment_from_combined_vesicles"][:] + + # !!!! 0.7 + write_segmentation_to_imod_as_points( + mrc_path, seg, out_path, min_radius=10, radius_factor=0.7 + ) + run(["imod", mrc_path, out_path]) + + +test_export() diff --git a/scripts/cooper/analysis/compute_skeleton_area.py b/scripts/cooper/analysis/compute_skeleton_area.py new file mode 100644 index 0000000..6fb05d0 --- /dev/null +++ b/scripts/cooper/analysis/compute_skeleton_area.py @@ -0,0 +1,44 @@ +import numpy as np + + +def calculate_surface_area(skeleton, voxel_size=(1.0, 1.0, 1.0)): + """ + Calculate the surface area of a 3D skeletonized object. + + Parameters: + skeleton (3D array): Binary 3D skeletonized array. + voxel_size (tuple): Physical size of voxels (z, y, x). + + Returns: + float: Approximate surface area of the skeleton. + """ + # Define the voxel dimensions + voxel_area = ( + voxel_size[1] * voxel_size[2], # yz-face area + voxel_size[0] * voxel_size[2], # xz-face area + voxel_size[0] * voxel_size[1], # xy-face area + ) + + # Compute the number of exposed faces for each voxel + exposed_faces = 0 + directions = [ + (1, 0, 0), (-1, 0, 0), # x-axis neighbors + (0, 1, 0), (0, -1, 0), # y-axis neighbors + (0, 0, 1), (0, 0, -1), # z-axis neighbors + ] + + # Iterate over all voxels in the skeleton + for z, y, x in np.argwhere(skeleton): + for i, (dz, dy, dx) in enumerate(directions): + neighbor = (z + dz, y + dy, x + dx) + # Check if the neighbor is outside the volume or not part of the skeleton + if ( + 0 <= neighbor[0] < skeleton.shape[0] and + 0 <= neighbor[1] < skeleton.shape[1] and + 0 <= neighbor[2] < skeleton.shape[2] and + skeleton[neighbor] == 1 + ): + continue + exposed_faces += voxel_area[i // 2] + + return exposed_faces diff --git a/scripts/cooper/analysis/correct_manual_azs.py b/scripts/cooper/analysis/correct_manual_azs.py new file mode 100644 index 0000000..7a88b4d --- /dev/null +++ b/scripts/cooper/analysis/correct_manual_azs.py @@ -0,0 +1,48 @@ +import os + +import h5py +import napari + +from magicgui import magicgui + + +def correct_manual_az(raw_path, seg_path): + with h5py.File(raw_path, "r") as f: + raw = f["raw"][:] + + seg_key = "az_thin_proofread" + with h5py.File(seg_path, "r") as f: + seg = f[seg_key][:] + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(seg) + + @magicgui(call_button="save") + def save(): + seg = v.layers["seg"].data + with h5py.File(seg_path, "a") as f: + f[seg_key][:] = seg + + v.window.add_dock_widget(save) + + napari.run() + + +def main(): + to_correct = [ + # ("Munc13DKO", "B_M13DKO_080212_CTRL4.8_crop"), + # ("SNAP25", "A_SNAP25_12082_KO1.2_6_crop"), + # ("SNAP25", "B_SNAP25_120812_CTRL1.3_13_crop"), + ("SNAP25", "B_SNAP25_12082_KO1.2_6_crop") + ] + for ds, fname in to_correct: + raw_path = os.path.join("imig_data", ds, f"{fname}.h5") + seg_path = os.path.join("proofread_az", ds, f"{fname}.h5") + assert os.path.exists(raw_path) + assert os.path.exists(seg_path) + correct_manual_az(raw_path, seg_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/analysis/export_az_to_imod.py b/scripts/cooper/analysis/export_az_to_imod.py new file mode 100644 index 0000000..fe72f01 --- /dev/null +++ b/scripts/cooper/analysis/export_az_to_imod.py @@ -0,0 +1,152 @@ +import os +import tempfile +from glob import glob +from subprocess import run +from shutil import copyfile + +import h5py +import pandas as pd + +from synaptic_reconstruction.imod.to_imod import write_segmentation_to_imod +from scipy.ndimage import binary_dilation, binary_closing + + +def check_imod(tomo_path, mod_path): + run(["imod", tomo_path, mod_path]) + + +def export_all_to_imod(check_input=True, check_export=True): + files = sorted(glob("./proofread_az/**/*.h5", recursive=True)) + mrc_root = "./mrc_files" + output_folder = "./az_export/initial_model" + + ratings = pd.read_excel("quality_ratings/az_quality_clean_FM.xlsx") + + for ff in files: + ds, fname = os.path.split(ff) + ds = os.path.basename(ds) + out_folder = os.path.join(output_folder, ds) + out_path = os.path.join(out_folder, fname.replace(".h5", ".mod")) + if os.path.exists(out_path): + continue + + restrict_to_good_azs = False + if restrict_to_good_azs: + tomo_name = os.path.splitext(fname)[0] + rating = ratings[ + (ratings["Dataset"] == ds) & (ratings["Tomogram"] == tomo_name) + ]["Rating"].values[0] + if rating != "Good": + print("Skipping", ds, tomo_name, "due to", rating) + continue + + os.makedirs(out_folder, exist_ok=True) + mrc_path = os.path.join(mrc_root, ds, fname.replace(".h5", ".rec")) + assert os.path.exists(mrc_path), mrc_path + + with h5py.File(ff, "r") as f: + if "thin_az_corrected" in f: + print("Loading corrected az!") + seg = f["thin_az_corrected"][:] + else: + seg = f["az_thin_proofread"][:] + + seg = binary_dilation(seg, iterations=2) + seg = binary_closing(seg, iterations=2) + + write_segmentation_to_imod(mrc_path, seg, out_path) + + if check_input: + import napari + from elf.io import open_file + with open_file(mrc_path, "r") as f: + raw = f["data"][:] + v = napari.Viewer() + v.add_image(raw) + v.add_labels(seg) + napari.run() + + if check_export: + check_imod(mrc_path, out_path) + + +# https://bio3d.colorado.edu/imod/doc/man/reducecont.html +def reduce_all_contours(): + pass + + +# https://bio3d.colorado.edu/imod/doc/man/smoothsurf.html#TOP +def smooth_all_surfaces(check_output=True): + input_files = sorted(glob("./az_export/initial_model/**/*.mod", recursive=True)) + + mrc_root = "./mrc_files" + output_folder = "./az_export/smoothed_model" + for ff in input_files: + ds, fname = os.path.split(ff) + ds = os.path.basename(ds) + out_folder = os.path.join(output_folder, ds) + out_file = os.path.join(out_folder, fname) + if os.path.exists(out_file): + continue + + os.makedirs(out_folder, exist_ok=True) + run(["smoothsurf", ff, out_file]) + if check_output: + mrc_path = os.path.join(mrc_root, ds, fname.replace(".mod", ".rec")) + assert os.path.exists(mrc_path), mrc_path + check_imod(mrc_path, out_file) + + +def measure_surfaces(): + input_files = sorted(glob("./az_export/smoothed_model/**/*.mod", recursive=True)) + + result = { + "Dataset": [], + "Tomogram": [], + "AZ Surface": [], + } + for ff in input_files: + ds, fname = os.path.split(ff) + ds = os.path.basename(ds) + fname = os.path.splitext(fname)[0] + + with tempfile.NamedTemporaryFile() as f_mesh, tempfile.NamedTemporaryFile() as f_mod: + tmp_path_mesh = f_mesh.name + tmp_path_mod = f_mod.name + copyfile(ff, tmp_path_mesh) + run(["imodmesh", tmp_path_mesh]) + run(["imodinfo", "-f", tmp_path_mod, tmp_path_mesh]) + area = None + with open(tmp_path_mod, "r") as f: + for line in f.readlines(): + line = line.strip() + if line.startswith("Total mesh surface area"): + area = float(line.split(" ")[-1]) + assert area is not None + area /= 2 + + result["Dataset"].append(ds) + result["Tomogram"].append(fname) + result["AZ Surface"].append(area) + + result = pd.DataFrame(result) + result.to_excel("./results/az_areas_all.xlsx", index=False) + + +def filter_surfaces(): + all_results = pd.read_excel("./results/az_areas_all.xlsx") + man_tomos = pd.read_csv("./man_tomos.tsv") + + man_results = all_results.merge(man_tomos[["Dataset", "Tomogram"]], on=["Dataset", "Tomogram"], how="inner") + man_results.to_excel("./results/az_areas_manual.xlsx", index=False) + + +def main(): + export_all_to_imod(False, False) + smooth_all_surfaces(False) + measure_surfaces() + filter_surfaces() + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/analysis/filter_vesicle_sizes.py b/scripts/cooper/analysis/filter_vesicle_sizes.py new file mode 100644 index 0000000..042bd56 --- /dev/null +++ b/scripts/cooper/analysis/filter_vesicle_sizes.py @@ -0,0 +1,50 @@ +import os +from glob import glob + +import numpy as np +import pandas as pd + + +def filter_sizes_by_distance(size_table, distance_table, out_dir, max_distance=100): + fname = os.path.basename(size_table) + print("Filtering vesicles for", fname) + + size_table = pd.read_csv(size_table) + distance_table = pd.read_csv(distance_table) + assert (size_table.columns == distance_table.columns).all() + out_columns = {} + n_tot, n_filtered = 0, 0 + all_values = [] + for col_name in size_table.columns: + size_values = size_table[col_name].values + distance_values = distance_table[col_name].values + size_values, distance_values = ( + size_values[np.isfinite(size_values)], + distance_values[np.isfinite(distance_values)] + ) + assert len(size_values) == len(distance_values) + n_tot += len(size_values) + mask = distance_values < max_distance + out_columns[col_name] = size_values[mask] + n_filtered += mask.sum() + all_values.extend(size_values[mask].tolist()) + + print("Total number of vesicles:", n_tot) + print("Number of vesicles after filtering:", n_filtered) + print("Average diameter:", np.mean(all_values)) + os.makedirs(out_dir, exist_ok=True) + out_path = os.path.join(out_dir, fname) + + filtered_sizes = pd.DataFrame.from_dict(out_columns, orient='index').transpose() + filtered_sizes.to_csv(out_path, index=False) + + +def main(): + size_tables = sorted(glob("./results/diameters/*.csv")) + distance_tables = sorted(glob("./results/distances/*.csv")) + for size_tab, distance_tab in zip(size_tables, distance_tables): + filter_sizes_by_distance(size_tab, distance_tab, "./results/filtered_diameters") + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/analysis/filter_vesicles.py b/scripts/cooper/analysis/filter_vesicles.py new file mode 100644 index 0000000..cef1d78 --- /dev/null +++ b/scripts/cooper/analysis/filter_vesicles.py @@ -0,0 +1,106 @@ +import os +from glob import glob + +import napari +import numpy as np +import h5py + +from skimage.measure import regionprops +from skimage.morphology import remove_small_holes +from tqdm import tqdm + + +def fill_and_filter_vesicles(vesicles): + ids, sizes = np.unique(vesicles, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + + # import matplotlib.pyplot as plt + # n, bins, patches = plt.hist(sizes, bins=32) + # print(bins[:5]) + # plt.show() + + min_size = 2500 + vesicles_pp = vesicles.copy() + filter_ids = ids[sizes < min_size] + vesicles_pp[np.isin(vesicles, filter_ids)] = 0 + + props = regionprops(vesicles_pp) + for prop in props: + bb = prop.bbox + bb = np.s_[ + bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5] + ] + mask = vesicles_pp[bb] == prop.label + mask = remove_small_holes(mask, area_threshold=1000) + vesicles_pp[bb][mask] = prop.label + + return vesicles_pp + + +# Filter out the vesicles so that only the ones overlapping with the max compartment are taken. +def process_tomogram(path, out_path): + with h5py.File(out_path, "r") as f: + if "vesicles" in f: + return + + with h5py.File(path, "r") as f: + raw = f["raw"][:] + compartments = f["compartments/segment_from_3Dmodel_v2"][:] + vesicles = f["vesicles/segment_from_combined_vesicles"][:] + + # Fill out small holes in vesicles and then apply a size filter. + vesicles_pp = fill_and_filter_vesicles(vesicles) + + def n_vesicles(mask, ves): + return len(np.unique(ves[mask])) - 1 + + # Find the segment with most vesicles. + props = regionprops(compartments, intensity_image=vesicles_pp, extra_properties=[n_vesicles]) + compartment_ids = [prop.label for prop in props] + vesicle_counts = [prop.n_vesicles for prop in props] + if len(compartment_ids) == 0: + mask = np.ones(compartments.shape, dtype="bool") + else: + mask = (compartments == compartment_ids[np.argmax(vesicle_counts)]).astype("uint8") + + # Filter all vesicles that are not in the mask. + props = regionprops(vesicles_pp, mask) + filter_ids = [prop.label for prop in props if prop.max_intensity == 0] + + name = os.path.basename(path) + print(name) + + no_filter = ["C_M13DKO_080212_CTRL6.7B_crop.h5", "E_M13DKO_080212_DKO1.2_crop.h5", + "G_M13DKO_080212_CTRL6.7B_crop.h5", "A_SNAP25_120812_CTRL2.3_14_crop.h5", + "A_SNAP25_12082_KO2.1_6_crop.h5", "B_SNAP25_120812_CTRL2.3_14_crop.h5", + "B_SNAP25_12082_CTRL2.3_5_crop.h5", "D_SNAP25_120812_CTRL2.3_14_crop.h5", + "G_SNAP25_12.08.12_KO1.1_3_crop.h5"] + # Don't filter for wrong masks (visual inspection) + if name not in no_filter: + vesicles_pp[np.isin(vesicles_pp, filter_ids)] = 0 + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(compartments, visible=False) + v.add_labels(vesicles, visible=False) + v.add_labels(vesicles_pp) + v.add_labels(mask) + v.title = name + napari.run() + + with h5py.File(out_path, "a") as f: + f.create_dataset("vesicles", data=vesicles_pp, compression="gzip") + f.create_dataset("mask", data=mask, compression="gzip") + + +def main(): + files = sorted(glob("imig_data/**/*.h5", recursive=True)) + out_files = sorted(glob("proofread_az/**/*.h5", recursive=True)) + + # for path, out_path in zip(files, out_files): + for path, out_path in tqdm(zip(files, out_files), total=len(files)): + process_tomogram(path, out_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/analysis/measure_distances.py b/scripts/cooper/analysis/measure_distances.py new file mode 100644 index 0000000..e5373c3 --- /dev/null +++ b/scripts/cooper/analysis/measure_distances.py @@ -0,0 +1,67 @@ +import os +from glob import glob + +import h5py +import pandas as pd +from tqdm import tqdm + +from synaptic_reconstruction.distance_measurements import measure_segmentation_to_object_distances + + +RESOLUTION = (1.554,) * 3 + + +def measure_distances(path, ds, fname): + with h5py.File(path, "r") as f: + if "thin_az_corrected" in f: + print("Loading corrected az!") + az = f["thin_az_corrected"][:] + else: + az = f["/az_thin_proofread"][:] + vesicles = f["vesicles"][:] + distances, _, _, _ = measure_segmentation_to_object_distances(vesicles, az, resolution=RESOLUTION) + return distances + + +def main(): + ratings = pd.read_excel("quality_ratings/az_quality_clean_FM.xlsx") + + dataset_results = { + ds: {"CTRL": {}, "DKO": {}} for ds in pd.unique(ratings["Dataset"]) + } + + restrict_to_good_azs = False + + paths = sorted(glob("proofread_az/**/*.h5", recursive=True)) + for path in tqdm(paths): + + ds, fname = os.path.split(path) + ds = os.path.split(ds)[1] + fname = os.path.splitext(fname)[0] + + category = "CTRL" if "CTRL" in fname else "DKO" + + if restrict_to_good_azs: + rating = ratings[ + (ratings["Dataset"] == ds) & (ratings["Tomogram"] == fname) + ]["Rating"].values[0] + if rating != "Good": + continue + + distances = measure_distances(path, ds, fname) + dataset_results[ds][category][fname] = distances + + for ds, categories in dataset_results.items(): + for category, tomogram_data in categories.items(): + sorted_data = dict(sorted(tomogram_data.items())) # Sort by tomogram names + result_df = pd.DataFrame.from_dict(sorted_data, orient='index').transpose() + + os.makedirs("./results/distances", exist_ok=True) + output_path = os.path.join("./results/distances", f"distance_analysis_for_{ds}_{category}.csv") + + # Save the DataFrame to CSV + result_df.to_csv(output_path, index=False) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/analysis/measure_vesicle_sizes.py b/scripts/cooper/analysis/measure_vesicle_sizes.py new file mode 100644 index 0000000..c4dc5e7 --- /dev/null +++ b/scripts/cooper/analysis/measure_vesicle_sizes.py @@ -0,0 +1,66 @@ +import os +from glob import glob + +import h5py +import pandas as pd +from tqdm import tqdm + +from synaptic_reconstruction.imod.to_imod import convert_segmentation_to_spheres + +RESOLUTION = (1.554,) * 3 + + +def measure_diameters(path, ds, fname): + with h5py.File(path, "r") as f: + vesicles = f["vesicles"][:] + + coordinates, radii = convert_segmentation_to_spheres( + vesicles, resolution=RESOLUTION, radius_factor=0.7, estimate_radius_2d=True + ) + # We need to redo the voxelscaling to go back to the pixel size in nanometer. + radii *= RESOLUTION[0] + + diams = radii * 2 + return diams + + +def main(): + ratings = pd.read_excel("quality_ratings/az_quality_clean_FM.xlsx") + + dataset_results = { + ds: {"CTRL": {}, "DKO": {}} for ds in pd.unique(ratings["Dataset"]) + } + + restrict_to_good_azs = False + paths = sorted(glob("proofread_az/**/*.h5", recursive=True)) + for path in tqdm(paths): + + ds, fname = os.path.split(path) + ds = os.path.split(ds)[1] + fname = os.path.splitext(fname)[0] + category = "CTRL" if "CTRL" in fname else "DKO" + + if restrict_to_good_azs: + rating = ratings[ + (ratings["Dataset"] == ds) & (ratings["Tomogram"] == fname) + ]["Rating"].values[0] + if rating != "Good": + continue + + diameters = measure_diameters(path, ds, fname) + dataset_results[ds][category][fname] = diameters + + for ds, categories in dataset_results.items(): + for category, tomogram_data in categories.items(): + sorted_data = dict(sorted(tomogram_data.items())) # Sort by tomogram names + result_df = pd.DataFrame.from_dict(sorted_data, orient='index').transpose() + + os.makedirs("./results/diameters", exist_ok=True) + output_path = os.path.join("./results/diameters", f"size_analysis_for_{ds}_{category}.csv") + + # Save the DataFrame to CSV + result_df.to_csv(output_path, index=False) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/analysis/proofread_bad_azs.py b/scripts/cooper/analysis/proofread_bad_azs.py new file mode 100644 index 0000000..e23b068 --- /dev/null +++ b/scripts/cooper/analysis/proofread_bad_azs.py @@ -0,0 +1,83 @@ +import os +from glob import glob + +import napari +import numpy as np +import pandas as pd +import h5py + +from magicgui import magicgui +from tqdm import tqdm +from synaptic_reconstruction.morphology import skeletonize_object +from synaptic_reconstruction.ground_truth.shape_refinement import edge_filter + + +def proofread_az(raw_path, seg_path): + assert os.path.exists(raw_path), raw_path + assert os.path.exists(seg_path), seg_path + + with h5py.File(seg_path, "r") as f: + if "thin_az_corrected" in f: + return + seg = f["/az_thin_proofread"][:] + with h5py.File(raw_path, "r") as f: + raw = f["raw"][:] + + hmap = edge_filter(raw, sigma=1.0, method="sato", per_slice=True, n_threads=8) + membrane_mask = hmap > 0.5 + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(seg, colormap={1: "orange"}, opacity=1) + v.add_labels(np.zeros_like(seg), name="canvas") + v.add_labels(membrane_mask, visible=False) + + @magicgui(call_button="skeletonize") + def skeletonize(): + data = v.layers["canvas"].data + data = np.logical_and(data, membrane_mask) + data = skeletonize_object(data) + new_mask = data != 0 + v.layers["seg"].data[new_mask] = data[new_mask] + + @magicgui(call_button="save") + def save(): + seg = v.layers["seg"].data + + with h5py.File(seg_path, "a") as f: + f.create_dataset("thin_az_corrected", data=seg, compression="gzip") + + v.window.add_dock_widget(skeletonize) + v.window.add_dock_widget(save) + + napari.run() + + +def main(): + ratings = pd.read_excel("quality_ratings/az_quality_clean_FM.xlsx") + + paths = sorted(glob("proofread_az/**/*.h5", recursive=True)) + for path in tqdm(paths): + + ds, fname = os.path.split(path) + ds = os.path.split(ds)[1] + fname = os.path.splitext(fname)[0] + + try: + rating = ratings[ + (ratings["Dataset"] == ds) & (ratings["Tomogram"] == fname) + ]["Rating"].values[0] + except IndexError: + breakpoint() + if rating == "Good": + continue + + print(rating) + print(ds, fname) + + raw_path = os.path.join("imig_data", ds, f"{fname}.h5") + proofread_az(raw_path, path) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/analysis/rate_az.py b/scripts/cooper/analysis/rate_az.py new file mode 100644 index 0000000..ee3db9a --- /dev/null +++ b/scripts/cooper/analysis/rate_az.py @@ -0,0 +1,75 @@ +import os +from glob import glob + +import h5py +import pandas as pd + +import napari + +from magicgui.widgets import PushButton, Container +from scipy.ndimage import binary_dilation, binary_closing +from tqdm import tqdm + + +# Create the widget +def create_widget(tab_path, ds, fname): + if os.path.exists(tab_path): + tab = pd.read_excel(tab_path) + else: + tab = None + + # Create buttons + good_button = PushButton(label="Good") + avg_button = PushButton(label="Avg") + bad_button = PushButton(label="Bad") + + def _update_table(rating): + nonlocal tab + + this_tab = pd.DataFrame( + {"Dataset": [ds], "Tomogram": [fname], "Rating": [rating]} + ) + if tab is None: + tab = this_tab + else: + tab = pd.concat([tab, this_tab]) + tab.to_excel(tab_path, index=False) + + # Connect actions to button clicks + good_button.clicked.connect(lambda: _update_table("Good")) + avg_button.clicked.connect(lambda: _update_table("Average")) + bad_button.clicked.connect(lambda: _update_table("Bad")) + + # Arrange buttons in a vertical container + container = Container(widgets=[good_button, avg_button, bad_button]) + return container + + +def rate_az(): + raw_paths = sorted(glob(os.path.join("imig_data/**/*.h5"), recursive=True)) + seg_paths = sorted(glob(os.path.join("proofread_az/**/*.h5"), recursive=True)) + + tab_path = "./az_quality.xlsx" + for rp, sp in tqdm(zip(raw_paths, seg_paths), total=len(raw_paths)): + with h5py.File(rp, "r") as f: + raw = f["raw"][:] + with h5py.File(sp, "r") as f: + seg = f["az_thin_proofread"][:] + + seg_pp = binary_dilation(seg, iterations=2) + seg_pp = binary_closing(seg_pp, iterations=2) + + ds, fname = os.path.split(rp) + ds = os.path.basename(ds) + fname = os.path.splitext(fname)[0] + widget = create_widget(tab_path, ds, fname) + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(seg, colormap={1: "green"}, opacity=1) + v.add_labels(seg_pp) + v.window.add_dock_widget(widget, area="right") + napari.run() + + +rate_az() diff --git a/scripts/cooper/export_mask_to_imod.py b/scripts/cooper/export_mask_to_imod.py index 98b4b2f..4273707 100644 --- a/scripts/cooper/export_mask_to_imod.py +++ b/scripts/cooper/export_mask_to_imod.py @@ -4,19 +4,11 @@ def export_mask_to_imod(args): - # Test script - # write_segmentation_to_imod( - # "synapse-examples/36859_J1_66K_TS_CA3_PS_26_rec_2Kb1dawbp_crop.mrc", - # "synapse-examples/36859_J1_66K_TS_CA3_PS_26_rec_2Kb1dawbp_crop_mitos.tif", - # "synapse-examples/mito.mod" - # ) write_segmentation_to_imod(args.input_path, args.segmentation_path, args.output_path) def main(): parser = argparse.ArgumentParser() - - args = parser.parse_args() parser.add_argument( "-i", "--input_path", required=True, help="The filepath to the mrc file containing the data." diff --git a/scripts/cooper/full_reconstruction/.gitignore b/scripts/cooper/full_reconstruction/.gitignore new file mode 100644 index 0000000..494542f --- /dev/null +++ b/scripts/cooper/full_reconstruction/.gitignore @@ -0,0 +1,2 @@ +04_full_reconstruction/ +mito_seg/ diff --git a/scripts/cooper/full_reconstruction/qualitative_evaluation.py b/scripts/cooper/full_reconstruction/qualitative_evaluation.py new file mode 100644 index 0000000..7cd8c20 --- /dev/null +++ b/scripts/cooper/full_reconstruction/qualitative_evaluation.py @@ -0,0 +1,135 @@ +import os + +import h5py +import numpy as np +import pandas as pd +import napari + +from skimage.measure import label + +from tqdm import tqdm + +val_table = "/home/pape/Desktop/sfb1286/mboc_synapse/qualitative-stem-eval.xlsx" +val_table = pd.read_excel(val_table) + + +def _get_n_azs(path): + access = np.s_[::2, ::2, ::2] + with h5py.File(path, "r") as f: + az = f["labels/active_zone"][access] + az = label(az) + ids, sizes = np.unique(az, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + n_azs = np.sum(sizes > 10000) + return n_azs, n_azs + + +def eval_az(): + azs_found = [] + azs_total = [] + + # for the "all" tomograms load the prediction, measure number components, + # size filter and count these as found and as total + for i, row in tqdm(val_table.iterrows(), total=len(val_table)): + az_found = row["AZ Found"] + if az_found == "all": + path = os.path.join("04_full_reconstruction", row.dataset, row.tomogram) + assert os.path.exists(path) + az_found, az_total = _get_n_azs(path) + else: + az_total = row["AZ Total"] + + azs_found.append(az_found) + azs_total.append(az_total) + + n_found = np.sum(azs_found) + n_azs = np.sum(azs_total) + + print("AZ Evaluation:") + print("Number of correctly identified AZs:", n_found, "/", n_azs, f"({float(n_found)/n_azs}%)") + + +# measure in how many pieces each compartment was split +def eval_compartments(): + pieces_per_compartment = [] + for i, row in val_table.iterrows(): + for comp in [ + "Compartment 1", + "Compartment 2", + "Compartment 3", + "Compartment 4", + ]: + n_pieces = row[comp] + if isinstance(n_pieces, str): + n_pieces = len(n_pieces.split(",")) + elif np.isnan(n_pieces): + continue + else: + assert isinstance(n_pieces, (float, int)) + n_pieces = 1 + pieces_per_compartment.append(n_pieces) + + avg = np.mean(pieces_per_compartment) + std = np.std(pieces_per_compartment) + max_ = np.max(pieces_per_compartment) + print("Compartment Evaluation:") + print("Avergage pieces per compartment:", avg, "+-", std) + print("Max pieces per compartment:", max_) + print("Number of compartments:", len(pieces_per_compartment)) + + +def eval_mitos(): + mito_correct = [] + mito_split = [] + mito_merged = [] + mito_total = [] + wrong_object = [] + + mito_table = val_table.fillna(0) + # measure % of mito correct, mito split and mito merged + for i, row in mito_table.iterrows(): + mito_correct.append(row["Mito Correct"]) + mito_split.append(row["Mito Split"]) + mito_merged.append(row["Mito Merged"]) + mito_total.append(row["Mito Total"]) + wrong_object.append(row["Wrong Object"]) + + n_mitos = np.sum(mito_total) + n_correct = np.sum(mito_correct) + print("Mito Evaluation:") + print("Number of correctly identified mitos:", n_correct, "/", n_mitos, f"({float(n_correct)/n_mitos}%)") + print("Number of merged mitos:", np.sum(mito_merged)) + print("Number of split mitos:", np.sum(mito_split)) + print("Number of wrongly identified objects:", np.sum(wrong_object)) + + +def check_mitos(): + scale = 3 + access = np.s_[::scale, ::scale, ::scale] + + root = "./04_full_reconstruction" + for i, row in tqdm(val_table.iterrows(), total=len(val_table)): + ds, fname = row.dataset, row.tomogram + path = os.path.join(root, ds, fname) + with h5py.File(path, "r") as f: + raw = f["raw"][access] + mitos = f["labels/mitochondria"][access] + + # ids, sizes = np.unique(mitos, return_counts=True) + v = napari.Viewer() + v.add_image(raw) + v.add_labels(mitos) + napari.run() + + +def main(): + # check_mitos() + + # eval_mitos() + # print() + eval_compartments() + # print() + # eval_az() + + +main() diff --git a/scripts/cooper/full_reconstruction/segment_mitochondria.py b/scripts/cooper/full_reconstruction/segment_mitochondria.py index 395de78..cb82275 100644 --- a/scripts/cooper/full_reconstruction/segment_mitochondria.py +++ b/scripts/cooper/full_reconstruction/segment_mitochondria.py @@ -8,23 +8,53 @@ ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/04_full_reconstruction" # noqa MODEL_PATH = "/scratch-grete/projects/nim00007/models/exports_for_cooper/mito_model_s2.pt" # noqa +# MODEL_PATH = "/scratch-grete/projects/nim00007/models/luca/mito/source_domain" + def run_seg(path): + + out_folder = "./mito_seg" + ds, fname = os.path.split(path) + ds = os.path.basename(ds) + + os.makedirs(os.path.join(out_folder, ds), exist_ok=True) + out_path = os.path.join(out_folder, ds, fname) + if os.path.exists(out_path): + return + with h5py.File(path, "r") as f: - if "labels/mitochondria" in f: - return raw = f["raw"][:] scale = (0.5, 0.5, 0.5) seg = segment_mitochondria(raw, model_path=MODEL_PATH, scale=scale, verbose=False) - with h5py.File(path, "a") as f: + with h5py.File(out_path, "a") as f: + f.create_dataset("labels/mitochondria", data=seg, compression="gzip") + + +def run_seg_and_pred(path): + with h5py.File(path, "r") as f: + raw = f["raw"][:] + + scale = (0.5, 0.5, 0.5) + seg, pred = segment_mitochondria( + raw, model_path=MODEL_PATH, scale=scale, verbose=False, return_predictions=True + ) + + out_folder = "./mito_pred" + os.makedirs(out_folder, exist_ok=True) + out_path = os.path.join(out_folder, os.path.basename(path)) + + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw[::2, ::2, ::2]) f.create_dataset("labels/mitochondria", data=seg, compression="gzip") + f.create_dataset("pred", data=pred, compression="gzip") def main(): paths = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True)) for path in tqdm(paths): run_seg(path) + # run_seg_and_pred(path) main() diff --git a/scripts/cooper/full_reconstruction/the_most_beautiful_synapse/.gitignore b/scripts/cooper/full_reconstruction/the_most_beautiful_synapse/.gitignore new file mode 100644 index 0000000..8fce603 --- /dev/null +++ b/scripts/cooper/full_reconstruction/the_most_beautiful_synapse/.gitignore @@ -0,0 +1 @@ +data/ diff --git a/scripts/cooper/full_reconstruction/the_most_beautiful_synapse/visualize_synapse.py b/scripts/cooper/full_reconstruction/the_most_beautiful_synapse/visualize_synapse.py new file mode 100644 index 0000000..c5d2467 --- /dev/null +++ b/scripts/cooper/full_reconstruction/the_most_beautiful_synapse/visualize_synapse.py @@ -0,0 +1,41 @@ +import napari +import numpy as np + +import imageio.v3 as imageio +from elf.io import open_file +from skimage.filters import gaussian + + +def visualize_synapse(): + scale = 3 + access = np.s_[::scale, ::scale, ::scale] + resolution = (scale * 0.868,) * 3 + + tomo_path = "./data/36859_J2_66K_TS_R04_MF05_rec_2Kb1dawbp_crop.mrc" + with open_file(tomo_path, "r") as f: + raw = f["data"][access] + raw = gaussian(raw) + + compartment = imageio.imread("./data/segmented/compartment.tif") + vesicles = imageio.imread("./data/segmented/vesicles.tif") + mitos = imageio.imread("./data/segmented/mitos.tif") + active_zone = imageio.imread("./data/segmented/active_zone.tif") + vesicle_ids = np.unique(vesicles)[1:] + + v = napari.Viewer() + v.add_image(raw[:, ::-1], scale=resolution) + v.add_labels(mitos[:, ::-1], scale=resolution) + v.add_labels(vesicles[:, ::-1], colormap={ves_id: "orange" for ves_id in vesicle_ids}, scale=resolution) + v.add_labels(active_zone[:, ::-1], colormap={1: "blue"}, scale=resolution) + v.add_labels(compartment[:, ::-1], colormap={1: "red"}, scale=resolution) + v.scale_bar.visible = True + v.scale_bar.unit = "nm" + v.scale_bar.font_size = 16 + napari.run() + + +def main(): + visualize_synapse() + + +main() diff --git a/scripts/cooper/full_reconstruction/visualize_results.py b/scripts/cooper/full_reconstruction/visualize_results.py index 5e3f596..dba2787 100644 --- a/scripts/cooper/full_reconstruction/visualize_results.py +++ b/scripts/cooper/full_reconstruction/visualize_results.py @@ -6,11 +6,13 @@ import numpy as np import pandas as pd +from skimage.filters import gaussian + ROOT = "./04_full_reconstruction" TABLE = "/home/pape/Desktop/sfb1286/mboc_synapse/draft_figures/full_reconstruction.xlsx" # Skip datasets for which all figures were already done. -SKIP_DS = ["20241019_Tomo-eval_MF_Synapse"] +SKIP_DS = ["20241019_Tomo-eval_MF_Synapse", "20241019_Tomo-eval_PS_Synapse"] def _get_name_and_row(path, table): @@ -46,13 +48,14 @@ def visualize_result(path, table): if ds_name in SKIP_DS: return - # if row["Use for vis"].values[0] == "yes": - if row["Use for vis"].values[0] in ("yes", "no"): + if row["Use for Vis"].values[0] == "no": return compartment_ids = _get_compartment_ids(row) # access = np.s_[:] - access = np.s_[::2, ::2, ::2] + scale = 3 + access = np.s_[::scale, ::scale, ::scale] + resolution = (scale * 0.868,) * 3 with h5py.File(path, "r") as f: raw = f["raw"][access] @@ -60,6 +63,10 @@ def visualize_result(path, table): active_zone = f["labels/active_zone"][access] mitos = f["labels/mitochondria"][access] compartments = f["labels/compartments"][access] + print("Loading done") + + raw = gaussian(raw) + print("Gaussian done") if any(comp_ids is not None for comp_ids in compartment_ids): mask = np.zeros(raw.shape, dtype="bool") @@ -78,13 +85,26 @@ def visualize_result(path, table): mitos[~mask] = 0 compartments = compartments_new + vesicle_ids = np.unique(vesicles)[1:] + + transpose = False + if transpose: + raw = raw[:, ::-1] + active_zone = active_zone[:, ::-1] + mitos = mitos[:, ::-1] + vesicles = vesicles[:, ::-1] + compartments = compartments[:, ::-1] + v = napari.Viewer() - v.add_image(raw) - v.add_labels(mitos) - v.add_labels(vesicles) - v.add_labels(compartments) - v.add_labels(active_zone) + v.add_image(raw, scale=resolution) + v.add_labels(mitos, scale=resolution) + v.add_labels(vesicles, colormap={ves_id: "orange" for ves_id in vesicle_ids}, scale=resolution) + v.add_labels(compartments, colormap={1: "red", 2: "green", 3: "orange"}, scale=resolution) + v.add_labels(active_zone, colormap={1: "blue"}, scale=resolution) v.title = f"{ds_name}/{name}" + v.scale_bar.visible = True + v.scale_bar.unit = "nm" + v.scale_bar.font_size = 16 napari.run() @@ -115,6 +135,7 @@ def main(): paths = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True)) table = pd.read_excel(TABLE) for path in paths: + print(path) visualize_result(path, table) # visualize_only_compartment(path, table) diff --git a/scripts/cooper/ground_truth/2D-data/.gitignore b/scripts/cooper/ground_truth/2D-data/.gitignore new file mode 100644 index 0000000..91bf576 --- /dev/null +++ b/scripts/cooper/ground_truth/2D-data/.gitignore @@ -0,0 +1,2 @@ +data/ +exported/ diff --git a/scripts/cooper/ground_truth/2D-data/extract_vesicles.py b/scripts/cooper/ground_truth/2D-data/extract_vesicles.py new file mode 100644 index 0000000..398c513 --- /dev/null +++ b/scripts/cooper/ground_truth/2D-data/extract_vesicles.py @@ -0,0 +1,80 @@ +import os +from glob import glob +from pathlib import Path + +import napari +import numpy as np +from elf.io import open_file +from magicgui import magicgui +from synaptic_reconstruction.imod.export import export_point_annotations + +EXPORT_FOLDER = "./exported" + + +def export_vesicles(mrc, mod): + os.makedirs(EXPORT_FOLDER, exist_ok=True) + + fname = Path(mrc).stem + output_path = os.path.join(EXPORT_FOLDER, f"{fname}.h5") + if os.path.exists(output_path): + return + + resolution = 0.592 + with open_file(mrc, "r") as f: + data = f["data"][:] + + segmentation, labels, label_names = export_point_annotations( + mod, shape=data.shape, resolution=resolution, exclude_labels=[7, 14] + ) + data, segmentation = data[0], segmentation[0] + + with open_file(output_path, "a") as f: + f.create_dataset("data", data=data, compression="gzip") + f.create_dataset("labels/vesicles", data=segmentation, compression="gzip") + + +def export_all_vesicles(): + mrc_files = sorted(glob(os.path.join("./data/*.mrc"))) + mod_files = sorted(glob(os.path.join("./data/*.mod"))) + for mrc, mod in zip(mrc_files, mod_files): + export_vesicles(mrc, mod) + + +def create_mask(file_path): + with open_file(file_path, "r") as f: + if "labels/mask" in f: + return + + data = f["data"][:] + vesicles = f["labels/vesicles"][:] + + mask = np.zeros_like(vesicles) + + v = napari.Viewer() + v.add_image(data) + v.add_labels(vesicles) + v.add_labels(mask) + + @magicgui(call_button="Save Mask") + def save_mask(v: napari.Viewer): + mask = v.layers["mask"].data.astype("uint8") + with open_file(file_path, "a") as f: + f.create_dataset("labels/mask", data=mask, compression="gzip") + + v.window.add_dock_widget(save_mask) + napari.run() + + +def create_all_masks(): + files = sorted(glob(os.path.join(EXPORT_FOLDER, "*.h5"))) + for ff in files: + create_mask(ff) + + +def main(): + export_all_vesicles() + create_all_masks() + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/ground_truth/az/.gitignore b/scripts/cooper/ground_truth/az/.gitignore new file mode 100644 index 0000000..6d52b05 --- /dev/null +++ b/scripts/cooper/ground_truth/az/.gitignore @@ -0,0 +1,3 @@ +AZ_segmentation/ +postprocessed_AZ/ +az_eval.xlsx diff --git a/scripts/cooper/ground_truth/az/check_proofread.py b/scripts/cooper/ground_truth/az/check_proofread.py new file mode 100644 index 0000000..52205ef --- /dev/null +++ b/scripts/cooper/ground_truth/az/check_proofread.py @@ -0,0 +1,52 @@ +import os + +import h5py +import napari + +from tqdm import tqdm + + +def check_proofread(raw_path, seg_path): + with h5py.File(seg_path, "r") as f: + seg1 = f["labels_pp/thin_az"][:] + seg2 = f["labels_pp/filtered_az"][:] + with h5py.File(raw_path, "r") as f: + raw = f["raw"][:] + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(seg1) + v.add_labels(seg2) + napari.run() + + +def main(): + # FIXME something wrong in the zenodo upload + root_raw = "/home/pape/Work/my_projects/synaptic-reconstruction/scripts/data_summary/for_zenodo/synapse-net/active_zones/train" # noqa + root_seg = "./postprocessed_AZ" + + test_tomograms = { + "01": [ + "WT_MF_DIV28_01_MS_09204_F1.h5", "WT_MF_DIV14_01_MS_B2_09175_CA3.h5", "M13_CTRL_22723_O2_05_DIV29_5.2.h5", "WT_Unt_SC_09175_D4_05_DIV14_mtk_05.h5", # noqa + "20190805_09002_B4_SC_11_SP.h5", "20190807_23032_D4_SC_01_SP.h5", "M13_DKO_22723_A1_03_DIV29_03_MS.h5", "WT_MF_DIV28_05_MS_09204_F1.h5", "M13_CTRL_09201_S2_06_DIV31_06_MS.h5", # noqa + "WT_MF_DIV28_1.2_MS_09002_B1.h5", "WT_Unt_SC_09175_C4_04_DIV15_mtk_04.h5", "M13_DKO_22723_A4_10_DIV29_10_MS.h5", "WT_MF_DIV14_3.2_MS_D2_09175_CA3.h5", # noqa + "20190805_09002_B4_SC_10_SP.h5", "M13_CTRL_09201_S2_02_DIV31_02_MS.h5", "WT_MF_DIV14_04_MS_E1_09175_CA3.h5", "WT_MF_DIV28_10_MS_09002_B3.h5", "WT_Unt_SC_05646_D4_02_DIV16_mtk_02.h5", "M13_DKO_22723_A4_08_DIV29_08_MS.h5", "WT_MF_DIV28_04_MS_09204_M1.h5", "WT_MF_DIV28_03_MS_09204_F1.h5", "M13_DKO_22723_A1_05_DIV29_05_MS.h5", # noqa + "WT_Unt_SC_09175_C4_06_DIV15_mtk_06.h5", "WT_MF_DIV28_09_MS_09002_B3.h5", "20190524_09204_F4_SC_07_SP.h5", + "WT_MF_DIV14_02_MS_C2_09175_CA3.h5", "M13_DKO_23037_K1_01_DIV29_01_MS.h5", "WT_Unt_SC_09175_E2_01_DIV14_mtk_01.h5", "20190807_23032_D4_SC_05_SP.h5", "WT_MF_DIV14_01_MS_E2_09175_CA3.h5", "WT_MF_DIV14_03_MS_B2_09175_CA3.h5", "M13_DKO_09201_O1_01_DIV31_01_MS.h5", "M13_DKO_09201_U1_04_DIV31_04_MS.h5", # noqa + "WT_MF_DIV14_04_MS_E2_09175_CA3_2.h5", "WT_Unt_SC_09175_D5_01_DIV14_mtk_01.h5", + "M13_CTRL_22723_O2_05_DIV29_05_MS_.h5", "WT_MF_DIV14_02_MS_B2_09175_CA3.h5", "WT_MF_DIV14_01.2_MS_D1_09175_CA3.h5", # noqa + ], + "12": ["20180305_09_MS.h5", "20180305_04_MS.h5", "20180305_08_MS.h5", + "20171113_04_MS.h5", "20171006_05_MS.h5", "20180305_01_MS.h5"], + } + + for ds, test_tomos in test_tomograms.items(): + ds_name_raw = "single_axis_tem" if ds == "01" else "chemical-fixation" + ds_name_seg = "01_hoi_maus_2020_incomplete" if ds == "01" else "12_chemical_fix_cryopreparation" + for tomo in tqdm(test_tomos, desc=f"Proofread {ds}"): + raw_path = os.path.join(root_raw, ds_name_raw, tomo) + seg_path = os.path.join(root_seg, ds_name_seg, tomo) + check_proofread(raw_path, seg_path) + + +main() diff --git a/scripts/cooper/ground_truth/az/evaluate_az.py b/scripts/cooper/ground_truth/az/evaluate_az.py new file mode 100644 index 0000000..63ca765 --- /dev/null +++ b/scripts/cooper/ground_truth/az/evaluate_az.py @@ -0,0 +1,92 @@ +import os + +import h5py +import pandas as pd +from elf.evaluation import dice_score + +from scipy.ndimage import binary_dilation, binary_closing +from tqdm import tqdm + + +def _expand_AZ(az): + return binary_closing( + binary_dilation(az, iterations=3), iterations=3 + ) + + +def eval_az(seg_path, gt_path, seg_key, gt_key): + with h5py.File(seg_path, "r") as f: + seg = f[seg_key][:] + with h5py.File(gt_path, "r") as f: + gt = f[gt_key][:] + assert seg.shape == gt.shape + + seg = _expand_AZ(seg) + gt = _expand_AZ(gt) + score = dice_score(seg, gt) + + # import napari + # v = napari.Viewer() + # v.add_labels(seg) + # v.add_labels(gt) + # v.title = f"Dice = {score}, {seg_path}" + # napari.run() + + return score + + +def main(): + res_path = "./az_eval.xlsx" + if not os.path.exists(res_path): + seg_root = "AZ_segmentation/postprocessed_AZ" + gt_root = "postprocessed_AZ" + + # Removed WT_Unt_SC_05646_D4_02_DIV16_mtk_02.h5 from the eval set because of contrast issues + test_tomograms = { + "01": [ + "WT_MF_DIV28_01_MS_09204_F1.h5", "WT_MF_DIV14_01_MS_B2_09175_CA3.h5", "M13_CTRL_22723_O2_05_DIV29_5.2.h5", "WT_Unt_SC_09175_D4_05_DIV14_mtk_05.h5", # noqa + "20190805_09002_B4_SC_11_SP.h5", "20190807_23032_D4_SC_01_SP.h5", "M13_DKO_22723_A1_03_DIV29_03_MS.h5", "WT_MF_DIV28_05_MS_09204_F1.h5", "M13_CTRL_09201_S2_06_DIV31_06_MS.h5", # noqa + "WT_MF_DIV28_1.2_MS_09002_B1.h5", "WT_Unt_SC_09175_C4_04_DIV15_mtk_04.h5", "M13_DKO_22723_A4_10_DIV29_10_MS.h5", "WT_MF_DIV14_3.2_MS_D2_09175_CA3.h5", # noqa + "20190805_09002_B4_SC_10_SP.h5", "M13_CTRL_09201_S2_02_DIV31_02_MS.h5", "WT_MF_DIV14_04_MS_E1_09175_CA3.h5", "WT_MF_DIV28_10_MS_09002_B3.h5", "M13_DKO_22723_A4_08_DIV29_08_MS.h5", "WT_MF_DIV28_04_MS_09204_M1.h5", "WT_MF_DIV28_03_MS_09204_F1.h5", "M13_DKO_22723_A1_05_DIV29_05_MS.h5", # noqa + "WT_Unt_SC_09175_C4_06_DIV15_mtk_06.h5", "WT_MF_DIV28_09_MS_09002_B3.h5", "20190524_09204_F4_SC_07_SP.h5", + "WT_MF_DIV14_02_MS_C2_09175_CA3.h5", "M13_DKO_23037_K1_01_DIV29_01_MS.h5", "WT_Unt_SC_09175_E2_01_DIV14_mtk_01.h5", "20190807_23032_D4_SC_05_SP.h5", "WT_MF_DIV14_01_MS_E2_09175_CA3.h5", "WT_MF_DIV14_03_MS_B2_09175_CA3.h5", "M13_DKO_09201_O1_01_DIV31_01_MS.h5", "M13_DKO_09201_U1_04_DIV31_04_MS.h5", # noqa + "WT_MF_DIV14_04_MS_E2_09175_CA3_2.h5", "WT_Unt_SC_09175_D5_01_DIV14_mtk_01.h5", + "M13_CTRL_22723_O2_05_DIV29_05_MS_.h5", "WT_MF_DIV14_02_MS_B2_09175_CA3.h5", "WT_MF_DIV14_01.2_MS_D1_09175_CA3.h5", # noqa + ], + "12": ["20180305_09_MS.h5", "20180305_04_MS.h5", "20180305_08_MS.h5", + "20171113_04_MS.h5", "20171006_05_MS.h5", "20180305_01_MS.h5"], + } + + scores = { + "Dataset": [], + "Tomogram": [], + "Dice": [] + } + for ds, test_tomos in test_tomograms.items(): + ds_name = "01_hoi_maus_2020_incomplete" if ds == "01" else "12_chemical_fix_cryopreparation" + for tomo in tqdm(test_tomos): + seg_path = os.path.join(seg_root, ds_name, tomo) + gt_path = os.path.join(gt_root, ds_name, tomo) + score = eval_az(seg_path, gt_path, seg_key="AZ/thin_az", gt_key="labels_pp/filtered_az") + + scores["Dataset"].append(ds_name) + scores["Tomogram"].append(tomo) + scores["Dice"].append(score) + + scores = pd.DataFrame(scores) + scores.to_excel(res_path, index=False) + + else: + scores = pd.read_excel(res_path) + + print("Evaluation for the datasets:") + for ds in pd.unique(scores.Dataset): + print(ds) + ds_scores = scores[scores.Dataset == ds]["Dice"] + print(ds_scores.mean(), "+-", ds_scores.std()) + + print("Total:") + print(scores["Dice"].mean(), "+-", scores["Dice"].std()) + + +main() diff --git a/scripts/cooper/ground_truth/az/proofread_az.py b/scripts/cooper/ground_truth/az/proofread_az.py new file mode 100644 index 0000000..5522ff2 --- /dev/null +++ b/scripts/cooper/ground_truth/az/proofread_az.py @@ -0,0 +1,98 @@ +import os + +import napari +import numpy as np +import h5py + +from tqdm import tqdm +from magicgui import magicgui + +from scipy.ndimage import binary_dilation, binary_closing +from synaptic_reconstruction.morphology import skeletonize_object +from synaptic_reconstruction.ground_truth.shape_refinement import edge_filter + + +def proofread_az(raw_path, seg_path): + assert os.path.exists(raw_path), raw_path + assert os.path.exists(seg_path), seg_path + + with h5py.File(seg_path, "r") as f: + if "labels_pp" in f: + return + seg = f["labels/thin_az"][:] + with h5py.File(raw_path, "r") as f: + raw = f["raw"][:] + + hmap = edge_filter(raw, sigma=1.0, method="sato", per_slice=True, n_threads=8) + membrane_mask = hmap > 0.5 + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(seg, colormap={1: "orange"}, opacity=1) + v.add_labels(seg, colormap={1: "yellow"}, opacity=1, name="seg_pp") + v.add_labels(np.zeros_like(seg), name="canvas") + v.add_labels(membrane_mask, visible=False) + + @magicgui(call_button="split") + def split(): + data = v.layers["seg"].data + ids, sizes = np.unique(data, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + data = data == ids[np.argmax(sizes)] + v.layers["seg_pp"].data = data + + @magicgui(call_button="skeletonize") + def skeletonize(): + data = v.layers["canvas"].data + data = np.logical_and(data, membrane_mask) + data = skeletonize_object(data) + v.layers["seg_pp"].data = data + + @magicgui(call_button="save") + def save(): + seg_pp = v.layers["seg_pp"].data + seg_pp_dilated = binary_dilation(seg_pp, iterations=2) + seg_pp_dilated = binary_closing(seg_pp, iterations=2) + + with h5py.File(seg_path, "a") as f: + f.create_dataset("labels_pp/thin_az", data=seg_pp, compression="gzip") + f.create_dataset("labels_pp/filtered_az", data=seg_pp_dilated, compression="gzip") + + v.window.add_dock_widget(split) + v.window.add_dock_widget(skeletonize) + v.window.add_dock_widget(save) + + napari.run() + + +def main(): + # FIXME something wrong in the zenodo upload + root_raw = "/home/pape/Work/my_projects/synaptic-reconstruction/scripts/data_summary/for_zenodo/synapse-net/active_zones/train" # noqa + root_seg = "./postprocessed_AZ" + + test_tomograms = { + "01": [ + "WT_MF_DIV28_01_MS_09204_F1.h5", "WT_MF_DIV14_01_MS_B2_09175_CA3.h5", "M13_CTRL_22723_O2_05_DIV29_5.2.h5", "WT_Unt_SC_09175_D4_05_DIV14_mtk_05.h5", # noqa + "20190805_09002_B4_SC_11_SP.h5", "20190807_23032_D4_SC_01_SP.h5", "M13_DKO_22723_A1_03_DIV29_03_MS.h5", "WT_MF_DIV28_05_MS_09204_F1.h5", "M13_CTRL_09201_S2_06_DIV31_06_MS.h5", # noqa + "WT_MF_DIV28_1.2_MS_09002_B1.h5", "WT_Unt_SC_09175_C4_04_DIV15_mtk_04.h5", "M13_DKO_22723_A4_10_DIV29_10_MS.h5", "WT_MF_DIV14_3.2_MS_D2_09175_CA3.h5", # noqa + "20190805_09002_B4_SC_10_SP.h5", "M13_CTRL_09201_S2_02_DIV31_02_MS.h5", "WT_MF_DIV14_04_MS_E1_09175_CA3.h5", "WT_MF_DIV28_10_MS_09002_B3.h5", "WT_Unt_SC_05646_D4_02_DIV16_mtk_02.h5", "M13_DKO_22723_A4_08_DIV29_08_MS.h5", "WT_MF_DIV28_04_MS_09204_M1.h5", "WT_MF_DIV28_03_MS_09204_F1.h5", "M13_DKO_22723_A1_05_DIV29_05_MS.h5", # noqa + "WT_Unt_SC_09175_C4_06_DIV15_mtk_06.h5", "WT_MF_DIV28_09_MS_09002_B3.h5", "20190524_09204_F4_SC_07_SP.h5", + "WT_MF_DIV14_02_MS_C2_09175_CA3.h5", "M13_DKO_23037_K1_01_DIV29_01_MS.h5", "WT_Unt_SC_09175_E2_01_DIV14_mtk_01.h5", "20190807_23032_D4_SC_05_SP.h5", "WT_MF_DIV14_01_MS_E2_09175_CA3.h5", "WT_MF_DIV14_03_MS_B2_09175_CA3.h5", "M13_DKO_09201_O1_01_DIV31_01_MS.h5", "M13_DKO_09201_U1_04_DIV31_04_MS.h5", # noqa + "WT_MF_DIV14_04_MS_E2_09175_CA3_2.h5", "WT_Unt_SC_09175_D5_01_DIV14_mtk_01.h5", + "M13_CTRL_22723_O2_05_DIV29_05_MS_.h5", "WT_MF_DIV14_02_MS_B2_09175_CA3.h5", "WT_MF_DIV14_01.2_MS_D1_09175_CA3.h5", # noqa + ], + "12": ["20180305_09_MS.h5", "20180305_04_MS.h5", "20180305_08_MS.h5", + "20171113_04_MS.h5", "20171006_05_MS.h5", "20180305_01_MS.h5"], + } + + for ds, test_tomos in test_tomograms.items(): + ds_name_raw = "single_axis_tem" if ds == "01" else "chemical_fixation" + ds_name_seg = "01_hoi_maus_2020_incomplete" if ds == "01" else "12_chemical_fix_cryopreparation" + for tomo in tqdm(test_tomos, desc=f"Proofread {ds}"): + raw_path = os.path.join(root_raw, ds_name_raw, tomo) + seg_path = os.path.join(root_seg, ds_name_seg, tomo) + proofread_az(raw_path, seg_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/data_summary/active_zone_training_data.xlsx b/scripts/data_summary/active_zone_training_data.xlsx new file mode 100644 index 0000000..b193653 Binary files /dev/null and b/scripts/data_summary/active_zone_training_data.xlsx differ diff --git a/scripts/data_summary/compartment_training_data.xlsx b/scripts/data_summary/compartment_training_data.xlsx new file mode 100644 index 0000000..e141f0b Binary files /dev/null and b/scripts/data_summary/compartment_training_data.xlsx differ diff --git a/scripts/data_summary/vesicle_domain_adaptation_data.xlsx b/scripts/data_summary/vesicle_domain_adaptation_data.xlsx new file mode 100644 index 0000000..8a47219 Binary files /dev/null and b/scripts/data_summary/vesicle_domain_adaptation_data.xlsx differ diff --git a/scripts/data_summary/vesicle_training_data.xlsx b/scripts/data_summary/vesicle_training_data.xlsx new file mode 100644 index 0000000..57fb145 Binary files /dev/null and b/scripts/data_summary/vesicle_training_data.xlsx differ diff --git a/scripts/inner_ear/analysis/.gitignore b/scripts/inner_ear/analysis/.gitignore new file mode 100644 index 0000000..cbad005 --- /dev/null +++ b/scripts/inner_ear/analysis/.gitignore @@ -0,0 +1,3 @@ +panels/ +auto_seg_export/ +*.zip diff --git a/scripts/inner_ear/analysis/analyze_distances.py b/scripts/inner_ear/analysis/analyze_distances.py new file mode 100644 index 0000000..e3921ba --- /dev/null +++ b/scripts/inner_ear/analysis/analyze_distances.py @@ -0,0 +1,160 @@ +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns + +from common import get_all_measurements, get_measurements_with_annotation + +POOL_DICT = {"Docked-V": "MP-V", "MP-V": "MP-V", "RA-V": "RA-V"} + + +def _plot_all(distances): + pools = pd.unique(distances["pool"]) + dist_cols = ["ribbon_distance [nm]", "pd_distance [nm]", "boundary_distance [nm]"] + + fig, axes = plt.subplots(3, 3) + + # multiple = "stack" + multiple = "layer" + + structures = ["Ribbon", "PD", "Boundary"] + for i, pool in enumerate(pools): + pool_distances = distances[distances["pool"] == pool] + for j, dist_col in enumerate(dist_cols): + ax = axes[i, j] + ax.set_title(f"{pool} to {structures[j]}") + sns.histplot( + data=pool_distances, x=dist_col, hue="approach", multiple=multiple, kde=False, ax=ax + ) + ax.set_xlabel("distance [nm]") + + fig.tight_layout() + plt.show() + + +# We only care about the following distances: +# - MP-V -> PD, AZ (Boundary) +# - Docked-V -> PD, AZ +# - RA-V -> Ribbon +def _plot_selected(distances, save_path=None): + fig, axes = plt.subplots(2, 2) + multiple = "layer" + + if save_path is not None and os.path.exists(save_path): + os.remove(save_path) + + def _plot(pool_name, distance_col, structure_name, ax): + + this_distances = distances[distances["combined_pool"] == pool_name][["tomogram", "approach", distance_col]] + + ax.set_title(f"{pool_name} to {structure_name}") + sns.histplot( + data=this_distances, x=distance_col, hue="approach", multiple=multiple, kde=False, ax=ax + ) + ax.set_xlabel("distance [nm]") + + if save_path is not None: + approaches = pd.unique(this_distances["approach"]) + tomo_names = pd.unique(this_distances["tomogram"]) + + tomograms = [] + distance_values = {approach: [] for approach in approaches} + + for tomo in tomo_names: + tomo_dists = this_distances[this_distances["tomogram"] == tomo] + max_vesicles = 0 + for approach in approaches: + n_vesicles = len(tomo_dists[tomo_dists["approach"] == approach].values) + if n_vesicles > max_vesicles: + max_vesicles = n_vesicles + + for approach in approaches: + app_dists = tomo_dists[tomo_dists["approach"] == approach][distance_col].values.tolist() + app_dists = app_dists + [np.nan] * (max_vesicles - len(app_dists)) + distance_values[approach].extend(app_dists) + tomograms.extend([tomo] * max_vesicles) + + save_distances = {"tomograms": tomograms} + save_distances.update(distance_values) + save_distances = pd.DataFrame(save_distances) + + sheet_name = f"{pool_name}_{structure_name}" + if os.path.exists(save_path): + with pd.ExcelWriter(save_path, engine="openpyxl", mode="a") as writer: + save_distances.to_excel(writer, sheet_name=sheet_name, index=False) + else: + save_distances.to_excel(save_path, index=False, sheet_name=sheet_name) + + # NOTE: we over-ride a plot here, should not do this in the actual version + _plot("MP-V", "pd_distance [nm]", "PD", axes[0, 0]) + _plot("MP-V", "boundary_distance [nm]", "AZ Membrane", axes[0, 1]) + # _plot("Docked-V", "pd_distance [nm]", "PD", axes[1, 0]) + # _plot("Docked-V", "boundary_distance [nm]", "AZ Membrane", axes[1, 0]) + _plot("RA-V", "ribbon_distance [nm]", "Ribbon", axes[1, 1]) + + fig.tight_layout() + plt.show() + + +def for_tomos_with_annotation(plot_all=True): + manual_assignments, semi_automatic_assignments, proofread_assignments = get_measurements_with_annotation() + + manual_distances = manual_assignments[ + ["tomogram", "pool", "ribbon_distance [nm]", "pd_distance [nm]", "boundary_distance [nm]"] + ] + manual_distances["approach"] = ["manual"] * len(manual_distances) + manual_distances.insert(1, "combined_pool", manual_distances["pool"].replace(POOL_DICT)) + + semi_automatic_distances = semi_automatic_assignments[ + ["tomogram", "pool", "ribbon_distance [nm]", "pd_distance [nm]", "boundary_distance [nm]"] + ] + semi_automatic_distances["approach"] = ["semi_automatic"] * len(semi_automatic_distances) + semi_automatic_distances.insert(1, "combined_pool", semi_automatic_distances["pool"].replace(POOL_DICT)) + + proofread_distances = proofread_assignments[ + ["tomogram", "pool", "ribbon_distance [nm]", "pd_distance [nm]", "boundary_distance [nm]"] + ] + proofread_distances["approach"] = ["proofread"] * len(proofread_distances) + proofread_distances.insert(1, "combined_pool", proofread_distances["pool"].replace(POOL_DICT)) + + distances = pd.concat([manual_distances, semi_automatic_distances, proofread_distances]) + if plot_all: + distances.to_excel("./results/distances_tomos_with_manual_annotations.xlsx", index=False) + _plot_all(distances) + else: + _plot_selected(distances, save_path="./results/selected_distances_tomos_with_manual_annotations.xlsx") + + +def for_all_tomos(plot_all=True): + semi_automatic_assignments, proofread_assignments = get_all_measurements() + + semi_automatic_distances = semi_automatic_assignments[ + ["tomogram", "pool", "ribbon_distance [nm]", "pd_distance [nm]", "boundary_distance [nm]"] + ] + semi_automatic_distances["approach"] = ["semi_automatic"] * len(semi_automatic_distances) + semi_automatic_distances.insert(1, "combined_pool", semi_automatic_distances["pool"].replace(POOL_DICT)) + + proofread_distances = proofread_assignments[ + ["tomogram", "pool", "ribbon_distance [nm]", "pd_distance [nm]", "boundary_distance [nm]"] + ] + proofread_distances["approach"] = ["proofread"] * len(proofread_distances) + proofread_distances.insert(1, "combined_pool", proofread_distances["pool"].replace(POOL_DICT)) + + distances = pd.concat([semi_automatic_distances, proofread_distances]) + if plot_all: + distances.to_excel("./results/distances_all_tomos.xlsx", index=False) + _plot_all(distances) + else: + _plot_selected(distances, save_path="./results/selected_distances_all_tomos.xlsx") + + +def main(): + plot_all = False + for_tomos_with_annotation(plot_all=plot_all) + for_all_tomos(plot_all=plot_all) + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/analysis/analyze_vesicle_diameters.py b/scripts/inner_ear/analysis/analyze_vesicle_diameters.py new file mode 100644 index 0000000..d41dbe6 --- /dev/null +++ b/scripts/inner_ear/analysis/analyze_vesicle_diameters.py @@ -0,0 +1,196 @@ +import os +import sys + +from glob import glob + +import mrcfile +import pandas as pd +from tqdm import tqdm + +from synaptic_reconstruction.imod.export import load_points_from_imodinfo +from synaptic_reconstruction.file_utils import get_data_path + +from common import get_finished_tomos + +POOL_DICT = {"Docked-V": "MP-V", "MP-V": "MP-V", "RA-V": "RA-V"} + +sys.path.append("../processing") + + +def aggregate_diameters( + data_root, table, save_path, get_tab, include_names, method, subset, radius_factor +): + radius_table = [] + for _, row in tqdm(table.iterrows(), total=len(table), desc="Collect tomo information"): + folder = row["Local Path"] + if folder == "": + continue + + tomo_name = os.path.relpath(folder, os.path.join(data_root, "Electron-Microscopy-Susi/Analyse")) + if ( + tomo_name in ("WT strong stim/Mouse 1/modiolar/1", "WT strong stim/Mouse 1/modiolar/2") and + (row["EM alt vs. Neu"] == "neu") + ): + continue + if tomo_name not in include_names: + continue + + tab_path = get_tab(folder) + if tab_path is None: + continue + + tab = pd.read_excel(tab_path) + this_tab = tab[["pool", "radius [nm]"]] + this_tab.loc[:, "radius [nm]"] = this_tab["radius [nm]"] * radius_factor + this_tab.insert(0, "tomogram", [tomo_name] * len(this_tab)) + this_tab.insert(3, "diameter [nm]", this_tab["radius [nm]"] * 2) + radius_table.append(this_tab) + + radius_table = pd.concat(radius_table) + radius_table.insert(1, "combined_pool", radius_table["pool"].replace(POOL_DICT)) + breakpoint() + + print("Saving table for", len(radius_table), "vesicles to", save_path, sheet_name) + if os.path.exists(save_path): + with pd.ExcelWriter(save_path, engine="openpyxl", mode="a") as writer: + radius_table.to_excel(writer, sheet_name=sheet_name, index=False) + else: + radius_table.to_excel(save_path, sheet_name=sheet_name, index=False) + + tomos = pd.unique(radius_table.tomogram) + return tomos + + +def aggregate_diameters_imod(data_root, table, save_path, include_names, sheet_name): + radius_table = [] + for _, row in tqdm(table.iterrows(), total=len(table), desc="Collect tomo information"): + folder = row["Local Path"] + if folder == "": + continue + + tomo_name = os.path.relpath(folder, os.path.join(data_root, "Electron-Microscopy-Susi/Analyse")) + tomo_name = os.path.relpath(folder, os.path.join(data_root, "Electron-Microscopy-Susi/Analyse")) + if ( + tomo_name in ("WT strong stim/Mouse 1/modiolar/1", "WT strong stim/Mouse 1/modiolar/2") and + (row["EM alt vs. Neu"] == "neu") + ): + continue + if tomo_name not in include_names: + continue + + annotation_folder = os.path.join(folder, "manuell") + if not os.path.exists(annotation_folder): + annotation_folder = os.path.join(folder, "Manuell") + if not os.path.exists(annotation_folder): + continue + + annotations = glob(os.path.join(annotation_folder, "*.mod")) + annotation_file = [ann for ann in annotations if ("vesikel" in ann.lower()) or ("vesicle" in ann.lower())] + if len(annotation_file) != 1: + continue + annotation_file = annotation_file[0] + + tomo_file = get_data_path(folder) + # with mrcfile.open(tomo_file) as f: + # shape = f.data.shape + # resolution = list(f.voxel_size.item()) + # resolution = [res / 10 for res in resolution][0] + + try: + _, radii, labels, label_names = load_points_from_imodinfo( + annotation_file, shape, resolution=[1.0, 1.0, 1.0] + ) + except AssertionError: + continue + + # Determined from matching the size of vesicles in IMOD. + this_tab = pd.DataFrame({ + "tomogram": [tomo_name] * len(radii), + "pool": [label_names[label_id] for label_id in labels], + "radius [nm]": radii, + "diameter [nm]": 2 * radii, + }) + radius_table.append(this_tab) + + radius_table = pd.concat(radius_table) + print("Saving table for", len(radius_table), "vesicles to", save_path, sheet_name) + radius_table.to_excel(save_path, index=False, sheet_name=sheet_name) + + man_tomos = pd.unique(radius_table.tomogram) + return man_tomos + + +def get_tab_semi_automatic(folder): + tab_name = "measurements_uncorrected_assignments.xlsx" + res_path = os.path.join(folder, "korrektur", tab_name) + if not os.path.exists(res_path): + res_path = os.path.join(folder, "Korrektur", tab_name) + if not os.path.exists(res_path): + res_path = None + return res_path + + +def get_tab_proofread(folder): + tab_name = "measurements.xlsx" + res_path = os.path.join(folder, "korrektur", tab_name) + if not os.path.exists(res_path): + res_path = os.path.join(folder, "Korrektur", tab_name) + if not os.path.exists(res_path): + res_path = None + return res_path + + +def get_tab_manual(folder): + tab_name = "measurements.xlsx" + res_path = os.path.join(folder, "manuell", tab_name) + if not os.path.exists(res_path): + res_path = os.path.join(folder, "Manuell", tab_name) + if not os.path.exists(res_path): + res_path = None + return res_path + + +def main(): + from parse_table import parse_table, get_data_root + + data_root = get_data_root() + table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Ãœbersicht.xlsx") + table = parse_table(table_path, data_root) + + all_tomos = get_finished_tomos() + + radius_factor = 0.85 + + print("All tomograms") + save_path = "./results/vesicle_diameters_all_tomos.xlsx" + aggregate_diameters( + data_root, table, save_path=save_path, get_tab=get_tab_semi_automatic, + include_names=all_tomos, + method="Semi-automatic", + subset="manual", + radius_factor=radius_factor, + ) + return + aggregate_diameters( + data_root, table, save_path=save_path, get_tab=get_tab_proofread, include_names=all_tomos, + sheet_name="Proofread", radius_factor=radius_factor, + ) + + print() + print("Tomograms with manual annotations") + save_path = "./results/vesicle_diameters_tomos_with_manual_annotations.xlsx" + man_tomos = aggregate_diameters_imod( + data_root, table, save_path=save_path, include_names=all_tomos, sheet_name="Manual", + ) + aggregate_diameters( + data_root, table, save_path=save_path, get_tab=get_tab_semi_automatic, include_names=man_tomos, + sheet_name="Semi-automatic", radius_factor=radius_factor, + ) + aggregate_diameters( + data_root, table, save_path=save_path, get_tab=get_tab_proofread, include_names=man_tomos, + sheet_name="Proofread", radius_factor=radius_factor, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/analysis/analyze_vesicle_pools.py b/scripts/inner_ear/analysis/analyze_vesicle_pools.py new file mode 100644 index 0000000..f27a5c2 --- /dev/null +++ b/scripts/inner_ear/analysis/analyze_vesicle_pools.py @@ -0,0 +1,103 @@ +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + +from common import get_all_measurements, get_measurements_with_annotation + + +def plot_pools(data, errors): + data_for_plot = pd.melt(data, id_vars="Pool", var_name="Method", value_name="Measurement") + + # Plot using seaborn + plt.figure(figsize=(8, 6)) + sns.barplot(data=data_for_plot, x="Pool", y="Measurement", hue="Method") + + # FIXME + # error_for_plot = pd.melt(errors, id_vars="Pool", var_name="Method", value_name="Error") + # # Add error bars manually + # for i, bar in enumerate(plt.gca().patches): + # # Get Standard Deviation for the current bar + # err = error_for_plot.iloc[i % len(error_for_plot)]["Error"] + # bar_x = bar.get_x() + bar.get_width() / 2 + # bar_y = bar.get_height() + # plt.errorbar(bar_x, bar_y, yerr=err, fmt="none", c="black", capsize=4) + + # Customize the chart + plt.title("Different measurements for vesicles per pool") + plt.xlabel("Vesicle Pools") + plt.ylabel("Vesicles per Tomogram") + plt.grid(axis="y", linestyle="--", alpha=0.7) + plt.legend(title="Approaches") + + # Show the plot + plt.tight_layout() + plt.show() + + +def for_tomos_with_annotation(): + manual_assignments, semi_automatic_assignments, proofread_assignments = get_measurements_with_annotation() + + manual_counts = manual_assignments.groupby(["tomogram", "pool"]).size().unstack(fill_value=0) + semi_automatic_counts = semi_automatic_assignments.groupby(["tomogram", "pool"]).size().unstack(fill_value=0) + proofread_counts = proofread_assignments.groupby(["tomogram", "pool"]).size().unstack(fill_value=0) + + manual_stats = manual_counts.agg(["mean", "std"]).transpose().reset_index() + semi_automatic_stats = semi_automatic_counts.agg(["mean", "std"]).transpose().reset_index() + proofread_stats = proofread_counts.agg(["mean", "std"]).transpose().reset_index() + + data = pd.DataFrame({ + "Pool": manual_stats["pool"], + "Semi-automatic": semi_automatic_stats["mean"], + "Proofread": proofread_stats["mean"], + "Manual": manual_stats["mean"], + }) + errors = pd.DataFrame({ + "Pool": manual_stats["pool"], + "Semi-automatic": semi_automatic_stats["std"], + "Proofread": proofread_stats["std"], + "Manual": manual_stats["std"], + }) + + plot_pools(data, errors) + + output_path = "./results/vesicle_pools_tomos_with_manual_annotations.xlsx" + data.to_excel(output_path, index=False, sheet_name="Average") + with pd.ExcelWriter(output_path, engine="openpyxl", mode="a") as writer: + errors.to_excel(writer, sheet_name="StandardDeviation", index=False) + + +def for_all_tomos(): + semi_automatic_assignments, proofread_assignments = get_all_measurements() + + proofread_counts = proofread_assignments.groupby(["tomogram", "pool"]).size().unstack(fill_value=0) + proofread_stats = proofread_counts.agg(["mean", "std"]).transpose().reset_index() + + semi_automatic_counts = semi_automatic_assignments.groupby(["tomogram", "pool"]).size().unstack(fill_value=0) + semi_automatic_stats = semi_automatic_counts.agg(["mean", "std"]).transpose().reset_index() + + data = pd.DataFrame({ + "Pool": proofread_stats["pool"], + "Semi-automatic": semi_automatic_stats["mean"], + "Proofread": proofread_stats["mean"], + }) + errors = pd.DataFrame({ + "Pool": proofread_stats["pool"], + "Semi-automatic": semi_automatic_stats["std"], + "Proofread": proofread_stats["std"], + }) + + plot_pools(data, errors) + + output_path = "./results/vesicle_pools_all_tomos.xlsx" + data.to_excel(output_path, index=False, sheet_name="Average") + with pd.ExcelWriter(output_path, engine="openpyxl", mode="a") as writer: + errors.to_excel(writer, sheet_name="StandardDeviation", index=False) + + +def main(): + for_tomos_with_annotation() + for_all_tomos() + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/analysis/combine_fully_automatic_results.py b/scripts/inner_ear/analysis/combine_fully_automatic_results.py new file mode 100644 index 0000000..54bdbc1 --- /dev/null +++ b/scripts/inner_ear/analysis/combine_fully_automatic_results.py @@ -0,0 +1,69 @@ +import os +import sys + +import pandas as pd + +sys.path.append("..") +sys.path.append("../processing") + + +def combine_fully_auto_results(table, data_root, output_path): + from combine_measurements import combine_results + + val_table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Validierungs-Tabelle-v3.xlsx") + val_table = pd.read_excel(val_table_path) + + results = {} + for _, row in table.iterrows(): + folder = row["Local Path"] + if folder == "": + continue + + row_selection = (val_table.Bedingung == row.Bedingung) &\ + (val_table.Maus == row.Maus) &\ + (val_table["Ribbon-Orientierung"] == row["Ribbon-Orientierung"]) &\ + (val_table["OwnCloud-Unterordner"] == row["OwnCloud-Unterordner"]) + complete_vals = val_table[row_selection]["Fertig!"].values + is_complete = (complete_vals == "ja").all() + if not is_complete: + continue + + micro = row["EM alt vs. Neu"] + + tomo_name = os.path.relpath(folder, os.path.join(data_root, "Electron-Microscopy-Susi/Analyse")) + tab_name = "measurements_uncorrected_assignments.xlsx" + res_path = os.path.join(folder, "korrektur", tab_name) + if not os.path.exists(res_path): + res_path = os.path.join(folder, "Korrektur", tab_name) + assert os.path.exists(res_path), res_path + results[tomo_name] = (res_path, "alt" if micro == "beides" else micro) + + if micro == "beides": + micro = "neu" + + new_root = os.path.join(folder, "neues EM") + if not os.path.exists(new_root): + new_root = os.path.join(folder, "Tomo neues EM") + assert os.path.exists(new_root) + + res_path = os.path.join(new_root, "korrektur", "measurements.xlsx") + if not os.path.exists(res_path): + res_path = os.path.join(new_root, "Korrektur", "measurements.xlsx") + assert os.path.exists(res_path), res_path + results[tomo_name] = (res_path, "alt" if micro == "beides" else micro) + + combine_results(results, output_path, sheet_name="vesicles") + + +def main(): + from parse_table import parse_table, get_data_root + + data_root = get_data_root() + table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Ãœbersicht.xlsx") + table = parse_table(table_path, data_root) + + res_path = "../results/fully_automatic_analysis_results.xlsx" + combine_fully_auto_results(table, data_root, output_path=res_path) + + +main() diff --git a/scripts/inner_ear/analysis/common.py b/scripts/inner_ear/analysis/common.py new file mode 100644 index 0000000..772cd31 --- /dev/null +++ b/scripts/inner_ear/analysis/common.py @@ -0,0 +1,88 @@ +# import os +import sys + +import numpy as np +import pandas as pd + +sys.path.append("../processing") + +from parse_table import get_data_root # noqa + + +def get_finished_tomos(): + # data_root = get_data_root() + # val_table = os.path.join(data_root, "Electron-Microscopy-Susi", "Validierungs-Tabelle-v3.xlsx") + + val_table = "/home/pape/Desktop/sfb1286/mboc_synapse/misc/Validierungs-Tabelle-v3-passt.xlsx" + val_table = pd.read_excel(val_table) + + val_table = val_table[val_table["Kommentar 22.11.24"] == "passt"] + n_tomos = len(val_table) + assert n_tomos > 0 + + tomo_names = [] + for _, row in val_table.iterrows(): + name = "/".join([ + row.Bedingung, f"Mouse {int(row.Maus)}", + row["Ribbon-Orientierung"].lower().rstrip("?"), + str(int(row["OwnCloud-Unterordner"]))] + ) + tomo_names.append(name) + + return tomo_names + + +def get_manual_assignments(): + result_path = "../results/20241124_1/fully_manual_analysis_results.xlsx" + results = pd.read_excel(result_path) + return results + + +def get_proofread_assignments(tomograms): + result_path = "../results/20241124_1/automatic_analysis_results.xlsx" + results = pd.read_excel(result_path) + results = results[results["tomogram"].isin(tomograms)] + return results + + +def get_semi_automatic_assignments(tomograms): + result_path = "../results/fully_automatic_analysis_results.xlsx" + results = pd.read_excel(result_path) + results = results[results["tomogram"].isin(tomograms)] + return results + + +def get_measurements_with_annotation(): + manual_assignments = get_manual_assignments() + + # Get the tomos with manual annotations and the ones which are fully done in proofreading. + manual_tomos = pd.unique(manual_assignments["tomogram"]) + finished_tomos = get_finished_tomos() + # Intersect them to get the tomos we are using. + tomos = np.intersect1d(manual_tomos, finished_tomos) + + manual_assignments = manual_assignments[manual_assignments["tomogram"].isin(tomos)] + semi_automatic_assignments = get_semi_automatic_assignments(tomos) + proofread_assignments = get_proofread_assignments(tomos) + + print("Tomograms with manual annotations:", len(tomos)) + return manual_assignments, semi_automatic_assignments, proofread_assignments + + +def get_all_measurements(): + tomos = get_finished_tomos() + print("All tomograms:", len(tomos)) + + semi_automatic_assignments = get_semi_automatic_assignments(tomos) + proofread_assignments = get_proofread_assignments(tomos) + + return semi_automatic_assignments, proofread_assignments + + +def main(): + get_measurements_with_annotation() + get_all_measurements() + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/analysis/export_seg_to_imod.py b/scripts/inner_ear/analysis/export_seg_to_imod.py new file mode 100644 index 0000000..eea4b14 --- /dev/null +++ b/scripts/inner_ear/analysis/export_seg_to_imod.py @@ -0,0 +1,128 @@ +import os +from shutil import copyfile +from subprocess import run + +import imageio.v3 as imageio +import mrcfile +import napari +import numpy as np +import pandas as pd +from elf.io import open_file +from skimage.transform import resize +from synaptic_reconstruction.imod.to_imod import write_segmentation_to_imod, write_segmentation_to_imod_as_points + +out_folder = "./auto_seg_export" +os.makedirs(out_folder, exist_ok=True) + + +def _resize(seg, tomo_path): + with open_file(tomo_path, "r") as f: + shape = f["data"].shape + + if shape != seg.shape: + seg = resize(seg, shape, order=0, anti_aliasing=False, preserve_range=True).astype(seg.dtype) + assert seg.shape == shape + return seg + + +def check_imod(tomo_path, mod_path): + run(["imod", tomo_path, mod_path]) + + +def export_pool(pool_name, pool_seg, tomo_path): + seg_path = f"./auto_seg_export/{pool_name}.tif" + pool_seg = _resize(pool_seg, tomo_path) + imageio.imwrite(seg_path, pool_seg, compression="zlib") + + output_path = f"./auto_seg_export/{pool_name}.mod" + write_segmentation_to_imod_as_points(tomo_path, seg_path, output_path, min_radius=5) + + check_imod(tomo_path, output_path) + + +def export_vesicles(folder, tomo_path): + vesicle_pool_path = os.path.join(folder, "Korrektur", "vesicle_pools.tif") + # pool_correction_path = os.path.join(folder, "Korrektur", "pool_correction.tif") + # pool_correction = imageio.imread(pool_correction_path) + + assignment_path = os.path.join(folder, "Korrektur", "measurements.xlsx") + assignments = pd.read_excel(assignment_path) + + vesicles = imageio.imread(vesicle_pool_path) + + pools = {} + for pool_name in pd.unique(assignments.pool): + pool_ids = assignments[assignments.pool == pool_name].id.values + pool_seg = vesicles.copy() + pool_seg[~np.isin(vesicles, pool_ids)] = 0 + pools[pool_name] = pool_seg + + view = False + if view: + v = napari.Viewer() + v.add_labels(vesicles, visible=False) + for pool_name, pool_seg in pools.items(): + v.add_labels(pool_seg, name=pool_name) + napari.run() + else: + for pool_name, pool_seg in pools.items(): + export_pool(pool_name, pool_seg, tomo_path) + + +def export_structure(folder, tomo, name, view=False): + path = os.path.join(folder, "Korrektur", f"{name}.tif") + seg = imageio.imread(path) + seg = _resize(seg, tomo) + + if view: + with open_file(tomo, "r") as f: + raw = f["data"][:] + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(seg) + napari.run() + + return + + seg_path = f"./auto_seg_export/{name}.tif" + imageio.imwrite(seg_path, seg, compression="zlib") + output_path = f"./auto_seg_export/{name}.mod" + write_segmentation_to_imod(tomo, seg_path, output_path) + check_imod(tomo, output_path) + + +def remove_scale(tomo): + new_path = "./auto_seg_export/Emb71M1aGridA1sec1mod7.rec.rec" + if os.path.exists(new_path): + return new_path + + copyfile(tomo, new_path) + + with mrcfile.open(new_path, "r+") as f: + # Set the origin to (0, 0, 0) + f.header.nxstart = 0 + f.header.nystart = 0 + f.header.nzstart = 0 + f.header.origin = (0.0, 0.0, 0.0) + + # Save changes + f.flush() + + return new_path + + +def main(): + folder = "/home/pape/Work/data/moser/em-synapses/Electron-Microscopy-Susi/Analyse/WT strong stim/Mouse 1/modiolar/1" + tomo = os.path.join(folder, "Emb71M1aGridA1sec1mod7.rec.rec") + + tomo = remove_scale(tomo) + + # export_vesicles(folder, tomo) + # export_structure(folder, tomo, "ribbon", view=False) + # export_structure(folder, tomo, "membrane", view=False) + export_structure(folder, tomo, "PD", view=False) + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/analysis/extract_ribbon_stats.py b/scripts/inner_ear/analysis/extract_ribbon_stats.py new file mode 100644 index 0000000..8ee9e12 --- /dev/null +++ b/scripts/inner_ear/analysis/extract_ribbon_stats.py @@ -0,0 +1,36 @@ +import numpy as np +import pandas as pd + + +def main(): + man_path = "../results/20240917_1/fully_manual_analysis_results.xlsx" + auto_path = "../results/20240917_1/automatic_analysis_results.xlsx" + + man_measurements = pd.read_excel(man_path, sheet_name="morphology") + man_measurements = man_measurements[man_measurements.structure == "ribbon"][ + ["tomogram", "surface [nm^2]", "volume [nm^3]"] + ] + + auto_measurements = pd.read_excel(auto_path, sheet_name="morphology") + auto_measurements = auto_measurements[auto_measurements.structure == "ribbon"][ + ["tomogram", "surface [nm^2]", "volume [nm^3]"] + ] + + # save all the automatic measurements + auto_measurements.to_excel("./results/ribbon_morphology_auto.xlsx", index=False) + + man_tomograms = pd.unique(man_measurements["tomogram"]) + auto_tomograms = pd.unique(auto_measurements["tomogram"]) + tomos = np.intersect1d(man_tomograms, auto_tomograms) + + man_measurements = man_measurements[man_measurements.tomogram.isin(tomos)] + auto_measurements = auto_measurements[auto_measurements.tomogram.isin(tomos)] + + save_path = "./results/ribbon_morphology_man-v-auto.xlsx" + man_measurements.to_excel(save_path, sheet_name="manual", index=False) + with pd.ExcelWriter(save_path, engine="openpyxl", mode="a") as writer: + auto_measurements.to_excel(writer, sheet_name="auto", index=False) + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/processing/run_analyis.py b/scripts/inner_ear/processing/run_analyis.py index baeade1..fbb00f1 100644 --- a/scripts/inner_ear/processing/run_analyis.py +++ b/scripts/inner_ear/processing/run_analyis.py @@ -10,8 +10,9 @@ from synaptic_reconstruction.file_utils import get_data_path from synaptic_reconstruction.distance_measurements import ( - measure_segmentation_to_object_distances, filter_blocked_segmentation_to_object_distances, + load_distances, + measure_segmentation_to_object_distances, ) from synaptic_reconstruction.morphology import compute_radii, compute_object_morphology @@ -52,7 +53,7 @@ def _load_segmentation(seg_path, tomo_shape): return seg -def compute_distances(segmentation_paths, save_folder, resolution, force, tomo_shape): +def compute_distances(segmentation_paths, save_folder, resolution, force, tomo_shape, use_corrected_vesicles=True): os.makedirs(save_folder, exist_ok=True) vesicles = None @@ -61,9 +62,10 @@ def _require_vesicles(): vesicle_path = segmentation_paths["vesicles"] if vesicles is None: - vesicle_pool_path = os.path.join(os.path.split(save_folder)[0], "vesicle_pools.tif") - if os.path.exists(vesicle_pool_path): - vesicle_path = vesicle_pool_path + if use_corrected_vesicles: + vesicle_pool_path = os.path.join(os.path.split(save_folder)[0], "vesicle_pools.tif") + if os.path.exists(vesicle_pool_path): + vesicle_path = vesicle_pool_path return _load_segmentation(vesicle_path, tomo_shape) else: @@ -171,8 +173,9 @@ def load_dist(measurement_path, seg_ids=None): # Filter out the blocked vesicles. if apply_extra_filters: + rav_dists, ep1, ep2, all_rav_ids = load_distances(distance_paths["ribbon"]) rav_ids = filter_blocked_segmentation_to_object_distances( - vesicles, distance_paths["ribbon"], seg_ids=rav_ids, line_dilation=4, verbose=True, + vesicles, rav_dists, ep1, ep2, all_rav_ids, filter_seg_ids=rav_ids, line_dilation=4, verbose=True, ) rav_ids = filter_border_vesicles(vesicles, seg_ids=rav_ids) @@ -334,8 +337,7 @@ def _insert_missing_vesicles(vesicle_path, original_vesicle_path, pool_correctio imageio.imwrite(vesicle_path, vesicles) -# TODO adapt to segmentation without PD -def analyze_folder(folder, version, n_ribbons, force): +def analyze_folder(folder, version, n_ribbons, force, use_corrected_vesicles): data_path = get_data_path(folder) output_folder = os.path.join(folder, "automatisch", f"v{version}") @@ -352,12 +354,20 @@ def analyze_folder(folder, version, n_ribbons, force): correction_folder = _match_correction_folder(folder) if os.path.exists(correction_folder): output_folder = correction_folder - result_path = os.path.join(output_folder, "measurements.xlsx") + + if use_corrected_vesicles: + result_path = os.path.join(output_folder, "measurements.xlsx") + else: + result_path = os.path.join(output_folder, "measurements_uncorrected_assignments.xlsx") + if os.path.exists(result_path) and not force: return print("Analyse the corrected segmentations from", correction_folder) for seg_name in segmentation_names: + if seg_name == "vesicles" and not use_corrected_vesicles: + continue + seg_path = _match_correction_file(correction_folder, seg_name) if os.path.exists(seg_path): @@ -371,7 +381,10 @@ def analyze_folder(folder, version, n_ribbons, force): segmentation_paths[seg_name] = seg_path - result_path = os.path.join(output_folder, "measurements.xlsx") + if use_corrected_vesicles: + result_path = os.path.join(output_folder, "measurements.xlsx") + else: + result_path = os.path.join(output_folder, "measurements_uncorrected_assignments.xlsx") if os.path.exists(result_path) and not force: return @@ -384,21 +397,29 @@ def analyze_folder(folder, version, n_ribbons, force): with open_file(data_path, "r") as f: tomo_shape = f["data"].shape - out_distance_folder = os.path.join(output_folder, "distances") + if use_corrected_vesicles: + out_distance_folder = os.path.join(output_folder, "distances") + else: + out_distance_folder = os.path.join(output_folder, "distances_uncorrected") distance_paths, skip = compute_distances( segmentation_paths, out_distance_folder, resolution, force=force, tomo_shape=tomo_shape, + use_corrected_vesicles=use_corrected_vesicles ) if skip: return if force or not os.path.exists(result_path): + + if not use_corrected_vesicles: + pool_correction_path = None + analyze_distances( segmentation_paths, distance_paths, resolution, result_path, tomo_shape, pool_correction_path=pool_correction_path ) -def run_analysis(table, version, force=False, val_table=None): +def run_analysis(table, version, force=False, val_table=None, use_corrected_vesicles=True): for i, row in tqdm(table.iterrows(), total=len(table)): folder = row["Local Path"] if folder == "": @@ -426,19 +447,19 @@ def run_analysis(table, version, force=False, val_table=None): micro = row["EM alt vs. Neu"] if micro == "beides": - analyze_folder(folder, version, n_ribbons, force=force) + analyze_folder(folder, version, n_ribbons, force=force, use_corrected_vesicles=use_corrected_vesicles) folder_new = os.path.join(folder, "Tomo neues EM") if not os.path.exists(folder_new): folder_new = os.path.join(folder, "neues EM") assert os.path.exists(folder_new), folder_new - analyze_folder(folder_new, version, n_ribbons, force=force) + analyze_folder(folder_new, version, n_ribbons, force=force, use_corrected_vesicles=use_corrected_vesicles) elif micro == "alt": - analyze_folder(folder, version, n_ribbons, force=force) + analyze_folder(folder, version, n_ribbons, force=force, use_corrected_vesicles=use_corrected_vesicles) elif micro == "neu": - analyze_folder(folder, version, n_ribbons, force=force) + analyze_folder(folder, version, n_ribbons, force=force, use_corrected_vesicles=use_corrected_vesicles) def main(): @@ -447,13 +468,16 @@ def main(): table = parse_table(table_path, data_root) version = 2 - force = True + force = False + use_corrected_vesicles = False - val_table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Validierungs-Tabelle-v3.xlsx") - val_table = pandas.read_excel(val_table_path) - # val_table = None + # val_table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Validierungs-Tabelle-v3.xlsx") + # val_table = pandas.read_excel(val_table_path) + val_table = None - run_analysis(table, version, force=force, val_table=val_table) + run_analysis( + table, version, force=force, val_table=val_table, use_corrected_vesicles=use_corrected_vesicles + ) if __name__ == "__main__": diff --git a/scripts/inner_ear/training/postprocessing_and_evaluation.py b/scripts/inner_ear/training/postprocessing_and_evaluation.py index 30c9e42..30c1313 100644 --- a/scripts/inner_ear/training/postprocessing_and_evaluation.py +++ b/scripts/inner_ear/training/postprocessing_and_evaluation.py @@ -13,8 +13,8 @@ from train_structure_segmentation import get_train_val_test_split -ROOT = "/home/pape/Work/data/synaptic_reconstruction/moser" -# ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser" +# ROOT = "/home/pape/Work/data/synaptic_reconstruction/moser" +ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser" MODEL_PATH = "/mnt/lustre-emmy-hdd/projects/nim00007/models/synaptic-reconstruction/vesicle-DA-inner_ear-v2" OUTPUT_ROOT = "./predictions" @@ -187,8 +187,8 @@ def segment_train_domain(): name = "train_domain" run_vesicle_segmentation(paths, MODEL_PATH, name, is_nested=True) postprocess_structures(paths, name, is_nested=True) - visualize(paths, name, is_nested=True) - results = evaluate(paths, name, is_nested=True, save_path="./results/train_domain_postprocessed.csv") + # visualize(paths, name, is_nested=True) + results = evaluate(paths, name, is_nested=True, save_path="./results/train_domain_postprocessed_v2.csv") print(results) print("Ribbon segmentation:", results["ribbon"].mean(), "+-", results["ribbon"].std()) print("PD segmentation:", results["PD"].mean(), "+-", results["PD"].std()) diff --git a/scripts/inner_ear/training/structure_prediction_and_evaluation.py b/scripts/inner_ear/training/structure_prediction_and_evaluation.py index cb174c7..7ed89a9 100644 --- a/scripts/inner_ear/training/structure_prediction_and_evaluation.py +++ b/scripts/inner_ear/training/structure_prediction_and_evaluation.py @@ -143,10 +143,10 @@ def predict_and_evaluate_train_domain(): print("Run evaluation on", len(paths), "tomos") name = "train_domain" - model_path = "./checkpoints/inner_ear_structure_model" + model_path = "./checkpoints/inner_ear_structure_model_v2" run_prediction(paths, model_path, name, is_nested=True) - evaluate(paths, name, is_nested=True, save_path="./results/train_domain.csv") + evaluate(paths, name, is_nested=True, save_path="./results/train_domain_v2.csv") visualize(paths, name, is_nested=True) @@ -187,9 +187,9 @@ def predict_and_evaluate_rat(): def main(): - # predict_and_evaluate_train_domain() + predict_and_evaluate_train_domain() # predict_and_evaluate_vesicle_pools() - predict_and_evaluate_rat() + # predict_and_evaluate_rat() if __name__ == "__main__": diff --git a/scripts/prepare_zenodo_uploads.py b/scripts/prepare_zenodo_uploads.py new file mode 100644 index 0000000..b642c07 --- /dev/null +++ b/scripts/prepare_zenodo_uploads.py @@ -0,0 +1,246 @@ +import os +from glob import glob +from shutil import copyfile + +import h5py +from tqdm import tqdm + +OUTPUT_ROOT = "./data_summary/for_zenodo" + + +def _copy_vesicles(tomos, out_folder): + label_key = "labels/vesicles/combined_vesicles" + os.makedirs(out_folder, exist_ok=True) + for tomo in tqdm(tomos, desc="Export tomos"): + out_path = os.path.join(out_folder, os.path.basename(tomo)) + if os.path.exists(out_path): + continue + + with h5py.File(tomo, "r") as f: + raw = f["raw"][:] + labels = f[label_key][:] + try: + fname = f.attrs["filename"] + except KeyError: + fname = None + + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw, compression="gzip") + f.create_dataset("labels/vesicles", data=labels, compression="gzip") + if fname is not None: + f.attrs["filename"] = fname + + +def _export_vesicles(train_root, test_root, name): + train_tomograms = sorted(glob(os.path.join(train_root, "*.h5"))) + test_tomograms = sorted(glob(os.path.join(test_root, "*.h5"))) + print(f"Vesicle data for {name}:") + print(len(train_tomograms), len(test_tomograms), len(train_tomograms) + len(test_tomograms)) + + train_out = os.path.join(OUTPUT_ROOT, "synapse-net", "vesicles", "train", name) + _copy_vesicles(train_tomograms, train_out) + + test_out = os.path.join(OUTPUT_ROOT, "synapse-net", "vesicles", "test", name) + _copy_vesicles(test_tomograms, test_out) + + +def _export_az(train_root, test_tomos, name): + tomograms = sorted(glob(os.path.join(train_root, "*.h5"))) + print(f"AZ data for {name}:") + + train_out = os.path.join(OUTPUT_ROOT, "synapse-net", "active_zones", "train", name) + test_out = os.path.join(OUTPUT_ROOT, "synapse-net", "active_zones", "test", name) + + os.makedirs(train_out, exist_ok=True) + os.makedirs(test_out, exist_ok=True) + + for tomo in tqdm(tomograms): + fname = os.path.basename(tomo) + if tomo in test_tomos: + out_path = os.path.join(test_out, fname) + else: + out_path = os.path.join(train_out, fname) + if os.path.exists(out_path): + continue + + with h5py.File(tomo, "r") as f: + raw = f["raw"][:] + az = f["labels/AZ"][:] + + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw, compression="gzip") + f.create_dataset("labels/AZ", data=az, compression="gzip") + + +# NOTE: we have very few mito annotations from 01, so we don't include them in here. +def prepare_single_ax_stem_chemical_fix(): + # single-axis-tem: vesicles + train_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/01_hoi_maus_2020_incomplete" # noqa + test_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/01_hoi_maus_2020_incomplete" # noqa + _export_vesicles(train_root, test_root, name="single_axis_tem") + + # single-axis-tem: active zones + train_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/01_hoi_maus_2020_incomplete" # noqa + test_tomos = [ + "WT_MF_DIV28_01_MS_09204_F1.h5", "WT_MF_DIV14_01_MS_B2_09175_CA3.h5", "M13_CTRL_22723_O2_05_DIV29_5.2.h5", "WT_Unt_SC_09175_D4_05_DIV14_mtk_05.h5", # noqa + "20190805_09002_B4_SC_11_SP.h5", "20190807_23032_D4_SC_01_SP.h5", "M13_DKO_22723_A1_03_DIV29_03_MS.h5", "WT_MF_DIV28_05_MS_09204_F1.h5", "M13_CTRL_09201_S2_06_DIV31_06_MS.h5", # noqa + "WT_MF_DIV28_1.2_MS_09002_B1.h5", "WT_Unt_SC_09175_C4_04_DIV15_mtk_04.h5", "M13_DKO_22723_A4_10_DIV29_10_MS.h5", "WT_MF_DIV14_3.2_MS_D2_09175_CA3.h5", # noqa + "20190805_09002_B4_SC_10_SP.h5", "M13_CTRL_09201_S2_02_DIV31_02_MS.h5", "WT_MF_DIV14_04_MS_E1_09175_CA3.h5", "WT_MF_DIV28_10_MS_09002_B3.h5", "WT_Unt_SC_05646_D4_02_DIV16_mtk_02.h5", "M13_DKO_22723_A4_08_DIV29_08_MS.h5", "WT_MF_DIV28_04_MS_09204_M1.h5", "WT_MF_DIV28_03_MS_09204_F1.h5", "M13_DKO_22723_A1_05_DIV29_05_MS.h5", # noqa + "WT_Unt_SC_09175_C4_06_DIV15_mtk_06.h5", "WT_MF_DIV28_09_MS_09002_B3.h5", "20190524_09204_F4_SC_07_SP.h5", + "WT_MF_DIV14_02_MS_C2_09175_CA3.h5", "M13_DKO_23037_K1_01_DIV29_01_MS.h5", "WT_Unt_SC_09175_E2_01_DIV14_mtk_01.h5", "20190807_23032_D4_SC_05_SP.h5", "WT_MF_DIV14_01_MS_E2_09175_CA3.h5", "WT_MF_DIV14_03_MS_B2_09175_CA3.h5", "M13_DKO_09201_O1_01_DIV31_01_MS.h5", "M13_DKO_09201_U1_04_DIV31_04_MS.h5", # noqa + "WT_MF_DIV14_04_MS_E2_09175_CA3_2.h5", "WT_Unt_SC_09175_D5_01_DIV14_mtk_01.h5", + "M13_CTRL_22723_O2_05_DIV29_05_MS_.h5", "WT_MF_DIV14_02_MS_B2_09175_CA3.h5", "WT_MF_DIV14_01.2_MS_D1_09175_CA3.h5", # noqa + ] + _export_az(train_root, test_tomos, name="single_axis_tem") + + # chemical_fixation: vesicles + train_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/12_chemical_fix_cryopreparation" # noqa + test_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets/12_chemical_fix_cryopreparation" # noqa + _export_vesicles(train_root, test_root, name="chemical_fixation") + + # chemical-fixation: active zones + train_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/12_chemical_fix_cryopreparation" # noqa + test_tomos = ["20180305_09_MS.h5", "20180305_04_MS.h5", "20180305_08_MS.h5", + "20171113_04_MS.h5", "20171006_05_MS.h5", "20180305_01_MS.h5"] + _export_az(train_root, test_tomos, name="chemical_fixation") + + +def prepare_ier(): + root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser/other_tomograms" + sets = { + "01_vesicle_pools": "vesicle_pools", + "02_tether": "tether", + "03_ratten_tomos": "rat", + } + + output_folder = os.path.join(OUTPUT_ROOT, "IER") + label_names = { + "ribbons": "ribbon", + "membrane": "membrane", + "presynapse": "PD", + "postsynapse": "PSD", + "vesicles": "vesicles", + } + + for name, output_name in sets.items(): + out_set = os.path.join(output_folder, output_name) + os.makedirs(out_set, exist_ok=True) + tomos = sorted(glob(os.path.join(root, name, "*.h5"))) + + print("Export", output_name) + for tomo in tqdm(tomos): + with h5py.File(tomo, "r") as f: + try: + fname = os.path.split(f.attrs["filename"])[1][:-4] + except KeyError: + fname = f.attrs["path"][1] + fname = "_".join(fname.split("/")[-2:]) + + out_path = os.path.join(out_set, os.path.basename(tomo)) + if os.path.exists(out_path): + continue + + raw = f["raw"][:] + labels = {} + for label_name, out_name in label_names.items(): + key = f"labels/{label_name}" + if key not in f: + continue + labels[out_name] = f[key][:] + + with h5py.File(out_path, "a") as f: + f.attrs["filename"] = fname + f.create_dataset("raw", data=raw, compression="gzip") + for label_name, seg in labels.items(): + f.create_dataset(f"labels/{label_name}", data=seg, compression="gzip") + + +def prepare_frog(): + root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/rizzoli/extracted" + train_tomograms = [ + "block10U3A_three.h5", "block30UB_one_two.h5", "block30UB_two.h5", "block10U3A_one.h5", + "block184B_one.h5", "block30UB_three.h5", "block10U3A_two.h5", "block30UB_four.h5", + "block30UB_one.h5", "block10U3A_five.h5", + ] + test_tomograms = ["block10U3A_four.h5", "block30UB_five.h5"] + + output_folder = os.path.join(OUTPUT_ROOT, "frog") + output_train = os.path.join(output_folder, "train_unlabeled") + os.makedirs(output_train, exist_ok=True) + + for name in train_tomograms: + path = os.path.join(root, name) + out_path = os.path.join(output_train, name) + if os.path.exists(out_path): + continue + copyfile(path, out_path) + + output_test = os.path.join(output_folder, "test") + os.makedirs(output_test, exist_ok=True) + for name in test_tomograms: + path = os.path.join(root, name) + out_path = os.path.join(output_test, name) + if os.path.exists(out_path): + continue + copyfile(path, out_path) + + +def prepare_2d_tem(): + train_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/2D_data/maus_2020_tem2d_wt_unt_div14_exported_scaled/good_for_DAtraining/maus_2020_tem2d_wt_unt_div14_exported_scaled" # noqa + test_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicle_gt_2d/maus_2020_tem2d" # noqa + train_images = [ + "MF_05649_P-09175-E_06.h5", "MF_05646_C-09175-B_001B.h5", "MF_05649_P-09175-E_07.h5", + "MF_05649_G-09175-C_001.h5", "MF_05646_C-09175-B_002.h5", "MF_05649_G-09175-C_04.h5", + "MF_05649_P-09175-E_05.h5", "MF_05646_C-09175-B_000.h5", "MF_05646_C-09175-B_001.h5" + ] + test_images = [ + "MF_05649_G-09175-C_04B.h5", "MF_05646_C-09175-B_000B.h5", + "MF_05649_G-09175-C_03.h5", "MF_05649_G-09175-C_02.h5" + ] + print(len(train_images) + len(test_images)) + + output_folder = os.path.join(OUTPUT_ROOT, "2d_tem") + + output_train = os.path.join(output_folder, "train_unlabeled") + os.makedirs(output_train, exist_ok=True) + for name in tqdm(train_images, desc="Export train images"): + out_path = os.path.join(output_train, name) + if os.path.exists(out_path): + continue + in_path = os.path.join(train_root, name) + with h5py.File(in_path, "r") as f: + raw = f["raw"][:] + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw, compression="gzip") + + output_test = os.path.join(output_folder, "test") + os.makedirs(output_test, exist_ok=True) + for name in tqdm(test_images, desc="Export test images"): + out_path = os.path.join(output_test, name) + if os.path.exists(out_path): + continue + in_path = os.path.join(test_root, name) + with h5py.File(in_path, "r") as f: + raw = f["data"][:] + labels = f["labels/vesicles"][:] + mask = f["labels/mask"][:] + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw, compression="gzip") + f.create_dataset("labels/vesicles", data=labels, compression="gzip") + f.create_dataset("labels/mask", data=mask, compression="gzip") + + +def prepare_munc_snap(): + pass + + +def main(): + prepare_single_ax_stem_chemical_fix() + # prepare_2d_tem() + # prepare_frog() + # prepare_ier() + # prepare_munc_snap() + + +if __name__ == "__main__": + main() diff --git a/scripts/summarize_data.py b/scripts/summarize_data.py new file mode 100644 index 0000000..5c62f17 --- /dev/null +++ b/scripts/summarize_data.py @@ -0,0 +1,210 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + + +az_train = pd.read_excel("data_summary/active_zone_training_data.xlsx") +compartment_train = pd.read_excel("data_summary/compartment_training_data.xlsx") +mito_train = pd.read_excel("data_summary/mitochondria.xlsx") +vesicle_train = pd.read_excel("data_summary/vesicle_training_data.xlsx") +vesicle_da = pd.read_excel("data_summary/vesicle_domain_adaptation_data.xlsx", sheet_name="cryo") + + +def training_resolutions(): + res_az = np.round(az_train["resolution"].mean(), 2) + res_compartment = np.round(compartment_train["resolution"].mean(), 2) + res_cryo = np.round(vesicle_da["resolution"].mean(), 2) + res_vesicles = np.round(vesicle_train["resolution"].mean(), 2) + res_mitos = np.round(mito_train["resolution"].mean(), 2) + + print("Training resolutions for models:") + print("active_zone:", res_az) + print("compartments:", res_compartment) + print("mitochondria:", 1.0) + print("vesicles_2d:", res_vesicles) + print("vesicles_3d:", res_vesicles) + print("vesicles_cryo:", res_cryo) + print("mito:", res_mitos) + # TODO inner ear + + +def pie_chart(data, count_col, title): + # Plot the pie chart + fig, ax = plt.subplots(figsize=(8, 6)) + wedges, texts, autotexts = plt.pie( + data[count_col], + labels=None, + # labels=data["Condition"], + autopct="%1.1f%%", # Display percentages + startangle=90, # Start at the top + colors=plt.cm.Paired.colors[:len(data)], # Optional: Custom color palette + textprops={"fontsize": 16} + ) + + ax.legend( + handles=wedges, # Use the wedges from the pie chart + labels=data["Condition"].values.tolist(), # Use categories for labels + loc="center left", + bbox_to_anchor=(1, 0.5), # Position the legend outside the chart + fontsize=14, + # title="" + ) + + for autot in autotexts: + autot.set_fontsize(18) + + plt.title(title, fontsize=18) + plt.tight_layout() + plt.show() + + +def summarize_vesicle_train_data(): + condition_summary = { + "Condition": [], + "Tomograms": [], + "Vesicles": [], + } + + conditions = pd.unique(vesicle_train.condition) + for condition in conditions: + ctab = vesicle_train[vesicle_train.condition == condition] + n_tomos = len(ctab) + n_vesicles_all = ctab["vesicle_count_all"].sum() + n_vesicles_imod = ctab["vesicle_count_imod"].sum() + print(condition) + print("Tomograms:", n_tomos) + print("All-Vesicles:", n_vesicles_all) + print("Vesicles-From-Manual:", n_vesicles_imod) + print() + if condition != "Chemical Fixation": + condition += " Tomo" + condition_summary["Condition"].append(condition) + condition_summary["Tomograms"].append(n_tomos) + condition_summary["Vesicles"].append(n_vesicles_all) + condition_summary = pd.DataFrame(condition_summary) + print() + print() + + print("Total:") + print("Tomograms:", len(vesicle_train)) + print("All-Vesicles:", vesicle_train["vesicle_count_all"].sum()) + print("Vesicles-From-Manual:", vesicle_train["vesicle_count_imod"].sum()) + print() + + train_tomos = vesicle_train[vesicle_train.used_for == "train/val"] + print("Training:") + print("Tomograms:", len(train_tomos)) + print("All-Vesicles:", train_tomos["vesicle_count_all"].sum()) + print("Vesicles-From-Manual:", train_tomos["vesicle_count_imod"].sum()) + print() + + test_tomos = vesicle_train[vesicle_train.used_for == "test"] + print("Test:") + print("Tomograms:", len(test_tomos)) + print("All-Vesicles:", test_tomos["vesicle_count_all"].sum()) + print("Vesicles-From-Manual:", test_tomos["vesicle_count_imod"].sum()) + + pie_chart(condition_summary, "Tomograms", "Tomograms per Condition") + pie_chart(condition_summary, "Vesicles", "Vesicles per Condition") + + +def summarize_vesicle_da(): + for name in ("inner_ear", "endbulb", "cryo", "frog", "maus_2d"): + tab = pd.read_excel("data_summary/vesicle_domain_adaptation_data.xlsx", sheet_name=name) + print(name) + print("N-tomograms:", len(tab)) + print("N-test:", (tab["used_for"] == "test").sum()) + print("N-vesicles:", tab["vesicle_count"].sum()) + print() + + +def summarize_az_train(): + conditions = pd.unique(az_train.condition) + print(conditions) + + print("Total:") + print("Tomograms:", len(az_train)) + print("Active Zones:", az_train["az_count"].sum()) + print() + + train_tomos = az_train[az_train.used_for == "train/val"] + print("Training:") + print("Tomograms:", len(train_tomos)) + print("Active Zones:", train_tomos["az_count"].sum()) + print() + + test_tomos = az_train[az_train.used_for == "test"] + print("Test:") + print("Tomograms:", len(test_tomos)) + print("Active Zones:", test_tomos["az_count"].sum()) + + +def summarize_compartment_train(): + conditions = pd.unique(compartment_train.condition) + print(conditions) + + print("Total:") + print("Tomograms:", len(compartment_train)) + print("Compartments:", compartment_train["compartment_count"].sum()) + print() + + train_tomos = compartment_train[compartment_train.used_for == "train/val"] + print("Training:") + print("Tomograms:", len(train_tomos)) + print("Compartments:", train_tomos["compartment_count"].sum()) + print() + + test_tomos = compartment_train[compartment_train.used_for == "test"] + print("Test:") + print("Tomograms:", len(test_tomos)) + print("Compartments:", test_tomos["compartment_count"].sum()) + + +def summarize_inner_ear_data(): + # NOTE: this is not all trainig data, but the data on which we run the analysis + # New tomograms from Sophia. + n_tomos_sophia_tot = 87 + n_tomos_sophia_manual = 33 # noqa + # This is the training data + n_tomos_sohphia_train = "" # TODO # noqa + + # Published tomograms + n_tomos_rat = 19 + n_tomos_tether = 3 + n_tomos_ves_pool = 6 + + # 28 + print("Total published:", n_tomos_rat + n_tomos_tether + n_tomos_ves_pool) + # 115 + print("Total:", n_tomos_rat + n_tomos_tether + n_tomos_ves_pool + n_tomos_sophia_tot) + + +def summarize_mitos(): + conditions = pd.unique(mito_train.condition) + print(conditions) + + print("Total:") + print("Tomograms:", len(mito_train)) + print("Mitos:", mito_train["mito_count_all"].sum()) + print() + + train_tomos = mito_train[mito_train.used_for == "train/val"] + print("Training:") + print("Tomograms:", len(train_tomos)) + print("Mitos:", train_tomos["mito_count_all"].sum()) + print() + + test_tomos = mito_train[mito_train.used_for == "test"] + print("Test:") + print("Tomograms:", len(test_tomos)) + print("Mitos:", test_tomos["mito_count_all"].sum()) + + +# training_resolutions() +# summarize_vesicle_train_data() +# summarize_vesicle_da() +# summarize_az_train() +summarize_compartment_train() +# summarize_inner_ear_data() +# summarize_inner_ear_data() +# summarize_mitos() diff --git a/setup.py b/setup.py index 4b898b3..3722b0d 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,8 @@ entry_points={ "console_scripts": [ "synapse_net.run_segmentation = synaptic_reconstruction.tools.cli:segmentation_cli", + "synapse_net.export_to_imod_points = synaptic_reconstruction.tools.cli:imod_point_cli", + "synapse_net.export_to_imod_objects = synaptic_reconstruction.tools.cli:imod_object_cli", ], "napari.manifest": [ "synaptic_reconstruction = synaptic_reconstruction:napari.yaml", diff --git a/synaptic_reconstruction/distance_measurements.py b/synaptic_reconstruction/distance_measurements.py index cbbc48c..4cf3181 100644 --- a/synaptic_reconstruction/distance_measurements.py +++ b/synaptic_reconstruction/distance_measurements.py @@ -1,6 +1,6 @@ import os import multiprocessing as mp -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -18,7 +18,25 @@ skfmm = None -def compute_geodesic_distances(segmentation, distance_to, resolution=None, unsigned=True): +def compute_geodesic_distances( + segmentation: np.ndarray, + distance_to: np.ndarray, + resolution: Optional[Union[int, float, Tuple[int, int, int]]] = None, + unsigned: bool = True, +) -> np.ndarray: + """Compute the geodesic distances between a segmentation and a distance target. + + This function require scikit-fmm to be installed. + + Args: + segmentation: The binary segmentation. + distance_to: The binary distance target. + resolution: The voxel size of the data, used to scale the distances. + unsigned: Whether to return the unsigned or signed distances. + + Returns: + Array with the geodesic distance values. + """ assert skfmm is not None, "Please install scikit-fmm to use compute_geodesic_distance." invalid = segmentation == 0 @@ -43,14 +61,12 @@ def compute_geodesic_distances(segmentation, distance_to, resolution=None, unsig return distances -# TODO update this def _compute_centroid_distances(segmentation, resolution, n_neighbors): - # TODO enable eccentricity centers instead props = regionprops(segmentation) centroids = np.array([prop.centroid for prop in props]) if resolution is not None: - pass # TODO scale the centroids - + scale_factor = np.array(resolution)[:, None] + centroids *= scale_factor pair_distances = pairwise_distances(centroids) return pair_distances @@ -313,11 +329,13 @@ def create_pairwise_distance_lines( endpoints1: One set of distance end points. endpoints2: The other set of distance end points. seg_ids: The segmentation pair corresponding to each distance. - n_neighbors: ... - pairs: ... - bb: .... - scale: ... - remove_duplicates: ... + n_neighbors: The number of nearest neighbors to take into consideration + for creating the distance lines. + pairs: Optional list of ids to use for creating the distance lines. + bb: Bounding box for restricing the distance line creation. + scale: Scale factor for resizing the distance lines. + Use this if the corresponding segmentations were downscaled for visualization. + remove_duplicates: Remove duplicate id pairs from the distance lines. Returns: The lines for plotting in napari. @@ -386,8 +404,10 @@ def create_object_distance_lines( endpoints1: One set of distance end points. endpoints2: The other set of distance end points. seg_ids: The segmentation ids corresponding to each distance. - max_distance: ... - scale: ... + max_distance: Maximal distance for drawing the distance line. + filter_seg_ids: Segmentation ids to restrict the distance lines. + scale: Scale factor for resizing the distance lines. + Use this if the corresponding segmentations were downscaled for visualization. Returns: The lines for plotting in napari. @@ -416,13 +436,32 @@ def create_object_distance_lines( return lines, properties -def keep_direct_distances(segmentation, measurement_path, line_dilation=0, scale=None): - """Filter out all distances that are not direct. - I.e. distances that cross another segmented object. - """ +def keep_direct_distances( + segmentation: np.ndarray, + distances: np.ndarray, + endpoints1: np.ndarray, + endpoints2: np.ndarray, + seg_ids: np.ndarray, + line_dilation: int = 0, + scale: Optional[Tuple[int, int, int]] = None, +) -> List[List[int]]: + """Filter out all distances that are not direct; distances that are occluded by another segmented object. - distances, ep1, ep2, seg_ids = load_distances(measurement_path) - distance_lines, properties = create_object_distance_lines(distances, ep1, ep2, seg_ids, scale=scale) + Args: + segmentation: The segmentation from which the distances are derived. + distances: The measurd distances. + endpoints1: One set of distance end points. + endpoints2: The other set of distance end points. + seg_ids: The segmentation ids corresponding to each distance. + line_dilation: Dilation factor of the distance lines for determining occlusions. + scale: Scaling factor of the segmentation compared to the distance measurements. + + Returns: + The list of id pairs that are kept. + """ + distance_lines, properties = create_object_distance_lines( + distances, endpoints1, endpoints2, seg_ids, scale=scale + ) ids_a, ids_b = properties["id_a"], properties["id_b"] filtered_ids_a, filtered_ids_b = [], [] @@ -459,10 +498,35 @@ def keep_direct_distances(segmentation, measurement_path, line_dilation=0, scale def filter_blocked_segmentation_to_object_distances( - segmentation, measurement_path, line_dilation=0, scale=None, seg_ids=None, verbose=False, -): - distances, ep1, ep2, seg_ids = load_distances(measurement_path) - distance_lines, properties = create_object_distance_lines(distances, ep1, ep2, seg_ids, scale=scale) + segmentation: np.ndarray, + distances: np.ndarray, + endpoints1: np.ndarray, + endpoints2: np.ndarray, + seg_ids: np.ndarray, + line_dilation: int = 0, + scale: Optional[Tuple[int, int, int]] = None, + filter_seg_ids: Optional[List[int]] = None, + verbose: bool = False, +) -> List[int]: + """Filter out all distances that are not direct; distances that are occluded by another segmented object. + + Args: + segmentation: The segmentation from which the distances are derived. + distances: The measurd distances. + endpoints1: One set of distance end points. + endpoints2: The other set of distance end points. + seg_ids: The segmentation ids corresponding to each distance. + line_dilation: Dilation factor of the distance lines for determining occlusions. + scale: Scaling factor of the segmentation compared to the distance measurements. + filter_seg_ids: Segmentation ids to restrict the distance lines. + verbose: Whether to print progressbar. + + Returns: + The list of id pairs that are kept. + """ + distance_lines, properties = create_object_distance_lines( + distances, endpoints1, endpoints2, seg_ids, scale=scale + ) all_seg_ids = properties["id"] filtered_ids = [] diff --git a/synaptic_reconstruction/file_utils.py b/synaptic_reconstruction/file_utils.py index fc572c8..d88a31d 100644 --- a/synaptic_reconstruction/file_utils.py +++ b/synaptic_reconstruction/file_utils.py @@ -1,7 +1,17 @@ import os +from typing import List, Optional, Union -def get_data_path(folder, n_tomograms=1): +def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, List[str]]: + """Get the path to all tomograms stored as .rec or .mrc files in a folder. + + Args: + folder: The folder with tomograms. + n_tomograms: The expected number of tomograms. + + Returns: + The filepath or list of filepaths of the tomograms in the folder. + """ file_names = os.listdir(folder) tomograms = [] for fname in file_names: @@ -11,7 +21,5 @@ def get_data_path(folder, n_tomograms=1): if n_tomograms is None: return tomograms - assert len(tomograms) == n_tomograms, f"{folder}: {len(tomograms)}, {n_tomograms}" - return tomograms[0] if n_tomograms == 1 else tomograms diff --git a/synaptic_reconstruction/ground_truth/.gitignore b/synaptic_reconstruction/ground_truth/.gitignore new file mode 100644 index 0000000..5c4b094 --- /dev/null +++ b/synaptic_reconstruction/ground_truth/.gitignore @@ -0,0 +1 @@ +edge_filter.py diff --git a/synaptic_reconstruction/ground_truth/matching.py b/synaptic_reconstruction/ground_truth/matching.py index ecfa88f..f085ee5 100644 --- a/synaptic_reconstruction/ground_truth/matching.py +++ b/synaptic_reconstruction/ground_truth/matching.py @@ -4,7 +4,21 @@ from skimage.segmentation import relabel_sequential -def find_additional_objects(ground_truth, segmentation, matching_threshold=0.5): +def find_additional_objects( + ground_truth: np.ndarray, + segmentation: np.ndarray, + matching_threshold: float = 0.5 +) -> np.ndarray: + """Compare ground-truth annotations with a segmentation to find objects not in the annotation. + + Args: + ground_trut: + segmentation: + matching_threshold: + + Returns: + """ + segmentation = relabel_sequential(segmentation)[0] # Match the objects in the segmentation to the ground-truth. diff --git a/synaptic_reconstruction/ground_truth/region_extraction.py b/synaptic_reconstruction/ground_truth/region_extraction.py deleted file mode 100644 index c246b0d..0000000 --- a/synaptic_reconstruction/ground_truth/region_extraction.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import Optional - -import numpy as np -from sklearn.decomposition import PCA -from scipy.ndimage import affine_transform - - -def rotate_3d_array(arr, rotation_matrix, center, order): - # Translate the array to center it at the origin - translation_to_origin = np.eye(4) - translation_to_origin[:3, 3] = -center - - # Translation back to original position - translation_back = np.eye(4) - translation_back[:3, 3] = center - - # Construct the full transformation matrix: Translation -> Rotation -> Translation back - transformation_matrix = np.eye(4) - transformation_matrix[:3, :3] = rotation_matrix # Apply the PCA rotation - - # Combine the transformations: T_back * R * T_origin - full_transformation = translation_back @ transformation_matrix @ translation_to_origin - - # Apply affine_transform (we extract the 3x3 rotation matrix and the translation vector) - rotated_arr = affine_transform( - arr, - full_transformation[:3, :3], # Rotation part - offset=full_transformation[:3, 3], # Translation part - output_shape=arr.shape, # Keep output shape the same - order=order - ) - return rotated_arr - - -# Find the rotation that aligns the data with the PCA -def _find_rotation(segmentation): - foreground_coords = np.argwhere(segmentation > 0) - - pca = PCA(n_components=3) - pca.fit(foreground_coords) - - rotation_matrix = pca.components_ - - return rotation_matrix - - -def extract_and_align_foreground( - segmentation: np.ndarray, - raw: Optional[np.ndarray] = None, - extract_bb: bool = True, -): - """Extract and align the bounding box containing foreground from the segmentation. - - This function will find the closest fitting, non-axis-aligned rectangular bounding box - that contains the segmentation foreground. It will then rotate the data, so that it is - axis-aligned. - - Args: - segmentation: The input segmentation. - raw: The raw data. - extract_bb: Whether to cout out the bounding box. - - Returns: - TODO - """ - rotation_matrix = _find_rotation(segmentation) - - # Calculate the center of the original array. - center = np.array(segmentation.shape) / 2.0 - - # Rotate the array. - segmentation = rotate_3d_array(segmentation, rotation_matrix, center, order=0) - - if extract_bb: - bb = np.where(segmentation != 0) - bb = tuple( - slice(int(b.min()), int(b.max()) + 1) for b in bb - ) - else: - bb = np.s_[:] - - if raw is not None: - raw = rotate_3d_array(raw, rotation_matrix, center, order=1) - - if raw is not None: - return segmentation[bb], raw[bb] - - return segmentation[bb] - - -if __name__ == "__main__": - import h5py - import napari - - segmentation_path = "tomogram-000.h5" - - with h5py.File(segmentation_path, "r") as f: - raw = f["/raw"][:] - segmentation = f["/labels/vesicles"][:] - - segmentation, raw = extract_and_align_foreground(segmentation, raw) - - v = napari.Viewer() - v.add_image(raw) - v.add_labels(segmentation) - napari.run() diff --git a/synaptic_reconstruction/ground_truth/shape_refinement.py b/synaptic_reconstruction/ground_truth/shape_refinement.py index 96d1f2a..8c357ae 100644 --- a/synaptic_reconstruction/ground_truth/shape_refinement.py +++ b/synaptic_reconstruction/ground_truth/shape_refinement.py @@ -52,8 +52,9 @@ def edge_filter( - "sato": Edges are found with a sato-filter, followed by smoothing and leveling. per_slice: Compute the filter per slice instead of for the whole volume. n_threads: Number of threads for parallel computation over the slices. + Returns: - Volume with edge strength. + Edge filter response. """ if method not in FILTERS: raise ValueError(f"Invalid edge filter method: {method}. Expect one of {FILTERS}.") @@ -100,6 +101,7 @@ def check_filters( The filter names must match `method` in `edge_filter`. sigmas: The sigma values to use for the filters. show: Whether to show the filter responses in napari. + Returns: Dictionary with the filter responses. """ @@ -153,6 +155,7 @@ def refine_vesicle_shapes( return_seeds: Whether to return the seeds used for the watershed. compactness: The compactness parameter passed to the watershed function. Higher compactness leads to more regular sized vesicles. + Returns: The refined vesicles. """ diff --git a/synaptic_reconstruction/imod/export.py b/synaptic_reconstruction/imod/export.py index 0607484..5457bf9 100644 --- a/synaptic_reconstruction/imod/export.py +++ b/synaptic_reconstruction/imod/export.py @@ -1,7 +1,7 @@ import shutil import tempfile from subprocess import run -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple import imageio.v3 as imageio import numpy as np @@ -144,7 +144,23 @@ def export_segmentation( imageio.imwrite(output_path, segmentation.astype("uint8"), compression="zlib") -def draw_spheres(coordinates, radii, shape, verbose=True): +def draw_spheres( + coordinates: np.ndarray, + radii: np.ndarray, + shape: Tuple[int, int, int], + verbose: bool = True, +) -> np.ndarray: + """Create a volumetric segmentation by painting spheres around the given coordinates. + + Args: + coordinates: The center coordinates of the spheres. + radii: The radii of the spheres. + shape: The shape of the volume. + verbose: Whether to print the progress bar. + + Returns: + The segmentation volume with painted spheres. + """ labels = np.zeros(shape, dtype="uint32") for label_id, (coord, radius) in tqdm( enumerate(zip(coordinates, radii), start=1), total=len(coordinates), disable=not verbose @@ -166,10 +182,35 @@ def draw_spheres(coordinates, radii, shape, verbose=True): def load_points_from_imodinfo( - imod_path, full_shape, bb=None, - exclude_labels=None, exclude_label_patterns=None, - resolution=None, -): + imod_path: str, + full_shape: Tuple[int, int, int], + bb: Optional[Tuple[slice, slice, slice]] = None, + exclude_labels: Optional[List[int]] = None, + exclude_label_patterns: Optional[List[str]] = None, + resolution: Optional[float] = None, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict[int, str]]: + """Load point coordinates, radii and label information from a .mod file. + + The coordinates and sizes returned will be scaled so that they are in + the voxel coordinate space if the 'resolution' parameter is passed. + If it is not passed then the radius will be returned in the physical resolution. + + Args: + imod_path: The filepath to the .mod file. + full_shape: The voxel shape of the volume. + bb: Optional bounding box to limit the extracted points to. + exclude_labels: Label ids to exclude from the export. + This can be used to exclude specific labels / classes, specifying them by their id. + exclude_label_patterns: Label names to exclude from the export. + This can be used to exclude specific labels / classes, specifying them by their name. + resolution: The resolution / voxel size of the data. Will be used to scale the radii. + + Returns: + The center coordinates of the sphere annotations. + The radii of the spheres. + The ids of the semantic labels. + The names of the semantic labels. + """ coordinates, sizes, labels = [], [], [] label_names = {} @@ -274,14 +315,33 @@ def load_points_from_imodinfo( def export_point_annotations( - imod_path, - shape, - bb=None, - exclude_labels=None, - exclude_label_patterns=None, - return_coords_and_radii=False, - resolution=None, -): + imod_path: str, + shape: Tuple[int, int, int], + bb: Optional[Tuple[slice, slice, slice]] = None, + exclude_labels: Optional[List[int]] = None, + exclude_label_patterns: Optional[List[str]] = None, + return_coords_and_radii: bool = False, + resolution: Optional[float] = None, +) -> Tuple[np.ndarray, np.ndarray, Dict[int, str]]: + """Create a segmentation by drawing spheres corresponding to objects from a .mod file. + + Args: + imod_path: The filepath to the .mod file. + shape: The voxel shape of the volume. + bb: Optional bounding box to limit the extracted points to. + exclude_labels: Label ids to exclude from the segmentation. + This can be used to exclude specific labels / classes, specifying them by their id. + exclude_label_patterns: Label names to exclude from the segmentation. + This can be used to exclude specific labels / classes, specifying them by their name. + return_coords_and_radii: Whether to also return the underlying coordinates + and radii of the exported spheres. + resolution: The resolution / voxel size of the data. Will be used to scale the radii. + + Returns: + The exported segmentation. + The label ids for the instance ids in the segmentation. + The map of label ids to corresponding obejct names. + """ coordinates, radii, labels, label_names = load_points_from_imodinfo( imod_path, shape, bb=bb, exclude_labels=exclude_labels, diff --git a/synaptic_reconstruction/imod/to_imod.py b/synaptic_reconstruction/imod/to_imod.py index e8095ee..5832213 100644 --- a/synaptic_reconstruction/imod/to_imod.py +++ b/synaptic_reconstruction/imod/to_imod.py @@ -8,6 +8,7 @@ from subprocess import run from typing import Optional, Tuple, Union +import h5py import imageio.v3 as imageio import mrcfile import numpy as np @@ -16,51 +17,71 @@ from tqdm import tqdm -# FIXME how to bring the data to the IMOD axis convention? -def _to_imod_order(data): - # data = np.swapaxes(data, 0, -1) - # data = np.fliplr(data) - # data = np.swapaxes(data, 0, -1) - return data +def _load_segmentation(segmentation_path, segmentation_key): + assert os.path.exists(segmentation_path), segmentation_path + if segmentation_key is None: + seg = imageio.imread(segmentation_path) + else: + with h5py.File(segmentation_path, "r") as f: + seg = f[segmentation_key][:] + return seg +# TODO: this has still some issues with some tomograms that has an offset info. +# For now, this occurs for the inner ear data tomograms; it works for Fidi's STEM tomograms. +# Ben's theory is that this might be due to data form JEOL vs. ThermoFischer microscopes. +# To test this I can check how it works for data from Maus et al. / Imig et al., which were taken on a JEOL. +# Can also check out the mrc documentation here: https://www.ccpem.ac.uk/mrc_format/mrc2014.php def write_segmentation_to_imod( mrc_path: str, - segmentation_path: str, + segmentation: Union[str, np.ndarray], output_path: str, + segmentation_key: Optional[str] = None, ) -> None: - """Write a segmentation to a mod file as contours. + """Write a segmentation to a mod file as closed contour object(s). Args: - mrc_path: a - segmentation_path: a - output_path: a + mrc_path: The filepath to the mrc file from which the segmentation was derived. + segmentation: The segmentation (either as numpy array or filepath to a .tif file). + output_path: The output path where the mod file will be saved. + segmentation_key: The key to the segmentation data in case the segmentation is stored in hdf5 files. """ cmd = "imodauto" cmd_path = shutil.which(cmd) assert cmd_path is not None, f"Could not find the {cmd} imod command." + # Load the segmentation case a filepath was passed. + if isinstance(segmentation, str): + segmentation = _load_segmentation(segmentation, segmentation_key) + + # Binarize the segmentation and flip its axes to match the IMOD axis convention. + segmentation = (segmentation > 0).astype("uint8") + segmentation = np.flip(segmentation, axis=1) + + # Read the voxel size and origin information from the mrc file. assert os.path.exists(mrc_path) - with mrcfile.open(mrc_path, mode="r+") as f: + with mrcfile.open(mrc_path, mode="r") as f: voxel_size = f.voxel_size + nx, ny, nz = f.header.nxstart, f.header.nystart, f.header.nzstart + origin = f.header.origin + # Write the input for imodauto to a temporary mrc file. with tempfile.NamedTemporaryFile(suffix=".mrc") as f: tmp_path = f.name - seg = (imageio.imread(segmentation_path) > 0).astype("uint8") - seg_ = _to_imod_order(seg) - - # import napari - # v = napari.Viewer() - # v.add_image(seg) - # v.add_labels(seg_) - # napari.run() - - mrcfile.new(tmp_path, data=seg_, overwrite=True) + mrcfile.new(tmp_path, data=segmentation, overwrite=True) + # Write the voxel_size and origin infomration. with mrcfile.open(tmp_path, mode="r+") as f: f.voxel_size = voxel_size + + f.header.nxstart = nx + f.header.nystart = ny + f.header.nzstart = nz + f.header.origin = (0.0, 0.0, 0.0) * 3 if origin is None else origin + f.update_header_from_data() + # Run the command. cmd_list = [cmd, "-E", "1", "-u", tmp_path, output_path] run(cmd_list) @@ -88,8 +109,8 @@ def convert_segmentation_to_spheres( props: Optional list of regionprops Returns: - np.array: the center coordinates - np.array: the radii + The center coordinates. + The radii. """ num_workers = multiprocessing.cpu_count() if num_workers is None else num_workers if props is None: @@ -116,7 +137,7 @@ def coords_and_rads(prop): if estimate_radius_2d: if resolution: - dists = np.array([distance_transform_edt(ma, sampling=resolution) for ma in mask]) + dists = np.array([distance_transform_edt(ma, sampling=resolution[1:]) for ma in mask]) else: dists = np.array([distance_transform_edt(ma) for ma in mask]) else: @@ -195,11 +216,12 @@ def _pad(inp, n=3): def write_segmentation_to_imod_as_points( mrc_path: str, - segmentation_path: str, + segmentation: Union[str, np.ndarray], output_path: str, min_radius: Union[int, float], radius_factor: float = 1.0, estimate_radius_2d: bool = True, + segmentation_key: Optional[str] = None, ) -> None: """Write segmentation results to .mod file with imod point annotations. @@ -207,13 +229,14 @@ def write_segmentation_to_imod_as_points( Args: mrc_path: Filepath to the mrc volume that was segmented. - segmentation_path: Filepath to the segmentation stored as .tif. + segmentation: The segmentation (either as numpy array or filepath to a .tif file). output_path: Where to save the .mod file. min_radius: Minimum radius for export. radius_factor: Factor for increasing the radius to account for too small exported spheres. estimate_radius_2d: If true the distance to boundary for determining the centroid and computing the radius will be computed only in 2d rather than in 3d. This can lead to better results in case of deformation across the depth axis. + segmentation_key: The key to the segmentation data in case the segmentation is stored in hdf5 files. """ # Read the resolution information from the mrcfile. @@ -224,7 +247,8 @@ def write_segmentation_to_imod_as_points( resolution = [res / 10 for res in resolution] # Extract the center coordinates and radii from the segmentation. - segmentation = imageio.imread(segmentation_path) + if isinstance(segmentation, str): + segmentation = _load_segmentation(segmentation, segmentation_key) coordinates, radii = convert_segmentation_to_spheres( segmentation, resolution=resolution, radius_factor=radius_factor, estimate_radius_2d=estimate_radius_2d ) @@ -233,16 +257,22 @@ def write_segmentation_to_imod_as_points( write_points_to_imod(coordinates, radii, segmentation.shape, min_radius, output_path) -# TODO we also need to support .rec files ... -def _get_file_paths(input_path, ext=".mrc"): +def _get_file_paths(input_path, ext=(".mrc", ".rec")): if not os.path.exists(input_path): - raise Exception(f"Input path not found {input_path}") + raise Exception(f"Input path not found {input_path}.") + + if isinstance(ext, str): + ext = (ext,) if os.path.isfile(input_path): input_files = [input_path] input_root = None else: - input_files = sorted(glob(os.path.join(input_path, "**", f"*{ext}"), recursive=True)) + input_files = [] + for ex in ext: + input_files.extend( + sorted(glob(os.path.join(input_path, "**", f"*{ex}"), recursive=True)) + ) input_root = input_path return input_files, input_root @@ -254,6 +284,7 @@ def export_helper( output_root: str, export_function: callable, force: bool = False, + segmentation_key: Optional[str] = None, ) -> None: """ Helper function to run imod export for files in a directory. @@ -270,9 +301,10 @@ def export_helper( the path to the segmentation in a .tif file and the output path as only arguments. If you want to pass additional arguments to this function the use 'funtools.partial' force: Whether to rerun segmentation for output files that are already present. + segmentation_key: The key to the segmentation data in case the segmentation is stored in hdf5 files. """ input_files, input_root = _get_file_paths(input_path) - segmentation_files, _ = _get_file_paths(segmentation_path, ext=".tif") + segmentation_files, _ = _get_file_paths(segmentation_path, ext=".tif" if segmentation_key is None else ".h5") assert len(input_files) == len(segmentation_files) for input_path, seg_path in tqdm(zip(input_files, segmentation_files), total=len(input_files)): @@ -291,4 +323,4 @@ def export_helper( continue os.makedirs(os.path.split(output_path)[0], exist_ok=True) - export_function(input_path, seg_path, output_path) + export_function(input_path, seg_path, output_path, segmentation_key=segmentation_key) diff --git a/synaptic_reconstruction/inference/actin.py b/synaptic_reconstruction/inference/actin.py index afc6e54..ff4ed17 100644 --- a/synaptic_reconstruction/inference/actin.py +++ b/synaptic_reconstruction/inference/actin.py @@ -24,8 +24,7 @@ def segment_actin( scale: Optional[List[float]] = None, mask: Optional[np.ndarray] = None, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: - """ - Segment actin in an input volume. + """Segment actin in an input volume. Args: input_volume: The input volume to segment. diff --git a/synaptic_reconstruction/inference/active_zone.py b/synaptic_reconstruction/inference/active_zone.py new file mode 100644 index 0000000..216deb9 --- /dev/null +++ b/synaptic_reconstruction/inference/active_zone.py @@ -0,0 +1,121 @@ +import time +from typing import Dict, List, Optional, Tuple, Union + +import elf.parallel as parallel +import numpy as np +import torch + +from skimage.segmentation import find_boundaries +from synaptic_reconstruction.inference.util import get_prediction, _Scaler + + +def find_intersection_boundary(segmented_AZ: np.ndarray, segmented_compartment: np.ndarray) -> np.ndarray: + """Find the intersection of the boundaries of each objects in segmented_compartment with segmented_AZ. + + Args: + segmented_AZ: 3D array representing the active zone (AZ). + segmented_compartment: 3D array representing the compartment, with multiple labels. + + Returns: + Array with the cumulative intersection of all boundaries of segmented_compartment labels with segmented_AZ. + """ + # Step 0: Initialize an empty array to accumulate intersections + cumulative_intersection = np.zeros_like(segmented_AZ, dtype=bool) + + # Step 1: Loop through each unique label in segmented_compartment (excluding 0 if it represents background) + labels = np.unique(segmented_compartment) + labels = labels[labels != 0] # Exclude background label (0) if necessary + + for label in labels: + # Step 2: Create a binary mask for the current label + label_mask = (segmented_compartment == label) + + # Step 3: Find the boundary of the current label's compartment + boundary_compartment = find_boundaries(label_mask, mode='outer') + + # Step 4: Find the intersection with the AZ for this label's boundary + intersection = np.logical_and(boundary_compartment, segmented_AZ) + + # Step 5: Accumulate intersections for each label + cumulative_intersection = np.logical_or(cumulative_intersection, intersection) + + return cumulative_intersection.astype(int) # Convert boolean array to int (1 for intersecting points, 0 elsewhere) + + +def _run_segmentation( + foreground, verbose, min_size, + # blocking shapes for parallel computation + block_shape=(128, 256, 256), +): + + # get the segmentation via seeded watershed + t0 = time.time() + seg = parallel.label(foreground > 0.5, block_shape=block_shape, verbose=verbose) + if verbose: + print("Compute connected components in", time.time() - t0, "s") + + # size filter + t0 = time.time() + ids, sizes = parallel.unique(seg, return_counts=True, block_shape=block_shape, verbose=verbose) + filter_ids = ids[sizes < min_size] + seg[np.isin(seg, filter_ids)] = 0 + if verbose: + print("Size filter in", time.time() - t0, "s") + seg = np.where(seg > 0, 1, 0) + return seg + + +def segment_active_zone( + input_volume: np.ndarray, + model_path: Optional[str] = None, + model: Optional[torch.nn.Module] = None, + tiling: Optional[Dict[str, Dict[str, int]]] = None, + min_size: int = 500, + verbose: bool = True, + return_predictions: bool = False, + scale: Optional[List[float]] = None, + mask: Optional[np.ndarray] = None, + compartment: Optional[np.ndarray] = None, +) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """Segment active zones in an input volume. + + Args: + input_volume: The input volume to segment. + model_path: The path to the model checkpoint if `model` is not provided. + model: Pre-loaded model. Either `model_path` or `model` is required. + tiling: The tiling configuration for the prediction. + verbose: Whether to print timing information. + scale: The scale factor to use for rescaling the input volume before prediction. + mask: An optional mask that is used to restrict the segmentation. + compartment: + + Returns: + The foreground mask as a numpy array. + """ + if verbose: + print("Segmenting AZ in volume of shape", input_volume.shape) + # Create the scaler to handle prediction with a different scaling factor. + scaler = _Scaler(scale, verbose) + input_volume = scaler.scale_input(input_volume) + + # Rescale the mask if it was given and run prediction. + if mask is not None: + mask = scaler.scale_input(mask, is_segmentation=True) + pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose) + + # Run segmentation and rescale the result if necessary. + foreground = pred[0] + print(f"shape {foreground.shape}") + + segmentation = _run_segmentation(foreground, verbose=verbose, min_size=min_size) + + # returning prediciton and intersection not possible atm, but currently do not need prediction anyways + if return_predictions: + pred = scaler.rescale_output(pred, is_segmentation=False) + return segmentation, pred + + if compartment is not None: + intersection = find_intersection_boundary(segmentation, compartment) + return segmentation, intersection + + return segmentation diff --git a/synaptic_reconstruction/inference/compartments.py b/synaptic_reconstruction/inference/compartments.py index a822d9f..dd6adf7 100644 --- a/synaptic_reconstruction/inference/compartments.py +++ b/synaptic_reconstruction/inference/compartments.py @@ -157,8 +157,7 @@ def segment_compartments( n_slices_exclude: int = 0, **kwargs, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: - """ - Segment synaptic compartments in an input volume. + """Segment synaptic compartments in an input volume. Args: input_volume: The input volume to segment. diff --git a/synaptic_reconstruction/inference/cristae.py b/synaptic_reconstruction/inference/cristae.py index 467bfd9..ce62dcb 100644 --- a/synaptic_reconstruction/inference/cristae.py +++ b/synaptic_reconstruction/inference/cristae.py @@ -43,8 +43,7 @@ def segment_cristae( scale: Optional[List[float]] = None, mask: Optional[np.ndarray] = None, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: - """ - Segment cristae in an input volume. + """Segment cristae in an input volume. Args: input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria diff --git a/synaptic_reconstruction/inference/mitochondria.py b/synaptic_reconstruction/inference/mitochondria.py index 027b4ed..a95712d 100644 --- a/synaptic_reconstruction/inference/mitochondria.py +++ b/synaptic_reconstruction/inference/mitochondria.py @@ -66,8 +66,7 @@ def segment_mitochondria( scale: Optional[List[float]] = None, mask: Optional[np.ndarray] = None, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: - """ - Segment mitochondria in an input volume. + """Segment mitochondria in an input volume. Args: input_volume: The input volume to segment. diff --git a/synaptic_reconstruction/inference/util.py b/synaptic_reconstruction/inference/util.py index cedfb07..434fb32 100644 --- a/synaptic_reconstruction/inference/util.py +++ b/synaptic_reconstruction/inference/util.py @@ -4,17 +4,18 @@ from glob import glob from typing import Dict, Optional, Tuple -# Suppress annoying import warnings. -with warnings.catch_warnings(): - warnings.simplefilter("ignore") - import bioimageio.core +# # Suppress annoying import warnings. +# with warnings.catch_warnings(): +# warnings.simplefilter("ignore") +# import bioimageio.core import imageio.v3 as imageio import elf.parallel as parallel +import mrcfile import numpy as np import torch import torch_em -import xarray +# import xarray from elf.io import open_file from scipy.ndimage import binary_closing @@ -80,9 +81,8 @@ def get_prediction( verbose: bool = True, with_channels: bool = False, mask: Optional[np.ndarray] = None, -): - """ - Run prediction on a given volume. +) -> np.ndarray: + """Run prediction on a given volume. This function will automatically choose the correct prediction implementation, depending on the model type. @@ -94,7 +94,8 @@ def get_prediction( tiling: The tiling configuration for the prediction. verbose: Whether to print timing information. with_channels: Whether to predict with channels. - mask: + mask: Optional binary mask. If given, the prediction will only be run in + the foreground region of the mask. Returns: The predicted volume. @@ -121,10 +122,9 @@ def get_prediction( # Run prediction with the bioimage.io library. if is_bioimageio: - # TODO determine if we use the old or new API and select the corresponding function if mask is not None: raise NotImplementedError - pred = get_prediction_bioimageio_old(input_volume, model_path, tiling, verbose) + raise NotImplementedError # Run prediction with the torch-em library. else: @@ -132,7 +132,7 @@ def get_prediction( # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself. if model_path.endswith("best.pt"): model_path = os.path.split(model_path)[0] - print(f"tiling {tiling}") + # print(f"tiling {tiling}") # Create updated_tiling with the same structure updated_tiling = { "tile": {}, @@ -141,7 +141,7 @@ def get_prediction( # Update tile dimensions for dim in tiling["tile"]: updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim] - print(f"updated_tiling {updated_tiling}") + # print(f"updated_tiling {updated_tiling}") pred = get_prediction_torch_em( input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask ) @@ -149,35 +149,6 @@ def get_prediction( return pred -def get_prediction_bioimageio_old( - input_volume: np.ndarray, # [z, y, x] - model_path: str, - tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} - verbose: bool = True, -): - """ - Run prediction using bioimage.io functionality on a given volume. - - Args: - input_volume: The input volume to predict on. - model_path: The path to the model checkpoint. - tiling: The tiling configuration for the prediction. - verbose: Whether to print timing information. - - Returns: - The predicted volume. - """ - # get foreground and boundary predictions from the model - t0 = time.time() - model = bioimageio.core.load_resource_description(model_path) - with bioimageio.core.create_prediction_pipeline(model) as pp: - input_ = xarray.DataArray(input_volume[None, None], dims=tuple("bczyx")) - pred = bioimageio.core.predict_with_tiling(pp, input_, tiling=tiling, verbose=verbose)[0].squeeze() - if verbose: - print("Prediction time in", time.time() - t0, "s") - return pred - - def get_prediction_torch_em( input_volume: np.ndarray, # [z, y, x] tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} @@ -187,8 +158,7 @@ def get_prediction_torch_em( with_channels: bool = False, mask: Optional[np.ndarray] = None, ) -> np.ndarray: - """ - Run prediction using torch-em on a given volume. + """Run prediction using torch-em on a given volume. Args: input_volume: The input volume to predict on. @@ -197,6 +167,8 @@ def get_prediction_torch_em( tiling: The tiling configuration for the prediction. verbose: Whether to print timing information. with_channels: Whether to predict with channels. + mask: Optional binary mask. If given, the prediction will only be run in + the foreground region of the mask. Returns: The predicted volume. @@ -281,6 +253,33 @@ def _load_input(img_path, extra_files, i): return input_volume +def _derive_scale(img_path, model_resolution): + try: + with mrcfile.open(img_path, "r") as f: + voxel_size = f.voxel_size + if len(model_resolution) == 2: + voxel_size = [voxel_size.y, voxel_size.x] + else: + voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x] + + assert len(voxel_size) == len(model_resolution) + # The voxel size is given in Angstrom and we need to translate it to nanometer. + voxel_size = [vsize / 10 for vsize in voxel_size] + + # Compute the correct scale factor. + scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution)) + print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution) + + except Exception: + warnings.warn( + f"The voxel size could not be read from the data for {img_path}. " + "This data will not be scaled for prediction." + ) + scale = None + + return scale + + def inference_helper( input_path: str, output_root: str, @@ -292,9 +291,10 @@ def inference_helper( mask_input_ext: str = ".tif", force: bool = False, output_key: Optional[str] = None, -): - """ - Helper function to run segmentation for mrc files. + model_resolution: Optional[Tuple[float, float, float]] = None, + scale: Optional[Tuple[float, float, float]] = None, +) -> None: + """Helper function to run segmentation for mrc files. Args: input_path: The path to the input data. @@ -312,7 +312,13 @@ def inference_helper( mask_input_ext: File extension for the mask inputs (by default .tif). force: Whether to rerun segmentation for output files that are already present. output_key: Output key for the prediction. If none will write an hdf5 file. + model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. + If given, the scaling factor will automatically be determined based on the voxel_size of the input data. + scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'. """ + if (scale is not None) and (model_resolution is not None): + raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.") + # Get the input files. If input_path is a folder then this will load all # the mrc files beneath it. Otherwise we assume this is an mrc file already # and just return the path to this mrc file. @@ -332,7 +338,7 @@ def inference_helper( mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext) assert len(input_files) == len(mask_files) - for i, img_path in tqdm(enumerate(input_files), total=len(input_files)): + for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"): # Determine the output file name. input_folder, input_name = os.path.split(img_path) @@ -350,7 +356,12 @@ def inference_helper( # Check if the output path is already present. # If it is we skip the prediction, unless force was set to true. if os.path.exists(output_path) and not force: - continue + if output_key is None: + continue + else: + with open_file(output_path, "r") as f: + if output_key in f: + continue # Load the input volume. If we have extra_files then this concatenates the # data across a new first axis (= channel axis). @@ -358,8 +369,18 @@ def inference_helper( # Load the mask (if given). mask = None if mask_files is None else imageio.imread(mask_files[i]) + # Determine the scale factor: + # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None. + if scale is None and model_resolution is None: + this_scale = None + elif scale is not None: # If 'scale' was passed then use it. + this_scale = scale + else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data + assert model_resolution is not None + this_scale = _derive_scale(img_path, model_resolution) + # Run the segmentation. - segmentation = segmentation_function(input_volume, mask=mask) + segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale) # Write the result to tif or h5. os.makedirs(os.path.split(output_path)[0], exist_ok=True) @@ -373,12 +394,21 @@ def inference_helper( print(f"Saved segmentation to {output_path}.") -def get_default_tiling(): +def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: """Determine the tile shape and halo depending on the available VRAM. + + Args: + is_2d: Whether to return tiling settings for 2d inference. + + Returns: + The default tiling settings for the available computational resources. """ - if torch.cuda.is_available(): - print("Determining suitable tiling") + if is_2d: + tile = {"x": 768, "y": 768, "z": 1} + halo = {"x": 128, "y": 128, "z": 0} + return {"tile": tile, "halo": halo} + if torch.cuda.is_available(): # We always use the same default halo. halo = {"x": 64, "y": 64, "z": 16} # before 64,64,8 @@ -410,19 +440,23 @@ def get_default_tiling(): return tiling -def parse_tiling(tile_shape, halo): - """ - Helper function to parse tiling parameter input from the command line. +def parse_tiling( + tile_shape: Tuple[int, int, int], + halo: Tuple[int, int, int], + is_2d: bool = False, +) -> Dict[str, Dict[str, int]]: + """Helper function to parse tiling parameter input from the command line. Args: tile_shape: The tile shape. If None the default tile shape is used. halo: The halo. If None the default halo is used. + is_2d: Whether to return tiling for a 2d model. Returns: - dict: the tiling specification + The tiling specification. """ - default_tiling = get_default_tiling() + default_tiling = get_default_tiling(is_2d=is_2d) if tile_shape is None: tile_shape = default_tiling["tile"] diff --git a/synaptic_reconstruction/inference/vesicles.py b/synaptic_reconstruction/inference/vesicles.py index 237d95a..884138a 100644 --- a/synaptic_reconstruction/inference/vesicles.py +++ b/synaptic_reconstruction/inference/vesicles.py @@ -134,8 +134,7 @@ def segment_vesicles( exclude_boundary: bool = False, mask: Optional[np.ndarray] = None, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: - """ - Segment vesicles in an input volume. + """Segment vesicles in an input volume or image. Args: input_volume: The input volume to segment. diff --git a/synaptic_reconstruction/morphology.py b/synaptic_reconstruction/morphology.py index 5cfd598..7fbce32 100644 --- a/synaptic_reconstruction/morphology.py +++ b/synaptic_reconstruction/morphology.py @@ -1,13 +1,17 @@ import multiprocessing as mp import warnings from concurrent import futures +from typing import Dict, List, Optional, Tuple import trimesh import numpy as np import pandas as pd -from scipy.ndimage import distance_transform_edt + +from scipy.ndimage import distance_transform_edt, convolve +from skimage.graph import MCP from skimage.measure import regionprops, marching_cubes, find_contours +from skimage.morphology import skeletonize, medial_axis, label from skimage.segmentation import find_boundaries @@ -54,9 +58,30 @@ def dist(input_): def compute_radii( - vesicles, resolution, ids=None, derive_radius_from_distances=True, derive_distances_2d=True, min_size=500 -): + vesicles: np.ndarray, + resolution: Tuple[float, float, float], + ids: Optional[List[int]] = None, + derive_radius_from_distances: bool = True, + derive_distances_2d: bool = True, + min_size: int = 500, +) -> Tuple[List[int], Dict[int, float]]: """Compute the radii for a vesicle segmentation. + + Args: + vesicles: The vesicle segmentation. + resolution: The pixel / voxel size of the data. + ids: Vesicle ids to restrict the radius computation to. + derive_radius_from_distances: Whether to derive the radii + from the distance to the vesicle boundaries, or from the + axis fitted to the vesicle by scikit-image regionprops. + derive_distances_2d: Whether to derive the radii individually in 2d + or in 3d. Deriving the radii in 3d is beneficial for anisotropic + data or data that suffers from the missing wedge effect. + min_size: The minimal size for extracting the radii. + + Returns: + The ids of the extracted radii. + The radii that were computed. """ if derive_radius_from_distances: ids, radii = _compute_radii_distances( @@ -69,65 +94,171 @@ def compute_radii( return ids, radii -# TODO adjust the surface for open vs. closed structures -def compute_object_morphology(object_, structure_name, resolution=None): - """ - Compute the morphology (volume and surface area) of a 2D or 3D object. +def compute_object_morphology( + object_: np.ndarray, + structure_name: str, + resolution: Tuple[float, float, float] = None +) -> pd.DataFrame: + """Compute the volume and surface area of a 2D or 3D object. Args: - object_ (np.ndarray): 2D or 3D binary object array. - structure_name (str): Name of the structure being analyzed. - resolution (tuple): Physical spacing between nm. + object_: 2D or 3D binary object array. + structure_name: Name of the structure being analyzed. + resolution: The pixel / voxel size of the data. Returns: - pd.DataFrame: Morphology information containing volume and surface area. + Morphology information containing volume and surface area. """ if object_.ndim == 2: # Use find_contours for 2D data contours = find_contours(object_, level=0.5) - + # Compute perimeter (total length of all contours) perimeter = sum( np.sqrt(np.sum(np.diff(contour, axis=0)**2, axis=1)).sum() for contour in contours ) - + # Compute area (number of positive pixels) area = np.sum(object_ > 0) - + # Adjust for resolution if provided if resolution is not None: area *= resolution[0] * resolution[1] perimeter *= resolution[0] - + morphology = pd.DataFrame({ "structure": [structure_name], "area [pixel^2]" if resolution is None else "area [nm^2]": [area], "perimeter [pixel]" if resolution is None else "perimeter [nm]": [perimeter], }) - + elif object_.ndim == 3: # Use marching_cubes for 3D data verts, faces, normals, _ = marching_cubes( object_, spacing=(1.0, 1.0, 1.0) if resolution is None else resolution, ) - + mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals) surface = mesh.area if mesh.is_watertight: volume = np.abs(mesh.volume) else: - warnings.warn("Could not compute mesh volume; setting it to NaN.") + warnings.warn("Could not compute mesh surface for the volume; setting it to NaN.") volume = np.nan - + morphology = pd.DataFrame({ "structure": [structure_name], "volume [pixel^3]" if resolution is None else "volume [nm^3]": [volume], "surface [pixel^2]" if resolution is None else "surface [nm^2]": [surface], }) - + else: raise ValueError("Input object must be a 2D or 3D numpy array.") - + return morphology + + +def _find_endpoints(component): + # Define a 3x3 kernel to count neighbors + kernel = np.ones((3, 3), dtype=int) + neighbor_count = convolve(component.astype(int), kernel, mode="constant", cval=0) + endpoints = np.argwhere((component == 1) & (neighbor_count == 2)) # Degree = 1 + return endpoints + + +def _compute_longest_path(component, endpoints): + # Use the first endpoint as the source + src = tuple(endpoints[0]) + cost = np.where(component, 1, np.inf) # Cost map: 1 for skeleton, inf for background + mcp = MCP(cost) + _, traceback = mcp.find_costs([src]) + + # Use the second endpoint as the destination + dst = tuple(endpoints[-1]) + + # Trace back the path + path = np.zeros_like(component, dtype=bool) + current = dst + + # Extract offsets from the MCP object + offsets = np.array(mcp.offsets) + nrows, ncols = component.shape + + while current != src: + path[current] = True + current_offset_index = traceback[current] + if current_offset_index < 0: + # No valid path found + break + offset = offsets[current_offset_index] + # Move to the predecessor + current = (current[0] - offset[0], current[1] - offset[1]) + # Ensure indices are within bounds + if not (0 <= current[0] < nrows and 0 <= current[1] < ncols): + break + + path[src] = True # Include the source + return path + + +def _prune_skeleton_longest_path(skeleton): + pruned_skeleton = np.zeros_like(skeleton, dtype=bool) + + # Label connected components in the skeleton + labeled_skeleton, num_labels = label(skeleton, return_num=True) + + for label_id in range(1, num_labels + 1): + # Isolate the current connected component + component = (labeled_skeleton == label_id) + + # Find the endpoints of the component + endpoints = _find_endpoints(component) + if len(endpoints) < 2: + continue # Skip if there are no valid endpoints + elif len(endpoints) == 2: # Nothing to prune + pruned_skeleton |= component + continue + + # Compute the longest path using MCP + longest_path = _compute_longest_path(component, endpoints) + pruned_skeleton |= longest_path + + return pruned_skeleton.astype(skeleton.dtype) + + +def skeletonize_object( + segmentation: np.ndarray, + method: str = "skeletonize", + prune: bool = True, + min_prune_size: int = 10, +) -> np.ndarray: + """Skeletonize a 3D object by inidividually skeletonizing each slice. + + Args: + segmentation: The segmented object. + method: The method to use for skeletonization. Either 'skeletonize' or 'medial_axis'. + prune: Whether to prune the skeleton. + min_prune_size: The minimal size of components after pruning. + + Returns: + The skeletonized object. + """ + assert method in ("skeletonize", "medial_axis") + seg_thin = np.zeros_like(segmentation) + skeletor = skeletonize if method == "skeletonize" else medial_axis + # Parallelize? + for z in range(segmentation.shape[0]): + skeleton = skeletor(segmentation[z]) + + if prune: + skeleton = _prune_skeleton_longest_path(skeleton) + if min_prune_size > 0: + skeleton = label(skeleton) + ids, sizes = np.unique(skeleton, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + skeleton = np.isin(skeleton, ids[sizes >= min_prune_size]) + + seg_thin[z] = skeleton + return seg_thin diff --git a/synaptic_reconstruction/sample_data.py b/synaptic_reconstruction/sample_data.py new file mode 100644 index 0000000..c0a3e47 --- /dev/null +++ b/synaptic_reconstruction/sample_data.py @@ -0,0 +1,34 @@ +import os +import pooch + + +def get_sample_data(name: str) -> str: + """Get the filepath to SynapseNet sample data, stored as mrc file. + + Args: + name: The name of the sample data. Currently, we only provide the 'tem_2d' sample data. + + Returns: + The filepath to the downloaded sample data. + """ + registry = { + "tem_2d.mrc": "3c6f9ff6d7673d9bf2fd46c09750c3c7dbb8fa1aa59dcdb3363b65cc774dcf28", + } + urls = { + "tem_2d.mrc": "https://owncloud.gwdg.de/index.php/s/5sAQ0U4puAspcHg/download", + } + key = f"{name}.mrc" + + if key not in registry: + valid_names = [k[:-4] for k in registry.keys()] + raise ValueError(f"Invalid sample name {name}, please choose one of {valid_names}.") + + cache_dir = os.path.expanduser(pooch.os_cache("synapse-net")) + data_registry = pooch.create( + path=os.path.join(cache_dir, "sample_data"), + base_url="", + registry=registry, + urls=urls, + ) + file_path = data_registry.fetch(key) + return file_path diff --git a/synaptic_reconstruction/tools/cli.py b/synaptic_reconstruction/tools/cli.py index e6482a7..a103cb2 100644 --- a/synaptic_reconstruction/tools/cli.py +++ b/synaptic_reconstruction/tools/cli.py @@ -1,12 +1,103 @@ import argparse from functools import partial -from .util import run_segmentation, get_model +from .util import ( + run_segmentation, get_model, get_model_registry, get_model_training_resolution, load_custom_model +) +from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod from ..inference.util import inference_helper, parse_tiling +def imod_point_cli(): + parser = argparse.ArgumentParser( + description="Convert a vesicle segmentation to an IMOD point model, " + "corresponding to a sphere for each vesicle in the segmentation." + ) + parser.add_argument( + "--input_path", "-i", required=True, + help="The filepath to the mrc file or the directory containing the tomogram data." + ) + parser.add_argument( + "--segmentation_path", "-s", required=True, + help="The filepath to the file or the directory containing the segmentations." + ) + parser.add_argument( + "--output_path", "-o", required=True, + help="The filepath to directory where the segmentations will be saved." + ) + parser.add_argument( + "--segmentation_key", "-k", + help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." + "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." + ) + parser.add_argument( + "--min_radius", type=float, default=10.0, + help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa + ) + parser.add_argument( + "--radius_factor", type=float, default=1.0, + help="A factor for scaling the sphere radius for the export. " + "This can be used to fit the size of segmented vesicles to the best matching spheres.", + ) + parser.add_argument( + "--force", action="store_true", + help="Whether to over-write already present export results." + ) + args = parser.parse_args() + + export_function = partial( + write_segmentation_to_imod_as_points, + min_radius=args.min_radius, + radius_factor=args.radius_factor, + ) + + export_helper( + input_path=args.input_path, + segmentation_path=args.segmentation_path, + output_root=args.output_path, + export_function=export_function, + force=args.force, + segmentation_key=args.segmentation_key, + ) + + +def imod_object_cli(): + parser = argparse.ArgumentParser( + description="Convert segmented objects to close contour IMOD models." + ) + parser.add_argument( + "--input_path", "-i", required=True, + help="The filepath to the mrc file or the directory containing the tomogram data." + ) + parser.add_argument( + "--segmentation_path", "-s", required=True, + help="The filepath to the file or the directory containing the segmentations." + ) + parser.add_argument( + "--output_path", "-o", required=True, + help="The filepath to directory where the segmentations will be saved." + ) + parser.add_argument( + "--segmentation_key", "-k", + help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." + "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." + ) + parser.add_argument( + "--force", action="store_true", + help="Whether to over-write already present export results." + ) + args = parser.parse_args() + export_helper( + input_path=args.input_path, + segmentation_path=args.segmentation_path, + output_root=args.output_path, + export_function=write_segmentation_to_imod, + force=args.force, + segmentation_key=args.segmentation_key, + ) + + # TODO: handle kwargs -# TODO: add custom model path def segmentation_cli(): parser = argparse.ArgumentParser(description="Run segmentation.") parser.add_argument( @@ -17,9 +108,11 @@ def segmentation_cli(): "--output_path", "-o", required=True, help="The filepath to directory where the segmentations will be saved." ) - # TODO: list the availabel models here by parsing the keys of the model registry + model_names = list(get_model_registry().urls.keys()) + model_names = ", ".join(model_names) parser.add_argument( - "--model", "-m", required=True, help="The model type." + "--model", "-m", required=True, + help=f"The model type. The following models are currently available: {model_names}" ) parser.add_argument( "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." @@ -41,10 +134,40 @@ def segmentation_cli(): parser.add_argument( "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." ) + parser.add_argument( + "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", + ) + parser.add_argument( + "--segmentation_key", "-s", + help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", + ) + parser.add_argument( + "--scale", type=float, + help="The factor for rescaling the data before inference. " + "By default, the scaling factor will be derived from the voxel size of the input data. " + "If this parameter is given it will over-ride the default behavior. " + ) args = parser.parse_args() - model = get_model(args.model) - tiling = parse_tiling(args.tile_shape, args.halo) + if args.checkpoint is None: + model = get_model(args.model) + else: + model = load_custom_model(args.checkpoint) + assert model is not None, f"The model from {args.checkpoint} could not be loaded." + + is_2d = "2d" in args.model + tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) + + # If the scale argument is not passed, then we get the average training resolution for the model. + # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. + if args.scale is None: + model_resolution = get_model_training_resolution(args.model) + model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) + scale = None + # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. + else: + model_resolution = None + scale = (2 if is_2d else 3) * (args.scale,) segmentation_function = partial( run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling, @@ -52,4 +175,5 @@ def segmentation_cli(): inference_helper( args.input_path, args.output_path, segmentation_function, mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, + output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, ) diff --git a/synaptic_reconstruction/tools/segmentation_widget.py b/synaptic_reconstruction/tools/segmentation_widget.py index 0b63642..548a465 100644 --- a/synaptic_reconstruction/tools/segmentation_widget.py +++ b/synaptic_reconstruction/tools/segmentation_widget.py @@ -136,7 +136,6 @@ def _create_settings_widget(self): setting_values.layout().addLayout(layout) # Create UI for the halo. - self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param( ("halo_x", "halo_y", "halo_z"), (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]), @@ -145,7 +144,7 @@ def _create_settings_widget(self): ) setting_values.layout().addLayout(layout) - # read voxel size from layer metadata + # Read voxel size from layer metadata. self.voxel_size_param, layout = self._add_float_param( "voxel_size", 0.0, min_val=0.0, max_val=100.0, ) diff --git a/synaptic_reconstruction/tools/util.py b/synaptic_reconstruction/tools/util.py index 2d135cc..edb51a1 100644 --- a/synaptic_reconstruction/tools/util.py +++ b/synaptic_reconstruction/tools/util.py @@ -6,8 +6,10 @@ import numpy as np import pooch -from ..inference.vesicles import segment_vesicles +from ..inference.active_zone import segment_active_zone +from ..inference.compartments import segment_compartments from ..inference.mitochondria import segment_mitochondria +from ..inference.vesicles import segment_vesicles def _save_table(save_path, data): @@ -52,7 +54,7 @@ def get_model_path(model_type: str) -> str: model_path = model_registry.fetch(model_type) return model_path - + def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: """Get the model for the given segmentation type. @@ -98,14 +100,14 @@ def run_segmentation( The segmentation. """ if model_type.startswith("vesicles"): - segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose) + segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) elif model_type == "mitochondria": - segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose) + segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) elif model_type == "active_zone": - raise NotImplementedError + segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) elif model_type == "compartments": - raise NotImplementedError - elif model_type == "inner_ear_structures": + segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) + elif model_type == "ribbon_synapse_structures": raise NotImplementedError else: raise ValueError(f"Unknown model type: {model_type}") diff --git a/synaptic_reconstruction/training/domain_adaptation.py b/synaptic_reconstruction/training/domain_adaptation.py index c4cb892..215d7fa 100644 --- a/synaptic_reconstruction/training/domain_adaptation.py +++ b/synaptic_reconstruction/training/domain_adaptation.py @@ -6,7 +6,7 @@ import torch_em.self_training as self_training from .semisupervised_training import get_unsupervised_loader -from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, determine_ndim +from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim def mean_teacher_adaptation( @@ -28,12 +28,15 @@ def mean_teacher_adaptation( n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, sampler: Optional[callable] = None, -): +) -> None: """Run domain adapation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain. We support different domain adaptation settings: - - + - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and + 'supervised_val_paths' are not given. + - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, + when 'supervised_train_paths' and 'supervised_val_paths' are given. Args: name: The name for the checkpoint to be trained. @@ -71,7 +74,7 @@ def mean_teacher_adaptation( based on the patch_shape and size of the volumes used for validation. """ assert (supervised_train_paths is None) == (supervised_val_paths is None) - is_2d, _ = determine_ndim(patch_shape) + is_2d, _ = _determine_ndim(patch_shape) if source_checkpoint is None: # training from scratch only makes sense if we have supervised training data diff --git a/synaptic_reconstruction/training/semisupervised_training.py b/synaptic_reconstruction/training/semisupervised_training.py index 8c2d0f2..1c9c0b8 100644 --- a/synaptic_reconstruction/training/semisupervised_training.py +++ b/synaptic_reconstruction/training/semisupervised_training.py @@ -6,7 +6,7 @@ import torch_em.self_training as self_training from torchvision import transforms -from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, determine_ndim +from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim def weak_augmentations(p: float = 0.75) -> callable: @@ -61,7 +61,7 @@ def get_unsupervised_loader( else: roi = None - _, ndim = determine_ndim(patch_shape) + _, ndim = _determine_ndim(patch_shape) raw_transform = torch_em.transform.get_raw_transform() transform = torch_em.transform.get_augmentations(ndim=ndim) @@ -99,7 +99,7 @@ def semisupervised_training( n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False, -): +) -> None: """Run semi-supervised segmentation training. Args: diff --git a/synaptic_reconstruction/training/supervised_training.py b/synaptic_reconstruction/training/supervised_training.py index 72a32f0..16cd1cb 100644 --- a/synaptic_reconstruction/training/supervised_training.py +++ b/synaptic_reconstruction/training/supervised_training.py @@ -19,6 +19,7 @@ def get_3d_model( initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level. final_activation: The activation applied to the last output layer. + Returns: The U-Net. """ @@ -60,14 +61,14 @@ def get_2d_model( return model -def adjust_patch_shape(data_shape, patch_shape): +def _adjust_patch_shape(data_shape, patch_shape): # If data is 2D and patch_shape is 3D, drop the extra dimension in patch_shape if data_shape == 2 and len(patch_shape) == 3: return patch_shape[1:] # Remove the leading dimension in patch_shape return patch_shape # Return the original patch_shape for 3D data -def determine_ndim(patch_shape): +def _determine_ndim(patch_shape): # Check for 2D or 3D training try: z, y, x = patch_shape @@ -120,7 +121,7 @@ def get_supervised_loader( Returns: The PyTorch dataloader. """ - _, ndim = determine_ndim(patch_shape) + _, ndim = _determine_ndim(patch_shape) if label_transform is not None: # A specific label transform was passed, do nothing. pass elif add_boundary_transform: @@ -137,7 +138,7 @@ def get_supervised_loader( label_transform = torch_em.transform.label.connected_components if ndim == 2: - adjusted_patch_shape = adjust_patch_shape(ndim, patch_shape) + adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) transform = torch_em.transform.Compose( torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) ) @@ -239,7 +240,7 @@ def supervised_training( check_loader(val_loader, n_samples=4) return - is_2d, _ = determine_ndim(patch_shape) + is_2d, _ = _determine_ndim(patch_shape) if is_2d: model = get_2d_model(out_channels=out_channels) else: diff --git a/test/test_cli.py b/test/test_cli.py new file mode 100644 index 0000000..6b0d1fb --- /dev/null +++ b/test/test_cli.py @@ -0,0 +1,68 @@ +import os +import unittest +from subprocess import run +from shutil import rmtree + +import imageio.v3 as imageio +import mrcfile +import pooch +from synaptic_reconstruction.sample_data import get_sample_data + + +class TestCLI(unittest.TestCase): + tmp_dir = "./tmp" + + def setUp(self): + self.data_path = get_sample_data("tem_2d") + os.makedirs(self.tmp_dir, exist_ok=True) + + def tearDown(self): + try: + rmtree(self.tmp_dir) + except OSError: + pass + + def check_segmentation_result(self): + output_path = os.path.join(self.tmp_dir, "tem_2d_prediction.tif") + self.assertTrue(os.path.exists(output_path)) + + prediction = imageio.imread(output_path) + with mrcfile.open(self.data_path, "r") as f: + data = f.data[:] + self.assertEqual(prediction.shape, data.shape) + + num_labels = prediction.max() + self.assertGreater(num_labels, 1) + + # import napari + # v = napari.Viewer() + # v.add_image(data) + # v.add_labels(prediction) + # napari.run() + + def test_segmentation_cli(self): + cmd = ["synapse_net.run_segmentation", "-i", self.data_path, "-o", self.tmp_dir, "-m", "vesicles_2d"] + run(cmd) + self.check_segmentation_result() + + def test_segmentation_cli_with_scale(self): + cmd = [ + "synapse_net.run_segmentation", "-i", self.data_path, "-o", self.tmp_dir, "-m", "vesicles_2d", + "--scale", "0.5" + ] + run(cmd) + self.check_segmentation_result() + + def test_segmentation_cli_with_checkpoint(self): + cache_dir = os.path.expanduser(pooch.os_cache("synapse-net")) + model_path = os.path.join(cache_dir, "models", "vesicles_2d") + cmd = [ + "synapse_net.run_segmentation", "-i", self.data_path, "-o", self.tmp_dir, "-m", "vesicles_2d", + "-c", model_path, + ] + run(cmd) + self.check_segmentation_result() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_distance_measurement.py b/test/test_distance_measurement.py index d3595c9..f3ed296 100644 --- a/test/test_distance_measurement.py +++ b/test/test_distance_measurement.py @@ -4,8 +4,8 @@ class TestDistanceMeasurement(unittest.TestCase): - def test_compute_boundary_distances(self): - from synaptic_reconstruction.distance_measurements import compute_boundary_distances + def test_measure_pairwise_object_distances(self): + from synaptic_reconstruction.distance_measurements import measure_pairwise_object_distances shape = (4, 64, 64) seg = np.zeros(shape, dtype="uint32") @@ -17,7 +17,7 @@ def test_compute_boundary_distances(self): seg[1, 16, 63] = 5 for resolution in (None, 2.3, 4.4): - distances, _, _, seg_ids = compute_boundary_distances(seg, resolution, n_threads=1) + distances, _, _, seg_ids = measure_pairwise_object_distances(seg, resolution=resolution, n_threads=1) factor = 1 if resolution is None else resolution