diff --git a/bnpm/stats.py b/bnpm/stats.py index aca5053..8de7892 100644 --- a/bnpm/stats.py +++ b/bnpm/stats.py @@ -1,5 +1,6 @@ import numpy as np import scipy.stats +import torch def ttest_paired_ratio(a, b): """ @@ -36,7 +37,7 @@ def ttest_paired_ratio(a, b): return p_val -def geometric_mean(a): +def geometric_mean(a, axis=0, nan_policy="omit"): """ Computes the geometric mean of an array of data. This is useful for computing the geometric mean of ratios. @@ -47,7 +48,23 @@ def geometric_mean(a): a (np.ndarray or torch.Tensor): Array of data. """ - return np.exp(np.mean(np.log(a))) + if isinstance(a, (np.ndarray, list)): + mean, nanmean, isnan, exp, log = np.mean, np.nanmean, np.isnan, np.exp, np.log + elif isinstance(a, torch.Tensor): + mean, nanmean, isnan, exp, log = torch.mean, torch.nanmean, torch.isnan, torch.exp, torch.log + else: + raise ValueError("Data must be a numpy array or a torch tensor.") + + if nan_policy == "omit": + mean = nanmean + elif nan_policy == "propagate": + mean = mean + elif nan_policy == "raise": + mean = mean + if isnan(a).any(): + raise ValueError("Data contains nan values.") + + return exp(mean(log(a), axis=axis)) def sparsity(a, axis=0):