Skip to content

Commit

Permalink
Merge pull request #84 from flatironinstitute/83-add-linting-test-enf…
Browse files Browse the repository at this point in the history
…orce-pre-commit-definitions

add linting check
  • Loading branch information
DSilva27 authored Aug 12, 2024
2 parents 18712d6 + c516088 commit 62563ef
Show file tree
Hide file tree
Showing 23 changed files with 398 additions and 217 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main_merge_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ jobs:
if: github.base_ref == 'main' && github.head_ref != 'dev'
run: |
echo "ERROR: You can only merge to main from dev."
exit 1
exit 1
12 changes: 12 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Runs the Ruff linter and formatter.

name: Lint

on: [push]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
22 changes: 21 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,28 @@ The "-e" flag will install the package in editable mode, which means you can edi

## Things to do before pushing to GitHub

In this project we use Ruff for linting, and pre-commit to make sure that the code being pushed is not broken or goes against PEP8 guidelines. When you run `git commit` the pre-commit pipeline should rune automatically. In the near future we will start using pytest and mypy to perform more checks.
### Using pre-commit hooks for code formatting and linting

When you install in developer mode with `".[dev]` you will install the [pre-commit](https://pre-commit.com/) package. To set up this package simply run

```bash
pre-commit install
```

Then, everytime before doing a commit (that is before `git add` and `git commit`) run the following command:

```bash
pre-commit run --all-files
```

This will run `ruff` linting and formatting. If there is anything that cannot be automatically fixed, the command will let you know the file and line that needs to be fixed before being able to commit. Once you have fixed everything, you will be able to run `git add` and `git commit` without issue.


### Make sure tests run

```bash
python -m pytest tests/
```

## Best practices for contributing

Expand Down
2 changes: 1 addition & 1 deletion config_files/config_distribution_to_distribution.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ cvxpy_solver: ECOS
optimal_q_kl:
n_iter: 100000
break_atol: 0.0001
output_fname: results/distribution_to_distribution_submission_0.pkl
output_fname: results/distribution_to_distribution_submission_0.pkl
10 changes: 5 additions & 5 deletions config_files/config_map_to_map.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
data:
n_pix: 224
psize: 2.146
psize: 2.146
submission:
fname: data/dataset_2_ground_truth/submission_0.pt
volume_key: volumes
metadata_key: populations
label_key: id
ground_truth:
volumes: data/dataset_2_ground_truth/maps_gt_flat.pt
metadata: data/dataset_2_ground_truth/metadata.csv
mask:
volumes: data/dataset_2_ground_truth/maps_gt_flat.pt
metadata: data/dataset_2_ground_truth/metadata.csv
mask:
do: true
volume: data/dataset_2_ground_truth/mask_dilated_wide_224x224.mrc
analysis:
Expand All @@ -23,4 +23,4 @@ analysis:
normalize:
do: true
method: median_zscore
output: results/map_to_map_distance_matrix_submission_0.pkl
output: results/map_to_map_distance_matrix_submission_0.pkl
4 changes: 3 additions & 1 deletion src/cryo_challenge/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from cryo_challenge.__about__ import __version__
from cryo_challenge.__about__ import __version__

__all__ = ["__version__"]
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import numpy as np
import pickle
from scipy.stats import rankdata
import yaml
import argparse
import torch
import ot

Expand All @@ -14,10 +12,12 @@


def sort_by_transport(cost):
m,n = cost.shape
_, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost)
indices = np.argsort((transport * np.arange(m)[...,None]).sum(0))
return cost[:,indices], indices, transport
m, n = cost.shape
_, transport = compute_wasserstein_between_distributions_from_weights_and_cost(
np.ones(m) / m, np.ones(n) / n, cost
)
indices = np.argsort((transport * np.arange(m)[..., None]).sum(0))
return cost[:, indices], indices, transport


def compute_wasserstein_between_distributions_from_weights_and_cost(
Expand Down Expand Up @@ -65,15 +65,14 @@ def make_assignment_matrix(cost_matrix):


def run(config):

metadata_df = pd.read_csv(config["gt_metadata_fname"])
metadata_df.sort_values("pc1", inplace=True)

with open(config["input_fname"], "rb") as f:
data = pickle.load(f)

# user_submitted_populations = np.ones(80)/80
user_submitted_populations = data["user_submitted_populations"]#.numpy()
user_submitted_populations = data["user_submitted_populations"] # .numpy()
id = torch.load(data["config"]["data"]["submission"]["fname"])["id"]

results_dict = {}
Expand Down Expand Up @@ -213,5 +212,5 @@ def optimal_q_kl(n_iter, x_start, A, Window, prob_gt, break_atol):
DistributionToDistributionResultsValidator.from_dict(results_dict)
with open(config["output_fname"], "wb") as f:
pickle.dump(results_dict, f)

return results_dict
53 changes: 34 additions & 19 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import math
import torch
from typing import Optional, Sequence
from typing_extensions import override
from typing_extensions import override
import mrcfile


class MapToMapDistance:
def __init__(self, config):
self.config = config
Expand All @@ -25,40 +26,44 @@ def get_distance_matrix(self, maps1, maps2):
)(maps1)

return distance_matrix

def get_computed_assets(self, maps1, maps2):
"""Return any computed assets that are needed for (downstream) analysis."""
return {}


class L2DistanceNorm(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

@override
def get_distance(self, map1, map2):
return torch.norm(map1 - map2)**2

return torch.norm(map1 - map2) ** 2


class L2DistanceSum(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def compute_cost_l2(self, map_1, map_2):
return ((map_1 - map_2) ** 2).sum()
@override

@override
def get_distance(self, map1, map2):
return self.compute_cost_l2(map1, map2)



class Correlation(MapToMapDistance):
def __init__(self, config):
super().__init__(config)

def compute_cost_corr(self, map_1, map_2):
return (map_1 * map_2).sum()

@override
@override
def get_distance(self, map1, map2):
return self.compute_cost_corr(map1, map2)
return self.compute_cost_corr(map1, map2)


class BioEM3dDistance(MapToMapDistance):
def __init__(self, config):
Expand Down Expand Up @@ -94,7 +99,12 @@ def compute_bioem3d_cost(self, map1, map2):
N = len(m1)

t1 = 2 * torch.pi * math.exp(1)
t2 = N * (ccc * coo - coc * coc) + 2 * co * coc * cc - ccc * co * co - coo * cc * cc
t2 = (
N * (ccc * coo - coc * coc)
+ 2 * co * coc * cc
- ccc * co * co
- coo * cc * cc
)
t3 = (N - 2) * (N * ccc - cc * cc)

smallest_float = torch.finfo(m1.dtype).tiny
Expand All @@ -107,10 +117,11 @@ def compute_bioem3d_cost(self, map1, map2):
cost = -log_prob
return cost

@override
@override
def get_distance(self, map1, map2):
return self.compute_bioem3d_cost(map1, map2)

return self.compute_bioem3d_cost(map1, map2)


class FSCDistance(MapToMapDistance):
def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -203,22 +214,26 @@ def compute_cost_fsc_chunk(self, maps_gt_flat, maps_user_flat, n_pix):
return cost_matrix, fsc_matrix

@override
def get_distance_matrix(self, maps1, maps2): # custom method
def get_distance_matrix(self, maps1, maps2): # custom method
maps_gt_flat = maps1
maps_user_flat = maps2
n_pix = self.config["data"]["n_pix"]
maps_gt_flat_cube = torch.zeros(len(maps_gt_flat), n_pix**3)
mask = (
mrcfile.open(self.config["data"]["mask"]["volume"]).data.astype(bool).flatten()
mrcfile.open(self.config["data"]["mask"]["volume"])
.data.astype(bool)
.flatten()
)
maps_gt_flat_cube[:, mask] = maps_gt_flat
maps_user_flat_cube = torch.zeros(len(maps_user_flat), n_pix**3)
maps_user_flat_cube[:, mask] = maps_user_flat

cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk(maps_gt_flat_cube, maps_user_flat_cube, n_pix)
self.stored_computed_assets = {'fsc_matrix': fsc_matrix}

cost_matrix, fsc_matrix = self.compute_cost_fsc_chunk(
maps_gt_flat_cube, maps_user_flat_cube, n_pix
)
self.stored_computed_assets = {"fsc_matrix": fsc_matrix}
return cost_matrix

@override
def get_computed_assets(self, maps1, maps2):
return self.stored_computed_assets # must run get_distance_matrix first
return self.stored_computed_assets # must run get_distance_matrix first
19 changes: 12 additions & 7 deletions src/cryo_challenge/_map_to_map/map_to_map_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
import torch

from ..data._validation.output_validators import MapToMapResultsValidator
from .._map_to_map.map_to_map_distance import FSCDistance, Correlation, L2DistanceSum, BioEM3dDistance
from .._map_to_map.map_to_map_distance import (
FSCDistance,
Correlation,
L2DistanceSum,
BioEM3dDistance,
)


AVAILABLE_MAP2MAP_DISTANCES = {
"fsc": FSCDistance,
"corr": Correlation,
"l2": L2DistanceSum,
"bioem": BioEM3dDistance,
}
"fsc": FSCDistance,
"corr": Correlation,
"l2": L2DistanceSum,
"bioem": BioEM3dDistance,
}


def run(config):
"""
Expand Down Expand Up @@ -40,7 +46,6 @@ def run(config):
submission[submission_metadata_key] / submission[submission_metadata_key].sum()
)


maps_user_flat = submission[submission_volume_key].reshape(
len(submission["volumes"]), -1
)
Expand Down
5 changes: 3 additions & 2 deletions src/cryo_challenge/_ploting/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np


def res_at_fsc_threshold(fscs, threshold=0.5):
res_fsc_half = np.argmin(fscs > threshold, axis=-1)
fraction_nyquist = 0.5*res_fsc_half / fscs.shape[-1]
return res_fsc_half, fraction_nyquist
fraction_nyquist = 0.5 * res_fsc_half / fscs.shape[-1]
return res_fsc_half, fraction_nyquist
24 changes: 18 additions & 6 deletions src/cryo_challenge/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from ._validation.config_validators import validate_input_config_disttodist as validate_input_config_disttodist
from ._validation.config_validators import validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl
from cryo_challenge.data._validation.output_validators import DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator
from cryo_challenge.data._validation.output_validators import MetricDistToDistValidator as MetricDistToDistValidator
from cryo_challenge.data._validation.output_validators import ReplicateValidatorEMD as ReplicateValidatorEMD
from cryo_challenge.data._validation.output_validators import ReplicateValidatorKL as ReplicateValidatorKL
from ._validation.config_validators import (
validate_input_config_disttodist as validate_input_config_disttodist,
)
from ._validation.config_validators import (
validate_config_dtd_optimal_q_kl as validate_config_dtd_optimal_q_kl,
)
from cryo_challenge.data._validation.output_validators import (
DistributionToDistributionResultsValidator as DistributionToDistributionResultsValidator,
)
from cryo_challenge.data._validation.output_validators import (
MetricDistToDistValidator as MetricDistToDistValidator,
)
from cryo_challenge.data._validation.output_validators import (
ReplicateValidatorEMD as ReplicateValidatorEMD,
)
from cryo_challenge.data._validation.output_validators import (
ReplicateValidatorKL as ReplicateValidatorKL,
)
10 changes: 6 additions & 4 deletions src/cryo_challenge/data/_io/svd_io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,16 @@ def load_ref_vols(box_size_ds: int, path_to_volumes: str, dtype=torch.float32):

# Reshape volumes to correct size
if volumes.dim() == 2:
box_size = int(round((float(volumes.shape[-1]) ** (1. / 3.))))
box_size = int(round((float(volumes.shape[-1]) ** (1.0 / 3.0))))
volumes = torch.reshape(volumes, (-1, box_size, box_size, box_size))
elif volumes.dim() == 4:
pass
else:
raise ValueError(f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape "
f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the "
f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size).")
raise ValueError(
f"The shape of the volumes stored in {path_to_volumes} have the unexpected shape "
f"{torch.shape}. Please, review the file and regenerate it so that volumes stored hasve the "
f"shape (num_vols, box_size ** 3) or (num_vols, box_size, box_size, box_size)."
)

volumes_ds = torch.empty(
(volumes.shape[0], box_size_ds, box_size_ds, box_size_ds), dtype=dtype
Expand Down
Loading

0 comments on commit 62563ef

Please sign in to comment.