Skip to content

Commit

Permalink
Merge pull request #100 from flatironinstitute/zernike_distance
Browse files Browse the repository at this point in the history
Implementation of external Zernike3D distance
  • Loading branch information
geoffwoollard authored Dec 19, 2024
2 parents 4602aea + 89a301e commit db1ead3
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 3 deletions.
50 changes: 50 additions & 0 deletions docs/setup_zernike3d_distance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
<h1 align='center'>How to setup Zernike3D distance?</h1>

<p align="center">

<img alt="Supported Python versions" src="https://img.shields.io/badge/Supported_Python_Versions-3.8_%7C_3.9_%7C_3.10_%7C_3.11_%7C_3.12-blue">
<img alt="GitHub Downloads (all assets, all releases)" src="https://img.shields.io/github/downloads/I2PC/Flexutils-Toolkit/total">
<img alt="GitHub License" src="https://img.shields.io/github/license/I2PC/Flexutils-Toolkit">

</p>

<p align="center">

<img alt="Flexutils" src="https://github.com/scipion-em/scipion-em-flexutils/raw/devel/flexutils/icon.png" width="200" height="200">

</p>



Zernike3D distance relies on the external software **[Flexutils](https://github.com/I2PC/Flexutils-Toolkit)**. The following document includes the installation guide to setup this software in your machine, as well as some guidelines on the parameters and characteristics of the Zernike3D distance.

# Flexutils installation
**Flexutils** can be installed in your system with the following commands:

```bash
git clone https://github.com/I2PC/Flexutils-Toolkit.git
cd Flexutils-Toolkit
bash install.sh
```

Any errors raised during the installation of the software or the computation of the Zernike3D distance can be reported through Flexutils GitHub issue [webpage](https://github.com/I2PC/Flexutils-Toolkit/issues).

# Defining the config file parameters
Zernike3D distance relies on the approximation of a deformation field between two volumes to measure their similarity metric. A detailed explanation on the theory behind the computation of these deformation fields is provided in the following publications: [Zernike3D-IUCRJ](https://journals.iucr.org/m/issues/2021/06/00/eh5012/) and [Zernike3D-NatComm](https://www.nature.com/articles/s41467-023-35791-y).

The software follows a neural network approximation, so the usage of a GPU is strongly recommended.

The Zernike3D distance requires a set of additional execution parameters that need to be supplied through the `config_map_to_map.yaml` file passed to the distance compution step. These additional parameters are presented below:

- **gpuID**: An integer larger than 0 determining the GPU to be used to train the Zernike3Deep neural network.
- **tmpDir**: A path to a folder needed to store the intermediate files generated by the software. This folder is **NOT** emptied once the execution finishes.
- **thr**: An integer larger than 0 determining the number of processes to use during the execution of the software.

```yaml
metrics:
- zernike3d
zernike3d_extra_params:
gpuID: 0
tmpDir: where/to/save/intermediate/files/folder
thr: 20
```
82 changes: 82 additions & 0 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import subprocess
import math
import torch
from typing import Optional, Sequence
Expand Down Expand Up @@ -55,6 +57,7 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
"""Compute the distance matrix between two sets of maps."""
if self.config["data"]["mask"]["do"]:
maps2 = maps2[:, self.mask]

else:
maps2 = maps2.reshape(len(maps2), -1)

Expand Down Expand Up @@ -87,6 +90,8 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):

else:
maps1 = maps1.reshape(len(maps1), -1)
if self.config["data"]["mask"]["do"]:
maps1 = maps1.reshape(len(maps1), -1)[:, self.mask]
maps2 = maps2.reshape(len(maps2), -1)
distance_matrix = torch.vmap(
lambda maps1: torch.vmap(
Expand Down Expand Up @@ -398,3 +403,80 @@ def res_at_fsc_threshold(fscs, threshold=0.5):
res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fsc_matrix)
self.stored_computed_assets = {"fraction_nyquist": fraction_nyquist}
return units_Angstroms[res_fsc_half]


class Zernike3DDistance(MapToMapDistance):
"""Zernike3D based distance.
Zernike3D distance relies on the estimation of the non-linear transformation needed to align two different maps.
The RMSD of the associated non-linear alignment represented as a deformation field is then used as the distance
between two maps
"""

@override
def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
gpuID = self.config["analysis"]["zernike3d_extra_params"]["gpuID"]
outputPath = self.config["analysis"]["zernike3d_extra_params"]["tmpDir"]
thr = self.config["analysis"]["zernike3d_extra_params"]["thr"]
numProjections = self.config["analysis"]["zernike3d_extra_params"][
"numProjections"
]

# Create output directory
if not os.path.isdir(outputPath):
os.mkdir(outputPath)

# Prepare data to call external
targets_paths = os.path.join(outputPath, "target_maps.npy")
references_path = os.path.join(outputPath, "reference_maps.npy")
if not os.path.isfile(targets_paths):
np.save(targets_paths, maps1)
if not os.path.isfile(references_path):
np.save(references_path, maps2)

# Check conda is in PATH (otherwise abort as external software is not installed)
try:
subprocess.check_call("conda", shell=True, stdout=subprocess.PIPE)
except FileNotFoundError:
raise Exception("Conda not found in PATH... Aborting")

# Check if conda env is installed
env_installed = subprocess.run(
r"conda env list | grep 'flexutils-tensorflow '",
shell=True,
check=False,
stdout=subprocess.PIPE,
).stdout
env_installed = bool(
env_installed.decode("utf-8").replace("\n", "").replace("*", "")
)
if not env_installed:
raise Exception("External software not found... Aborting")

# Find conda executable (needed to activate conda envs in a subprocess)
condabin_path = subprocess.run(
r"which conda | sed 's: ::g'",
shell=True,
check=False,
stdout=subprocess.PIPE,
).stdout
condabin_path = condabin_path.decode("utf-8").replace("\n", "").replace("*", "")

# Call external program
subprocess.check_call(
f'eval "$({condabin_path} shell.bash hook)" &&'
f" conda activate flexutils-tensorflow && "
f"compute_distance_matrix_zernike3deep.py --references_file {references_path} "
f"--targets_file {targets_paths} --out_path {outputPath} --gpu {gpuID} --num_projections {numProjections} "
f"--thr {thr}",
shell=True,
)

# Read distance matrix
dists = np.load(os.path.join(outputPath, "dist_mat.npy")).T
self.stored_computed_assets = {"zernike3d": dists}
return dists

@override
def get_computed_assets(self, maps1, maps2, global_store_of_running_results):
return self.stored_computed_assets # must run get_distance_matrix first
3 changes: 3 additions & 0 deletions src/cryo_challenge/_map_to_map/map_to_map_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
L2DistanceNorm,
BioEM3dDistance,
FSCResDistance,
Zernike3DDistance,
)


Expand All @@ -18,6 +19,7 @@
"l2": L2DistanceNorm,
"bioem": BioEM3dDistance,
"res": FSCResDistance,
"zernike3d": Zernike3DDistance,
}


Expand Down Expand Up @@ -51,6 +53,7 @@ def run(config):
maps_user_flat = submission[submission_volume_key].reshape(
len(submission["volumes"]), -1
)

maps_gt_flat = torch.load(
config["data"]["ground_truth"]["volumes"], mmap=do_low_memory_mode
)
Expand Down
2 changes: 2 additions & 0 deletions src/cryo_challenge/data/_validation/output_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class MapToMapResultsValidator:
bioem: Optional[dict] = None
fsc: Optional[dict] = None
res: Optional[dict] = None
zernike3d: Optional[dict] = None

def __post_init__(self):
validate_input_config_mtm(self.config)
Expand Down Expand Up @@ -151,6 +152,7 @@ class DistributionToDistributionResultsValidator:
res: Optional[dict] = None
l2: Optional[dict] = None
corr: Optional[dict] = None
zernike3d: Optional[dict] = None

def __post_init__(self):
validate_input_config_disttodist(self.config)
Expand Down
2 changes: 1 addition & 1 deletion tests/config_files/test_config_map_to_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ data:
metadata: tests/data/Ground_truth/test_metadata_10.csv
mask:
do: true
volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc
volume: tests/data/Ground_truth/test_mask_bool.mrc
analysis:
metrics:
- l2
Expand Down
31 changes: 31 additions & 0 deletions tests/config_files/test_config_map_to_map_external.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
data:
n_pix: 16
psize: 30.044
submission:
fname: tests/data/dataset_2_submissions/submission_1000.pt
volume_key: volumes
metadata_key: populations
label_key: id
ground_truth:
volumes: tests/data/Ground_truth/test_maps_gt_flat_10.pt
metadata: tests/data/Ground_truth/test_metadata_10.csv
mask:
do: false
volume: tests/data/Ground_truth/test_mask_bool.mrc
analysis:
zernike3d_extra_params:
gpuID: 0
tmpDir: tmp_zernike
thr: 20
numProjections: 20 # projecions should be 20-100
metrics:
- zernike3d
chunk_size_submission: 4
chunk_size_gt: 5
low_memory:
do: false
chunk_size_low_memory: null
normalize:
do: true
method: median_zscore
output: tests/results/test_map_to_map_distance_matrix_submission_0.pkl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ data:
metadata: tests/data/Ground_truth/test_metadata_10.csv
mask:
do: true
volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc
volume: tests/data/Ground_truth/test_mask_bool.mrc
analysis:
metrics:
- l2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ data:
metadata: tests/data/Ground_truth/test_metadata_10.csv
mask:
do: false
volume: tests/data/Ground_truth/test_mask_dilated_wide.mrc
volume: tests/data/Ground_truth/test_mask_bool.mrc
analysis:
metrics:
- l2
Expand Down
Binary file added tests/data/Ground_truth/test_mask_bool.mrc
Binary file not shown.
12 changes: 12 additions & 0 deletions tests/test_map_to_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@


def test_run_map2map_pipeline():
try:
args = OmegaConf.create(
{"config": "tests/config_files/test_config_map_to_map_external.yaml"}
)
results_dict = run_map2map_pipeline.main(args)
assert "zernike3d" in results_dict.keys()
except Exception as e:
print(e)
print(
"External test failed. Skipping test. Fails when running in CI if external dependencies are not installed."
)

for config_fname, config_fname_low_memory in zip(
[
"tests/config_files/test_config_map_to_map.yaml",
Expand Down

0 comments on commit db1ead3

Please sign in to comment.