Skip to content

Commit

Permalink
Add torch support to geometric_mean function
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Mar 27, 2024
1 parent ed37a67 commit b2f9d72
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions bnpm/stats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import scipy.stats
import torch

def ttest_paired_ratio(a, b):
"""
Expand Down Expand Up @@ -36,7 +37,7 @@ def ttest_paired_ratio(a, b):
return p_val


def geometric_mean(a):
def geometric_mean(a, axis=0, nan_policy="omit"):
"""
Computes the geometric mean of an array of data.
This is useful for computing the geometric mean of ratios.
Expand All @@ -47,7 +48,23 @@ def geometric_mean(a):
a (np.ndarray or torch.Tensor):
Array of data.
"""
return np.exp(np.mean(np.log(a)))
if isinstance(a, (np.ndarray, list)):
mean, nanmean, isnan, exp, log = np.mean, np.nanmean, np.isnan, np.exp, np.log
elif isinstance(a, torch.Tensor):
mean, nanmean, isnan, exp, log = torch.mean, torch.nanmean, torch.isnan, torch.exp, torch.log
else:
raise ValueError("Data must be a numpy array or a torch tensor.")

if nan_policy == "omit":
mean = nanmean
elif nan_policy == "propagate":
mean = mean
elif nan_policy == "raise":
mean = mean
if isnan(a).any():
raise ValueError("Data contains nan values.")

return exp(mean(log(a), axis=axis))


def sparsity(a, axis=0):
Expand Down

0 comments on commit b2f9d72

Please sign in to comment.