Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add n_jobs support for pairwise distance computation #305

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pertpy/tools/_distances/_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def pairwise(
groupby: str,
groups: list[str] | None = None,
verbose: bool = True,
n_jobs: int = 1,
**kwargs,
) -> pd.DataFrame:
"""Get pairwise distances between groups of cells.
Expand Down Expand Up @@ -148,7 +149,7 @@ def pairwise(
if self.metric_fct.accepts_precomputed:
# Precompute the pairwise distances if needed
if f"{self.obsm_key}_predistances" not in adata.obsp.keys():
self.precompute_distances(adata, **kwargs)
self.precompute_distances(adata, n_jobs=n_jobs, **kwargs)
pwd = adata.obsp[f"{self.obsm_key}_predistances"]
for index_x, group_x in enumerate(fct(groups)):
idx_x = grouping == group_x
Expand Down Expand Up @@ -179,7 +180,7 @@ def pairwise(
df.name = f"pairwise {self.metric}"
return df

def precompute_distances(self, adata: AnnData, cell_wise_metric: str = "euclidean") -> None:
def precompute_distances(self, adata: AnnData, cell_wise_metric: str = "euclidean", n_jobs: int = None) -> None:
"""Precompute pairwise distances between all cells, writes to adata.obsp.

The precomputed distances are stored in adata.obsp under the key
Expand All @@ -192,7 +193,7 @@ def precompute_distances(self, adata: AnnData, cell_wise_metric: str = "euclidea
"""
# Precompute the pairwise distances
cells = adata.obsm[self.obsm_key].copy()
pwd = pairwise_distances(cells, cells, metric=cell_wise_metric)
pwd = pairwise_distances(cells, cells, metric=cell_wise_metric, n_jobs=n_jobs)
# Write to adata.obsp
adata.obsp[f"{self.obsm_key}_predistances"] = pwd

Expand Down
Loading