Skip to content

Commit

Permalink
Refactor pairwise_orthogonalization_torch function to remove device p…
Browse files Browse the repository at this point in the history
…arameter and handle numpy arrays
  • Loading branch information
RichieHakim committed Apr 15, 2024
1 parent ab495b8 commit 6c06e30
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
1 change: 0 additions & 1 deletion bnpm/ca2p_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,6 @@ def trace_quality_metrics(
v1=F.T,
v2=Fneu.T,
center=True,
device=device,
)

# F_baseline = torch.quantile(F, percentile_baseline/100, dim=1, keepdim=True)
Expand Down
13 changes: 7 additions & 6 deletions bnpm/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def pairwise_orthogonalization_torch_helper(v1, v2, center:bool=True):
EVR_total_weighted = torch.nansum(v1_var * EVR) / torch.sum(v1_var)
EVR_total_unweighted = torch.nanmean(EVR)
return v1_orth, EVR, EVR_total_weighted, EVR_total_unweighted
def pairwise_orthogonalization_torch(v1, v2, center:bool=True, device='cpu'):
def pairwise_orthogonalization_torch(v1, v2, center:bool=True):
"""
Orthogonalizes columns of v2 off of the columns of v1
and returns the orthogonalized v1 and the explained
Expand Down Expand Up @@ -425,13 +425,14 @@ def pairwise_orthogonalization_torch(v1, v2, center:bool=True, device='cpu'):
"""
if isinstance(v1, np.ndarray):
v1 = torch.from_numpy(v1)
if isinstance(v2, np.ndarray):
v2 = torch.from_numpy(v2)
else:
raise ValueError("v2 must be a numpy array if v1 is a numpy array")
return_numpy = True
else:
elif isinstance(v1, torch.Tensor):
return_numpy = False
if isinstance(v2, np.ndarray):
v2 = torch.from_numpy(v2)
v1 = v1.to(device)
v2 = v2.to(device)

v1_orth, EVR, EVR_total_weighted, EVR_total_unweighted = pairwise_orthogonalization_torch_helper(v1, v2, center=center)
if return_numpy:
v1_orth = v1_orth.cpu().numpy()
Expand Down

0 comments on commit 6c06e30

Please sign in to comment.