From c5f211d7a48ac0f88debea387c04b8f5168e253c Mon Sep 17 00:00:00 2001 From: XinmingTu Date: Sat, 8 Jul 2023 19:34:01 -0700 Subject: [PATCH] add n_jobs support for pairwise distance computation (#305) * add n_jobs support for pairwise distance * update the n_jobs default to 1 * pre-commit Signed-off-by: zethson --------- Signed-off-by: zethson Co-authored-by: Tu Co-authored-by: Lukas Heumos --- pertpy/tools/_distances/_distances.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pertpy/tools/_distances/_distances.py b/pertpy/tools/_distances/_distances.py index fd8b7fe9..1e3b967a 100644 --- a/pertpy/tools/_distances/_distances.py +++ b/pertpy/tools/_distances/_distances.py @@ -117,6 +117,7 @@ def pairwise( groupby: str, groups: list[str] | None = None, verbose: bool = True, + n_jobs: int = 1, **kwargs, ) -> pd.DataFrame: """Get pairwise distances between groups of cells. @@ -148,7 +149,7 @@ def pairwise( if self.metric_fct.accepts_precomputed: # Precompute the pairwise distances if needed if f"{self.obsm_key}_predistances" not in adata.obsp.keys(): - self.precompute_distances(adata, **kwargs) + self.precompute_distances(adata, n_jobs=n_jobs, **kwargs) pwd = adata.obsp[f"{self.obsm_key}_predistances"] for index_x, group_x in enumerate(fct(groups)): idx_x = grouping == group_x @@ -179,7 +180,7 @@ def pairwise( df.name = f"pairwise {self.metric}" return df - def precompute_distances(self, adata: AnnData, cell_wise_metric: str = "euclidean") -> None: + def precompute_distances(self, adata: AnnData, cell_wise_metric: str = "euclidean", n_jobs: int = None) -> None: """Precompute pairwise distances between all cells, writes to adata.obsp. The precomputed distances are stored in adata.obsp under the key @@ -192,7 +193,7 @@ def precompute_distances(self, adata: AnnData, cell_wise_metric: str = "euclidea """ # Precompute the pairwise distances cells = adata.obsm[self.obsm_key].copy() - pwd = pairwise_distances(cells, cells, metric=cell_wise_metric) + pwd = pairwise_distances(cells, cells, metric=cell_wise_metric, n_jobs=n_jobs) # Write to adata.obsp adata.obsp[f"{self.obsm_key}_predistances"] = pwd