Skip to content

Commit

Permalink
Refactor EV function to handle both numpy and torch tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Apr 14, 2024
1 parent bc92087 commit 3fca066
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions bnpm/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def pairwise_orthogonalization(v1, v2, center:bool=True):
assert v1.ndim == v2.ndim
if v1.ndim==1:
v1 = v1[:,None]
if v2.ndim==1:
v2 = v2[:,None]
assert v1.shape[1] == v2.shape[1]
assert v1.shape[0] == v2.shape[0]
Expand Down Expand Up @@ -468,11 +469,15 @@ def EV(y_true, y_pred):
average of all EV values. Same as
sklearn.metrics.explained_variance_score(y_true, y_pred, multioutput='uniform_average')
'''

EV = 1 - np.sum((y_true - y_pred)**2, axis=0) / np.sum((y_true - np.mean(y_true, axis=0))**2, axis=0)
y_true_var = np.var(y_true, axis=0)
EV_total_weighted = np.sum( y_true_var* EV ) / np.sum(y_true_var)
EV_total_unweighted = np.mean(EV)
if isinstance(y_true, np.ndarray):
sum, mean, var = np.sum, np.mean, np.var
elif isinstance(y_true, torch.Tensor):
sum, mean, var = torch.sum, torch.mean, torch.var

EV = 1 - sum((y_true - y_pred)**2, axis=0) / sum((y_true - mean(y_true, axis=0))**2, axis=0)
y_true_var = var(y_true, axis=0)
EV_total_weighted = sum( y_true_var* EV ) / sum(y_true_var)
EV_total_unweighted = mean(EV)

return EV , EV_total_weighted , EV_total_unweighted

Expand Down

0 comments on commit 3fca066

Please sign in to comment.