Skip to content

Commit

Permalink
fix paths in testing and ruff linting
Browse files Browse the repository at this point in the history
  • Loading branch information
DSilva27 committed Aug 5, 2024
1 parent e24b34b commit c5b2276
Show file tree
Hide file tree
Showing 28 changed files with 324 additions and 209 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main_merge_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ jobs:
if: github.base_ref == 'main' && github.head_ref != 'dev'
run: |
echo "ERROR: You can only merge to main from dev."
exit 1
exit 1
2 changes: 1 addition & 1 deletion .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
- uses: chartboost/ruff-action@v1
1 change: 0 additions & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,3 @@ jobs:
pytest tests/test_svd.py
pytest tests/test_map_to_map.py
pytest tests/test_distribution_to_distribution.py
2 changes: 1 addition & 1 deletion config_files/config_distribution_to_distribution.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ cvxpy_solver: ECOS
optimal_q_kl:
n_iter: 100000
break_atol: 0.0001
output_fname: results/distribution_to_distribution_submission_0.pkl
output_fname: results/distribution_to_distribution_submission_0.pkl
10 changes: 5 additions & 5 deletions config_files/config_map_to_map_distance_matrix.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
data:
n_pix: 224
psize: 2.146
psize: 2.146
submission:
fname: data/dataset_2_ground_truth/submission_0.pt
volume_key: volumes
metadata_key: populations
label_key: id
ground_truth:
volumes: data/dataset_2_ground_truth/maps_gt_flat.pt
metadata: data/dataset_2_ground_truth/metadata.csv
mask:
volumes: data/dataset_2_ground_truth/maps_gt_flat.pt
metadata: data/dataset_2_ground_truth/metadata.csv
mask:
do: true
volume: data/dataset_2_ground_truth/mask_dilated_wide_224x224.mrc
analysis:
Expand All @@ -23,4 +23,4 @@ analysis:
normalize:
do: true
method: median_zscore
output: results/map_to_map_distance_matrix_submission_0.pkl
output: results/map_to_map_distance_matrix_submission_0.pkl
4 changes: 3 additions & 1 deletion src/cryo_challenge/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from cryo_challenge.__about__ import __version__
from cryo_challenge.__about__ import __version__

__all__ = ["__version__"]
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import numpy as np
import pickle
from scipy.stats import rankdata
import yaml
import argparse
import torch
import ot

Expand All @@ -14,10 +12,12 @@


def sort_by_transport(cost):
m,n = cost.shape
_, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost)
indices = np.argsort((transport * np.arange(m)[...,None]).sum(0))
return cost[:,indices], indices, transport
m, n = cost.shape
_, transport = compute_wasserstein_between_distributions_from_weights_and_cost(
np.ones(m) / m, np.ones(n) / n, cost
)
indices = np.argsort((transport * np.arange(m)[..., None]).sum(0))
return cost[:, indices], indices, transport


def compute_wasserstein_between_distributions_from_weights_and_cost(
Expand Down Expand Up @@ -65,15 +65,14 @@ def make_assignment_matrix(cost_matrix):


def run(config):

metadata_df = pd.read_csv(config["gt_metadata_fname"])
metadata_df.sort_values("pc1", inplace=True)

with open(config["input_fname"], "rb") as f:
data = pickle.load(f)

# user_submitted_populations = np.ones(80)/80
user_submitted_populations = data["user_submitted_populations"]#.numpy()
user_submitted_populations = data["user_submitted_populations"] # .numpy()
id = torch.load(data["config"]["data"]["submission"]["fname"])["id"]

results_dict = {}
Expand Down Expand Up @@ -213,5 +212,5 @@ def optimal_q_kl(n_iter, x_start, A, Window, prob_gt, break_atol):
DistributionToDistributionResultsValidator.from_dict(results_dict)
with open(config["output_fname"], "wb") as f:
pickle.dump(results_dict, f)

return results_dict
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run(config):
user_submission_label = submission[label_key]

# n_trunc = 10
metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"])#[:n_trunc]
metadata_gt = pd.read_csv(config["data"]["ground_truth"]["metadata"]) # [:n_trunc]

results_dict = {}
results_dict["config"] = config
Expand Down
5 changes: 3 additions & 2 deletions src/cryo_challenge/_ploting/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np


def res_at_fsc_threshold(fscs, threshold=0.5):
res_fsc_half = np.argmin(fscs > threshold, axis=-1)
fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1]
return res_fsc_half, fraction_nyquist
fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1]
return res_fsc_half, fraction_nyquist
24 changes: 18 additions & 6 deletions src/cryo_challenge/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from ._validation.config_validators import validate_input_config_disttodist as validate_input_config_disttodist
from ._validation.config_validators import validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl
from cryo_challenge.data._validation.output_validators import DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator
from cryo_challenge.data._validation.output_validators import MetricDistToDistValidator as MetricDistToDistValidator
from cryo_challenge.data._validation.output_validators import ReplicateValidatorEMD as ReplicateValidatorEMD
from cryo_challenge.data._validation.output_validators import ReplicateValidatorKL as ReplicateValidatorKL
from ._validation.config_validators import (
validate_input_config_disttodist as validate_input_config_disttodist,
)
from ._validation.config_validators import (
validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl,
)
from cryo_challenge.data._validation.output_validators import (
DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator,
)
from cryo_challenge.data._validation.output_validators import (
MetricDistToDistValidator as MetricDistToDistValidator,
)
from cryo_challenge.data._validation.output_validators import (
ReplicateValidatorEMD as ReplicateValidatorEMD,
)
from cryo_challenge.data._validation.output_validators import (
ReplicateValidatorKL as ReplicateValidatorKL,
)
10 changes: 6 additions & 4 deletions src/cryo_challenge/data/_io/svd_io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,16 @@ def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32):

# Reshape volumes to correct size
if volumes.dim() == 2:
box_size = int(round((float(volumes.shape[-1]) ** (1. / 3.))))
box_size = int(round((float(volumes.shape[-1]) ** (1.0 / 3.0))))
volumes = torch.reshape(volumes, (-1, box_size, box_size, box_size))
elif volumes.dim() == 4:
pass
else:
raise ValueError(f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape "
f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the "
f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size).")
raise ValueError(
f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape "
f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the "
f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size)."
)

volumes_ds = torch.empty(
(volumes.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype
Expand Down
2 changes: 1 addition & 1 deletion src/cryo_challenge/data/_validation/config_validators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from numbers import Number
import pandas as pd
import os
from typing import List


def validate_generic_config(config: dict, reference: dict) -> None:
"""
Expand Down
21 changes: 13 additions & 8 deletions src/cryo_challenge/data/_validation/output_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@dataclass_json
@dataclass
class MapToMapResultsValidator:
'''
"""
Validate the output dictionary of the map-to-map distance matrix computation.
config: dict, input config dictionary.
Expand All @@ -22,7 +22,8 @@ class MapToMapResultsValidator:
l2: dict, L2 results.
bioem: dict, BioEM results.
fsc: dict, FSC results.
'''
"""

config: dict
user_submitted_populations: torch.Tensor
corr: Optional[dict] = None
Expand All @@ -49,7 +50,7 @@ class ReplicateValidatorEMD:
Validate the output dictionary of one EMD in the the distribution-to-distribution pipeline.
q_opt: List[float], optimal user submitted distribution, which sums to 1.
EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt).
EMD_opt: float, EMD between the ground truth distribution (p) and the (optimized) user submitted distribution (q_opt).
The transport plan is a joint distribution, such that:
summing over the rows gives the (optimized) user submitted distribution, and summing over the columns gives the ground truth distribution.
transport_plan_opt: List[List[float]], transport plan between the ground truth distribution (p, rows) and the (optimized) user submitted distribution (q_opt, columns).
Expand All @@ -61,6 +62,7 @@ class ReplicateValidatorEMD:
The transport plan is a joint distribution, such that:
summing over the rows gives the user submitted distribution, and summing over the columns gives the ground truth distribution.
"""

q_opt: List[float]
EMD_opt: float
transport_plan_opt: List[List[float]]
Expand All @@ -87,8 +89,9 @@ class ReplicateValidatorKL:
iter_stop: int, number of iterations until convergence.
eps_stop: float, stopping criterion.
klpq_submitted: float, KL divergence between the ground truth distribution (p) and the user submitted distribution (q).
klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p).
klqp_submitted: float, KL divergence between the user submitted distribution (q) and the ground truth distribution (p).
"""

q_opt: List[float]
klpq_opt: float
klqp_opt: float
Expand All @@ -106,11 +109,12 @@ def __post_init__(self):
@dataclass_json
@dataclass
class MetricDistToDistValidator:
'''
"""
Validate the output dictionary of one map to map metric in the the distribution-to-distribution pipeline.
replicates: dict, dictionary of replicates.
'''
"""

replicates: dict

def validate_replicates(self, n_replicates):
Expand All @@ -126,7 +130,7 @@ def validate_replicates(self, n_replicates):
@dataclass_json
@dataclass
class DistributionToDistributionResultsValidator:
'''
"""
Validate the output dictionary of the distribution-to-distribution pipeline.
config: dict, input config dictionary.
Expand All @@ -136,7 +140,8 @@ class DistributionToDistributionResultsValidator:
bioem: dict, BioEM distance results.
l2: dict, L2 distance results.
corr: dict, correlation distance results.
'''
"""

config: dict
user_submitted_populations: torch.Tensor
id: str
Expand Down
2 changes: 1 addition & 1 deletion tests/config_files/test_config_map_to_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ data:
n_pix: 224
psize: 2.146
submission:
fname: tests/data/dataset_2_submissions/test_submission_0_n8.pt
fname: tests/data/dataset_2_submissions/submission_10000.pt
volume_key: volumes
metadata_key: populations
label_key: id
Expand Down
2 changes: 1 addition & 1 deletion tests/config_files/test_config_svd.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
path_to_volumes: tests/data/dataset_2_submissions/
box_size_ds: 32
submission_list: [0]
submission_list: [10000]
experiment_mode: "all_vs_ref" # options are "all_vs_all", "all_vs_ref"
# optional unless experiment_mode is "all_vs_ref"
path_to_reference: tests/data/Ground_truth/test_maps_gt_flat_10.pt
Expand Down
1 change: 0 additions & 1 deletion tests/data/dataset_2_submissions/submission_0.pt

This file was deleted.

12 changes: 0 additions & 12 deletions tests/scripts/fetch_test_data.sh

This file was deleted.

8 changes: 5 additions & 3 deletions tests/test_distribution_to_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from cryo_challenge._commands import run_distribution2distribution_pipeline


def test_run_distribution2distribution_pipeline():
args = OmegaConf.create({'config': 'tests/config_files/test_config_distribution_to_distribution.yaml'})
run_distribution2distribution_pipeline.main(args)
def test_run_distribution2distribution_pipeline():
args = OmegaConf.create(
{"config": "tests/config_files/test_config_distribution_to_distribution.yaml"}
)
run_distribution2distribution_pipeline.main(args)
8 changes: 5 additions & 3 deletions tests/test_map_to_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from cryo_challenge._commands import run_map2map_pipeline


def test_run_map2map_pipeline():
args = OmegaConf.create({'config': 'tests/config_files/test_config_map_to_map.yaml'})
run_map2map_pipeline.main(args)
def test_run_map2map_pipeline():
args = OmegaConf.create(
{"config": "tests/config_files/test_config_map_to_map.yaml"}
)
run_map2map_pipeline.main(args)
6 changes: 3 additions & 3 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from cryo_challenge._commands import run_preprocessing


def test_run_preprocessing():
args = OmegaConf.create({'config': 'tests/config_files/test_config_preproc.yaml'})
run_preprocessing.main(args)
def test_run_preprocessing():
args = OmegaConf.create({"config": "tests/config_files/test_config_preproc.yaml"})
run_preprocessing.main(args)
6 changes: 3 additions & 3 deletions tests/test_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from cryo_challenge._commands import run_svd


def test_run_preprocessing():
args = OmegaConf.create({'config': 'tests/config_files/test_config_svd.yaml'})
run_svd.main(args)
def test_run_preprocessing():
args = OmegaConf.create({"config": "tests/config_files/test_config_svd.yaml"})
run_svd.main(args)
4 changes: 2 additions & 2 deletions tutorials/1_tutorial_preprocessing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@
"# Select path to Config file\n",
"# An example of this file is available in the path ../config_files/config_preproc.yaml\n",
"config_preproc_path = FileChooser(os.path.expanduser(\"~\"))\n",
"config_preproc_path.filter_pattern = '*.yaml'\n",
"config_preproc_path.filter_pattern = \"*.yaml\"\n",
"display(config_preproc_path)"
]
},
Expand All @@ -226,7 +226,7 @@
"if os.path.isabs(output_path):\n",
" full_output_path = output_path\n",
"else:\n",
" full_output_path = os.path.join(os.getcwd(), '..', output_path)"
" full_output_path = os.path.join(os.getcwd(), \"..\", output_path)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions tutorials/2_tutorial_svd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"# Select path to SVD config file\n",
"# An example of this file is available in the path ../config_files/config_svd.yaml\n",
"config_svd_path = FileChooser(os.path.expanduser(\"~\"))\n",
"config_svd_path.filter_pattern = '*.yaml'\n",
"config_svd_path.filter_pattern = \"*.yaml\"\n",
"display(config_svd_path)"
]
},
Expand Down Expand Up @@ -125,7 +125,7 @@
"source": [
"# Select path to SVD results\n",
"svd_results_path = FileChooser(os.path.expanduser(\"~\"))\n",
"svd_results_path.filter_pattern = '*.pt'\n",
"svd_results_path.filter_pattern = \"*.pt\"\n",
"display(svd_results_path)"
]
},
Expand Down Expand Up @@ -316,7 +316,7 @@
"source": [
"# Select path to SVD results\n",
"svd_all_vs_all_results_path = FileChooser(os.path.expanduser(\"~\"))\n",
"svd_all_vs_all_results_path.filter_pattern = '*.pt'\n",
"svd_all_vs_all_results_path.filter_pattern = \"*.pt\"\n",
"display(svd_all_vs_all_results_path)"
]
},
Expand Down
Loading

0 comments on commit c5b2276

Please sign in to comment.