From 6c06e30a12021deb5dc58fe0dd10a18bd0e13abc Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Mon, 15 Apr 2024 00:23:46 -0400 Subject: [PATCH] Refactor pairwise_orthogonalization_torch function to remove device parameter and handle numpy arrays --- bnpm/ca2p_preprocessing.py | 1 - bnpm/similarity.py | 13 +++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/bnpm/ca2p_preprocessing.py b/bnpm/ca2p_preprocessing.py index b5b18af..9f51b08 100644 --- a/bnpm/ca2p_preprocessing.py +++ b/bnpm/ca2p_preprocessing.py @@ -442,7 +442,6 @@ def trace_quality_metrics( v1=F.T, v2=Fneu.T, center=True, - device=device, ) # F_baseline = torch.quantile(F, percentile_baseline/100, dim=1, keepdim=True) diff --git a/bnpm/similarity.py b/bnpm/similarity.py index c97ceb7..b30ff41 100644 --- a/bnpm/similarity.py +++ b/bnpm/similarity.py @@ -381,7 +381,7 @@ def pairwise_orthogonalization_torch_helper(v1, v2, center:bool=True): EVR_total_weighted = torch.nansum(v1_var * EVR) / torch.sum(v1_var) EVR_total_unweighted = torch.nanmean(EVR) return v1_orth, EVR, EVR_total_weighted, EVR_total_unweighted -def pairwise_orthogonalization_torch(v1, v2, center:bool=True, device='cpu'): +def pairwise_orthogonalization_torch(v1, v2, center:bool=True): """ Orthogonalizes columns of v2 off of the columns of v1 and returns the orthogonalized v1 and the explained @@ -425,13 +425,14 @@ def pairwise_orthogonalization_torch(v1, v2, center:bool=True, device='cpu'): """ if isinstance(v1, np.ndarray): v1 = torch.from_numpy(v1) + if isinstance(v2, np.ndarray): + v2 = torch.from_numpy(v2) + else: + raise ValueError("v2 must be a numpy array if v1 is a numpy array") return_numpy = True - else: + elif isinstance(v1, torch.Tensor): return_numpy = False - if isinstance(v2, np.ndarray): - v2 = torch.from_numpy(v2) - v1 = v1.to(device) - v2 = v2.to(device) + v1_orth, EVR, EVR_total_weighted, EVR_total_unweighted = pairwise_orthogonalization_torch_helper(v1, v2, center=center) if return_numpy: v1_orth = v1_orth.cpu().numpy()