Skip to content

Commit

Permalink
add n_jobs support for pairwise distance computation (#305)
Browse files Browse the repository at this point in the history
* add n_jobs support for pairwise distance

* update the n_jobs default to 1

* pre-commit

Signed-off-by: zethson <[email protected]>

---------

Signed-off-by: zethson <[email protected]>
Co-authored-by: Tu <[email protected]>
Co-authored-by: Lukas Heumos <[email protected]>
  • Loading branch information
3 people authored Jul 9, 2023
1 parent 75ba14c commit c5f211d
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pertpy/tools/_distances/_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def pairwise(
groupby: str,
groups: list[str] | None = None,
verbose: bool = True,
n_jobs: int = 1,
**kwargs,
) -> pd.DataFrame:
"""Get pairwise distances between groups of cells.
Expand Down Expand Up @@ -148,7 +149,7 @@ def pairwise(
if self.metric_fct.accepts_precomputed:
# Precompute the pairwise distances if needed
if f"{self.obsm_key}_predistances" not in adata.obsp.keys():
self.precompute_distances(adata, **kwargs)
self.precompute_distances(adata, n_jobs=n_jobs, **kwargs)
pwd = adata.obsp[f"{self.obsm_key}_predistances"]
for index_x, group_x in enumerate(fct(groups)):
idx_x = grouping == group_x
Expand Down Expand Up @@ -179,7 +180,7 @@ def pairwise(
df.name = f"pairwise {self.metric}"
return df

def precompute_distances(self, adata: AnnData, cell_wise_metric: str = "euclidean") -> None:
def precompute_distances(self, adata: AnnData, cell_wise_metric: str = "euclidean", n_jobs: int = None) -> None:
"""Precompute pairwise distances between all cells, writes to adata.obsp.
The precomputed distances are stored in adata.obsp under the key
Expand All @@ -192,7 +193,7 @@ def precompute_distances(self, adata: AnnData, cell_wise_metric: str = "euclidea
"""
# Precompute the pairwise distances
cells = adata.obsm[self.obsm_key].copy()
pwd = pairwise_distances(cells, cells, metric=cell_wise_metric)
pwd = pairwise_distances(cells, cells, metric=cell_wise_metric, n_jobs=n_jobs)
# Write to adata.obsp
adata.obsp[f"{self.obsm_key}_predistances"] = pwd

Expand Down

0 comments on commit c5f211d

Please sign in to comment.