From 3fca066648ec3c19323ecbadb5ffab087d66076b Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Sun, 14 Apr 2024 16:34:10 -0400 Subject: [PATCH] Refactor EV function to handle both numpy and torch tensors --- bnpm/similarity.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/bnpm/similarity.py b/bnpm/similarity.py index 48cf3cb..c97ceb7 100644 --- a/bnpm/similarity.py +++ b/bnpm/similarity.py @@ -339,6 +339,7 @@ def pairwise_orthogonalization(v1, v2, center:bool=True): assert v1.ndim == v2.ndim if v1.ndim==1: v1 = v1[:,None] + if v2.ndim==1: v2 = v2[:,None] assert v1.shape[1] == v2.shape[1] assert v1.shape[0] == v2.shape[0] @@ -468,11 +469,15 @@ def EV(y_true, y_pred): average of all EV values. Same as sklearn.metrics.explained_variance_score(y_true, y_pred, multioutput='uniform_average') ''' - - EV = 1 - np.sum((y_true - y_pred)**2, axis=0) / np.sum((y_true - np.mean(y_true, axis=0))**2, axis=0) - y_true_var = np.var(y_true, axis=0) - EV_total_weighted = np.sum( y_true_var* EV ) / np.sum(y_true_var) - EV_total_unweighted = np.mean(EV) + if isinstance(y_true, np.ndarray): + sum, mean, var = np.sum, np.mean, np.var + elif isinstance(y_true, torch.Tensor): + sum, mean, var = torch.sum, torch.mean, torch.var + + EV = 1 - sum((y_true - y_pred)**2, axis=0) / sum((y_true - mean(y_true, axis=0))**2, axis=0) + y_true_var = var(y_true, axis=0) + EV_total_weighted = sum( y_true_var* EV ) / sum(y_true_var) + EV_total_unweighted = mean(EV) return EV , EV_total_weighted , EV_total_unweighted